Skip to content

Commit

Permalink
Implement support for opset3 EmbeddingBag ops (#546)
Browse files Browse the repository at this point in the history
* [MO] Implement EmbeddingBag_3

* Transform dynamic sub-graph of Wide and Deep into EmbeddingSegmentsSum

- Expressed SparseWeightedSum sub-graph through EmbeddingSegmentsSum
- Removed experimental SparseWeightedSum layer
- Implemented tests for the transformation

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix EmbeddingBag shape infer

* Fix EmbeddingSegmentsSum transformation for Wide and Deep

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix EmbeddingSegmentSum replacer after ports swap

Signed-off-by: Roman Kazantsev <[email protected]>

* Update package_BOM.txt

Signed-off-by: Roman Kazantsev <[email protected]>

* Add unit tests for EmbeddingXXX shape infer

* Fix ATen resolver

* Remove deleted files from BOM

* Add opset version to embedding_bag

* Use base class for EmbeddingBag

* Fix per_sample_weights case

* Fix EmbeddingSegmentsSum transformation

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix EmbeddingBag checks

* Fix ATen front transformation and merge conflicts

* Fix BOM

* Work around limitation for I64 input of W&D model

Signed-off-by: Roman Kazantsev <[email protected]>

* Cleanup where operation to fix affect of WhereDecomposition transform

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix BOM

* Correct EmbeddingSegmentSum transform for Wide and Deep

Add casting segment ids to i32 and remove ConstToResult sub-graph.

Signed-off-by: Roman Kazantsev <[email protected]>

* Update BOM with RemoveConstToResult transform

Signed-off-by: Roman Kazantsev <[email protected]>

* Add more comments for RemoveConstToResult transformation

Signed-off-by: Roman Kazantsev <[email protected]>

* Remove useless logging in EmbeddingSegmentsSum transformation

Signed-off-by: Roman Kazantsev <[email protected]>

* Small fixes

* Move EmbeddingBag resolving back to front phase

* Improve error messages

* Fix typo in unittests

* Reimplement sparse_reshape middle transform

Avoid deprecated API.

Signed-off-by: Roman Kazantsev <[email protected]>

* Clean-up graph after sparse_reshape and ConstToResult transformation

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix clean-up for transformations

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix clean-up for transformation #2

Signed-off-by: Roman Kazantsev <[email protected]>

Co-authored-by: Roman Kazantsev <[email protected]>
  • Loading branch information
mvafin and rkazants authored Jun 8, 2020
1 parent d155483 commit f1811ad
Show file tree
Hide file tree
Showing 18 changed files with 735 additions and 654 deletions.
4 changes: 1 addition & 3 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ extensions/front/tf/cumsum_ext.py
extensions/front/tf/deconv_ext.py
extensions/front/tf/depth_to_space.py
extensions/front/tf/elementwise_ext.py
extensions/front/tf/embedding_segments_sum.py
extensions/front/tf/expand_dims_ext.py
extensions/front/tf/extract_image_patches_ext.py
extensions/front/tf/fake_const_ext.py
Expand Down Expand Up @@ -441,7 +442,6 @@ extensions/front/tf/sparse_segment_mean_ext.py
extensions/front/tf/sparse_segment_sqrtn_ext.py
extensions/front/tf/sparse_segment_sum_ext.py
extensions/front/tf/sparse_to_dense_ext.py
extensions/front/tf/sparse_weighted_sum.py
extensions/front/tf/split_ext.py
extensions/front/tf/ssd_support.json
extensions/front/tf/ssd_support_api_v1.14.json
Expand Down Expand Up @@ -522,7 +522,6 @@ extensions/middle/DilatedConvolution.py
extensions/middle/EltwiseChecker.py
extensions/middle/EltwiseInputNormalization.py
extensions/middle/EltwiseInputReshape.py
extensions/middle/EmbeddingBagResolver.py
extensions/middle/FakeSplitOutputs.py
extensions/middle/FusedBatchNormNonConstant.py
extensions/middle/FusedBatchNormTraining.py
Expand Down Expand Up @@ -686,7 +685,6 @@ extensions/ops/sparse_segment_mean.py
extensions/ops/sparse_segment_sqrtn.py
extensions/ops/sparse_segment_sum.py
extensions/ops/sparse_to_dense.py
extensions/ops/sparse_weighted_sum.py
extensions/ops/spatial_transformer.py
extensions/ops/splice.py
extensions/ops/split.py
Expand Down
37 changes: 37 additions & 0 deletions model-optimizer/extensions/back/SpecialNodesFinalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,43 @@ def find_and_replace_pattern(self, graph: Graph):
graph.remove_node(node.id)


class RemoveConstToResult(BackReplacementPattern):
"""
Transformation looks for a sub-graph "Const->Result" and removes Result node.
Currently IE is unable to handle such graph so this transformation removes to work around this case.
For instance, this case appears for Wide and Deep model.
"""
enabled = True

@staticmethod
def pattern():
return dict(
nodes=[
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('result_node', {'type': 'Result', 'kind': 'op'}),
],
edges=[
('const_node', 'const_data'),
('const_data', 'result_node')
]
)

@staticmethod
def replace_pattern(graph: Graph, match: dict):
const_node = match['const_node']
const_data_node = match['const_data']
result_node = match['result_node']
nodes_to_remove = [result_node.id]

# in case only const data consumer that is the result node, remove the whole sub-graph
if len(const_node.out_port(0).get_destinations()) == 1:
nodes_to_remove.append(const_node.id)
nodes_to_remove.append(const_data_node.id)

graph.remove_node(nodes_to_remove)


class NormalizeTI(BackReplacementPattern):
"""
This transformation is used while generating IR of lower than 10 version
Expand Down
62 changes: 56 additions & 6 deletions model-optimizer/extensions/front/ATenToEmbeddingBag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
limitations under the License.
"""

from extensions.ops.embedding_bag import EmbeddingBag
from extensions.ops.embedding_bag import EmbeddingBagOffsetsSum, EmbeddingBagPackedSum
from extensions.ops.rank import Rank
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_node
from mo.ops.broadcast import Broadcast
from mo.ops.concat import Concat
from mo.ops.shape import Shape
from mo.ops.unsqueeze import Unsqueeze
from mo.utils.shape import node_to_get_shape_value_of_indices, get_canonical_axis_index_node, \
get_shape_values_by_indices_node


class AtenToEmbeddingBag(FrontReplacementPattern):
Expand All @@ -27,11 +36,52 @@ class AtenToEmbeddingBag(FrontReplacementPattern):

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
node_name = node.name
rename_node(node, node_name + '/Old')
embedding_bag = EmbeddingBag(graph, {'name': node_name, 'mode': node.mode,
'scale_grad_by_freq': node.scale_grad_by_freq}).create_node()
assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
'mode is supported for node {}.'.format(node.id)
node_name = node.soft_get('name', node.id)
rename_node(node, node_name + '/TBR')
is_packed = False
if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
is_packed = True
embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node()
else:
embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node()
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
rename_node(embedding_bag, node_name)
node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0))
node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1))
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
node.out_port(0).get_connection().set_source(embedding_bag.out_port(0))
if len(node.in_ports()) == 4 and not node.in_port(3).disconnected():
if is_packed:
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2))
else:
# connect per_sample_weights
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4))

weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node()

weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node()
last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1)
weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node)

weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0])

zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])},
{'name': node_name + '/Broadcast'})
zero_col_node.in_port(1).connect(weights_last_dim.out_port(0))

default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
{'name': node_name + '/Unsqueeze'})
default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0))

# expand embedding table with zeros
weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2,
'name': node_name + '/Concat'}).create_node()
embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0))
weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0))
weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0))
weights_concat.in_port(1).connect(default_embeddings_node.out_port(0))
weights_concat.out_port(0).connect(embedding_bag.in_port(0))

# point default index to expanded part of embedding table
weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
164 changes: 132 additions & 32 deletions model-optimizer/extensions/front/ATenToEmbeddingBag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,147 @@

import unittest

import numpy as np

from extensions.front.ATenToEmbeddingBag import AtenToEmbeddingBag
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph

nodes_attributes = {
'weights_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'indices_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'offsets_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'aten': {'type': None, 'kind': 'op', 'op': 'ATen', 'mode': 0, 'operator': 'embedding_bag', 'name': 'my_aten',
'scale_grad_by_freq': 0},
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},

# new EmbeddingBag layer
'emb_bag': {'type': None, 'kind': 'op', 'op': 'EmbeddingBag', 'mode': 0, 'scale_grad_by_freq': 0},
}
from mo.utils.unittest.graph import build_graph, result, \
regular_op, const


class AtenToEmbeddingBagTest(unittest.TestCase):
def test(self):
graph = build_graph(nodes_attributes,
[('weights_inp', 'aten', {'in': 0, 'out': 0}),
('indices_inp', 'aten', {'in': 1, 'out': 0}),
('offsets_inp', 'aten', {'in': 2, 'out': 0}),
('aten', 'result', {'in': 0, 'out': 0}),
],
{}, nodes_with_edges_only=True)

graph_ref = build_graph(nodes_attributes,
[('weights_inp', 'emb_bag', {'in': 0, 'out': 0}),
('indices_inp', 'emb_bag', {'in': 1, 'out': 0}),
('offsets_inp', 'emb_bag', {'in': 2, 'out': 0}),
('emb_bag', 'result', {'in': 0, 'out': 0}),
],
{}, nodes_with_edges_only=True)
nodes = {
**const('weights_inp', np.random.randn(100, 2)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('offsets_inp', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', 'op': 'EmbeddingBagOffsetsSum'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('offsets_inp', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

edges_ref = [('weights_inp', 'emb_bag'),
('indices_inp', 'emb_bag'),
('offsets_inp', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_packed(self):
nodes = {
**const('weights_inp', np.random.randn(100, 4)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagPackedSum', 'kind': 'op',
'op': 'EmbeddingBagPackedSum'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

replacer = AtenToEmbeddingBag()
replacer.find_and_replace_pattern(graph)
edges_ref = [('weights_inp', 'emb_bag'),
('indices_inp', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_per_sample_weights(self):
nodes = {
**const('weights_inp', np.random.randn(100, 2)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('offsets_inp', {'type': 'Parameter'}),
**regular_op('per_sample_weights', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op',
'op': 'EmbeddingBagOffsetsSum'}),
**regular_op('WeightsRank', {'type': None, 'kind': 'op', 'op': 'Rank'}),
**regular_op('WeightsRank/axis', {'type': 'Add', 'kind': 'op', 'op': 'Add'}),
**regular_op('gather1', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}),
**regular_op('gather2', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}),
**regular_op('WeightsShape', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}),
**regular_op('Broadcast', {'type': 'Broadcast', 'kind': 'op', 'op': 'Broadcast'}),
**regular_op('Unsqueeze', {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'}),
**const('WeightsShape/Axis', int64_array(0)),
**const('zero1', int64_array(0)),
**const('zero2', int64_array(0)),
**const('Unsqueeze/value', int64_array(0)),
**const('Broadcast/value', int64_array(0)),
**const('neg', int64_array(-1)),
**regular_op('Concat', {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('offsets_inp', 'aten'),
('per_sample_weights', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges, nodes_with_edges_only=True)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

edges_ref = [('weights_inp', 'Concat', {'in': 0, 'out': 0}),
('weights_inp', 'WeightsShape', {'in': 0, 'out': 0}),
('weights_inp', 'WeightsRank', {'in': 0, 'out': 0}),
('WeightsRank', 'WeightsRank/axis'),
('neg', 'WeightsRank/axis'),
('WeightsShape', 'gather1', {'in': 0, 'out': 0}),
('WeightsRank/axis', 'gather1'),
('WeightsShape/Axis', 'gather1'),
('WeightsShape', 'gather2', {'in': 0, 'out': 0}),
('zero1', 'gather2'),
('zero2', 'gather2'),
('Broadcast/value', 'Broadcast'),
('gather1', 'Broadcast'),
('Broadcast', 'Unsqueeze'),
('Unsqueeze/value', 'Unsqueeze'),
('Unsqueeze', 'Concat'),
('Concat', 'emb_bag'),
('indices_inp', 'emb_bag'),
('offsets_inp', 'emb_bag'),
('gather2', 'emb_bag'),
('per_sample_weights', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref, nodes_with_edges_only=True)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='EmbeddingBag')[0]]['name'] == 'my_aten')
3 changes: 1 addition & 2 deletions model-optimizer/extensions/front/onnx/aten_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class ATenFrontExtractor(FrontExtractorOp):
def extract(cls, node):
mode = onnx_attr(node, 'mode', 'i', default=1)
operator = onnx_attr(node, 'operator', 's').decode()
scale_grad_by_freq = onnx_attr(node, 'scale_grad_by_freq', 'i', default=0)

ATen.update_node_stat(node, {'operator': operator, 'mode': mode, 'scale_grad_by_freq': scale_grad_by_freq})
ATen.update_node_stat(node, {'operator': operator, 'mode': mode})
return cls.enabled
4 changes: 2 additions & 2 deletions model-optimizer/extensions/front/tf/WhereDecomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class WhereDecomposition(FrontReplacementOp):
enabled = True

def run_after(self):
from extensions.front.tf.sparse_weighted_sum import ExperimentalSparseWeightedSumFrontReplacer
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer
return [ExperimentalSparseWeightedSumFrontReplacer, TransposeOrderNormalizer]
return [EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2, TransposeOrderNormalizer]

def replace_op(self, graph: Graph, node: Node):
node_name = node.soft_get('name', node.id)
Expand Down
Loading

0 comments on commit f1811ad

Please sign in to comment.