From dd1ea51e686a8363c015e92f62daf1dab6b3c74a Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 11:17:21 +0300 Subject: [PATCH 01/14] Added HSwish operation --- model-optimizer/extensions/ops/activation_ops.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/model-optimizer/extensions/ops/activation_ops.py b/model-optimizer/extensions/ops/activation_ops.py index c6c3f3e186daf9..162ebf70e702e3 100644 --- a/model-optimizer/extensions/ops/activation_ops.py +++ b/model-optimizer/extensions/ops/activation_ops.py @@ -115,8 +115,6 @@ class Atanh(Activation): class ReLU6(AttributedClamp): - op = 'ReLU6' - def __init__(self, graph: Graph, attrs: dict): relu6_attrs = {'min': 0, 'max': 6} relu6_attrs.update(attrs) @@ -244,6 +242,12 @@ class Mish(Activation): operation = staticmethod(lambda x: x * np.tanh(np.ln(np.exp(x) + 1.0))) +class HSwish(Activation): + op = 'HSwish' + version = 'opset4' + operation = staticmethod(lambda x: x * np.minimum(np.maximum(x + 3.0, 0.0), 6.0) / 6.0) + + class Swish(Op): op = 'Swish' From 5140dfadbb3a32030ba6f96bac80df3de1abd8a5 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 11:17:41 +0300 Subject: [PATCH 02/14] Added HSwish fusing transformation --- .../extensions/front/HSwish_fusion.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 model-optimizer/extensions/front/HSwish_fusion.py diff --git a/model-optimizer/extensions/front/HSwish_fusion.py b/model-optimizer/extensions/front/HSwish_fusion.py new file mode 100644 index 00000000000000..e99eedb948ebc4 --- /dev/null +++ b/model-optimizer/extensions/front/HSwish_fusion.py @@ -0,0 +1,71 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import numpy as np + +from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer +from extensions.ops.activation_ops import HSwish +from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.subgraph_matcher import SubgraphMatch +from mo.graph.graph import Graph, rename_nodes + + +class HSwishWithClamp(FrontReplacementSubgraph): + """ + The transformation looks for the pattern with ReLU6 (Clamp) defining the HSwish function. + """ + enabled = True + + def run_after(self): + return [AttributedClampNormalizer] + + def pattern(self): + return dict( + nodes=[ + ('input', dict()), + ('add', dict(op='Add')), + ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [0.0], atol=1e-6))), + ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [3.0], atol=1e-6))), + ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [6.0], atol=1e-6))), + ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [1 / 6.0], atol=1e-6))), + ('clamp', dict(op='Clamp')), + ('mul', dict(op='Mul')), + ('mul_2', dict(op='Mul')), + ], + edges=[ + ('input', 'add', {'out': 0}), + ('input', 'mul', {'out': 0}), + ('const_3', 'add', {}), + ('add', 'clamp', {'in': 0}), + ('clamp', 'mul', {}), + ('const_0', 'clamp', {'in': 1}), + ('const_6', 'clamp', {'in': 2}), + ('mul', 'mul_2', {}), + ('const_1_6', 'mul_2', {}), + ]) + + def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): + add = match['add'] + mul_2 = match['mul_2'] + + # determine the input port of Add which gets the 'input' node output + input_port = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const') + mul_2_name = mul_2.soft_get('name', mul_2.id) + + hswish = HSwish(graph, {}).create_node() + hswish.in_port(0).connect(add.in_port(input_port).get_source()) + mul_2.out_port(0).get_connection().set_source(hswish.out_port(0)) + + rename_nodes([(mul_2, mul_2_name + '/TBR'), (hswish, mul_2_name)]) From 8d5554da35709856a18e4ba7fabd3d35fa47c2f3 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 11:40:33 +0300 Subject: [PATCH 03/14] Fixed BOM --- model-optimizer/automation/package_BOM.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index bd74604d7db1e8..0c7c57999eba04 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -127,6 +127,7 @@ extensions/front/freeze_placeholder_value.py extensions/front/GeLUMerger_Erf.py extensions/front/GeLUMerger_Tanh.py extensions/front/global_pooling_to_reduce.py +extensions/front/HSwish_fusing.py extensions/front/image_scaler.py extensions/front/input_cut.py extensions/front/instance_normalization.py From f64586ac50f527f67fc06370831a1e2a05f54660 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 14:54:17 +0300 Subject: [PATCH 04/14] Added unit test for HSwish fusing transformation --- .../extensions/front/HSwish_fusing_test.py | 126 ++++++++++++++++++ .../extensions/front/HSwish_fusion.py | 68 ++++++++-- model-optimizer/mo/utils/unittest/graph.py | 12 ++ 3 files changed, 194 insertions(+), 12 deletions(-) create mode 100644 model-optimizer/extensions/front/HSwish_fusing_test.py diff --git a/model-optimizer/extensions/front/HSwish_fusing_test.py b/model-optimizer/extensions/front/HSwish_fusing_test.py new file mode 100644 index 00000000000000..99a31bf8d2f04c --- /dev/null +++ b/model-optimizer/extensions/front/HSwish_fusing_test.py @@ -0,0 +1,126 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +from extensions.front.HSwish_fusion import HSwishWithClamp, HSwishWithMinMax +from mo.front.common.partial_infer.utils import int64_array, float_array +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs + +ref_nodes = {**regular_op('input', {'type': 'Parameter'}), + **regular_op('hswish', {'type': 'HSwish', 'name': 'final_mul'}), + **result('result') + } +ref_edges = [('input', 'hswish'), ('hswish', 'result')] + + +class HSwishWithClampTest(unittest.TestCase): + nodes = { + **regular_op('input', {'type': 'Parameter'}), + **regular_op('add', {'op': 'Add'}), + **regular_op('relu6', {'op': 'Clamp'}), + **regular_op('mul', {'op': 'Mul'}), + **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}), + **const('const_0', float_array([0.0])), + **const('const_3', float_array([3.0])), + **const('const_6', float_array([6.0])), + **const('const_1_6', float_array([1.0 / 6.0])), + **result('result'), + } + + edges = [('input', 'mul', {'in': 0, 'out': 0}), + ('input', 'add', {'in': 0, 'out': 0}), + ('const_3', 'add', {'in': 1, 'out': 0}), + ('add', 'relu6', {'in': 0, 'out': 0}), + ('const_0', 'relu6', {'in': 1, 'out': 0}), + ('const_6', 'relu6', {'in': 2, 'out': 0}), + ('relu6', 'mul', {'in': 1, 'out': 0}), + ('mul', 'mul_2', {'in': 0, 'out': 0}), + ('const_1_6', 'mul_2', {'in': 1, 'out': 0}), + ('mul_2', 'result', {'in': 0, 'out': 0})] + + def test_hswish_with_clamp(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {}) + + graph_ref = build_graph(ref_nodes, ref_edges) + graph.stage = 'front' + + HSwishWithClamp().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + def test_hswish_with_clamp_wrong_constant(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}}) + + graph_ref = graph.copy() + graph.stage = 'front' + + HSwishWithClamp().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + +class HSwishWithMinMaxTest(unittest.TestCase): + nodes = { + **regular_op('input', {'type': 'Parameter'}), + **regular_op('add', {'op': 'Add'}), + **regular_op('max', {'op': 'Maximum'}), + **regular_op('min', {'op': 'Minimum'}), + **regular_op('mul', {'op': 'Mul'}), + **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}), + **const('const_0', float_array([0.0])), + **const('const_3', float_array([3.0])), + **const('const_6', float_array([6.0])), + **const('const_1_6', float_array([1.0 / 6.0])), + **result('result'), + } + + edges = [('input', 'mul', {'in': 1, 'out': 0}), + ('input', 'add', {'in': 0, 'out': 0}), + ('const_3', 'add', {'in': 1, 'out': 0}), + ('add', 'max', {'in': 0, 'out': 0}), + ('const_0', 'max', {'in': 1, 'out': 0}), + ('max', 'min', {'in': 0, 'out': 0}), + ('const_6', 'min', {'in': 1, 'out': 0}), + ('min', 'mul', {'in': 0, 'out': 0}), + ('mul', 'mul_2', {'in': 0, 'out': 0}), + ('const_1_6', 'mul_2', {'in': 1, 'out': 0}), + ('mul_2', 'result', {'in': 0, 'out': 0})] + + def test_hswish_with_min_max(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {}) + + graph_ref = build_graph(ref_nodes, ref_edges) + graph.stage = 'front' + + HSwishWithMinMax().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + def test_hswish_with_min_max_wrong_constant(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}}) + + graph_ref = graph.copy() + graph.stage = 'front' + + HSwishWithMinMax().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/front/HSwish_fusion.py b/model-optimizer/extensions/front/HSwish_fusion.py index e99eedb948ebc4..862b9aa85c55b9 100644 --- a/model-optimizer/extensions/front/HSwish_fusion.py +++ b/model-optimizer/extensions/front/HSwish_fusion.py @@ -22,6 +22,21 @@ from mo.graph.graph import Graph, rename_nodes +def replace_with_hswish(graph: Graph, match: [dict, SubgraphMatch]): + add = match['add'] + mul_2 = match['mul_2'] + + # determine the input port of Add which gets the 'input' node output + input_port = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const') + mul_2_name = mul_2.soft_get('name', mul_2.id) + + hswish = HSwish(graph, {}).create_node() + hswish.in_port(0).connect(add.in_port(input_port).get_source()) + mul_2.out_port(0).get_connection().set_source(hswish.out_port(0)) + + rename_nodes([(mul_2, mul_2_name + '/TBR'), (hswish, mul_2_name)]) + + class HSwishWithClamp(FrontReplacementSubgraph): """ The transformation looks for the pattern with ReLU6 (Clamp) defining the HSwish function. @@ -45,27 +60,56 @@ def pattern(self): ('mul_2', dict(op='Mul')), ], edges=[ - ('input', 'add', {'out': 0}), - ('input', 'mul', {'out': 0}), + ('input', 'add', {}), + ('input', 'mul', {}), ('const_3', 'add', {}), ('add', 'clamp', {'in': 0}), - ('clamp', 'mul', {}), ('const_0', 'clamp', {'in': 1}), ('const_6', 'clamp', {'in': 2}), + ('clamp', 'mul', {}), ('mul', 'mul_2', {}), ('const_1_6', 'mul_2', {}), ]) def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): - add = match['add'] - mul_2 = match['mul_2'] + replace_with_hswish(graph, match) - # determine the input port of Add which gets the 'input' node output - input_port = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const') - mul_2_name = mul_2.soft_get('name', mul_2.id) - hswish = HSwish(graph, {}).create_node() - hswish.in_port(0).connect(add.in_port(input_port).get_source()) - mul_2.out_port(0).get_connection().set_source(hswish.out_port(0)) +class HSwishWithMinMax(FrontReplacementSubgraph): + """ + The transformation looks for the pattern with Min/Max defining the HSwish function. + """ + enabled = True + + def run_after(self): + return [AttributedClampNormalizer] + + def pattern(self): + return dict( + nodes=[ + ('input', dict()), + ('add', dict(op='Add')), + ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [0.0], atol=1e-6))), + ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [3.0], atol=1e-6))), + ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [6.0], atol=1e-6))), + ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [1 / 6.0], atol=1e-6))), + ('max', dict(op='Maximum')), + ('min', dict(op='Minimum')), + ('mul', dict(op='Mul')), + ('mul_2', dict(op='Mul')), + ], + edges=[ + ('input', 'add', {'out': 0}), + ('input', 'mul', {'out': 0}), + ('const_3', 'add', {}), + ('add', 'max', {}), + ('const_0', 'max', {}), + ('max', 'min', {}), + ('const_6', 'min', {}), + ('min', 'mul', {}), + ('mul', 'mul_2', {}), + ('const_1_6', 'mul_2', {}), + ]) - rename_nodes([(mul_2, mul_2_name + '/TBR'), (hswish, mul_2_name)]) + def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): + replace_with_hswish(graph, match) diff --git a/model-optimizer/mo/utils/unittest/graph.py b/model-optimizer/mo/utils/unittest/graph.py index ae4edb599c91fd..b7af97f18ed644 100644 --- a/model-optimizer/mo/utils/unittest/graph.py +++ b/model-optimizer/mo/utils/unittest/graph.py @@ -236,6 +236,18 @@ def build_graph_with_edge_attrs(nodes_attrs: dict, edges: list, update_attribute assert (node_name in graph.nodes()) for attr, value in new_attrs.items(): graph.node[node_name][attr] = value + + for node in graph.get_op_nodes(): + # Add in_ports attribute + in_edges = node.in_edges() + for attr in in_edges.values(): + node.add_input_port(idx=attr['in']) + + # Add out_ports attribute + out_edges = node.out_edges() + for attr in out_edges.values(): + node.add_output_port(idx=attr['out']) + graph.graph['cmd_params'] = cli return graph From 0ee215732aedff129045f79f94a72569070c0c2f Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 14:54:34 +0300 Subject: [PATCH 05/14] Fixed unit tests for transformations using 'build_graph_with_edge_attrs' function to build the graph --- model-optimizer/extensions/front/caffe/axpy_test.py | 8 ++++---- model-optimizer/extensions/front/tf/fifo_replacer_test.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/model-optimizer/extensions/front/caffe/axpy_test.py b/model-optimizer/extensions/front/caffe/axpy_test.py index 9f26c9459448c0..fc462b60d6b78f 100644 --- a/model-optimizer/extensions/front/caffe/axpy_test.py +++ b/model-optimizer/extensions/front/caffe/axpy_test.py @@ -29,10 +29,10 @@ def test_axpy(self): 'axpy': {'type': 'Axpy', 'kind': 'op', 'op': 'Axpy'}, 'node_4': {'kind': 'op', 'type': 'Identity', 'op': 'Parameter'}} edges = [ - ('node_1', 'axpy', {'in': 0}), - ('node_2', 'axpy', {'in': 1}), - ('node_3', 'axpy', {'in': 2}), - ('axpy', 'node_4', {'in': 0})] + ('node_1', 'axpy', {'in': 0, 'out': 0}), + ('node_2', 'axpy', {'in': 1, 'out': 0}), + ('node_3', 'axpy', {'in': 2, 'out': 0}), + ('axpy', 'node_4', {'in': 0, 'out': 0})] graph = build_graph_with_edge_attrs(nodes, edges) node = Node(graph, 'axpy') replacer = AxpyToSSandAdd() diff --git a/model-optimizer/extensions/front/tf/fifo_replacer_test.py b/model-optimizer/extensions/front/tf/fifo_replacer_test.py index fbe6c3357ffbc0..e4a6099c320b24 100644 --- a/model-optimizer/extensions/front/tf/fifo_replacer_test.py +++ b/model-optimizer/extensions/front/tf/fifo_replacer_test.py @@ -61,9 +61,9 @@ def test_fifo_with_out_label_batch(self): 'image_batch': {'op': 'Identity', 'data_type': np.float32, 'kind': 'op'}, } edges_no_label = [ - ('placeholder', 'batch_join', {'out': 0}), - ('batch_join/fifo_queue', 'batch_join', {'out': 0}), - ('batch_join', 'image_batch', {'out': 0}) + ('placeholder', 'batch_join', {'out': 0, 'in': 0}), + ('batch_join/fifo_queue', 'batch_join', {'out': 0, 'in': 1}), + ('batch_join', 'image_batch', {'out': 0, 'in': 0}) ] graph = build_graph_with_edge_attrs(nodes_no_label, edges_no_label) From a6efda43908947c8db76b4220f5c1d6140a71096 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 16:01:23 +0300 Subject: [PATCH 06/14] Added fusion transformation for Swish operation --- model-optimizer/automation/package_BOM.txt | 3 +- .../extensions/front/HSwish_fusion.py | 16 ++-- .../extensions/front/Swish_fusion.py | 96 +++++++++++++++++++ .../extensions/front/Swish_fusion_test.py | 84 ++++++++++++++++ 4 files changed, 190 insertions(+), 9 deletions(-) create mode 100644 model-optimizer/extensions/front/Swish_fusion.py create mode 100644 model-optimizer/extensions/front/Swish_fusion_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 0c7c57999eba04..377691d9fc09a6 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -127,7 +127,7 @@ extensions/front/freeze_placeholder_value.py extensions/front/GeLUMerger_Erf.py extensions/front/GeLUMerger_Tanh.py extensions/front/global_pooling_to_reduce.py -extensions/front/HSwish_fusing.py +extensions/front/HSwish_fusion.py extensions/front/image_scaler.py extensions/front/input_cut.py extensions/front/instance_normalization.py @@ -332,6 +332,7 @@ extensions/front/split_normalizer.py extensions/front/SqueezeNormalize.py extensions/front/standalone_const_eraser.py extensions/front/sub.py +extensions/front/Swish_fusion.py extensions/front/tf/__init__.py extensions/front/tf/activation_ext.py extensions/front/tf/argmax_ext.py diff --git a/model-optimizer/extensions/front/HSwish_fusion.py b/model-optimizer/extensions/front/HSwish_fusion.py index 862b9aa85c55b9..c6c3205531ab66 100644 --- a/model-optimizer/extensions/front/HSwish_fusion.py +++ b/model-optimizer/extensions/front/HSwish_fusion.py @@ -51,10 +51,10 @@ def pattern(self): nodes=[ ('input', dict()), ('add', dict(op='Add')), - ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [0.0], atol=1e-6))), - ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [3.0], atol=1e-6))), - ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [6.0], atol=1e-6))), - ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [1 / 6.0], atol=1e-6))), + ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))), + ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))), + ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))), + ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))), ('clamp', dict(op='Clamp')), ('mul', dict(op='Mul')), ('mul_2', dict(op='Mul')), @@ -89,10 +89,10 @@ def pattern(self): nodes=[ ('input', dict()), ('add', dict(op='Add')), - ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [0.0], atol=1e-6))), - ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [3.0], atol=1e-6))), - ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [6.0], atol=1e-6))), - ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, [1 / 6.0], atol=1e-6))), + ('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))), + ('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))), + ('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))), + ('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))), ('max', dict(op='Maximum')), ('min', dict(op='Minimum')), ('mul', dict(op='Mul')), diff --git a/model-optimizer/extensions/front/Swish_fusion.py b/model-optimizer/extensions/front/Swish_fusion.py new file mode 100644 index 00000000000000..0dc6146311a1c4 --- /dev/null +++ b/model-optimizer/extensions/front/Swish_fusion.py @@ -0,0 +1,96 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import numpy as np + +from extensions.ops.activation_ops import Swish +from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.subgraph_matcher import SubgraphMatch +from mo.graph.graph import Graph, rename_nodes + + +class SwishWithSigmoidWithoutBeta(FrontReplacementSubgraph): + """ + The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x) + """ + enabled = True + + def pattern(self): + return dict( + nodes=[ + ('input', dict()), + ('sigmoid', dict(op='Sigmoid')), + ('mul', dict(op='Mul')), + ], + edges=[ + ('input', 'sigmoid', {}), + ('input', 'mul', {}), + ('sigmoid', 'mul', {}), + ]) + + def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): + sigmoid = match['sigmoid'] + mul = match['mul'] + + mul_name = mul.soft_get('name', mul.id) + + swish = Swish(graph, {}).create_node() + swish.in_port(0).connect(sigmoid.in_port(0).get_source()) + mul.out_port(0).get_connection().set_source(swish.out_port(0)) + + rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)]) + + +class SwishWithSigmoidWithBeta(FrontReplacementSubgraph): + """ + The transformation looks for the pattern with Sigmoid defining the Swish function: Swish(x) = x * Sigmoid(x * beta) + """ + enabled = True + + def pattern(self): + return dict( + nodes=[ + ('input', dict()), + ('sigmoid', dict(op='Sigmoid')), + ('beta', dict()), + ('mul_beta', dict(op='Mul')), + ('mul', dict(op='Mul')), + ], + edges=[ + ('input', 'mul_beta', {}), + ('input', 'mul', {}), + ('beta', 'mul_beta', {}), + ('mul_beta', 'sigmoid', {}), + ('sigmoid', 'mul', {}), + ]) + + def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): + beta = match['beta'] + mul = match['mul'] + mul_beta = match['mul_beta'] + + # determine the input port of Mul which gets the 'input' node output + mul_beta_input_port = int(mul_beta.in_port(0).get_connection().get_source().node.id == beta.id) + mul_name = mul.soft_get('name', mul.id) + + swish = Swish(graph, {}).create_node() + swish.in_port(0).connect(mul_beta.in_port(mul_beta_input_port).get_source()) + + # connect Beta value + swish.in_port(1).connect(mul_beta.in_port(1 - mul_beta_input_port).get_source()) + + mul.out_port(0).get_connection().set_source(swish.out_port(0)) + + rename_nodes([(mul, mul_name + '/TBR'), (swish, mul_name)]) diff --git a/model-optimizer/extensions/front/Swish_fusion_test.py b/model-optimizer/extensions/front/Swish_fusion_test.py new file mode 100644 index 00000000000000..5f3e9d58e1bde2 --- /dev/null +++ b/model-optimizer/extensions/front/Swish_fusion_test.py @@ -0,0 +1,84 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +from extensions.front.Swish_fusion import SwishWithSigmoidWithoutBeta, SwishWithSigmoidWithBeta +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs + +ref_nodes = {**regular_op('input', {'type': 'Parameter'}), + **regular_op('swish', {'type': 'Swish', 'name': 'final_mul'}), + **result('result') + } +ref_edges = [('input', 'swish'), ('swish', 'result')] + + +class SwishWithSigmoidWithoutBetaTest(unittest.TestCase): + nodes = { + **regular_op('input', {'type': 'Parameter'}), + **regular_op('sigmoid', {'op': 'Sigmoid'}), + **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}), + **result('result'), + } + + edges = [('input', 'mul', {'in': 0, 'out': 0}), + ('input', 'sigmoid', {'in': 0, 'out': 0}), + ('sigmoid', 'mul', {'in': 1, 'out': 0}), + ('mul', 'result', {'in': 0, 'out': 0})] + + def test_swish_with_sigmoid_without_beta_test(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {}) + + graph_ref = build_graph(ref_nodes, ref_edges) + graph.stage = 'front' + + SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + +class SwishWithSigmoidWithBetaTest(unittest.TestCase): + nodes = { + **regular_op('input', {'type': 'Parameter'}), + **regular_op('beta', {'type': 'Parameter'}), + **regular_op('mul_beta', {'op': 'Mul'}), + **regular_op('sigmoid', {'op': 'Sigmoid'}), + **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}), + **result('result'), + } + + edges = [('input', 'mul_beta', {'in': 0, 'out': 0}), + ('input', 'mul_2', {'in': 0, 'out': 0}), + ('beta', 'mul_beta', {'in': 1, 'out': 0}), + ('mul_beta', 'sigmoid', {'in': 0, 'out': 0}), + ('sigmoid', 'mul_2', {'in': 1, 'out': 0}), + ('mul_2', 'result', {'in': 0, 'out': 0})] + + def test_swish_with_sigmoid_with_beta_test(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {}) + + new_ref_nodes = ref_nodes.copy() + new_ref_nodes.update(**regular_op('beta', {'type': 'Parameter'})) + + graph_ref = build_graph(new_ref_nodes, ref_edges + [('beta', 'swish')]) + graph.stage = 'front' + + SwishWithSigmoidWithBeta().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) From 1c22db101afd8892505bf39a26802e5d9fbbf832 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 16:23:45 +0300 Subject: [PATCH 07/14] Added fusing transformation for Softplus operation --- model-optimizer/automation/package_BOM.txt | 1 + .../extensions/front/HSwish_fusing_test.py | 2 +- .../extensions/front/Softplus_fusion.py | 54 +++++++++++++++ .../extensions/front/Softplus_fusion_test.py | 68 +++++++++++++++++++ .../extensions/front/Swish_fusion.py | 1 - .../extensions/front/Swish_fusion_test.py | 2 +- 6 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 model-optimizer/extensions/front/Softplus_fusion.py create mode 100644 model-optimizer/extensions/front/Softplus_fusion_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 377691d9fc09a6..2291f901611534 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -327,6 +327,7 @@ extensions/front/reshape_dim_normalizer.py extensions/front/restore_ports.py extensions/front/scatter_normalizer.py extensions/front/softmax.py +extensions/front/SoftPlus_fusion.py extensions/front/softsign_replacer.py extensions/front/split_normalizer.py extensions/front/SqueezeNormalize.py diff --git a/model-optimizer/extensions/front/HSwish_fusing_test.py b/model-optimizer/extensions/front/HSwish_fusing_test.py index 99a31bf8d2f04c..b26bee0106d1c3 100644 --- a/model-optimizer/extensions/front/HSwish_fusing_test.py +++ b/model-optimizer/extensions/front/HSwish_fusing_test.py @@ -17,7 +17,7 @@ import unittest from extensions.front.HSwish_fusion import HSwishWithClamp, HSwishWithMinMax -from mo.front.common.partial_infer.utils import int64_array, float_array +from mo.front.common.partial_infer.utils import float_array from mo.utils.ir_engine.compare_graphs import compare_graphs from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs diff --git a/model-optimizer/extensions/front/Softplus_fusion.py b/model-optimizer/extensions/front/Softplus_fusion.py new file mode 100644 index 00000000000000..1a70e5d27af29c --- /dev/null +++ b/model-optimizer/extensions/front/Softplus_fusion.py @@ -0,0 +1,54 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import numpy as np + +from extensions.ops.activation_ops import SoftPlus +from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.subgraph_matcher import SubgraphMatch +from mo.graph.graph import Graph, rename_nodes + + +class SoftplusFusion(FrontReplacementSubgraph): + """ + The transformation looks for the pattern for the Softplus function: Softplus(x) = ln(1 + e^x) + """ + enabled = True + + def pattern(self): + return dict( + nodes=[ + ('exp', dict(op='Exp')), + ('add', dict(op='Add')), + ('const_1', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1.0, atol=1e-6))), + ('ln', dict(op='Log')), + ], + edges=[ + ('exp', 'add', {}), + ('const_1', 'add', {}), + ('add', 'ln', {}), + ]) + + def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): + ln = match['ln'] + exp = match['exp'] + + ln_name = ln.soft_get('name', ln.id) + + softplus = SoftPlus(graph, {}).create_node() + softplus.in_port(0).connect(exp.in_port(0).get_source()) + ln.out_port(0).get_connection().set_source(softplus.out_port(0)) + + rename_nodes([(ln, ln_name + '/TBR'), (softplus, ln_name)]) diff --git a/model-optimizer/extensions/front/Softplus_fusion_test.py b/model-optimizer/extensions/front/Softplus_fusion_test.py new file mode 100644 index 00000000000000..6c498b37a4a0d4 --- /dev/null +++ b/model-optimizer/extensions/front/Softplus_fusion_test.py @@ -0,0 +1,68 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +from extensions.front.Softplus_fusion import SoftplusFusion +from mo.front.common.partial_infer.utils import float_array +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs + +ref_nodes = {**regular_op('input', {'type': 'Parameter'}), + **regular_op('softplus', {'type': 'SoftPlus', 'name': 'final_mul'}), + **result('result') + } +ref_edges = [('input', 'softplus'), ('softplus', 'result')] + + +class SoftplusFusionTest(unittest.TestCase): + nodes = { + **regular_op('input', {'type': 'Parameter'}), + **regular_op('exp', {'op': 'Exp'}), + **const('const_1', float_array([1.0])), + **regular_op('add', {'op': 'Add'}), + **regular_op('ln', {'op': 'Log', 'name': 'final_log'}), + **result('result'), + } + + edges = [('input', 'exp', {'in': 0, 'out': 0}), + ('const_1', 'add', {'in': 0, 'out': 0}), + ('exp', 'add', {'in': 1, 'out': 0}), + ('add', 'ln', {'in': 0, 'out': 0}), + ('ln', 'result', {'in': 0, 'out': 0})] + + def test_softplus_fusion_test(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {}) + + graph_ref = build_graph(ref_nodes, ref_edges) + graph.stage = 'front' + + SoftplusFusion().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + def test_softplus_fusion_test_wrong_const(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_1': {'value': float_array([0.9999])}}) + + graph_ref = graph.copy() + graph.stage = 'front' + + SoftplusFusion().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + diff --git a/model-optimizer/extensions/front/Swish_fusion.py b/model-optimizer/extensions/front/Swish_fusion.py index 0dc6146311a1c4..fafd21de8a3dcf 100644 --- a/model-optimizer/extensions/front/Swish_fusion.py +++ b/model-optimizer/extensions/front/Swish_fusion.py @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import numpy as np from extensions.ops.activation_ops import Swish from mo.front.common.replacement import FrontReplacementSubgraph diff --git a/model-optimizer/extensions/front/Swish_fusion_test.py b/model-optimizer/extensions/front/Swish_fusion_test.py index 5f3e9d58e1bde2..fa716c5dd2073b 100644 --- a/model-optimizer/extensions/front/Swish_fusion_test.py +++ b/model-optimizer/extensions/front/Swish_fusion_test.py @@ -18,7 +18,7 @@ from extensions.front.Swish_fusion import SwishWithSigmoidWithoutBeta, SwishWithSigmoidWithBeta from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs +from mo.utils.unittest.graph import build_graph, regular_op, result, build_graph_with_edge_attrs ref_nodes = {**regular_op('input', {'type': 'Parameter'}), **regular_op('swish', {'type': 'Swish', 'name': 'final_mul'}), From 1fbd886f7b72cc69d1ab8278195c88b78950343a Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 16:47:36 +0300 Subject: [PATCH 08/14] Added fusion transformation for Mish operation --- model-optimizer/automation/package_BOM.txt | 3 +- .../extensions/front/Mish_fusion.py | 59 +++++++++++++++++++ .../extensions/front/Mish_fusion_test.py | 54 +++++++++++++++++ 3 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 model-optimizer/extensions/front/Mish_fusion.py create mode 100644 model-optimizer/extensions/front/Mish_fusion_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 2291f901611534..4167d680d551f5 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -151,6 +151,7 @@ extensions/front/LayerNorm.py extensions/front/Log1p.py extensions/front/LogSoftmax.py extensions/front/MatMul_normalizer.py +extensions/front/Mish_fusion.py extensions/front/MoveEmbeddedInputsToInputs.py extensions/front/mxnet/__init__.py extensions/front/mxnet/activation.py @@ -327,7 +328,7 @@ extensions/front/reshape_dim_normalizer.py extensions/front/restore_ports.py extensions/front/scatter_normalizer.py extensions/front/softmax.py -extensions/front/SoftPlus_fusion.py +extensions/front/Softplus_fusion.py extensions/front/softsign_replacer.py extensions/front/split_normalizer.py extensions/front/SqueezeNormalize.py diff --git a/model-optimizer/extensions/front/Mish_fusion.py b/model-optimizer/extensions/front/Mish_fusion.py new file mode 100644 index 00000000000000..032bdb8e40899e --- /dev/null +++ b/model-optimizer/extensions/front/Mish_fusion.py @@ -0,0 +1,59 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from extensions.front.Softplus_fusion import SoftplusFusion +from extensions.ops.activation_ops import Mish +from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.subgraph_matcher import SubgraphMatch +from mo.graph.graph import Graph, rename_nodes + + +class MishFusion(FrontReplacementSubgraph): + """ + The transformation looks for the pattern with Softplus defining the Mish function: Mish(x) = x * tanh(Softplus(x)). + """ + enabled = True + + def run_after(self): + return [SoftplusFusion] + + def pattern(self): + return dict( + nodes=[ + ('input', dict()), + ('mul', dict(op='Mul')), + ('tanh', dict(op='Tanh')), + ('softplus', dict(op='Softplus')), + ], + edges=[ + ('input', 'mul', {}), + ('input', 'softplus', {}), + ('softplus', 'tanh', {}), + ('tanh', 'mul'), + ]) + + def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): + mul = match['mul'] + + # determine the input port of Mul which gets the 'input' node output + input_port = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh') + mul_name = mul.soft_get('name', mul.id) + + mish = Mish(graph, {}).create_node() + mish.in_port(0).connect(mul.in_port(input_port).get_source()) + mul.out_port(0).get_connection().set_source(mish.out_port(0)) + + rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)]) diff --git a/model-optimizer/extensions/front/Mish_fusion_test.py b/model-optimizer/extensions/front/Mish_fusion_test.py new file mode 100644 index 00000000000000..c07fb6608620b3 --- /dev/null +++ b/model-optimizer/extensions/front/Mish_fusion_test.py @@ -0,0 +1,54 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +from extensions.front.Mish_fusion import MishFusion +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, regular_op, result, build_graph_with_edge_attrs + +ref_nodes = {**regular_op('input', {'type': 'Parameter'}), + **regular_op('mish', {'type': 'Mish', 'name': 'final_mul'}), + **result('result') + } +ref_edges = [('input', 'mish'), ('mish', 'result')] + + +class MishFusionTest(unittest.TestCase): + nodes = { + **regular_op('input', {'type': 'Parameter'}), + **regular_op('softplus', {'op': 'Softplus'}), + **regular_op('tanh', {'op': 'Tanh'}), + **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}), + **result('result'), + } + + edges = [('input', 'softplus', {'in': 0, 'out': 0}), + ('input', 'mul', {'in': 0, 'out': 0}), + ('softplus', 'tanh', {'in': 0, 'out': 0}), + ('tanh', 'mul', {'in': 1, 'out': 0}), + ('mul', 'result', {'in': 0, 'out': 0})] + + def test_mish_fusion_test(self): + graph = build_graph_with_edge_attrs(self.nodes, self.edges, {}) + + graph_ref = build_graph(ref_nodes, ref_edges) + graph.stage = 'front' + + MishFusion().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) From b86f847b2a92d8179ea572c6f9721880dde80c79 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 19:36:59 +0300 Subject: [PATCH 09/14] Added check for the node name in the unit tests --- model-optimizer/extensions/front/HSwish_fusing_test.py | 2 ++ model-optimizer/extensions/front/Mish_fusion_test.py | 1 + model-optimizer/extensions/front/Softplus_fusion_test.py | 3 ++- model-optimizer/extensions/front/Swish_fusion_test.py | 2 ++ 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/model-optimizer/extensions/front/HSwish_fusing_test.py b/model-optimizer/extensions/front/HSwish_fusing_test.py index b26bee0106d1c3..fa9cfaea53f887 100644 --- a/model-optimizer/extensions/front/HSwish_fusing_test.py +++ b/model-optimizer/extensions/front/HSwish_fusing_test.py @@ -63,6 +63,7 @@ def test_hswish_with_clamp(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'HSwish') def test_hswish_with_clamp_wrong_constant(self): graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}}) @@ -113,6 +114,7 @@ def test_hswish_with_min_max(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'HSwish') def test_hswish_with_min_max_wrong_constant(self): graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}}) diff --git a/model-optimizer/extensions/front/Mish_fusion_test.py b/model-optimizer/extensions/front/Mish_fusion_test.py index c07fb6608620b3..0082f5ff2a7cde 100644 --- a/model-optimizer/extensions/front/Mish_fusion_test.py +++ b/model-optimizer/extensions/front/Mish_fusion_test.py @@ -52,3 +52,4 @@ def test_mish_fusion_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Mish') diff --git a/model-optimizer/extensions/front/Softplus_fusion_test.py b/model-optimizer/extensions/front/Softplus_fusion_test.py index 6c498b37a4a0d4..3b0a81d86c68bd 100644 --- a/model-optimizer/extensions/front/Softplus_fusion_test.py +++ b/model-optimizer/extensions/front/Softplus_fusion_test.py @@ -22,7 +22,7 @@ from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs ref_nodes = {**regular_op('input', {'type': 'Parameter'}), - **regular_op('softplus', {'type': 'SoftPlus', 'name': 'final_mul'}), + **regular_op('softplus', {'type': 'SoftPlus', 'name': 'final_log'}), **result('result') } ref_edges = [('input', 'softplus'), ('softplus', 'result')] @@ -54,6 +54,7 @@ def test_softplus_fusion_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + self.assertTrue(graph.get_op_nodes(name='final_log')[0].op == 'SoftPlus') def test_softplus_fusion_test_wrong_const(self): graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_1': {'value': float_array([0.9999])}}) diff --git a/model-optimizer/extensions/front/Swish_fusion_test.py b/model-optimizer/extensions/front/Swish_fusion_test.py index fa716c5dd2073b..98f4ce2cdf9fa5 100644 --- a/model-optimizer/extensions/front/Swish_fusion_test.py +++ b/model-optimizer/extensions/front/Swish_fusion_test.py @@ -50,6 +50,7 @@ def test_swish_with_sigmoid_without_beta_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Swish') class SwishWithSigmoidWithBetaTest(unittest.TestCase): @@ -82,3 +83,4 @@ def test_swish_with_sigmoid_with_beta_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Swish') From 9919a2e3a542f98b2c9513b3f56f803197e63746 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 19:39:53 +0300 Subject: [PATCH 10/14] Fixed Mish fusion pattern --- model-optimizer/extensions/front/Mish_fusion.py | 4 ++-- model-optimizer/extensions/front/Mish_fusion_test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/model-optimizer/extensions/front/Mish_fusion.py b/model-optimizer/extensions/front/Mish_fusion.py index 032bdb8e40899e..e18eba641a92a6 100644 --- a/model-optimizer/extensions/front/Mish_fusion.py +++ b/model-optimizer/extensions/front/Mish_fusion.py @@ -23,7 +23,7 @@ class MishFusion(FrontReplacementSubgraph): """ - The transformation looks for the pattern with Softplus defining the Mish function: Mish(x) = x * tanh(Softplus(x)). + The transformation looks for the pattern with Softplus defining the Mish function: Mish(x) = x * tanh(SoftPlus(x)). """ enabled = True @@ -36,7 +36,7 @@ def pattern(self): ('input', dict()), ('mul', dict(op='Mul')), ('tanh', dict(op='Tanh')), - ('softplus', dict(op='Softplus')), + ('softplus', dict(op='SoftPlus')), ], edges=[ ('input', 'mul', {}), diff --git a/model-optimizer/extensions/front/Mish_fusion_test.py b/model-optimizer/extensions/front/Mish_fusion_test.py index 0082f5ff2a7cde..b7049ad2148ebc 100644 --- a/model-optimizer/extensions/front/Mish_fusion_test.py +++ b/model-optimizer/extensions/front/Mish_fusion_test.py @@ -30,7 +30,7 @@ class MishFusionTest(unittest.TestCase): nodes = { **regular_op('input', {'type': 'Parameter'}), - **regular_op('softplus', {'op': 'Softplus'}), + **regular_op('softplus', {'op': 'SoftPlus'}), **regular_op('tanh', {'op': 'Tanh'}), **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}), **result('result'), From 621c36f885016f1d12e4e2ec78d0418345464e3a Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 19:47:25 +0300 Subject: [PATCH 11/14] Updated Mish fusion transformation. Added unit test --- .../extensions/front/Mish_fusion.py | 16 +++++++----- .../extensions/front/Mish_fusion_test.py | 25 ++++++++++++++++++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/model-optimizer/extensions/front/Mish_fusion.py b/model-optimizer/extensions/front/Mish_fusion.py index e18eba641a92a6..dd990a67a87158 100644 --- a/model-optimizer/extensions/front/Mish_fusion.py +++ b/model-optimizer/extensions/front/Mish_fusion.py @@ -33,27 +33,31 @@ def run_after(self): def pattern(self): return dict( nodes=[ - ('input', dict()), ('mul', dict(op='Mul')), ('tanh', dict(op='Tanh')), ('softplus', dict(op='SoftPlus')), ], edges=[ - ('input', 'mul', {}), - ('input', 'softplus', {}), +# ('input', 'mul', {}), +# ('input', 'softplus', {}), ('softplus', 'tanh', {}), ('tanh', 'mul'), ]) def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): mul = match['mul'] + mul_name = mul.soft_get('name', mul.id) + softplus = match['softplus'] # determine the input port of Mul which gets the 'input' node output - input_port = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh') - mul_name = mul.soft_get('name', mul.id) + input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh') + + # check that the same tensor provided as input to Mul and SoftPlus + if mul.in_port(input_port_idx).get_source() != softplus.in_port(0).get_source(): + return mish = Mish(graph, {}).create_node() - mish.in_port(0).connect(mul.in_port(input_port).get_source()) + mish.in_port(0).connect(mul.in_port(input_port_idx).get_source()) mul.out_port(0).get_connection().set_source(mish.out_port(0)) rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)]) diff --git a/model-optimizer/extensions/front/Mish_fusion_test.py b/model-optimizer/extensions/front/Mish_fusion_test.py index b7049ad2148ebc..8e840d18540941 100644 --- a/model-optimizer/extensions/front/Mish_fusion_test.py +++ b/model-optimizer/extensions/front/Mish_fusion_test.py @@ -42,7 +42,7 @@ class MishFusionTest(unittest.TestCase): ('tanh', 'mul', {'in': 1, 'out': 0}), ('mul', 'result', {'in': 0, 'out': 0})] - def test_mish_fusion_test(self): + def test_mish_fusion(self): graph = build_graph_with_edge_attrs(self.nodes, self.edges, {}) graph_ref = build_graph(ref_nodes, ref_edges) @@ -53,3 +53,26 @@ def test_mish_fusion_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Mish') + + def test_mish_fusion_different_source(self): + # check case when different tensors goes to Mul and SoftPlus + graph = build_graph_with_edge_attrs({ + **regular_op('input', {'type': 'Parameter'}), + **regular_op('input_2', {'type': 'Parameter'}), + **regular_op('softplus', {'op': 'SoftPlus'}), + **regular_op('tanh', {'op': 'Tanh'}), + **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}), + **result('result'), + }, [('input', 'softplus', {'in': 0, 'out': 0}), + ('input_2', 'mul', {'in': 0, 'out': 0}), + ('softplus', 'tanh', {'in': 0, 'out': 0}), + ('tanh', 'mul', {'in': 1, 'out': 0}), + ('mul', 'result', {'in': 0, 'out': 0})], {}) + + graph_ref = graph.copy() + graph.stage = 'front' + + MishFusion().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) From 1a4df7c98290f8eb68f4c29049438715b53e3443 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 20:03:56 +0300 Subject: [PATCH 12/14] Updated HSwish fusing transformation --- .../extensions/front/HSwish_fusing_test.py | 66 +++++++++++++++++++ .../extensions/front/HSwish_fusion.py | 18 +++-- 2 files changed, 79 insertions(+), 5 deletions(-) diff --git a/model-optimizer/extensions/front/HSwish_fusing_test.py b/model-optimizer/extensions/front/HSwish_fusing_test.py index fa9cfaea53f887..34ba941c7bb63e 100644 --- a/model-optimizer/extensions/front/HSwish_fusing_test.py +++ b/model-optimizer/extensions/front/HSwish_fusing_test.py @@ -76,6 +76,38 @@ def test_hswish_with_clamp_wrong_constant(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + def test_hswish_with_clamp_different_tensors(self): + graph = build_graph_with_edge_attrs({ + **regular_op('input', {'type': 'Parameter'}), + **regular_op('input_2', {'type': 'Parameter'}), + **regular_op('add', {'op': 'Add'}), + **regular_op('relu6', {'op': 'Clamp'}), + **regular_op('mul', {'op': 'Mul'}), + **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}), + **const('const_0', float_array([0.0])), + **const('const_3', float_array([3.0])), + **const('const_6', float_array([6.0])), + **const('const_1_6', float_array([1.0 / 6.0])), + **result('result'), + }, [('input', 'mul', {'in': 0, 'out': 0}), + ('input_2', 'add', {'in': 0, 'out': 0}), + ('const_3', 'add', {'in': 1, 'out': 0}), + ('add', 'relu6', {'in': 0, 'out': 0}), + ('const_0', 'relu6', {'in': 1, 'out': 0}), + ('const_6', 'relu6', {'in': 2, 'out': 0}), + ('relu6', 'mul', {'in': 1, 'out': 0}), + ('mul', 'mul_2', {'in': 0, 'out': 0}), + ('const_1_6', 'mul_2', {'in': 1, 'out': 0}), + ('mul_2', 'result', {'in': 0, 'out': 0})]) + + graph_ref = graph.copy() + graph.stage = 'front' + + HSwishWithClamp().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + class HSwishWithMinMaxTest(unittest.TestCase): nodes = { @@ -126,3 +158,37 @@ def test_hswish_with_min_max_wrong_constant(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) + + def test_hswish_with_min_max_different_tensors(self): + graph = build_graph_with_edge_attrs({ + **regular_op('input', {'type': 'Parameter'}), + **regular_op('input_2', {'type': 'Parameter'}), + **regular_op('add', {'op': 'Add'}), + **regular_op('max', {'op': 'Maximum'}), + **regular_op('min', {'op': 'Minimum'}), + **regular_op('mul', {'op': 'Mul'}), + **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}), + **const('const_0', float_array([0.0])), + **const('const_3', float_array([3.0])), + **const('const_6', float_array([6.0])), + **const('const_1_6', float_array([1.0 / 6.0])), + **result('result'), + }, [('input_2', 'mul', {'in': 1, 'out': 0}), + ('input', 'add', {'in': 0, 'out': 0}), + ('const_3', 'add', {'in': 1, 'out': 0}), + ('add', 'max', {'in': 0, 'out': 0}), + ('const_0', 'max', {'in': 1, 'out': 0}), + ('max', 'min', {'in': 0, 'out': 0}), + ('const_6', 'min', {'in': 1, 'out': 0}), + ('min', 'mul', {'in': 0, 'out': 0}), + ('mul', 'mul_2', {'in': 0, 'out': 0}), + ('const_1_6', 'mul_2', {'in': 1, 'out': 0}), + ('mul_2', 'result', {'in': 0, 'out': 0})]) + + graph_ref = graph.copy() + graph.stage = 'front' + + HSwishWithMinMax().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/front/HSwish_fusion.py b/model-optimizer/extensions/front/HSwish_fusion.py index c6c3205531ab66..81486925e4486d 100644 --- a/model-optimizer/extensions/front/HSwish_fusion.py +++ b/model-optimizer/extensions/front/HSwish_fusion.py @@ -24,14 +24,20 @@ def replace_with_hswish(graph: Graph, match: [dict, SubgraphMatch]): add = match['add'] + mul = match['mul'] mul_2 = match['mul_2'] - # determine the input port of Add which gets the 'input' node output - input_port = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const') + # determine the input port of Add and Mul which gets the 'input' node output + add_input_port_idx = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const') + mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') in ['Clamp', 'Minimum']) + + # check that the same tensor provided as input to Add and Mul + if add.in_port(add_input_port_idx).get_source() != mul.in_port(mul_input_port_idx).get_source(): + return mul_2_name = mul_2.soft_get('name', mul_2.id) hswish = HSwish(graph, {}).create_node() - hswish.in_port(0).connect(add.in_port(input_port).get_source()) + hswish.in_port(0).connect(add.in_port(add_input_port_idx).get_source()) mul_2.out_port(0).get_connection().set_source(hswish.out_port(0)) rename_nodes([(mul_2, mul_2_name + '/TBR'), (hswish, mul_2_name)]) @@ -39,7 +45,8 @@ def replace_with_hswish(graph: Graph, match: [dict, SubgraphMatch]): class HSwishWithClamp(FrontReplacementSubgraph): """ - The transformation looks for the pattern with ReLU6 (Clamp) defining the HSwish function. + The transformation looks for the pattern with ReLU6 (Clamp) defining the HSwish function: + HSwish(x) = x * Relu6(x + 3) / 6.0. """ enabled = True @@ -77,7 +84,8 @@ def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): class HSwishWithMinMax(FrontReplacementSubgraph): """ - The transformation looks for the pattern with Min/Max defining the HSwish function. + The transformation looks for the pattern with Min/Max defining the HSwish function: + HSwish(x) = x * Min(Max(x + 3, 0), 6) / 6.0. """ enabled = True From c109aaf899dafa83270828e8e3d3d432a545eb1b Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 20:14:29 +0300 Subject: [PATCH 13/14] Updated Swish fusion transformation and tests --- .../extensions/front/Mish_fusion.py | 4 +- .../extensions/front/Swish_fusion.py | 28 +++++++----- .../extensions/front/Swish_fusion_test.py | 44 +++++++++++++++++++ 3 files changed, 63 insertions(+), 13 deletions(-) diff --git a/model-optimizer/extensions/front/Mish_fusion.py b/model-optimizer/extensions/front/Mish_fusion.py index dd990a67a87158..5d0bfa2e638746 100644 --- a/model-optimizer/extensions/front/Mish_fusion.py +++ b/model-optimizer/extensions/front/Mish_fusion.py @@ -38,9 +38,7 @@ def pattern(self): ('softplus', dict(op='SoftPlus')), ], edges=[ -# ('input', 'mul', {}), -# ('input', 'softplus', {}), - ('softplus', 'tanh', {}), + ('softplus', 'tanh'), ('tanh', 'mul'), ]) diff --git a/model-optimizer/extensions/front/Swish_fusion.py b/model-optimizer/extensions/front/Swish_fusion.py index fafd21de8a3dcf..e8deba0f605806 100644 --- a/model-optimizer/extensions/front/Swish_fusion.py +++ b/model-optimizer/extensions/front/Swish_fusion.py @@ -29,22 +29,25 @@ class SwishWithSigmoidWithoutBeta(FrontReplacementSubgraph): def pattern(self): return dict( nodes=[ - ('input', dict()), ('sigmoid', dict(op='Sigmoid')), ('mul', dict(op='Mul')), ], edges=[ - ('input', 'sigmoid', {}), - ('input', 'mul', {}), ('sigmoid', 'mul', {}), ]) def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): sigmoid = match['sigmoid'] mul = match['mul'] - mul_name = mul.soft_get('name', mul.id) + # determine the input port of Mul which gets the 'input' node output + mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid') + + # check that the same tensor provided as input to Mul and Sigmoid + if mul.in_port(mul_input_port_idx).get_source() != sigmoid.in_port(0).get_source(): + return + swish = Swish(graph, {}).create_node() swish.in_port(0).connect(sigmoid.in_port(0).get_source()) mul.out_port(0).get_connection().set_source(swish.out_port(0)) @@ -79,16 +82,21 @@ def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): beta = match['beta'] mul = match['mul'] mul_beta = match['mul_beta'] - - # determine the input port of Mul which gets the 'input' node output - mul_beta_input_port = int(mul_beta.in_port(0).get_connection().get_source().node.id == beta.id) mul_name = mul.soft_get('name', mul.id) + # determine the input port of Muls which get the 'input' node output + mul_beta_input_port_idx = int(mul_beta.in_port(0).get_connection().get_source().node.id == beta.id) + mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Sigmoid') + + # check that the same tensor provided as input to Mul and MulBeta + if mul.in_port(mul_input_port_idx).get_source() != mul_beta.in_port(mul_beta_input_port_idx).get_source(): + return + swish = Swish(graph, {}).create_node() - swish.in_port(0).connect(mul_beta.in_port(mul_beta_input_port).get_source()) + swish.in_port(0).connect(mul_beta.in_port(mul_beta_input_port_idx).get_source()) - # connect Beta value - swish.in_port(1).connect(mul_beta.in_port(1 - mul_beta_input_port).get_source()) + # connect Beta valueо + swish.in_port(1).connect(mul_beta.in_port(1 - mul_beta_input_port_idx).get_source()) mul.out_port(0).get_connection().set_source(swish.out_port(0)) diff --git a/model-optimizer/extensions/front/Swish_fusion_test.py b/model-optimizer/extensions/front/Swish_fusion_test.py index 98f4ce2cdf9fa5..c775726eefbb3c 100644 --- a/model-optimizer/extensions/front/Swish_fusion_test.py +++ b/model-optimizer/extensions/front/Swish_fusion_test.py @@ -52,6 +52,26 @@ def test_swish_with_sigmoid_without_beta_test(self): self.assertTrue(flag, resp) self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Swish') + def test_swish_with_sigmoid_without_beta_different_tensors(self): + graph = build_graph_with_edge_attrs({ + **regular_op('input', {'type': 'Parameter'}), + **regular_op('input_2', {'type': 'Parameter'}), + **regular_op('sigmoid', {'op': 'Sigmoid'}), + **regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}), + **result('result'), + }, [('input_2', 'mul', {'in': 0, 'out': 0}), + ('input', 'sigmoid', {'in': 0, 'out': 0}), + ('sigmoid', 'mul', {'in': 1, 'out': 0}), + ('mul', 'result', {'in': 0, 'out': 0})], {}) + + graph_ref = graph.copy() + graph.stage = 'front' + + SwishWithSigmoidWithoutBeta().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + class SwishWithSigmoidWithBetaTest(unittest.TestCase): nodes = { @@ -84,3 +104,27 @@ def test_swish_with_sigmoid_with_beta_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Swish') + + def test_swish_with_sigmoid_with_beta_different_tensors(self): + graph = build_graph_with_edge_attrs({ + **regular_op('input', {'type': 'Parameter'}), + **regular_op('input_2', {'type': 'Parameter'}), + **regular_op('beta', {'type': 'Parameter'}), + **regular_op('mul_beta', {'op': 'Mul'}), + **regular_op('sigmoid', {'op': 'Sigmoid'}), + **regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}), + **result('result'), + }, [('input', 'mul_beta', {'in': 0, 'out': 0}), + ('input_2', 'mul_2', {'in': 0, 'out': 0}), + ('beta', 'mul_beta', {'in': 1, 'out': 0}), + ('mul_beta', 'sigmoid', {'in': 0, 'out': 0}), + ('sigmoid', 'mul_2', {'in': 1, 'out': 0}), + ('mul_2', 'result', {'in': 0, 'out': 0})], {}) + + graph_ref = graph.copy() + graph.stage = 'front' + + SwishWithSigmoidWithBeta().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) From 5effdc7fc5658234e524348e0ad12c0889763257 Mon Sep 17 00:00:00 2001 From: Evgeny Lazarev Date: Wed, 26 Aug 2020 22:20:57 +0300 Subject: [PATCH 14/14] Fixed unit tests --- model-optimizer/extensions/front/HSwish_fusing_test.py | 6 ++++-- model-optimizer/extensions/front/Mish_fusion_test.py | 3 ++- model-optimizer/extensions/front/Softplus_fusion_test.py | 3 ++- model-optimizer/extensions/front/Swish_fusion.py | 5 +---- model-optimizer/extensions/front/Swish_fusion_test.py | 6 ++++-- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/model-optimizer/extensions/front/HSwish_fusing_test.py b/model-optimizer/extensions/front/HSwish_fusing_test.py index 34ba941c7bb63e..b7cedba4a16d6f 100644 --- a/model-optimizer/extensions/front/HSwish_fusing_test.py +++ b/model-optimizer/extensions/front/HSwish_fusing_test.py @@ -63,7 +63,8 @@ def test_hswish_with_clamp(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'HSwish') + self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and + graph.get_op_nodes(name='final_mul')[0].op == 'HSwish') def test_hswish_with_clamp_wrong_constant(self): graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}}) @@ -146,7 +147,8 @@ def test_hswish_with_min_max(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'HSwish') + self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and + graph.get_op_nodes(name='final_mul')[0].op == 'HSwish') def test_hswish_with_min_max_wrong_constant(self): graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}}) diff --git a/model-optimizer/extensions/front/Mish_fusion_test.py b/model-optimizer/extensions/front/Mish_fusion_test.py index 8e840d18540941..c1d97c431a4f2d 100644 --- a/model-optimizer/extensions/front/Mish_fusion_test.py +++ b/model-optimizer/extensions/front/Mish_fusion_test.py @@ -52,7 +52,8 @@ def test_mish_fusion(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Mish') + self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and + graph.get_op_nodes(name='final_mul')[0].op == 'Mish') def test_mish_fusion_different_source(self): # check case when different tensors goes to Mul and SoftPlus diff --git a/model-optimizer/extensions/front/Softplus_fusion_test.py b/model-optimizer/extensions/front/Softplus_fusion_test.py index 3b0a81d86c68bd..eba085e93352a6 100644 --- a/model-optimizer/extensions/front/Softplus_fusion_test.py +++ b/model-optimizer/extensions/front/Softplus_fusion_test.py @@ -54,7 +54,8 @@ def test_softplus_fusion_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.get_op_nodes(name='final_log')[0].op == 'SoftPlus') + self.assertTrue(len(graph.get_op_nodes(name='final_log')) == 1 and + graph.get_op_nodes(name='final_log')[0].op == 'SoftPlus') def test_softplus_fusion_test_wrong_const(self): graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_1': {'value': float_array([0.9999])}}) diff --git a/model-optimizer/extensions/front/Swish_fusion.py b/model-optimizer/extensions/front/Swish_fusion.py index e8deba0f605806..bd47af787b664b 100644 --- a/model-optimizer/extensions/front/Swish_fusion.py +++ b/model-optimizer/extensions/front/Swish_fusion.py @@ -64,15 +64,12 @@ class SwishWithSigmoidWithBeta(FrontReplacementSubgraph): def pattern(self): return dict( nodes=[ - ('input', dict()), ('sigmoid', dict(op='Sigmoid')), ('beta', dict()), ('mul_beta', dict(op='Mul')), ('mul', dict(op='Mul')), ], edges=[ - ('input', 'mul_beta', {}), - ('input', 'mul', {}), ('beta', 'mul_beta', {}), ('mul_beta', 'sigmoid', {}), ('sigmoid', 'mul', {}), @@ -95,7 +92,7 @@ def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): swish = Swish(graph, {}).create_node() swish.in_port(0).connect(mul_beta.in_port(mul_beta_input_port_idx).get_source()) - # connect Beta valueо + # connect Beta value swish.in_port(1).connect(mul_beta.in_port(1 - mul_beta_input_port_idx).get_source()) mul.out_port(0).get_connection().set_source(swish.out_port(0)) diff --git a/model-optimizer/extensions/front/Swish_fusion_test.py b/model-optimizer/extensions/front/Swish_fusion_test.py index c775726eefbb3c..08144c8d006608 100644 --- a/model-optimizer/extensions/front/Swish_fusion_test.py +++ b/model-optimizer/extensions/front/Swish_fusion_test.py @@ -50,7 +50,8 @@ def test_swish_with_sigmoid_without_beta_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Swish') + self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and + graph.get_op_nodes(name='final_mul')[0].op == 'Swish') def test_swish_with_sigmoid_without_beta_different_tensors(self): graph = build_graph_with_edge_attrs({ @@ -103,7 +104,8 @@ def test_swish_with_sigmoid_with_beta_test(self): (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.get_op_nodes(name='final_mul')[0].op == 'Swish') + self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and + graph.get_op_nodes(name='final_mul')[0].op == 'Swish') def test_swish_with_sigmoid_with_beta_different_tensors(self): graph = build_graph_with_edge_attrs({