Skip to content

Commit

Permalink
Fix const node non-deterministic names (part 1) (openvinotoolkit#996)
Browse files Browse the repository at this point in the history
* Update node names
  • Loading branch information
Anton Chetverikov authored Jun 26, 2020
1 parent 0cdc549 commit 5aa9ffb
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
"""
from collections import deque

import numpy as np

from extensions.front.MatMul_normalizer import FullyConnectedDecomposer
from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape
from extensions.middle.TensorIteratorMerge import op_type
from extensions.ops.activation_ops import activation_ops
from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph
from mo.ops.const import Const


class ReplaceConvolutionTranspose(FrontReplacementSubgraph):
Expand Down Expand Up @@ -61,10 +60,9 @@ def replace_sub_graph(self, graph: Graph, match: dict):
convolution_nodes = [node for node in nodes_with_weights if Node(graph, node).op == 'Convolution']
for convolution_node in convolution_nodes:
target_node = self.search_target_node(Node(graph, convolution_node))
order_const = Const(graph, dict(value=np.array([0, 3, 2, 1]))).create_node()
permute_node = Transpose(graph, dict(name=target_node.name + '/Transpose')).create_node()
permute_node = create_op_with_const_inputs(graph, Transpose, {1: int64_array([0, 3, 2, 1])},
{'name': target_node.name + '/Transpose'})
target_node.insert_node_after(permute_node, 0)
order_const.out_port(0).connect(permute_node.in_port(1))

def run_after(self):
from extensions.front.flatten_to_reshape import FlattenToReshape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def test_simple_convolution(self):
('conv', 'reshape_conv'),
('reshape_conv', 'scale_shift'),
])
graph.stage = 'front'
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
conv_node = Node(graph, graph.nodes['conv']['name'])
permute = conv_node.out_node()
self.assertEqual(permute.op, 'Transpose')
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))

def test_conv_pool(self):
graph = build_graph(self.nodes_attributes, [
Expand All @@ -53,11 +54,12 @@ def test_conv_pool(self):
('pool', 'reshape_after_pool'),
('reshape_after_pool', 'fc'),
])
graph.stage = 'front'
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
pool_node = Node(graph, graph.nodes['pool']['name'])
permute = pool_node.out_node()
self.assertEqual(permute.op, 'Transpose')
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))

def test_conv_act_pool(self):
graph = build_graph(self.nodes_attributes, [
Expand All @@ -68,8 +70,9 @@ def test_conv_act_pool(self):
('pool', 'reshape_after_pool'),
('reshape_after_pool', 'fc'),
])
graph.stage = 'front'
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
pool_node = Node(graph, graph.nodes['pool']['name'])
permute = pool_node.out_node()
self.assertEqual(permute.op, 'Transpose')
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))
Original file line number Diff line number Diff line change
Expand Up @@ -53,35 +53,38 @@ def pattern(self):
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['conv']
node_name = node.soft_get('name', node.id)

# create Reshape before convolution
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
shape = Shape(graph, {}).create_node()
shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
shape.in_port(0).connect(node.in_port(0).get_source())

split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape)
conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride])}).create_node()
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]))
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
{'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)

pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/patch_stride/inverse'})
conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride]),
'name': node_name + '/patch_stride/'}).create_node()
pow_node.in_port(0).connect(conv_patch_stride.out_port(0))

mul = Mul(graph, {}).create_node()
mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
mul.in_port(0).connect(split.out_port(1))
mul.in_port(1).connect(pow_node.out_port(0))

const_1 = Const(graph, {'value': int64_array([1])}).create_node()
concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})

concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
concat.in_port(0).connect(split.out_port(0))
concat.in_port(1).connect(mul.out_port(0))
concat.in_port(2).connect(const_1.out_port(0))
concat.in_port(3).connect(conv_patch_stride.out_port(0))

reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node()
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
reshape_in.in_port(1).connect(concat.out_port(0))

# create Reshape after Convolution
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
{'name': node.name + '/Reshape/'})
{'name': node_name + '/reshape_out'})

# connect input_reshape_node
source = node.in_port(0).get_source()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,41 @@ def pattern(self):
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['pool']
node_name = node.soft_get('name', node.id)

if node.pool_step is None:
node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])

# create Reshape before convolution
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
shape = Shape(graph, {}).create_node()
shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
shape.in_port(0).connect(node.in_port(0).get_source())

split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape)
node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride])}).create_node()
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]))
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
{'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)

pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/pool_stride/inverse'})
node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride]),
'name': node_name + '/pool_stride/'}).create_node()
pow_node.in_port(0).connect(node_pool_stride.out_port(0))

mul = Mul(graph, {}).create_node()
mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
mul.in_port(0).connect(split.out_port(1))
mul.in_port(1).connect(pow_node.out_port(0))

const_1 = Const(graph, {'value': int64_array([1])}).create_node()
concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})

concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
concat.in_port(0).connect(split.out_port(0))
concat.in_port(3).connect(mul.out_port(0))
concat.in_port(2).connect(const_1.out_port(0))
concat.in_port(1).connect(node_pool_stride.out_port(0))

reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node()
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
reshape_in.in_port(1).connect(concat.out_port(0))

# create Reshape after Convolution
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
{'name': node.name + '/Reshape/'})
{'name': node_name + '/reshape_out'})

# connect input_reshape_node
source = node.in_port(0).get_source()
Expand Down
44 changes: 22 additions & 22 deletions model-optimizer/extensions/front/kaldi/replace_lstm_nonlinearity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import numpy as np

from extensions.ops.activation_ops import Sigmoid, Tanh
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from mo.front.caffe.extractors.utils import input_as_const
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph
from mo.ops.concat import Concat
from mo.ops.const import Const
from extensions.ops.elementwise import Add, Mul
from mo.ops.scale_shift import ScaleShiftOp


Expand All @@ -40,80 +40,80 @@ def run_before(self):

def replace_op(self, graph: Graph, node: Node):
# split input to (i_part, f_part, c_part, o_part, ct_1)
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
split_node = Split(graph, {'name': 'Split_lstm_input_',
'num_splits': 5}).create_node()
node_name = node.soft_get('name', node.id)
split_node = create_op_with_const_inputs(graph, Split, {1: np.int64(1)},
{'name': node_name + '/split_lstm_input',
'num_splits': 5})
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
split_node.in_port(1).connect(split_node_axis.out_port(0))

# i_t = Sigmoid(i_part + w_ic*ct_1)
i_scale_attrs = {'name': 'i_scaleshift',
i_scale_attrs = {'name': node_name + '/i_scaleshift',
'bias_term': False}
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
split_node.out_port(4).connect(i_scale.in_port(0))

sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
split_node.out_port(0).connect(sum_i_c.in_port(0))
i_scale.out_port(0).connect(sum_i_c.in_port(1))

i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
i_sigmoid = Sigmoid(graph, {'name': node_name + '/i_sigmoid'}).create_node()
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))

# f_t = Sigmoid(f_part + w_fc*ct_1)
f_scale_attrs = {'name': 'f_scaleshift',
f_scale_attrs = {'name': node_name + '/f_scaleshift',
'bias_term': False}
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
split_node.out_port(4).connect(f_scale.in_port(0))

sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
split_node.out_port(1).connect(sum_f_c.in_port(0))
f_scale.out_port(0).connect(sum_f_c.in_port(1))

f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
f_sigmoid = Sigmoid(graph, {'name': node_name + '/f_sigmoid'}).create_node()
sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))

# c_t = f_t*ct_1 + i_t * tanh(c_part)
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
split_node.out_port(2).connect(c_tanh.in_port(0))

prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
prod_i_c_tanh = Mul(graph, {'name': node_name + '/prod_i_c_tanh_'}).create_node()
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))

prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
prod_f_ct_1 = Mul(graph, {'name': node_name + '/prod_f_ct_1_'}).create_node()
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))

sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node()
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))

# o_t = Sigmoid(o_part + w_oc*c_t)
o_scale_attrs = {'name': 'o_scaleshift',
o_scale_attrs = {'name': node_name + '/o_scaleshift',
'bias_term': False}
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
sum_f_i.out_port(0).connect(o_scale.in_port(0))

sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
split_node.out_port(3).connect(sum_o_c.in_port(0))
o_scale.out_port(0).connect(sum_o_c.in_port(1))

o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
o_sigmoid = Sigmoid(graph, {'name': node_name + '/o_sigmoid'}).create_node()
sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))

# m_t = o_t * Tanh(c_t)
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
prod_o_c_t_tanh = Mul(graph, {'name': node_name + '/prod_o_c_t_tanh_'}).create_node()
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))

# add concat to create 1 output
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
concat = Concat(graph, {'name': node_name + '/concat_c_m'}).create_node()
concat.add_sequence_of_ports('in', range(2))
sum_f_i.out_port(0).connect(concat.in_port(0))
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))
Expand Down
10 changes: 5 additions & 5 deletions model-optimizer/extensions/front/mxnet/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from extensions.ops.gather import Gather
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.const import Const


class GatherFrontReplacer(FrontReplacementOp):
Expand All @@ -26,10 +26,10 @@ class GatherFrontReplacer(FrontReplacementOp):

def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
gather_node = Gather(graph, dict(name=node.id + '/embedding_',
symbol_dict={'name': node.id + '/embedding_'})).create_node()
axis_const = Const(graph, {'value': int64_array(0)}).create_node()

gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)},
{'name': node.soft_get('name', node.id) + '/embedding_'})

node.in_port(0).get_connection().set_destination(gather_node.in_port(1))
node.in_port(1).get_connection().set_destination(gather_node.in_port(0))
axis_const.out_port(0).connect(gather_node.in_port(2))
node.out_port(0).get_connection().set_source(gather_node.out_port(0))
11 changes: 6 additions & 5 deletions model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,21 @@ def replace_sub_graph(self, graph: Graph, match: dict):
assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name)
axis = node.axis

reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()

if axis == 0:
dim = Const(graph, {'value': int64_array([1, -1])}).create_node()
dim = Const(graph, {'value': int64_array([1, -1]), 'name': reshape_node.name + '/shape'}).create_node()
elif axis == 1:
dim = Const(graph, {'value': int64_array([0, -1])}).create_node()
dim = Const(graph, {'value': int64_array([0, -1]), 'name': reshape_node.name + '/shape'}).create_node()
else:
shape = Shape(graph, {'name': name + '/input_shape'}).create_node()

idxs = list(range(axis)) if axis > 0 else list(range(axis, 0))

axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs)
first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]),
{'keep_dims': True})
second_dims = Const(graph, {'value': int64_array([-1])}).create_node()
{'name': name + '/first_dims', 'keep_dims': True})
second_dims = Const(graph, {'value': int64_array([-1]), 'name': name + '/second_dims'}).create_node()

node.in_port(0).get_source().connect(shape.in_port(0))
axis_shape_portion.out_port(0).connect(first_dims.in_port(0))
Expand All @@ -72,7 +74,6 @@ def replace_sub_graph(self, graph: Graph, match: dict):

dim = new_shape_node_from_shape_nodes(order_of_dims)

reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
reshape_node.in_port(1).connect(dim.out_port(0))

node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
Expand Down
10 changes: 4 additions & 6 deletions model-optimizer/extensions/front/onnx/hard_sigmoid_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from mo.front.common.replacement import FrontReplacementOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.graph.graph import Node, Graph
from mo.ops.const import Const
from mo.front.tf.graph_utils import create_op_with_const_inputs


class HardSigmoidFrontExtractor(FrontReplacementOp):
Expand All @@ -30,11 +30,9 @@ class HardSigmoidFrontExtractor(FrontReplacementOp):
def replace_op(self, graph: Graph, node: Node):
alpha = onnx_attr(node, 'alpha', 'f', default=0.2)
beta = onnx_attr(node, 'beta', 'f', default=0.5)
alpha_node = Const(graph, {'value': np.array(alpha)}).create_node()
beta_node = Const(graph, {'value': np.array(beta)}).create_node()

hard_sigmoid = HardSigmoid(graph, {'name': node.name + '/HardSigmoid_'}).create_node()
hard_sigmoid = create_op_with_const_inputs(graph, HardSigmoid, {1: np.array(alpha), 2: np.array(beta)},
{'name': node.name + '/HardSigmoid_'})

node.in_port(0).get_connection().set_destination(hard_sigmoid.in_port(0))
alpha_node.out_port(0).connect(hard_sigmoid.in_port(1))
beta_node.out_port(0).connect(hard_sigmoid.in_port(2))
return [hard_sigmoid.id]
Loading

0 comments on commit 5aa9ffb

Please sign in to comment.