From 4832f19416e2354afab68f46793155285c04a388 Mon Sep 17 00:00:00 2001 From: Ben Hoff Date: Thu, 11 Jul 2019 09:26:12 -0400 Subject: [PATCH] added in handeling for openvino 2019 --- cvat/apps/auto_annotation/model_loader.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/cvat/apps/auto_annotation/model_loader.py b/cvat/apps/auto_annotation/model_loader.py index a933190ec955..09b67f417f0b 100644 --- a/cvat/apps/auto_annotation/model_loader.py +++ b/cvat/apps/auto_annotation/model_loader.py @@ -7,6 +7,7 @@ import cv2 import os import subprocess +import numpy as np from cvat.apps.auto_annotation.inference_engine import make_plugin, make_network @@ -28,9 +29,18 @@ def __init__(self, model, weights): raise Exception("Following layers are not supported by the plugin for specified device {}:\n {}". format(plugin.device, ", ".join(not_supported_layers))) - self._input_blob_name = next(iter(network.inputs)) + iter_inputs = iter(network.inputs) + self._input_blob_name = next(iter_inputs) self._output_blob_name = next(iter(network.outputs)) + self._require_image_info = False + + # NOTE: handeling for the inclusion of `image_info` in OpenVino2019 + if 'image_info' in network.inputs: + self._require_image_info = True + if self._input_blob_name == 'image_info': + self._input_blob_name = next(iter_inputs) + self._net = plugin.load(network=network, num_requests=2) input_type = network.inputs[self._input_blob_name] self._input_layout = input_type if isinstance(input_type, list) else input_type.shape @@ -39,7 +49,16 @@ def infer(self, image): _, _, h, w = self._input_layout in_frame = image if image.shape[:-1] == (h, w) else cv2.resize(image, (w, h)) in_frame = in_frame.transpose((2, 0, 1)) # Change data layout from HWC to CHW - results = self._net.infer(inputs={self._input_blob_name: in_frame}) + inputs = {self._input_blob_name: in_frame} + if self._require_image_info: + info = np.zeros([1, 3]) + info[0, 0] = h + info[0, 1] = w + # frame number + info[0, 2] = 1 + inputs['image_info'] = info + + results = self._net.infer(inputs) if len(results) == 1: return results[self._output_blob_name].copy() else: