Skip to content

Commit

Permalink
Modifications for Object Detection SSD Python sample (#3976)
Browse files Browse the repository at this point in the history
* Modifications for Object Detection SSD Sample Python

* Fixes for tests

* Changing the way output checks are processed
  • Loading branch information
Maksim Makridin authored Feb 5, 2021
1 parent 47127fb commit e47186a
Showing 1 changed file with 18 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,7 @@ def main():
# -----------------------------------------------------------------------------------------------------

# --------------------------- 3. Read and preprocess input --------------------------------------------

print("inputs number: " + str(len(net.input_info.keys())))
assert len(net.input_info.keys()) == 1, 'Sample supports networks with one input'

for input_key in net.input_info:
print("input shape: " + str(net.input_info[input_key].input_data.shape))
print("input key: " + input_key)
if len(net.input_info[input_key].input_data.layout) == 4:
n, c, h, w = net.input_info[input_key].input_data.shape

Expand All @@ -96,7 +90,6 @@ def main():
image = cv2.resize(image, (w, h))
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
images[i] = image

# -----------------------------------------------------------------------------------------------------

# --------------------------- 4. Configure input & output ---------------------------------------------
Expand All @@ -122,23 +115,30 @@ def main():
data[input_name] = images

if input_info_name != "":
infos = np.ndarray(shape=(n, c), dtype=float)
detection_size = net.input_info[input_info_name].input_data.shape[1]
infos = np.ndarray(shape=(n, detection_size), dtype=float)
for i in range(n):
infos[i, 0] = h
infos[i, 1] = w
infos[i, 2] = 1.0
for j in range(2, detection_size):
infos[i, j] = 1.0
data[input_info_name] = infos

# --------------------------- Prepare output blobs ----------------------------------------------------
log.info('Preparing output blobs')

output_name, output_info = "", None
func = ng.function_from_cnn(net)
ops = func.get_ordered_ops()
output_name, output_info = "", net.outputs[next(iter(net.outputs.keys()))]
output_ops = {op.friendly_name : op for op in ops \
if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput"}
if len(output_ops) != 0:
output_name, output_info = output_ops.popitem()
if func:
ops = func.get_ordered_ops()
for op in ops:
if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput":
output_name = op.friendly_name
output_info = net.outputs[output_name]
break
else:
output_name = list(net.outputs.keys())[0]
output_info = net.outputs[output_name]

if output_name == "":
log.error("Can't find a DetectionOutput layer in the topology")
Expand Down Expand Up @@ -189,12 +189,12 @@ def main():
else:
print()

tmp_image = cv2.imread(args.input)
for imid in classes:
tmp_image = cv2.imread(args.input)
for box in boxes[imid]:
cv2.rectangle(tmp_image, (box[0], box[1]), (box[2], box[3]), (232, 35, 244), 2)
cv2.imwrite("out.bmp", tmp_image)
log.info("Image out.bmp created!")
cv2.imwrite("out.bmp", tmp_image)
log.info("Image out.bmp created!")
# -----------------------------------------------------------------------------------------------------

log.info("Execution successful\n")
Expand Down

0 comments on commit e47186a

Please sign in to comment.