博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
教你如何结合WebRTC与TensorFlow实现图像检测
阅读量:2089 次
发布时间:2019-04-29

本文共 28301 字,大约阅读时间需要 94 分钟。

原文作者:Chad Hart

原文地址:https://webrtchacks.com/webrtc-cv-tensorflow/

摘要:本文作者介绍了结合WebRTC与TensorFlow实现图像检测的具体过程,不论对于TensorFlow的使用者,还是WebRTC的开发者来讲都有参考意义。

TensorFlow是目前最流行的机器学习框架之一。TensorFlow的一大优势是,它的很多库都有人积极进行维护和更新。而我最喜欢的其中一个库就是TensorFlow对象检测API。可以对一张图形上的多个对象进行分类,并提供它们的具体位置。该API在将近1000个对象类上进行了预先训练,可提供各种经过预先训练的模型,让你可以在速度与准确性之间权衡取舍。

有这些模型的指引固然很好,但所有这些模型都要使用图像才能发挥作用,而这些图像则需要你自行添加到一个文件夹中。我其实很想将其与实时的WebRTC流配合到一起,通过网络实现实时的计算机视觉。由于未能找到这方面的任何例子或指南,我决定写这篇博文来介绍具体的实现方法。对于使用RTC的人,可以将本文作为一篇快速指南加以参考,了解如何使用TensorFlow来处理WebRTC流。对于使用TensorFlow的人士,则可以将本文作为一份快速简介,了解如何向自己的项目中添加WebRTC。使用WebRTC的人需要对Python比较熟悉。而使用TensorFlow的人则需要熟悉网络交互和一些JavaScript。

本文不适合作为WebRTC或TensorFlow的入门指南使用。如需这样的指南,应参考等,网上的相关介绍与指南数不胜数。

利用Tensor Flow和WebRTC检测猫咪

直接告诉我如何实现吧

如果你来这里只是为了快速找到一些参考信息,或者懒得读详细的文字介绍,按照下面的方法即可快速着手。首先安装。加载一个命令提示窗口,接着键入下面的命令:

docker run -it -p 5000:5000 chadhart/tensorflow-object-detection:runserver

然后在浏览器地址栏中键入并转到http://localhost:5000/local,接受摄像头权限请求,你应该会看到类似下面的界面:

点击[阅读原文]观看视频

基本架构

我们首先建立一个基本架构,用以在本地将一个本地网络摄像头流从WebRTC的发送到一个Python服务器,这要用到网络服务器和(Object Detection API)。具体的设置大致如下图所示

图:为搭配使用WebRTC与TensorFlow对象检测API而建立的基本架构


Flask将提供html和JavaScript文件供浏览器呈现。getUserMedia.js负责抓取本地视频流。接下来,objDetect.js会使用方法向TensorFlow对象检测API发送图像,该API则返回它所看到的对象(它称之为“类”)及对象在图像中的位置。我们会将这些详细信息封装到一个JSON对象中,然后将该对象发回给objDetect.js,这样我们就能将我们所看到的对象的方框和标签显示出来。

配置

设置和前提条件

在开始之前,我们需要先对Tensorflow和对象检测API进行一些设置。

使用Docker轻松完成设置

我在OSX、Windows 10和Raspbian已经设置过好几次(过程可不简单)。各种版本依赖关系错综复杂,把这些关系理顺并非易事,特别是当你只是想看看一些前期工作是否行得通时,你可能会感到气馁。我推荐使用Docker来避免这些棘手问题。你将需要学习Docker,这也是非学不可的东西,与其试着构建合适的版本,倒不如花些时间学习它来得更为高效。TensorFlow项目维护了一些官方的Docker映像,比如

如果你使用Docker,我们就可以使用我为这篇博文创建的映像。在命令行中,请运行以下命令:

git clone https://github.com/webrtcHacks/tfObjWebrtc.gitcd tfObjWebrtcdocker run -it -p 5000:5000 --name tf-webrtchacks -v $(pwd):/code chadhart/tensorflow-object-detection:webrtchacks

请注意,docker run中的$(pwd)仅适用于Linux和Windows Powershell。在Windows 10命令行中,请使用%cd%。

看到这里,你应该已经进入了Docker容器。现在,请运行:

python setup.py install

这样,就会使用最新的TensorFlow Docker映像,并将Docker主机上的端口5000连接到端口5000,将容器命名为tf-webrtchacks,将一个本地目录映射到容器中的一个新/code目录,将该目录设为默认目录(我们接下来将在该目录中操作),然后运行bash以便进行命令行交互。完成这些准备工作后,我们才能开始。

如果你才刚开始接触TensorFlow,可能需要先按照中的说明运行初始Jupyter notebook,然后再回来执行上述命令。

另一种麻烦的实现方法

如果你打算从头开始,则需要安装TensorFlow,它自身有很多依赖项,比如Python。TensorFlow项目针对各种平台都提供了指南,具体请访问。对象检测API也有自己的,以及一些额外的依赖项。完成这些准备工作后,请运行下面的命令:

git clone https://github.com/webrtcHacks/tfObjWebrtc.gitcd tfObjWebrtcpython setup.py install

这样,就应该安装好了所有的Python依赖项,将相应的Tensorflow对象检测API文件都复制了过来,并安装了Protobufs。如果这一步行不通,我建议检查setup.py,然后手动在其中运行命令,以解决存在的任何问题。

第1部分——确保Tensorflow正常工作

为确保TensorFlow对象检测API正常工作,我们首先从用于演示对象检测的官方版经调整后的版本着手。我将此文件保存为object_detection_tutorial.py

如果你剪切并粘贴该notebook的每个部分,得到的结果应如下所示:

# IMPORTSimport numpy as npimport osimport six.moves.urllib as urllibimport sysimport tarfileimport tensorflow as tfimport zipfilefrom collections import defaultdictfrom io import StringIO# from matplotlib import pyplot as plt ### CWHfrom PIL import Imageif tf.__version__ != '1.4.0':  raise ImportError('Please upgrade your tensorflow installation to v1.4.0!')# ENV SETUP  ### CWH: remove matplot display and manually add paths to references'''# This is needed to display the images.%matplotlib inline# This is needed since the notebook is stored in the object_detection folder.sys.path.append("..")'''# Object detection importsfrom object_detection.utils import label_map_util    ### CWH: Add object_detection path#from object_detection.utils import visualization_utils as vis_util ### CWH: used for visualization# Model Preparation# What model to download.MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17'MODEL_FILE = MODEL_NAME + '.tar.gz'DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'# Path to frozen detection graph. This is the actual model that is used for the object detection.PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'# List of the strings that is used to add correct label for each box.PATH_TO_LABELS = os.path.join('object_detection/data', 'mscoco_label_map.pbtxt') ### CWH: Add object_detection pathNUM_CLASSES = 90# Download Modelopener = urllib.request.URLopener()opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)tar_file = tarfile.open(MODEL_FILE)for file in tar_file.getmembers():  file_name = os.path.basename(file.name)  if 'frozen_inference_graph.pb' in file_name:    tar_file.extract(file, os.getcwd())# Load a (frozen) Tensorflow model into memory.detection_graph = tf.Graph()with detection_graph.as_default():  od_graph_def = tf.GraphDef()  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:    serialized_graph = fid.read()    od_graph_def.ParseFromString(serialized_graph)    tf.import_graph_def(od_graph_def, name='')# Loading label maplabel_map = label_map_util.load_labelmap(PATH_TO_LABELS)categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)category_index = label_map_util.create_category_index(categories)# Helper codedef load_image_into_numpy_array(image):  (im_width, im_height) = image.size  return np.array(image.getdata()).reshape(      (im_height, im_width, 3)).astype(np.uint8)# Detection# For the sake of simplicity we will use only 2 images:# image1.jpg# image2.jpg# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images' #cwhTEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]# Size, in inches, of the output images.IMAGE_SIZE = (12, 8)with detection_graph.as_default():  with tf.Session(graph=detection_graph) as sess:    # Definite input and output Tensors for detection_graph    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')    # Each box represents a part of the image where a particular object was detected.    detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')    # Each score represent how level of confidence for each of the objects.    # Score is shown on the result image, together with the class label.    detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')    detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')    num_detections = detection_graph.get_tensor_by_name('num_detections:0')    for image_path in TEST_IMAGE_PATHS:      image = Image.open(image_path)      # the array based representation of the image will be used later in order to prepare the      # result image with boxes and labels on it.      image_np = load_image_into_numpy_array(image)      # Expand dimensions since the model expects images to have shape: [1, None, None, 3]      image_np_expanded = np.expand_dims(image_np, axis=0)      # Actual detection.      (boxes, scores, classes, num) = sess.run(          [detection_boxes, detection_scores, detection_classes, num_detections],          feed_dict={image_tensor: image_np_expanded})      ### CWH: below is used for visualizing with Matplot      '''      # Visualization of the results of a detection.      vis_util.visualize_boxes_and_labels_on_image_array(          image_np,          np.squeeze(boxes),          np.squeeze(classes).astype(np.int32),          np.squeeze(scores),          category_index,          use_normalized_coordinates=True,          line_thickness=8)      plt.figure(figsize=IMAGE_SIZE)      plt.imshow(image_np)        '''

在这里我就不再赘述实际TensorFlow代码的作用了,这方面的信息可在演示中找到。我将重点介绍我们对代码所做的修改。

我注释了几个小节:

  1. 更改了一些位置引用

  2. 删除了对Python matplot的所有引用。Python matplot用于在GUI环境中以可视化方式呈现输出结果。在我的Docker环境中没有设置它——根据你采用的具体运行方式,可以酌情决定是否保留这些引用。

对象检测API的输出结果

正如第111行所示,对象检测API输出4种对象:

  1. 类——一个由对象名组成的数组

  2. 分值——一个由置信度分值组成的数组

  3. 方框——检测到的每个对象所在的位置

  4. 数量——检测到的对象总数

类、分值和方框都是相互并列、大小相等的数组,因此classes[n]与scores[n]和boxes[n]都是一一对应的。

由于我删去了可视化功能,我们需要通过某种方式来查看结果,所以我们要把下面的命令添加到文件末尾:

### CWH: Print the object details to the console instead of visualizing them with the code aboveclasses = np.squeeze(classes).astype(np.int32)scores = np.squeeze(scores)boxes = np.squeeze(boxes)threshold = 0.50  #CWH: set a minimum score threshold of 50%obj_above_thresh = sum(n > threshold for n in scores)print("detected %s objects in %s above a %s score" % ( obj_above_thresh, image_path, threshold))for c in range(0, len(classes)):  if scores[c] > threshold:      class_name = category_index[classes[c]]['name']      print(" object %s is a %s - score: %s, location: %s" % (c, class_name, scores[c], boxes[c]))

第一个np.squeeze部分只是将多维数组输出缩减成一维,这与原来的可视化代码一样。我认为这是TensorFlow的一个副产品,因为它通常会输出多维数组。

接着我们要为它输出的分值设置一个阈值。好像TensorFlow默认会返回100个对象。其中很多对象嵌套在置信度更高的对象内或与这些对象重叠。在选择阈值方面我还没有发现任何最佳做法,不过对于这些示例图像来说,50%似乎是合适的。

最后,我们需要循环遍历这些数组,直接输出那些超过阈值的分值。

如果运行下面的命令:

python object_detection_tutorial.py
应该会获得下面的输出:

detected 2 objects in object_detection/test_images/image1.jpg above a 0.5 score object 0 is a dog - score: 0.940691, location: [ 0.03908405  0.01921503  0.87210345  0.31577349] object 1 is a dog - score: 0.934503, location: [ 0.10951501  0.40283561  0.92464608  0.97304785]detected 10 objects in object_detection/test_images/image2.jpg above a 0.5 score object 0 is a person - score: 0.916878, location: [ 0.55387682  0.39422381  0.59312469  0.40913767] object 1 is a kite - score: 0.829445, location: [ 0.38294643  0.34582412  0.40220094  0.35902989] object 2 is a person - score: 0.778505, location: [ 0.57416666  0.057667    0.62335181  0.07475379] object 3 is a kite - score: 0.769985, location: [ 0.07991442  0.4374091   0.16590245  0.50060284] object 4 is a kite - score: 0.755539, location: [ 0.26564282  0.20112294  0.30753511  0.22309387] object 5 is a person - score: 0.634234, location: [ 0.68338078  0.07842994  0.84058815  0.11782578] object 6 is a kite - score: 0.607407, location: [ 0.38510025  0.43172216  0.40073246  0.44773054] object 7 is a person - score: 0.589102, location: [ 0.76061964  0.15739655  0.93692541  0.20186904] object 8 is a person - score: 0.512377, location: [ 0.54281253  0.25604743  0.56234604  0.26740867] object 9 is a person - score: 0.501464, location: [ 0.58708113  0.02699314  0.62043804  0.04133803]

第2部分——打造一项对象API网络服务

在这一部分,我们将对教程代码作一些改动,以将其作为一项网络服务加以运行。我在Python方面的经验颇为有限(主要在Raspberry Pi项目中使用过),所以如有不对的地方,请添加备注或提交,以便我可以修正。

2.1将演示代码转变成一项服务

至此我们已经让TensorFlow Object API能够正常工作了,接下来我们就将它封装成一个可以调用的函数。我将演示代码复制到了一个名为的新python文件中。你可以看到,我删除了很多没有用到或注释掉的行,以及用于将详细信息输出到控制台的部分(暂时删除)。

由于我们要将这些信息输出到网络上,因此最好将我们的输出结果封装成一个JSON对象。为此,请务必向你导入的内容中添加一个importjso语句,然后再添加下面的命令:

# added to put object in JSONclass Object(object):    def __init__(self):        self.name="Tensor Flow Object API Service 0.0.1"    def toJSON(self):        return json.dumps(self.__dict__)
接下来,我们要重复利用之前的代码创建一个get_objects函数:

def get_objects(image, threshold=0.5):    image_np = load_image_into_numpy_array(image)    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]    image_np_expanded = np.expand_dims(image_np, axis=0)    # Actual detection.    (boxes, scores, classes, num) = sess.run(        [detection_boxes, detection_scores, detection_classes, num_detections],        feed_dict={image_tensor: image_np_expanded})    classes = np.squeeze(classes).astype(np.int32)    scores = np.squeeze(scores)    boxes = np.squeeze(boxes)obj_above_thresh = sum(n > threshold for n in scores)    obj_above_thresh = sum(n > threshold for n in scores)    print("detected %s objects in image above a %s score" % (obj_above_thresh, threshold))

在此函数中我们添加了一个图像输入参数和一个默认为0.5的threshold值。其余内容都是在演示代码的基础上重构的。

现在我们再向此函数添加一些代码,以查询具体的值并将它们输出到一个JSON对象中:

output = []    #Add some metadata to the output    item = Object()    item.numObjects = obj_above_thresh    item.threshold = threshold    output.append(item)    for c in range(0, len(classes)):        class_name = category_index[classes[c]]['name']        if scores[c] >= threshold:      # only return confidences equal or greater than the threshold            print(" object %s - score: %s, coordinates: %s" % (class_name, scores[c], boxes[c]))            item = Object()            item.name = 'Object'            item.class_name = class_name            item.score = float(scores[c])            item.y = float(boxes[c][0])            item.x = float(boxes[c][1])            item.height = float(boxes[c][2])            item.width = float(boxes[c][3])            output.append(item)    outputJson = json.dumps([ob.__dict__ for ob in output])    return outputJson

这一次我们是使用Object类来创建一些初始元数据并将这些元数据添加到output列表中。然后我们使用循环向此列表中添加Object数据。最后,将此列表转换成JSON并予以返回。

之后,我们来创建一个测试文件(这里要提醒自己:先做测试),以检查它是否调用了

import scan_imageimport osfrom PIL import Image# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images' #cwhTEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]for image_path in TEST_IMAGE_PATHS:    image = Image.open(image_path)    response = object_detection_api.get_objects(image)    print("returned JSON: \n%s" % response)

至此万事俱备,接下来就是运行了。

python object_detection_test.py
除了前面的控制台输出之外,你应该还会看到一个JSON字符串:

returned JSON: [{
"threshold": 0.5, "name": "webrtcHacks Sample Tensor Flow Object API Service 0.0.1", "numObjects": 10}, {
"name": "Object", "class_name": "person", "height": 0.5931246876716614, "width": 0.40913766622543335, "score": 0.916878342628479, "y": 0.5538768172264099, "x": 0.39422380924224854}, {
"name": "Object", "class_name": "kite", "height": 0.40220093727111816, "width": 0.3590298891067505, "score": 0.8294452428817749, "y": 0.3829464316368103, "x": 0.34582412242889404}, {
"name": "Object", "class_name": "person", "height": 0.6233518123626709, "width": 0.0747537910938263, "score": 0.7785054445266724, "y": 0.5741666555404663, "x": 0.057666998356580734}, {
"name": "Object", "class_name": "kite", "height": 0.16590245068073273, "width": 0.5006028413772583, "score": 0.7699846625328064, "y": 0.07991442084312439, "x": 0.43740910291671753}, {
"name": "Object", "class_name": "kite", "height": 0.3075351119041443, "width": 0.22309386730194092, "score": 0.7555386424064636, "y": 0.26564282178878784, "x": 0.2011229395866394}, {
"name": "Object", "class_name": "person", "height": 0.8405881524085999, "width": 0.11782577633857727, "score": 0.6342343688011169, "y": 0.6833807826042175, "x": 0.0784299373626709}, {
"name": "Object", "class_name": "kite", "height": 0.40073245763778687, "width": 0.44773054122924805, "score": 0.6074065566062927, "y": 0.38510024547576904, "x": 0.43172216415405273}, {
"name": "Object", "class_name": "person", "height": 0.9369254112243652, "width": 0.20186904072761536, "score": 0.5891017317771912, "y": 0.7606196403503418, "x": 0.15739655494689941}, {
"name": "Object", "class_name": "person", "height": 0.5623460412025452, "width": 0.26740866899490356, "score": 0.5123767852783203, "y": 0.5428125262260437, "x": 0.25604742765426636}, {
"name": "Object", "class_name": "person", "height": 0.6204380393028259, "width": 0.04133802652359009, "score": 0.5014638304710388, "y": 0.5870811343193054, "x": 0.026993142440915108}]

2.2添加一个网络服务器

我们已经有了函数——接下来我们就用它来打造一项网络服务。

先使用测试用的路由(Route)运行

我们有了一个可以轻松添加到网络服务的良好API。我发现使用是最简单的测试方法。我们来创建一个,然后执行一次快速测试:

import object_detection_apiimport osfrom PIL import Imagefrom flask import Flask, request, Responseapp = Flask(__name__)@app.route('/')def index():    return Response('Tensor Flow object detection')@app.route('/test')def test():    PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images'  # cwh    TEST_IMAGE_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3)]    image = Image.open(TEST_IMAGE_PATHS[0])    objects = object_detection_api.get_objects(image)    return objectsif __name__ == '__main__':    app.run(debug=True, host='0.0.0.0')

现在,运行该服务器:

python server.py

确保该服务正常工作

然后调用该网络服务。就我自己的情况而言,我只是从主机运行了下面的命令(因为我的Docker实例现在正在前台运行该服务器):

curl http://localhost:5000/test | python -m json.tool

 json.tool将帮助你为输出结果设置格式。你应该会看到下面的结果:

% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current                                 Dload  Upload   Total   Spent    Left  Speed100   467  100   467    0     0    300      0  0:00:01  0:00:01 --:--:--   300[    {        "name": "webrtcHacks Sample Tensor Flow Object API Service 0.0.1",        "numObjects": 2,        "threshold": 0.5    },    {        "class_name": "dog",        "height": 0.8721034526824951,        "name": "Object",        "score": 0.9406907558441162,        "width": 0.31577348709106445,        "x": 0.01921503245830536,        "y": 0.039084047079086304    },    {        "class_name": "dog",        "height": 0.9246460795402527,        "name": "Object",        "score": 0.9345026612281799,        "width": 0.9730478525161743,        "x": 0.4028356075286865,        "y": 0.10951501131057739    }]

好了,接下来我们就要接受一个包含一个图片文件及其他一些参数的POST,使用真实路由运行了。为此,需要在/test路由函数下添加一个新的/image路由:

@app.route('/image', methods=['POST'])def image():    try:        image_file = request.files['image']  # get the image        # Set an image confidence threshold value to limit returned data        threshold = request.form.get('threshold')        if threshold is None:            threshold = 0.5        else:            threshold = float(threshold)        # finally run the image through tensor flow object detection`        image_object = Image.open(image_file)        objects = object_detection_api.get_objects(image_object, threshold)        return objects    except Exception as e:        print('POST /image error: %e' % e)        return e

这样就会从一个采用表单编码方式的POST中获取图片,并且可以选择指定一个阈值,然后将该图片传递给我们的object_detection_api。

我们来测试一下:

curl -F "image=@./object_detection/test_images/image1.jpg" http://localhost:5000/image | python -m json.tool

这时看到的结果应该与上面使用/test路径时相同。继续测试,可以指定你任选的其他本地图像的路径。

让该服务在localhost以外的位置也能正常工作

如果你打算在localhost上运行浏览器,可能就不需要再做些什么。但如果是真实的服务,甚至是在需要运行很多测试的情况下,这就不太现实了。如果要跨网络运行网络服务,或者使用其他资源运行网络服务,都需要用到CORS。幸好,在路由前添加以下代码就可以轻松解决这一问题:

# for CORS@app.after_requestdef after_request(response):    response.headers.add('Access-Control-Allow-Origin', '*')    response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')    response.headers.add('Access-Control-Allow-Methods', 'GET,POST') # Put any other methods you need here    return response

让该服务支持安全源

最佳做法是搭配HTTPS使用WebRTC,因为Chrome和Safari等浏览器若不作专门配置,则仅支持安全源(不过Chrome可以很好地支持localhost,你也可以将Safari设为允许在非安全网站上捕获信息——的调试工具部分了解详情)。为此,你需要获取一些SSL证书或生成一些自托管证书。我将我自己的证书放在了ssl/目录中,然后将最后一行app.run更改为: 

app.run(debug=True, host='0.0.0.0', ssl_context=('ssl/server.crt', 'ssl/server.key'))
如果你使用的是自签名证书,你在使用CURL进行测试时可能需要添加--insecure选项:
curl -F "image=@./object_detection/test_images/image2.jpg" --insecure https://localhost:5000/image | python -m json.tool

严格来讲并非一定要生成你自己的证书,而且这会增加一定的工作量,所以在server.py最底部,我依然让SSL版本保持被注释掉的状态。

如果是要投入生产环境中使用的应用程序,你可能需要使用nginx之类的代理向外发送HTTPS,同时在内部依然使用HTTP(此外还要做很多其他方面的改进)。

添加一些路由以便提供我们的网页

在开始介绍浏览器端的工作之前,我们先为后面需要用到的一些路由生成存根。为此,请将下面的代码放在index()路由后面:

@app.route('/local')def local():    return Response(open('./static/local.html').read(), mimetype="text/html")@app.route('/video')def remote():    return Response(open('./static/video.html').read(), mimetype="text/html")
Python方面的工作到此就结束了。接下来我们将用到JavaScript,并且需要编写一些HTML。

第3部分——浏览器端

开始之前,请先在项目的根目录位置创建一个static目录。我们将从这里提供HTML和JavaScript。

现在,我们首先使用WebRTC的getUserMedia抓取一个本地摄像头Feed。从这里,我们需要将该Feed的快照发送到刚刚创建的对象检测Web API,获取结果,然后使用canvas实时地在视频上显示这些结果。

HTML

我们先创建文件:

    
Tensor Flow Object Detection from getUserMedia

此网页的作用如下:

▪使用WebRTC 代码填充(polyfill)

▪设置一些样式,以便

       ▪ 将各个元素一个个叠加起来

       ▪ 将视频放在底部,这样我们就能使用canvas在它上面绘图

▪为我们的getUserMedia流创建一个视频元素

▪链接到一个调用getUserMedia的JavaScript文件

▪链接到一个将与我们的对象检测API交互并在我们的视频上绘制方框的JavaScript文件

获取摄像头流

现在,在静态目录中创建一个文件,并将下面的代码添加到该文件中:

//Get camera videoconst constraints = {    audio: false,    video: {        width: {
min: 640, ideal: 1280, max: 1920}, height: {
min: 480, ideal: 720, max: 1080} }};navigator.mediaDevices.getUserMedia(constraints) .then(stream => { document.getElementById("myVideo").srcObject = stream; console.log("Got local user video"); }) .catch(err => { console.log('navigator.getUserMedia error: ', err) });

在这里你会看到我们首先设置了一些约束条件。对于我自己的情况,我需要一段1280×720视频,但要求范围在640×480与1920×1080之间。然后,我们使用这些约束条件执行getUserMedia,并将所生成的流分配给我们在HTML中创建的视频对象。

对象检测API的客户端版本

TensorFlow对象检测API教程包含了可执行以下操作的代码:获取现有图像,将其发送给实际API进行“推断”(对象检测),然后为它所看到的对象显示方框和类名。要想在浏览器中模拟这一功能,我们需要:

  1. 抓取图像——我们会创建一个canvas来完成这一步

  2. 将这些图像发送给API——为此,我们会将文件作为XMLHttpRequest中form-body的一部分进行传递

  3. 再使用一个canvas将结果绘制在我们的实时流上

要完成所有这些步骤,需要在静态文件夹中创建一个文件。

初始化和设置

我们需要先定义一些参数:

//Parametersconst s = document.getElementById('objDetect');const sourceVideo = s.getAttribute("data-source");  //the source video to useconst uploadWidth = s.getAttribute("data-uploadWidth") || 640; //the width of the upload fileconst mirror = s.getAttribute("data-mirror") || false; //mirror the boundary boxesconst scoreThreshold = s.getAttribute("data-scoreThreshold") || 0.5;

会注意到,我将其中一些参数作为data-元素添加到了自己的HTML代码中。我最终是要在多个不同的项目中使用这段代码,并且希望重用相同的代码库,而如此添加参数就可以轻松做到这一点。待具体使用这些参数时,我会一一解释。

设置视频和canvas元素

我们需要一个变量来表示我们的视频元素,需要一些起始事件,还需要创建上面提到的2个canvas。

//Video element selectorv = document.getElementById(sourceVideo);//for starting eventslet isPlaying = false,    gotMetadata = false;//Canvas setup//create a canvas to grab an image for uploadlet imageCanvas = document.createElement('canvas');let imageCtx = imageCanvas.getContext("2d");//create a canvas for drawing object boundarieslet drawCanvas = document.createElement('canvas');document.body.appendChild(drawCanvas);let drawCtx = drawCanvas.getContext("2d");

drawCanvas用于显示我们的方框和标签。imageCanvas用于向我们的对象检测API上传数据。我们需要向可见HTML添加drawCanvas,这样我们就能在绘制对象框时看到它。接下来,需要跳转到底部,逐函数向上编写。

启动该程序

1.触发视频事件

我们来启动该程序。首先要触发一些视频事件:

//Starting events//check if metadata is ready - we need the video sizev.onloadedmetadata = () => {    console.log("video metadata ready");    gotMetadata = true;    if (isPlaying)        startObjectDetection();};//see if the video has started playingv.onplaying = () => {    console.log("video playing");    isPlaying = true;    if (gotMetadata) {        startObjectDetection();    }};

先查找视频的onplay事件和loadedmetadata事件——如果没有视频,图像处理也就无从谈起。我们需要用到元数据来设置我们的绘图canvas尺寸,使其与下一部分中的视频尺寸相符。

2.启动主对象检测子例程

//Start object detectionfunction startObjectDetection() {    console.log("starting object detection");    //Set canvas sizes base don input video    drawCanvas.width = v.videoWidth;    drawCanvas.height = v.videoHeight;    imageCanvas.width = uploadWidth;    imageCanvas.height = uploadWidth * (v.videoHeight / v.videoWidth);    //Some styles for the drawcanvas    drawCtx.lineWidth = "4";    drawCtx.strokeStyle = "cyan";    drawCtx.font = "20px Verdana";    drawCtx.fillStyle = "cyan";

虽然drawCanvas必须与视频元素大小相同,但imageCanvas绝不会显示出来,只会发送到我们的API。可以使用文件开头的uploadWidth参数减小此大小,以帮助降低所需的带宽量和服务器上的处理需求。需要注意的是,减小图片可能会影响识别准确度,特别是图片缩减得过小的时候。

至此我们还需要为drawCanvas设置一些样式。我选择的是cyan,但你可以任选颜色。只是要确保所选的颜色与视频Feed对比明显,从而提供很好的可见度。

3.toBlob conversion

//Save and send the first image    imageCtx.drawImage(v, 0, 0, v.videoWidth, v.videoHeight, 0, 0, uploadWidth, uploadWidth * (v.videoHeight / v.videoWidth));    imageCanvas.toBlob(postFile, 'image/jpeg');}
设置好canvas大小后,我们需要确定如何发送图像。一开始我采取的是较为复杂的方法,结果看到Fippo的grab()函数在最后一个事件处,所以我又改用了简单的toBlob方法。待图片转换为blob(二进制大对象)后,我们就会将它发送到我们要创建的下一个函数,即postFile。

有一点需要注意——Edge似乎不支持HTMLCanvasElement.toBlob方法。好像可以改用推荐的polyfill或改用msToBlob,但这两个我都还没有机会试过。

将图像发送至对象检测API

//Add file blob to a form and postfunction postFile(file) {    //Set options as form data    let formdata = new FormData();    formdata.append("image", file);    formdata.append("threshold", scoreThreshold);    let xhr = new XMLHttpRequest();    xhr.open('POST', window.location.origin + '/image', true);    xhr.onload = function () {        if (this.status === 200) {            let objects = JSON.parse(this.response);            //console.log(objects);            //draw the boxes            drawBoxes(objects);            //Send the next image            imageCanvas.toBlob(postFile, 'image/jpeg');        }        else{            console.error(xhr);        }    };    xhr.send(formdata);}

我们的postFile接受图像blob作为实参。要发送此数据,我们需要使用XHR将其作为表单数据通过POST方法发布。不要忘了,我们的对象检测API还接受一个可选的阈值,所以在这里我们也可以加入此阈值。为便于调整,同时避免操作此库,你可以在我们在开头设置的data-标记中加入此参数及其他一些参数。

我们设置好表单后,需要使用XHR来发送它并等待响应。获取到返回的对象后,我们就可以绘制它们(见下一个函数)。这样就大功告成了。由于我们想要持续不断地执行上述操作,因此我们需要在获取到上一API调用返回的响应后,立即继续抓取新图像并再次发送。

绘制方框和类标签

接下来我们需要使用一个函数来绘制对象API输出,以便我们可以实际查看一下检测到的是什么:

function drawBoxes(objects) {    //clear the previous drawings    drawCtx.clearRect(0, 0, drawCanvas.width, drawCanvas.height);    //filter out objects that contain a class_name and then draw boxes and labels on each    objects.filter(object => object.class_name).forEach(object => {        let x = object.x * drawCanvas.width;        let y = object.y * drawCanvas.height;        let width = (object.width * drawCanvas.width) - x;        let height = (object.height * drawCanvas.height) - y;        //flip the x axis if local video is mirrored        if (mirror){            x = drawCanvas.width - (x + width)        }        drawCtx.fillText(object.class_name + " - " + Math.round(object.score * 100, 1) + "%", x + 5, y + 20);        drawCtx.strokeRect(x, y, width, height);    });}

由于我们希望每次都使用一个干净的绘图板来绘制矩形,我们首先要使用clearRect来清空canvas。然后,直接使用class_name对项目进行过滤,然后对剩余的每个项目执行绘图操作。

在objects对象中传递的坐标是以百分比为单位表示的图像大小。要在canvas上使用它们,我们需要将它们转换成以像素数表示的尺寸。我们还要检查是否启用了镜像参数。如果已启用,我们需要翻转x轴,以便与视频流翻转后的镜像视图相匹配。最后,我们需要编写对象class_name并绘制矩形。

让我们试一下吧!

现在,打开你最喜欢的WebRTC浏览器,在地址栏中输入网址。如果你是在同一台计算机上运行,网址将为http://localhost:5000/local(如果设置了证书,则为https://localhost:5000/local)。

关于优化

上述设置将通过服务器运行尽可能多的帧。除非为Tensorflow设置了GPU优化,否则这会消耗大量的CPU资源(例如,我自己的情况是消耗了一整个核心),即便不作任何改动也是如此。更高效的做法是,限制调用该API的频率,仅在视频流中有新活动时才调用该API。为此,我在一个新的文件中对objDetect.js做了一些修改。

修改前后内容大致相同,我只不过添加了2个新函数。首先,不再是每次都抓取图像,而是使用一个新函数sendImageFromCanvas(),仅当图片在指定的帧率内发生了变化时,该函数才发送图片。帧率用一个新的updateInterval参数表示,限定了可以调用该API的最大间隔。为此,我们需要使用新的canvas和内容。

这段代码很简单:

//Check if the image has changed & enough time has passeed sending it to the APIfunction sendImageFromCanvas() {    imageCtx.drawImage(v, 0, 0, v.videoWidth, v.videoHeight, 0, 0, uploadWidth, uploadWidth * (v.videoHeight / v.videoWidth));    let imageChanged = imageChange(imageCtx, imageChangeThreshold);    let enoughTime = (new Date() - lastFrameTime) > updateInterval;    if (imageChanged && enoughTime) {        imageCanvas.toBlob(postFile, 'image/jpeg');        lastFrameTime = new Date();    }    else {        setTimeout(sendImageFromCanvas, updateInterval);    }}
imageChangeThreshold是一个百分比,表示有改动的像素所占的百分比。我们获得此百分比后将其传递给imageChange函数,此函数返回True或False,表示是否超出了阈值。下面显示的就是这个函数:

//Function to measure the chagne in an imagefunction imageChange(sourceCtx, changeThreshold) {    let changedPixels = 0;    const threshold = changeThreshold * sourceCtx.canvas.width * sourceCtx.canvas.height;   //the number of pixes that change change    let currentFrame = sourceCtx.getImageData(0, 0, sourceCtx.canvas.width, sourceCtx.canvas.height).data;    //handle the first frame    if (lastFrameData === null) {        lastFrameData = currentFrame;        return true;    }    //look for the number of pixels that changed    for (let i = 0; i < currentFrame.length; i += 4) {        let lastPixelValue = lastFrameData[i] + lastFrameData[i + 1] + lastFrameData[i + 2];        let currentPixelValue = currentFrame[i] + currentFrame[i + 1] + currentFrame[i + 2];        //see if the change in the current and last pixel is greater than 10; 0 was too sensitive        if (Math.abs(lastPixelValue - currentPixelValue) > (10)) {            changedPixels++        }    }    //console.log("current frame hits: " + hits);    lastFrameData = currentFrame;    return (changedPixels > threshold);}

上面的这个函数其实是经过大幅改进后的版本,之前的版本是我很久以前编写的,用于在中检测婴儿动作。它首先测量每个像素的RGB颜色值。如果这些值与该像素的总体颜色值相比绝对差值超过10,则将该像素视为已改动。10只是随意定的一个值,但在我的测试中似乎是很合适的值。如果有改动的像素数超出threshold,该函数就会返回True。

对此稍微深入研究后,我发现其他一些算法通常转换成灰度值,因为颜色并不能很好地反映动作。应用高斯模糊也可以消除编码差异。提出了一个很好的建议,即借鉴在检测视频活动时使用的算法()。后续还会分享更多技巧。

适合任何视频元素

这段代码实际上对任何<video>元素都适用,包括使用WebRTC peerConnection连接的远程对等方的视频。我并不想把这篇博文/代码写得再长一些、复杂一些,不过我确实在静态文件夹中包含了一个文件,作为演示之用:

    
Tensor Flow Object Detection from a video

不妨使用你自己的视频试一下。只是提醒一下,如果你使用的是托管在另一台服务器上的视频,要注意CORS问题。

一不小心写成了长篇

完成上面这一切耗费了很多时间,不过现在我希望开始着手有趣的部分了:尝试不同的型号和训练我自己的分类器。已发布的对象检测API是针对静态图像设计的。那么针对视频和对象跟踪进行了调整的模型又是什么样的呢?很值得一试。

此外,这里还有太多的优化工作需要做。我在运行此服务时并没有配备GPU,如果配备的话,性能会大为不同。如果帧的数量不多,支持1个客户端需要大约1个核心,而我使用的是最快但准确性最低的模型。在这方面有很大的性能提升空间。如果观察此服务在GPU云端网络中性能如何,想必也很有趣。

你可能感兴趣的文章
梯度消失问题与如何选择激活函数
查看>>
为什么在优化算法中使用指数加权平均
查看>>
Java集合详解1:一文读懂ArrayList,Vector与Stack使用方法和实现原理
查看>>
Java集合详解2:一文读懂Queue和LinkedList
查看>>
Java集合详解4:一文读懂HashMap和HashTable的区别以及常见面试题
查看>>
Java集合详解5:深入理解LinkedHashMap和LRU缓存
查看>>
Java集合详解6:这次,从头到尾带你解读Java中的红黑树
查看>>
Java并发指南2:深入理解Java内存模型JMM
查看>>
Java并发指南5:JMM中的final关键字解析
查看>>
Java并发指南6:Java内存模型JMM总结
查看>>
Java网络编程和NIO详解6:Linux epoll实现原理详解
查看>>
Java网络编程和NIO详解7:浅谈 Linux 中NIO Selector 的实现原理
查看>>
Java网络编程与NIO详解8:浅析mmap和Direct Buffer
查看>>
Java网络编程与NIO详解10:深度解读Tomcat中的NIO模型
查看>>
Java网络编程与NIO详解11:Tomcat中的Connector源码分析(NIO)
查看>>
深入理解JVM虚拟机1:JVM内存的结构与消失的永久代
查看>>
深入理解JVM虚拟机3:垃圾回收器详解
查看>>
深入理解JVM虚拟机4:Java class介绍与解析实践
查看>>
深入理解JVM虚拟机5:虚拟机字节码执行引擎
查看>>
深入理解JVM虚拟机6:深入理解JVM类加载机制
查看>>