diff --git a/model-optimizer/extensions/middle/EltwiseInputReshape.py b/model-optimizer/extensions/middle/EltwiseInputReshape.py index bd834f46ea2f32..7ddbfe1b6dad73 100644 --- a/model-optimizer/extensions/middle/EltwiseInputReshape.py +++ b/model-optimizer/extensions/middle/EltwiseInputReshape.py @@ -69,7 +69,6 @@ def find_and_replace_pattern(self, graph: Graph): class EltwiseInputReshape(MiddleReplacementPattern): # This pass should be called directly from pipeline before layout change and other permutations enabled = False - force_shape_inference = True def find_and_replace_pattern(self, graph: Graph): # Generate a map for producers of eltwise nodes with non-normalized shapes @@ -97,11 +96,21 @@ def find_and_replace_pattern(self, graph: Graph): for unsqueeze_dims in mapping[producer_port].keys(): unsqueeze_name = producer_node.soft_get('name', producer_node.id) + '/EltwiseReshape' unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(list(unsqueeze_dims))}, - {'name': unsqueeze_name, - 'override_output_shape': True}) + {'name': unsqueeze_name}) unsqueeze_node.in_port(0).connect(producer_port) # Insert Reshape with determined output shape between the current producer and eltwise node for consumer_port in mapping[producer_port][unsqueeze_dims]: consumer_port.connect(unsqueeze_node.out_port(0)) + + # The shape and value adjustments must be explicitly done within the transformation + # since the transformation is called from Fusing transformation that excludes + # automatic call of shape inference pass + producer_port_value = producer_port.data.get_value() + producer_port_shape = producer_port.data.get_shape() + new_shape = np.insert(producer_port_shape, np.zeros_like(unsqueeze_dims), 1) + if producer_port_value is not None: + unsqueeze_node.out_port(0).data.set_value(np.reshape(producer_port_value, new_shape)) + else: + unsqueeze_node.out_port(0).data.set_shape(new_shape) diff --git a/model-optimizer/extensions/middle/EltwiseInputReshape_test.py b/model-optimizer/extensions/middle/EltwiseInputReshape_test.py index 1da0c56bdbfabb..8b8a0451b52274 100644 --- a/model-optimizer/extensions/middle/EltwiseInputReshape_test.py +++ b/model-optimizer/extensions/middle/EltwiseInputReshape_test.py @@ -20,7 +20,6 @@ from extensions.middle.EltwiseInputReshape import EltwiseInputReshape from mo.front.common.partial_infer.utils import int64_array -from mo.middle.passes.eliminate import shape_inference from mo.middle.passes.eliminate_test import build_graph from mo.utils.ir_engine.compare_graphs import compare_graphs @@ -86,7 +85,6 @@ def test1_not_constant(self): 'placeholder_3_data': {'shape': np.array([64, 1])}, 'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])} }, nodes_with_edges_only=True) - shape_inference(graph) graph_ref = build_graph(nodes_attributes, [ @@ -120,7 +118,6 @@ def test1_not_constant(self): pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) - shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True) self.assertTrue(flag, resp) @@ -216,7 +213,6 @@ def test_mega_hardcore(self): pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) - shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True) self.assertTrue(flag, resp) @@ -278,7 +274,6 @@ def test2_not_constant(self): pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) - shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) @@ -334,7 +329,6 @@ def test3_not_constant(self): pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) - shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) @@ -397,7 +391,6 @@ def test4_constant(self): pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) - shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) @@ -451,7 +444,6 @@ def test5_constant(self): pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) - shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp) @@ -495,7 +487,6 @@ def test6_not_constant(self): pattern = EltwiseInputReshape() pattern.find_and_replace_pattern(graph) - shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)