Skip to content

Commit

Permalink
[MO] Fix ONNX Clamp-11 shape infer with no min/max inputs (#2603)
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir authored Oct 12, 2020
1 parent ef2aa3a commit 9a9b231
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion model-optimizer/extensions/back/ClampNormalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def replace_pattern(self, graph: Graph, match: dict):
clamp.out_port(0).get_connection().set_source(min_node.out_port(0))
clamp.in_port(2).get_connection().set_destination(min_node.in_port(1))
assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used'
rename_node(max_node if min_node is None else min_node, name)
rename_node(min_node if min_node is not None else max_node, name)
else:
a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node()
rename_node(a_clamp, name)
Expand Down
4 changes: 2 additions & 2 deletions model-optimizer/extensions/back/ClampNormalizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_all_dynamic_inputs(self):
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)

def test_no_2nd_input(self):
def test_no_max_input(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
Expand All @@ -95,7 +95,7 @@ def test_no_2nd_input(self):
(flag, resp) = compare_graphs(graph, ref_graph, 'result')
self.assertTrue(flag, resp)

def test_no_1st_input(self):
def test_no_min_input(self):
nodes = {
**regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}),
**regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}),
Expand Down
11 changes: 5 additions & 6 deletions model-optimizer/mo/ops/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,13 @@ def __init__(self, graph: Graph, attrs: dict):
@staticmethod
def infer(node):
name = node.soft_get('name', node.id)
connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()]

assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \
'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports))
min_input_connected = node.has_port('in', 1) and not node.in_port(1).disconnected()
max_input_connected = node.has_port('in', 2) and not node.in_port(2).disconnected()

input_value = node.in_port(0).data.get_value()
min_value = node.in_port(1).data.get_value()
max_value = node.in_port(2).data.get_value()
min_value = node.in_port(1).data.get_value() if min_input_connected else np.finfo(np.float32).min
max_value = node.in_port(2).data.get_value() if max_input_connected else np.finfo(np.float32).max

if input_value is not None and min_value is not None and max_value is not None:
assert np.all(max_value >= min_value), \
'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name)
Expand Down

0 comments on commit 9a9b231

Please sign in to comment.