diff --git a/model-optimizer/extensions/back/ClampNormalizer.py b/model-optimizer/extensions/back/ClampNormalizer.py index 52fe3c73535c7b..a702b9544750f9 100644 --- a/model-optimizer/extensions/back/ClampNormalizer.py +++ b/model-optimizer/extensions/back/ClampNormalizer.py @@ -63,7 +63,7 @@ def replace_pattern(self, graph: Graph, match: dict): clamp.out_port(0).get_connection().set_source(min_node.out_port(0)) clamp.in_port(2).get_connection().set_destination(min_node.in_port(1)) assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used' - rename_node(max_node if min_node is None else min_node, name) + rename_node(min_node if min_node is not None else max_node, name) else: a_clamp = AttributedClamp(graph, {'name': name, 'min': min_value, 'max': max_value}).create_node() rename_node(a_clamp, name) diff --git a/model-optimizer/extensions/back/ClampNormalizer_test.py b/model-optimizer/extensions/back/ClampNormalizer_test.py index eb5b07e7a53cf0..e1b8c2ec9b03e6 100644 --- a/model-optimizer/extensions/back/ClampNormalizer_test.py +++ b/model-optimizer/extensions/back/ClampNormalizer_test.py @@ -73,7 +73,7 @@ def test_all_dynamic_inputs(self): (flag, resp) = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp) - def test_no_2nd_input(self): + def test_no_max_input(self): nodes = { **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), @@ -95,7 +95,7 @@ def test_no_2nd_input(self): (flag, resp) = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp) - def test_no_1st_input(self): + def test_no_min_input(self): nodes = { **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], {'type': 'Parameter'}), **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], {'type': None, 'op': 'Clamp'}), diff --git a/model-optimizer/mo/ops/clamp.py b/model-optimizer/mo/ops/clamp.py index c22f59f5bffea8..b2942bbaba457f 100644 --- a/model-optimizer/mo/ops/clamp.py +++ b/model-optimizer/mo/ops/clamp.py @@ -71,14 +71,13 @@ def __init__(self, graph: Graph, attrs: dict): @staticmethod def infer(node): name = node.soft_get('name', node.id) - connected_in_ports = [port.idx for port in node.in_ports().values() if not port.disconnected()] - - assert len(connected_in_ports) == 3 and sorted(connected_in_ports) == [0, 1, 2], \ - 'Clamp should have exactly three inputs, but it has {}'.format(len(connected_in_ports)) + min_input_connected = node.has_port('in', 1) and not node.in_port(1).disconnected() + max_input_connected = node.has_port('in', 2) and not node.in_port(2).disconnected() input_value = node.in_port(0).data.get_value() - min_value = node.in_port(1).data.get_value() - max_value = node.in_port(2).data.get_value() + min_value = node.in_port(1).data.get_value() if min_input_connected else np.finfo(np.float32).min + max_value = node.in_port(2).data.get_value() if max_input_connected else np.finfo(np.float32).max + if input_value is not None and min_value is not None and max_value is not None: assert np.all(max_value >= min_value), \ 'Clamp max_value=={} is less than min_value=={} for node `{}`'.format(max_value, min_value, name)