Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MO fusing activations #1942

Merged
merged 15 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ extensions/front/freeze_placeholder_value.py
extensions/front/GeLUMerger_Erf.py
extensions/front/GeLUMerger_Tanh.py
extensions/front/global_pooling_to_reduce.py
extensions/front/HSwish_fusion.py
extensions/front/image_scaler.py
extensions/front/input_cut.py
extensions/front/instance_normalization.py
Expand All @@ -150,6 +151,7 @@ extensions/front/LayerNorm.py
extensions/front/Log1p.py
extensions/front/LogSoftmax.py
extensions/front/MatMul_normalizer.py
extensions/front/Mish_fusion.py
extensions/front/MoveEmbeddedInputsToInputs.py
extensions/front/mxnet/__init__.py
extensions/front/mxnet/activation.py
Expand Down Expand Up @@ -326,11 +328,13 @@ extensions/front/reshape_dim_normalizer.py
extensions/front/restore_ports.py
extensions/front/scatter_normalizer.py
extensions/front/softmax.py
extensions/front/Softplus_fusion.py
extensions/front/softsign_replacer.py
extensions/front/split_normalizer.py
extensions/front/SqueezeNormalize.py
extensions/front/standalone_const_eraser.py
extensions/front/sub.py
extensions/front/Swish_fusion.py
extensions/front/tf/__init__.py
extensions/front/tf/activation_ext.py
extensions/front/tf/argmax_ext.py
Expand Down
126 changes: 126 additions & 0 deletions model-optimizer/extensions/front/HSwish_fusing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""
Copyright (C) 2018-2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from extensions.front.HSwish_fusion import HSwishWithClamp, HSwishWithMinMax
from mo.front.common.partial_infer.utils import float_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, const, regular_op, result, build_graph_with_edge_attrs

ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
**regular_op('hswish', {'type': 'HSwish', 'name': 'final_mul'}),
**result('result')
}
ref_edges = [('input', 'hswish'), ('hswish', 'result')]


class HSwishWithClampTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('add', {'op': 'Add'}),
**regular_op('relu6', {'op': 'Clamp'}),
**regular_op('mul', {'op': 'Mul'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**const('const_0', float_array([0.0])),
**const('const_3', float_array([3.0])),
**const('const_6', float_array([6.0])),
**const('const_1_6', float_array([1.0 / 6.0])),
**result('result'),
}

edges = [('input', 'mul', {'in': 0, 'out': 0}),
('input', 'add', {'in': 0, 'out': 0}),
('const_3', 'add', {'in': 1, 'out': 0}),
('add', 'relu6', {'in': 0, 'out': 0}),
('const_0', 'relu6', {'in': 1, 'out': 0}),
('const_6', 'relu6', {'in': 2, 'out': 0}),
('relu6', 'mul', {'in': 1, 'out': 0}),
('mul', 'mul_2', {'in': 0, 'out': 0}),
('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})]

def test_hswish_with_clamp(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})

graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'

HSwishWithClamp().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_hswish_with_clamp_wrong_constant(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}})

graph_ref = graph.copy()
graph.stage = 'front'

HSwishWithClamp().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)


class HSwishWithMinMaxTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('add', {'op': 'Add'}),
**regular_op('max', {'op': 'Maximum'}),
**regular_op('min', {'op': 'Minimum'}),
**regular_op('mul', {'op': 'Mul'}),
**regular_op('mul_2', {'op': 'Mul', 'name': 'final_mul'}),
**const('const_0', float_array([0.0])),
**const('const_3', float_array([3.0])),
**const('const_6', float_array([6.0])),
**const('const_1_6', float_array([1.0 / 6.0])),
**result('result'),
}

edges = [('input', 'mul', {'in': 1, 'out': 0}),
('input', 'add', {'in': 0, 'out': 0}),
('const_3', 'add', {'in': 1, 'out': 0}),
('add', 'max', {'in': 0, 'out': 0}),
('const_0', 'max', {'in': 1, 'out': 0}),
('max', 'min', {'in': 0, 'out': 0}),
('const_6', 'min', {'in': 1, 'out': 0}),
('min', 'mul', {'in': 0, 'out': 0}),
('mul', 'mul_2', {'in': 0, 'out': 0}),
('const_1_6', 'mul_2', {'in': 1, 'out': 0}),
('mul_2', 'result', {'in': 0, 'out': 0})]

def test_hswish_with_min_max(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})

graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'

HSwishWithMinMax().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_hswish_with_min_max_wrong_constant(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_0': {'value': float_array([0.00001])}})

graph_ref = graph.copy()
graph.stage = 'front'

HSwishWithMinMax().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
115 changes: 115 additions & 0 deletions model-optimizer/extensions/front/HSwish_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Copyright (C) 2018-2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np

from extensions.front.AttributedClampNormalizer import AttributedClampNormalizer
from extensions.ops.activation_ops import HSwish
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes


def replace_with_hswish(graph: Graph, match: [dict, SubgraphMatch]):
add = match['add']
mul_2 = match['mul_2']

# determine the input port of Add which gets the 'input' node output
input_port = int(add.in_port(0).get_connection().get_source().node.soft_get('op') == 'Const')
mul_2_name = mul_2.soft_get('name', mul_2.id)

hswish = HSwish(graph, {}).create_node()
hswish.in_port(0).connect(add.in_port(input_port).get_source())
mul_2.out_port(0).get_connection().set_source(hswish.out_port(0))

rename_nodes([(mul_2, mul_2_name + '/TBR'), (hswish, mul_2_name)])


class HSwishWithClamp(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with ReLU6 (Clamp) defining the HSwish function.
"""
enabled = True

def run_after(self):
return [AttributedClampNormalizer]

def pattern(self):
return dict(
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
('clamp', dict(op='Clamp')),
('mul', dict(op='Mul')),
('mul_2', dict(op='Mul')),
],
edges=[
('input', 'add', {}),
('input', 'mul', {}),
('const_3', 'add', {}),
('add', 'clamp', {'in': 0}),
('const_0', 'clamp', {'in': 1}),
('const_6', 'clamp', {'in': 2}),
('clamp', 'mul', {}),
('mul', 'mul_2', {}),
('const_1_6', 'mul_2', {}),
])

def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
replace_with_hswish(graph, match)


class HSwishWithMinMax(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with Min/Max defining the HSwish function.
"""
enabled = True

def run_after(self):
return [AttributedClampNormalizer]

def pattern(self):
return dict(
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.isclose(v, 1 / 6.0, atol=1e-6))),
('max', dict(op='Maximum')),
('min', dict(op='Minimum')),
('mul', dict(op='Mul')),
('mul_2', dict(op='Mul')),
],
edges=[
('input', 'add', {'out': 0}),
('input', 'mul', {'out': 0}),
('const_3', 'add', {}),
('add', 'max', {}),
('const_0', 'max', {}),
('max', 'min', {}),
('const_6', 'min', {}),
('min', 'mul', {}),
('mul', 'mul_2', {}),
('const_1_6', 'mul_2', {}),
])

def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
replace_with_hswish(graph, match)
59 changes: 59 additions & 0 deletions model-optimizer/extensions/front/Mish_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Copyright (C) 2018-2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from extensions.front.Softplus_fusion import SoftplusFusion
from extensions.ops.activation_ops import Mish
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes


class MishFusion(FrontReplacementSubgraph):
"""
The transformation looks for the pattern with Softplus defining the Mish function: Mish(x) = x * tanh(Softplus(x)).
"""
enabled = True

def run_after(self):
return [SoftplusFusion]

def pattern(self):
return dict(
nodes=[
('input', dict()),
lazarevevgeny marked this conversation as resolved.
Show resolved Hide resolved
('mul', dict(op='Mul')),
('tanh', dict(op='Tanh')),
('softplus', dict(op='Softplus')),
],
edges=[
('input', 'mul', {}),
('input', 'softplus', {}),
('softplus', 'tanh', {}),
('tanh', 'mul'),
])

def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
mul = match['mul']

# determine the input port of Mul which gets the 'input' node output
input_port = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh')
mul_name = mul.soft_get('name', mul.id)

mish = Mish(graph, {}).create_node()
mish.in_port(0).connect(mul.in_port(input_port).get_source())
mul.out_port(0).get_connection().set_source(mish.out_port(0))

rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)])
54 changes: 54 additions & 0 deletions model-optimizer/extensions/front/Mish_fusion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Copyright (C) 2018-2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from extensions.front.Mish_fusion import MishFusion
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, regular_op, result, build_graph_with_edge_attrs

ref_nodes = {**regular_op('input', {'type': 'Parameter'}),
**regular_op('mish', {'type': 'Mish', 'name': 'final_mul'}),
**result('result')
}
ref_edges = [('input', 'mish'), ('mish', 'result')]


class MishFusionTest(unittest.TestCase):
nodes = {
**regular_op('input', {'type': 'Parameter'}),
**regular_op('softplus', {'op': 'Softplus'}),
**regular_op('tanh', {'op': 'Tanh'}),
**regular_op('mul', {'op': 'Mul', 'name': 'final_mul'}),
**result('result'),
}

edges = [('input', 'softplus', {'in': 0, 'out': 0}),
('input', 'mul', {'in': 0, 'out': 0}),
('softplus', 'tanh', {'in': 0, 'out': 0}),
('tanh', 'mul', {'in': 1, 'out': 0}),
('mul', 'result', {'in': 0, 'out': 0})]

def test_mish_fusion_test(self):
graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})

graph_ref = build_graph(ref_nodes, ref_edges)
graph.stage = 'front'

MishFusion().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
Loading