Skip to content

Commit

Permalink
added in handeling for openvino 2019 (#545)
Browse files Browse the repository at this point in the history
  • Loading branch information
benhoff authored and nmanovic committed Jul 11, 2019
1 parent 6e1e063 commit 3ae8a72
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions cvat/apps/auto_annotation/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import cv2
import os
import subprocess
import numpy as np

from cvat.apps.auto_annotation.inference_engine import make_plugin, make_network

Expand All @@ -28,9 +29,18 @@ def __init__(self, model, weights):
raise Exception("Following layers are not supported by the plugin for specified device {}:\n {}".
format(plugin.device, ", ".join(not_supported_layers)))

self._input_blob_name = next(iter(network.inputs))
iter_inputs = iter(network.inputs)
self._input_blob_name = next(iter_inputs)
self._output_blob_name = next(iter(network.outputs))

self._require_image_info = False

# NOTE: handeling for the inclusion of `image_info` in OpenVino2019
if 'image_info' in network.inputs:
self._require_image_info = True
if self._input_blob_name == 'image_info':
self._input_blob_name = next(iter_inputs)

self._net = plugin.load(network=network, num_requests=2)
input_type = network.inputs[self._input_blob_name]
self._input_layout = input_type if isinstance(input_type, list) else input_type.shape
Expand All @@ -39,7 +49,16 @@ def infer(self, image):
_, _, h, w = self._input_layout
in_frame = image if image.shape[:-1] == (h, w) else cv2.resize(image, (w, h))
in_frame = in_frame.transpose((2, 0, 1)) # Change data layout from HWC to CHW
results = self._net.infer(inputs={self._input_blob_name: in_frame})
inputs = {self._input_blob_name: in_frame}
if self._require_image_info:
info = np.zeros([1, 3])
info[0, 0] = h
info[0, 1] = w
# frame number
info[0, 2] = 1
inputs['image_info'] = info

results = self._net.infer(inputs)
if len(results) == 1:
return results[self._output_blob_name].copy()
else:
Expand Down

0 comments on commit 3ae8a72

Please sign in to comment.