Skip to content

Commit

Permalink
Fixes for Mask-RCNN conversion (openvinotoolkit#654)
Browse files Browse the repository at this point in the history
* Fixed ONNX Mask-RCNN conversion

* Fixed validate_and_infet_types for NMS ops: added check for number of connected inputs

* Updated NMS ops to properly handle optional input with index 2

* Fixed typo in the implementation
  • Loading branch information
lazarevevgeny authored May 28, 2020
1 parent ec5c9db commit 0efe474
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 37 deletions.
12 changes: 11 additions & 1 deletion model-optimizer/extensions/front/onnx/mask_rcnn_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np

from extensions.front.onnx.softmaxONNX_to_softmax import SoftmaxONNXFrontReplacer
from extensions.ops.Cast import Cast
from extensions.ops.detectionoutput_onnx import ExperimentalDetectronDetectionOutput
from extensions.ops.parameter import Parameter
from extensions.ops.roifeatureextractor_onnx import ExperimentalDetectronROIFeatureExtractor
Expand All @@ -29,7 +31,7 @@
input_fpn_heads = ('486', '454', '422', '390')


class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral):
class ONNXMaskRCNNTransformation(FrontReplacementFromConfigFileGeneral):
"""
This transformation performs 3 actions:
1. Replaces a sub-graph calculating ROIAlign over FPN heads with a single ExperimentalDetectronROIFeatureExtractor
Expand All @@ -42,6 +44,11 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
"""
replacement_id = 'ONNXMaskRCNNReplacement'

def run_before(self):
# the node "2774" which is used in this transformation is of op SoftMaxONNX. But operations of op SoftMaxONNX
# will be replaced with a transformation SoftmaxONNXFrontReplacer
return [SoftmaxONNXFrontReplacer]

def transform_graph(self, graph: Graph, replacement_descriptions: dict):
insert_ExperimentalDetectronROIFeatureExtractor2(graph)
insert_do(graph, replacement_descriptions)
Expand Down Expand Up @@ -80,6 +87,9 @@ def insert_do(graph: Graph, replacement_descriptions):
old_do_output_nodes = [Node(graph, node_id) for node_id in do_outputs]
for old_node, new_port in zip(old_do_output_nodes, do_output_ports):
old_node.out_port(0).get_connection().set_source(new_port)
# the consumer of the second output port of the ExperimentalDetectronDetectionOutput is the Mul node which second
# input is of type int64 so it is necessary to insert Cast to have data types match
do_node.out_port(1).get_connection().insert_node(Cast(graph, {'dst_type': np.int64}).create_node())


def insert_ExperimentalDetectronROIFeatureExtractor1(graph: Graph):
Expand Down
106 changes: 70 additions & 36 deletions ngraph/src/ngraph/op/non_max_suppression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ shared_ptr<Node>
{
check_new_args_count(this, new_args);
NODE_VALIDATION_CHECK(
this, new_args.size() >= 3 && new_args.size() <= 5, "Number of inputs must be 3, 4 or 5");
this, new_args.size() >= 2 && new_args.size() <= 5, "Number of inputs must be 2, 3, 4 or 5");
if (new_args.size() == 5)
{
return make_shared<op::v1::NonMaxSuppression>(new_args.at(0),
Expand All @@ -83,7 +83,7 @@ shared_ptr<Node>
m_box_encoding,
m_sort_result_descending);
}
else
else if (new_args.size() == 3)
{
return make_shared<op::v1::NonMaxSuppression>(
new_args.at(0),
Expand All @@ -94,6 +94,17 @@ shared_ptr<Node>
m_box_encoding,
m_sort_result_descending);
}
else
{
return make_shared<op::v1::NonMaxSuppression>(
new_args.at(0),
new_args.at(1),
op::Constant::create(element::i32, Shape{}, {0}),
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f}),
m_box_encoding,
m_sort_result_descending);
}
}

bool ngraph::op::v1::NonMaxSuppression::visit_attributes(AttributeVisitor& visitor)
Expand Down Expand Up @@ -133,24 +144,30 @@ void op::v1::NonMaxSuppression::validate_and_infer_types()
"Expected a 3D tensor for the 'scores' input. Got: ",
scores_ps);

const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);
if (get_inputs().size() >= 3) {
const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);
}

const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
if (get_inputs().size() >= 4) {
const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
}

const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
if (get_inputs().size() >= 5) {
const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
}

const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
Expand Down Expand Up @@ -268,7 +285,7 @@ shared_ptr<Node>
{
check_new_args_count(this, new_args);
NODE_VALIDATION_CHECK(
this, new_args.size() >= 3 && new_args.size() <= 5, "Number of inputs must be 3, 4 or 5");
this, new_args.size() >= 2 && new_args.size() <= 5, "Number of inputs must be 2, 3, 4 or 5");
if (new_args.size() == 5)
{
return make_shared<op::v3::NonMaxSuppression>(new_args.at(0),
Expand All @@ -292,7 +309,7 @@ shared_ptr<Node>
m_sort_result_descending,
m_output_type);
}
else
else if (new_args.size() == 3)
{
return make_shared<op::v3::NonMaxSuppression>(
new_args.at(0),
Expand All @@ -301,6 +318,17 @@ shared_ptr<Node>
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f}),
m_box_encoding,
m_sort_result_descending);
}
else
{
return make_shared<op::v3::NonMaxSuppression>(
new_args.at(0),
new_args.at(1),
op::Constant::create(element::i32, Shape{}, {0}),
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f}),
m_box_encoding,
m_sort_result_descending,
m_output_type);
}
Expand Down Expand Up @@ -343,24 +371,30 @@ void op::v3::NonMaxSuppression::validate_and_infer_types()
"Expected a 3D tensor for the 'scores' input. Got: ",
scores_ps);

const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);
if (get_inputs().size() >= 3) {
const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);
}

const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
if (get_inputs().size() >= 4) {
const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
}

const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
if (get_inputs().size() >= 5) {
const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
}

const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
Expand Down

0 comments on commit 0efe474

Please sign in to comment.