diff --git a/model-optimizer/extensions/middle/FusedBatchNormTraining.py b/model-optimizer/extensions/middle/FusedBatchNormTraining.py index e693c2d0b451e9..56761ccd052e3d 100644 --- a/model-optimizer/extensions/middle/FusedBatchNormTraining.py +++ b/model-optimizer/extensions/middle/FusedBatchNormTraining.py @@ -58,9 +58,9 @@ def replace_pattern(self, graph: Graph, match: dict): input_rank = len(node.in_port(0).data.get_shape()) rng = create_op_with_const_inputs(graph, Range, - {0: int64_array(2), 1: int64_array(input_rank), 2: int64_array(1)}, + {0: int64_array(1), 1: int64_array(input_rank - 1), 2: int64_array(1)}, {'name': node_name + '/Range', 'output_type': np.int64}) - mvn = MVN(graph, {'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'outside_sqrt', + mvn = MVN(graph, {'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'override_output_shape': True}).create_node() node.in_port(0).get_connection().insert_node(mvn) mvn.in_port(1).connect(rng.out_port(0)) diff --git a/model-optimizer/unit_tests/extensions/middle/FusedBatchNormTraining_test.py b/model-optimizer/unit_tests/extensions/middle/FusedBatchNormTraining_test.py index d7a4c4a3e398ec..76475b2eb746b5 100644 --- a/model-optimizer/unit_tests/extensions/middle/FusedBatchNormTraining_test.py +++ b/model-optimizer/unit_tests/extensions/middle/FusedBatchNormTraining_test.py @@ -52,16 +52,16 @@ 'reshape_to_orig': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'}, 'reshape_to_orig_data': {'value': None, 'shape': None, 'kind': 'data'}, - 'start': {'kind': 'op', 'op': 'Const'}, - 'start_data': {'value': None, 'shape': None, 'kind': 'data'}, - 'stop': {'kind': 'op', 'op': 'Const'}, - 'stop_data': {'value': None, 'shape': None, 'kind': 'data'}, - 'step': {'kind': 'op', 'op': 'Const'}, - 'step_data': {'value': None, 'shape': None, 'kind': 'data'}, + 'start': {'kind': 'op', 'op': 'Const', 'value': int64_array(1)}, + 'start_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(1)}, + 'stop': {'kind': 'op', 'op': 'Const', 'value': int64_array(3)}, + 'stop_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(3)}, + 'step': {'kind': 'op', 'op': 'Const', 'value': int64_array(1)}, + 'step_data': {'value': None, 'shape': None, 'kind': 'data', 'value': int64_array(1)}, 'mvn_axes': {'kind': 'op', 'op': 'Range'}, 'mvn_axes_data': {'value': None, 'shape': None, 'kind': 'data'}, - 'mvn': {'type': 'MVN', 'value': None, 'kind': 'op', 'op': 'MVN', 'eps': 1e-3}, + 'mvn': {'type': 'MVN', 'value': None, 'kind': 'op', 'op': 'MVN', 'eps': 1e-3, 'eps_mode': 'inside_sqrt'}, 'mvn_data': {'value': None, 'shape': None, 'kind': 'data'}, 'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},