Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for Mask-RCNN conversion #654

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion model-optimizer/extensions/front/onnx/mask_rcnn_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import numpy as np

from extensions.front.onnx.softmaxONNX_to_softmax import SoftmaxONNXFrontReplacer
from extensions.ops.Cast import Cast
from extensions.ops.detectionoutput_onnx import ExperimentalDetectronDetectionOutput
from extensions.ops.parameter import Parameter
from extensions.ops.roifeatureextractor_onnx import ExperimentalDetectronROIFeatureExtractor
Expand All @@ -29,7 +31,7 @@
input_fpn_heads = ('486', '454', '422', '390')


class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral):
class ONNXMaskRCNNTransformation(FrontReplacementFromConfigFileGeneral):
"""
This transformation performs 3 actions:
1. Replaces a sub-graph calculating ROIAlign over FPN heads with a single ExperimentalDetectronROIFeatureExtractor
Expand All @@ -42,6 +44,11 @@ class ObjectDetectionAPIOutputReplacement(FrontReplacementFromConfigFileGeneral)
"""
replacement_id = 'ONNXMaskRCNNReplacement'

def run_before(self):
# the node "2774" which is used in this transformation is of op SoftMaxONNX. But operations of op SoftMaxONNX
# will be replaced with a transformation SoftmaxONNXFrontReplacer
return [SoftmaxONNXFrontReplacer]

def transform_graph(self, graph: Graph, replacement_descriptions: dict):
insert_ExperimentalDetectronROIFeatureExtractor2(graph)
insert_do(graph, replacement_descriptions)
Expand Down Expand Up @@ -80,6 +87,9 @@ def insert_do(graph: Graph, replacement_descriptions):
old_do_output_nodes = [Node(graph, node_id) for node_id in do_outputs]
for old_node, new_port in zip(old_do_output_nodes, do_output_ports):
old_node.out_port(0).get_connection().set_source(new_port)
# the consumer of the second output port of the ExperimentalDetectronDetectionOutput is the Mul node which second
# input is of type int64 so it is necessary to insert Cast to have data types match
do_node.out_port(1).get_connection().insert_node(Cast(graph, {'dst_type': np.int64}).create_node())


def insert_ExperimentalDetectronROIFeatureExtractor1(graph: Graph):
Expand Down
52 changes: 30 additions & 22 deletions ngraph/src/ngraph/op/non_max_suppression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,22 @@ void op::v1::NonMaxSuppression::validate_and_infer_types()
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);

const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
if (get_inputs().size() >= 4) {
const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
}

const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
if (get_inputs().size() >= 5) {
const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
}

const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
Expand Down Expand Up @@ -349,18 +353,22 @@ void op::v3::NonMaxSuppression::validate_and_infer_types()
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);

const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
if (get_inputs().size() >= 4) {
GlebKazantaev marked this conversation as resolved.
Show resolved Hide resolved
const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
}

const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
if (get_inputs().size() >= 5) {
const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
}

const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
Expand Down