Skip to content

Commit

Permalink
Reverse input channels fusion (#4276)
Browse files Browse the repository at this point in the history
* Side fix found while working on Windows machine.

* Fix for non-fused Reverse Input Channels subgraph
  • Loading branch information
Evgenya Stepyreva authored Feb 11, 2021
1 parent 4e3d7d2 commit 08ac8d9
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 31 deletions.
17 changes: 10 additions & 7 deletions model-optimizer/extensions/back/ReverseInputChannels.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ class ReverseChannelsPropagationDown(BackReplacementPattern):
'BatchNormalization': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'FakeQuantize': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Multiply': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Divide': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Add': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Subtract': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Pow': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),
'Convert': lambda node, rc: ReverseChannelsPropagationDown.pass_rc_through_eltwise(node, rc),

Expand Down Expand Up @@ -198,9 +200,8 @@ def pass_rc_through_eltwise(node, reverse_channels):
continue
shape = port.data.get_shape()
non_one_dims = np.where(shape != 1)[0]
if len(non_one_dims) == 0:
# shape contains only ones - nothing to flip for this input
continue
if shape[reverse_channels.axis] == 1:
continue # nothing to flip for this input
if len(non_one_dims) == 1 and shape[non_one_dims.item()] == reverse_channels.order.size:
new_axis = non_one_dims.item()
elif np.array_equal(before_shape, shape):
Expand Down Expand Up @@ -238,7 +239,8 @@ def pass_rc_through_shape(node, reverse_channels):
"""
stops propagation of RIC through shape taking operations, due to RIC does not change shape
"""
reverse_channels.out_port(0).get_connection().set_source(reverse_channels.in_port(0).get_connection().get_source())
reverse_channels.out_port(0).get_connection().set_source(
reverse_channels.in_port(0).get_connection().get_source())
return False

@staticmethod
Expand Down Expand Up @@ -269,7 +271,9 @@ class ReverseChannelsPropagationUp(BackReplacementPattern):
'BatchNormalization': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'FakeQuantize': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Multiply': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Divide': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Add': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Subtract': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Pow': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
'Convert': lambda node, rc: ReverseChannelsPropagationUp.lift_up_through_eltwise(node, rc),
}
Expand Down Expand Up @@ -300,9 +304,8 @@ def lift_up_through_eltwise(node: Node, reverse_channels: Node):
shape = port.data.get_shape()

non_one_dims = np.where(shape != 1)[0]
if len(non_one_dims) == 0:
# shape contains only ones - nothing to flip for this input
continue
if shape[reverse_channels.axis] == 1:
continue # nothing to flip for this input
if len(non_one_dims) == 1 and shape[non_one_dims.item()] == reverse_channels.order.size:
axis = non_one_dims.item()
elif np.array_equal(before_shape, shape):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.pattern_match import check_value
from mo.ops.broadcast import Broadcast


Expand Down Expand Up @@ -51,7 +52,7 @@ def pattern(**kwargs):
('shape', dict(op='ShapeOf')),
('random_uniform', dict(op='RandomUniform')),
('mul', dict(op='Mul')),
('add_const', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 0.0, atol=0))),
('add_const', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 0.0, atol=0)))),
('add', dict(op='Add')),
('add2', dict(op='Add')),
('floor', dict(op='Floor')),
Expand Down
37 changes: 23 additions & 14 deletions model-optimizer/extensions/front/HSigmoid_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes
from mo.middle.pattern_match import check_value
from mo.utils.graph import Node


Expand Down Expand Up @@ -50,10 +51,11 @@ def pattern(self):
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1.0 / 6.0, atol=1e-6))),
('const_0', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 0.0, atol=1e-6)))),
('const_3', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 3.0, atol=1e-6)))),
('const_6', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 6.0, atol=1e-6)))),
('const_1_6',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0 / 6.0, atol=1e-6)))),
('clamp', dict(op='Clamp')),
('mul_2', dict(op='Mul')),
],
Expand Down Expand Up @@ -86,10 +88,11 @@ def pattern(self):
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1.0 / 6.0, atol=1e-6))),
('const_0', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 0.0, atol=1e-6)))),
('const_3', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 3.0, atol=1e-6)))),
('const_6', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 6.0, atol=1e-6)))),
('const_1_6',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0 / 6.0, atol=1e-6)))),
('max', dict(op='Maximum')),
('min', dict(op='Minimum')),
('mul_2', dict(op='Mul')),
Expand Down Expand Up @@ -123,12 +126,15 @@ def pattern(self):
return dict(
nodes=[
('input', dict()),
('add_const', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
('add_const',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 3.0, atol=1e-6)))),
('add', dict(op='Add')),
('relu', dict(op='ReLU')),
('min_const', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
('min_const',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 6.0, atol=1e-6)))),
('min', dict(op='Minimum')),
('div_const', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
('div_const',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 6.0, atol=1e-6)))),
('div', dict(op='Div')),
],
edges=[
Expand Down Expand Up @@ -159,12 +165,15 @@ def pattern(self):
return dict(
nodes=[
('input', dict()),
('add_const', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
('add_const',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 3.0, atol=1e-6)))),
('add', dict(op='Add')),
('relu', dict(op='ReLU')),
('min_const', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
('min_const',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 6.0, atol=1e-6)))),
('min', dict(op='Minimum')),
('mul_const', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1.0/6.0, atol=1e-6))),
('mul_const',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0 / 6.0, atol=1e-6)))),
('mul', dict(op='Mul')),
],
edges=[
Expand Down
19 changes: 11 additions & 8 deletions model-optimizer/extensions/front/HSwish_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes
from mo.middle.pattern_match import check_value


def replace_with_hswish(graph: Graph, match: [dict, SubgraphMatch]):
Expand Down Expand Up @@ -58,10 +59,11 @@ def pattern(self):
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1 / 6.0, atol=1e-6))),
('const_0', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 0.0, atol=1e-6)))),
('const_3', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 3.0, atol=1e-6)))),
('const_6', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 6.0, atol=1e-6)))),
('const_1_6',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0 / 6.0, atol=1e-6)))),
('clamp', dict(op='Clamp')),
('mul', dict(op='Mul')),
('mul_2', dict(op='Mul')),
Expand Down Expand Up @@ -97,10 +99,11 @@ def pattern(self):
nodes=[
('input', dict()),
('add', dict(op='Add')),
('const_0', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 0.0, atol=1e-6))),
('const_3', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 3.0, atol=1e-6))),
('const_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 6.0, atol=1e-6))),
('const_1_6', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1 / 6.0, atol=1e-6))),
('const_0', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 0.0, atol=1e-6)))),
('const_3', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 3.0, atol=1e-6)))),
('const_6', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 6.0, atol=1e-6)))),
('const_1_6',
dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0 / 6.0, atol=1e-6)))),
('max', dict(op='Maximum')),
('min', dict(op='Minimum')),
('mul', dict(op='Mul')),
Expand Down
3 changes: 2 additions & 1 deletion model-optimizer/extensions/front/Softplus_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph, rename_nodes
from mo.middle.pattern_match import check_value


class SoftplusFusion(FrontReplacementSubgraph):
Expand All @@ -32,7 +33,7 @@ def pattern(self):
nodes=[
('exp', dict(op='Exp')),
('add', dict(op='Add')),
('const_1', dict(op='Const', value=lambda v: v is not None and np.allclose(v, 1.0, atol=1e-6))),
('const_1', dict(op='Const', value=lambda v: check_value(v, lambda x: np.allclose(x, 1.0, atol=1e-6)))),
('ln', dict(op='Log')),
],
edges=[
Expand Down
5 changes: 5 additions & 0 deletions model-optimizer/mo/middle/pattern_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging as log

import numpy as np
from networkx.algorithms import isomorphism as ism

from mo.graph.graph import Node, dict_includes, Graph
Expand Down Expand Up @@ -163,3 +164,7 @@ def find_isomorphisms(graph: Graph, nodes: list, edges: list):
match = {k: Node(graph, match[k]) for k in match.keys()}
result.append(match)
return result


def check_value(v: np.ndarray, check: callable):
return v is not None and np.all(np.isreal(v)) and check(v)

0 comments on commit 08ac8d9

Please sign in to comment.