Skip to content

Commit

Permalink
MO fusing activations (#1942)
Browse files Browse the repository at this point in the history
* Added HSwish operation

* Added HSwish fusing transformation

* Fixed BOM

* Added unit test for HSwish fusing transformation

* Fixed unit tests for transformations using 'build_graph_with_edge_attrs' function to build the graph

* Added fusion transformation for Swish operation

* Added fusing transformation for Softplus operation

* Added fusion transformation for Mish operation

* Added check for the node name in the unit tests

* Fixed Mish fusion pattern

* Updated Mish fusion transformation. Added unit test

* Updated HSwish fusing transformation

* Updated Swish fusion transformation and tests

* Fixed unit tests
  • Loading branch information
lazarevevgeny authored Aug 27, 2020
1 parent 0182b97 commit a4d90a0
Show file tree
Hide file tree
Showing 13 changed files with 844 additions and 9 deletions.
4 changes: 4 additions & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ extensions/front/freeze_placeholder_value.py
extensions/front/GeLUMerger_Erf.py
extensions/front/GeLUMerger_Tanh.py
extensions/front/global_pooling_to_reduce.py
extensions/front/HSwish_fusion.py
extensions/front/image_scaler.py
extensions/front/input_cut.py
extensions/front/instance_normalization.py
Expand All @@ -150,6 +151,7 @@ extensions/front/LayerNorm.py
extensions/front/Log1p.py
extensions/front/LogSoftmax.py
extensions/front/MatMul_normalizer.py
extensions/front/Mish_fusion.py
extensions/front/MoveEmbeddedInputsToInputs.py
extensions/front/mxnet/__init__.py
extensions/front/mxnet/activation.py
Expand Down Expand Up @@ -326,11 +328,13 @@ extensions/front/reshape_dim_normalizer.py
extensions/front/restore_ports.py
extensions/front/scatter_normalizer.py
extensions/front/softmax.py
extensions/front/Softplus_fusion.py
extensions/front/softsign_replacer.py
extensions/front/split_normalizer.py
extensions/front/SqueezeNormalize.py
extensions/front/standalone_const_eraser.py
extensions/front/sub.py
extensions/front/Swish_fusion.py
extensions/front/tf/__init__.py
extensions/front/tf/activation_ext.py
extensions/front/tf/argmax_ext.py
Expand Down
196 changes: 196 additions & 0 deletions model-optimizer/extensions/front/HSwish_fusing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from extensions.front.HSwish_fusion import HSwishWithClamp, HSwishWithMinMax
from mo.front.common.partial_infer.utils import float_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs

ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
**regular_op('hswish', {'type': 'HSwish', 'name': 'final_mul'}),
**result('result')
}
ref_edges = [('input', 'hswish'), ('hswish', 'result')]


class HSwishWithClampTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('add', {'op': 'Add'}),
**regular_op('relu6', {'op': 'Clamp'}),
**regular_op('mul', {'op': 'Mul'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**const('const_0', float_array([0.0])),
**const('const_3', float_array([3.0])),
**const('const_6', float_array([6.0])),
**const('const_1_6', float_array([1.0 / 6.0])),
**result('result'),
}

edges = [('input', 'mul', {'in': 0, 'out': 0}),
('input', 'add', {'in': 0, 'out': 0}),
('const_3', 'add', {'in': 1, 'out': 0}),
('add', 'relu6', {'in': 0, 'out': 0}),
('const_0', 'relu6', {'in': 1, 'out': 0}),
('const_6', 'relu6', {'in': 2, 'out': 0}),
('relu6', 'mul', {'in': 1, 'out': 0}),
('mul', 'mul_2', {'in': 0, 'out': 0}),
('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})]

def test_hswish_with_clamp(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})

graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'

HSwishWithClamp().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'HSwish')

def test_hswish_with_clamp_wrong_constant(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}})

graph_ref = graph.copy()
graph.stage = 'front'

HSwishWithClamp().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_hswish_with_clamp_different_tensors(self):
graph = build_graph_with_edge_attrs({
**regular_op('input', {'type': 'Parameter'}),
**regular_op('input_2', {'type': 'Parameter'}),
**regular_op('add', {'op': 'Add'}),
**regular_op('relu6', {'op': 'Clamp'}),
**regular_op('mul', {'op': 'Mul'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**const('const_0', float_array([0.0])),
**const('const_3', float_array([3.0])),
**const('const_6', float_array([6.0])),
**const('const_1_6', float_array([1.0 / 6.0])),
**result('result'),
}, [('input', 'mul', {'in': 0, 'out': 0}),
('input_2', 'add', {'in': 0, 'out': 0}),
('const_3', 'add', {'in': 1, 'out': 0}),
('add', 'relu6', {'in': 0, 'out': 0}),
('const_0', 'relu6', {'in': 1, 'out': 0}),
('const_6', 'relu6', {'in': 2, 'out': 0}),
('relu6', 'mul', {'in': 1, 'out': 0}),
('mul', 'mul_2', {'in': 0, 'out': 0}),
('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})])

graph_ref = graph.copy()
graph.stage = 'front'

HSwishWithClamp().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)


class HSwishWithMinMaxTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('add', {'op': 'Add'}),
**regular_op('max', {'op': 'Maximum'}),
**regular_op('min', {'op': 'Minimum'}),
**regular_op('mul', {'op': 'Mul'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**const('const_0', float_array([0.0])),
**const('const_3', float_array([3.0])),
**const('const_6', float_array([6.0])),
**const('const_1_6', float_array([1.0 / 6.0])),
**result('result'),
}

edges = [('input', 'mul', {'in': 1, 'out': 0}),
('input', 'add', {'in': 0, 'out': 0}),
('const_3', 'add', {'in': 1, 'out': 0}),
('add', 'max', {'in': 0, 'out': 0}),
('const_0', 'max', {'in': 1, 'out': 0}),
('max', 'min', {'in': 0, 'out': 0}),
('const_6', 'min', {'in': 1, 'out': 0}),
('min', 'mul', {'in': 0, 'out': 0}),
('mul', 'mul_2', {'in': 0, 'out': 0}),
('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})]

def test_hswish_with_min_max(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})

graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'

HSwishWithMinMax().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(len(graph.get_op_nodes(name='final_mul')) == 1 and
graph.get_op_nodes(name='final_mul')[0].op == 'HSwish')

def test_hswish_with_min_max_wrong_constant(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}})

graph_ref = graph.copy()
graph.stage = 'front'

HSwishWithMinMax().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_hswish_with_min_max_different_tensors(self):
graph = build_graph_with_edge_attrs({
**regular_op('input', {'type': 'Parameter'}),
**regular_op('input_2', {'type': 'Parameter'}),
**regular_op('add', {'op': 'Add'}),
**regular_op('max', {'op': 'Maximum'}),
**regular_op('min', {'op': 'Minimum'}),
**regular_op('mul', {'op': 'Mul'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**const('const_0', float_array([0.0])),
**const('const_3', float_array([3.0])),
**const('const_6', float_array([6.0])),
**const('const_1_6', float_array([1.0 / 6.0])),
**result('result'),
}, [('input_2', 'mul', {'in': 1, 'out': 0}),
('input', 'add', {'in': 0, 'out': 0}),
('const_3', 'add', {'in': 1, 'out': 0}),
('add', 'max', {'in': 0, 'out': 0}),
('const_0', 'max', {'in': 1, 'out': 0}),
('max', 'min', {'in': 0, 'out': 0}),
('const_6', 'min', {'in': 1, 'out': 0}),
('min', 'mul', {'in': 0, 'out': 0}),
('mul', 'mul_2', {'in': 0, 'out': 0}),
('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})])

graph_ref = graph.copy()
graph.stage = 'front'

HSwishWithMinMax().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
123 changes: 123 additions & 0 deletions model-optimizer/extensions/front/HSwish_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np

from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer
from extensions.ops.activation_ops import HSwish
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes


def replace_with_hswish(graph: Graph, match: [dict, SubgraphMatch]):
add = match['add']
mul = match['mul']
mul_2 = match['mul_2']

# determine the input port of Add and Mul which gets the 'input' node output
add_input_port_idx = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const')
mul_input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') in ['Clamp', 'Minimum'])

# check that the same tensor provided as input to Add and Mul
if add.in_port(add_input_port_idx).get_source() != mul.in_port(mul_input_port_idx).get_source():
return
mul_2_name = mul_2.soft_get('name', mul_2.id)

hswish = HSwish(graph, {}).create_node()
hswish.in_port(0).connect(add.in_port(add_input_port_idx).get_source())
mul_2.out_port(0).get_connection().set_source(hswish.out_port(0))

rename_nodes([(mul_2, mul_2_name + '/TBR'), (hswish, mul_2_name)])


class HSwishWithClamp(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with ReLU6 (Clamp) defining the HSwish function:
HSwish(x) = x * Relu6(x + 3) / 6.0.
"""
enabled = True

def run_after(self):
return [AttributedClampNormalizer]

def pattern(self):
return dict(
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
('clamp', dict(op='Clamp')),
('mul', dict(op='Mul')),
('mul_2', dict(op='Mul')),
],
edges=[
('input', 'add', {}),
('input', 'mul', {}),
('const_3', 'add', {}),
('add', 'clamp', {'in': 0}),
('const_0', 'clamp', {'in': 1}),
('const_6', 'clamp', {'in': 2}),
('clamp', 'mul', {}),
('mul', 'mul_2', {}),
('const_1_6', 'mul_2', {}),
])

def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
replace_with_hswish(graph, match)


class HSwishWithMinMax(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with Min/Max defining the HSwish function:
HSwish(x) = x * Min(Max(x + 3, 0), 6) / 6.0.
"""
enabled = True

def run_after(self):
return [AttributedClampNormalizer]

def pattern(self):
return dict(
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
('max', dict(op='Maximum')),
('min', dict(op='Minimum')),
('mul', dict(op='Mul')),
('mul_2', dict(op='Mul')),
],
edges=[
('input', 'add', {'out': 0}),
('input', 'mul', {'out': 0}),
('const_3', 'add', {}),
('add', 'max', {}),
('const_0', 'max', {}),
('max', 'min', {}),
('const_6', 'min', {}),
('min', 'mul', {}),
('mul', 'mul_2', {}),
('const_1_6', 'mul_2', {}),
])

def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
replace_with_hswish(graph, match)
Loading

0 comments on commit a4d90a0

Please sign in to comment.