Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for opset3 EmbeddingBag ops #546

Merged
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
90bf6c6
[MO] Implement EmbeddingBag_3
mvafin Apr 24, 2020
bb897c4
Transform dynamic sub-graph of Wide and Deep into EmbeddingSegmentsSum
rkazants Apr 28, 2020
3a86aff
Fix EmbeddingBag shape infer
mvafin Apr 30, 2020
40295be
Fix EmbeddingSegmentsSum transformation for Wide and Deep
rkazants May 7, 2020
725a1bb
Fix EmbeddingSegmentSum replacer after ports swap
rkazants May 13, 2020
f4bcd88
Update package_BOM.txt
rkazants May 13, 2020
05b1faf
Add unit tests for EmbeddingXXX shape infer
mvafin May 15, 2020
973e9b5
Fix ATen resolver
mvafin May 15, 2020
4bf92b2
Remove deleted files from BOM
mvafin May 19, 2020
75b5ec4
Add opset version to embedding_bag
mvafin May 20, 2020
e9cf0fa
Use base class for EmbeddingBag
mvafin May 20, 2020
adab635
Fix per_sample_weights case
mvafin May 21, 2020
af14345
Fix EmbeddingSegmentsSum transformation
rkazants May 21, 2020
0f202f9
Fix EmbeddingBag checks
mvafin May 22, 2020
f71fc27
Fix ATen front transformation and merge conflicts
mvafin May 25, 2020
c6ceb77
Fix BOM
mvafin May 26, 2020
ab67e0e
Work around limitation for I64 input of W&D model
rkazants May 31, 2020
3b9ae65
Cleanup where operation to fix affect of WhereDecomposition transform
rkazants Jun 1, 2020
80960f9
Fix BOM
mvafin Jun 1, 2020
2d6de5d
Correct EmbeddingSegmentSum transform for Wide and Deep
rkazants Jun 3, 2020
c03fc6b
Update BOM with RemoveConstToResult transform
rkazants Jun 3, 2020
3e6c42b
Merge remote-tracking branch 'remotes/upstream/master' into feature/e…
mvafin Jun 3, 2020
284fde6
Merge remote-tracking branch 'remotes/upstream/master' into feature/e…
mvafin Jun 3, 2020
fbf745f
Add more comments for RemoveConstToResult transformation
rkazants Jun 4, 2020
1ba3ccc
Remove useless logging in EmbeddingSegmentsSum transformation
rkazants Jun 4, 2020
47fc039
Small fixes
mvafin Jun 5, 2020
9dcee86
Merge remote-tracking branch 'remotes/upstream/master' into feature/e…
mvafin Jun 5, 2020
4da1a30
Move EmbeddingBag resolving back to front phase
mvafin Jun 5, 2020
a296137
Improve error messages
mvafin Jun 5, 2020
8df85b4
Fix typo in unittests
mvafin Jun 5, 2020
71773aa
Reimplement sparse_reshape middle transform
rkazants Jun 7, 2020
42b6441
Clean-up graph after sparse_reshape and ConstToResult transformation
rkazants Jun 8, 2020
9b5e8de
Fix clean-up for transformations
rkazants Jun 8, 2020
954bc6f
Fix clean-up for transformation #2
rkazants Jun 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ extensions/front/tf/cumsum_ext.py
extensions/front/tf/deconv_ext.py
extensions/front/tf/depth_to_space.py
extensions/front/tf/elementwise_ext.py
extensions/front/tf/embedding_segments_sum.py
extensions/front/tf/expand_dims_ext.py
extensions/front/tf/extract_image_patches_ext.py
extensions/front/tf/fake_const_ext.py
Expand Down Expand Up @@ -438,7 +439,6 @@ extensions/front/tf/sparse_segment_mean_ext.py
extensions/front/tf/sparse_segment_sqrtn_ext.py
extensions/front/tf/sparse_segment_sum_ext.py
extensions/front/tf/sparse_to_dense_ext.py
extensions/front/tf/sparse_weighted_sum.py
extensions/front/tf/split_ext.py
extensions/front/tf/ssd_support.json
extensions/front/tf/ssd_support_api_v1.14.json
Expand Down Expand Up @@ -519,7 +519,6 @@ extensions/middle/DilatedConvolution.py
extensions/middle/EltwiseChecker.py
extensions/middle/EltwiseInputNormalization.py
extensions/middle/EltwiseInputReshape.py
extensions/middle/EmbeddingBagResolver.py
extensions/middle/FakeSplitOutputs.py
extensions/middle/FusedBatchNormNonConstant.py
extensions/middle/FusedBatchNormTraining.py
Expand Down Expand Up @@ -683,7 +682,6 @@ extensions/ops/sparse_segment_mean.py
extensions/ops/sparse_segment_sqrtn.py
extensions/ops/sparse_segment_sum.py
extensions/ops/sparse_to_dense.py
extensions/ops/sparse_weighted_sum.py
extensions/ops/spatial_transformer.py
extensions/ops/splice.py
extensions/ops/split.py
Expand Down
37 changes: 37 additions & 0 deletions model-optimizer/extensions/back/SpecialNodesFinalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,43 @@ def find_and_replace_pattern(self, graph: Graph):
graph.remove_node(node.id)


class RemoveConstToResult(BackReplacementPattern):
"""
Transformation looks for a sub-graph "Const->Result" and removes Result node.
Currently IE is unable to handle such graph so this transformation removes to work around this case.
For instance, this case appears for Wide and Deep model.
"""
enabled = True

@staticmethod
def pattern():
return dict(
nodes=[
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('result_node', {'type': 'Result', 'kind': 'op'}),
],
edges=[
('const_node', 'const_data'),
('const_data', 'result_node')
]
)

@staticmethod
def replace_pattern(graph: Graph, match: dict):
const_node = match['const_node']
const_data_node = match['const_data']
result_node = match['result_node']
nodes_to_remove = [result_node.id]

# in case only const data consumer that is the result node, remove the whole sub-graph
if len(const_data_node.out_nodes()) == 1:
rkazants marked this conversation as resolved.
Show resolved Hide resolved
nodes_to_remove.append(const_node.id)
nodes_to_remove.append(const_data_node.id)

graph.remove_node(nodes_to_remove)


class NormalizeTI(BackReplacementPattern):
"""
This transformation is used while generating IR of lower than 10 version
Expand Down
62 changes: 56 additions & 6 deletions model-optimizer/extensions/front/ATenToEmbeddingBag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,18 @@
limitations under the License.
"""

from extensions.ops.embedding_bag import EmbeddingBag
from extensions.ops.embedding_bag import EmbeddingBagOffsetsSum, EmbeddingBagPackedSum
from extensions.ops.rank import Rank
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_node
from mo.ops.broadcast import Broadcast
from mo.ops.concat import Concat
from mo.ops.shape import Shape
from mo.ops.unsqueeze import Unsqueeze
from mo.utils.shape import node_to_get_shape_value_of_indices, get_canonical_axis_index_node, \
get_shape_values_by_indices_node


class AtenToEmbeddingBag(FrontReplacementPattern):
Expand All @@ -27,11 +36,52 @@ class AtenToEmbeddingBag(FrontReplacementPattern):

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
node_name = node.name
rename_node(node, node_name + '/Old')
embedding_bag = EmbeddingBag(graph, {'name': node_name, 'mode': node.mode,
'scale_grad_by_freq': node.scale_grad_by_freq}).create_node()
assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \
'mode is supported for node {}.'.format(node.id)
node_name = node.soft_get('name', node.id)
rename_node(node, node_name + '/TBR')
is_packed = False
if len(node.in_ports()) < 3 or node.in_port(2).disconnected():
is_packed = True
embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node()
else:
embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node()
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
rename_node(embedding_bag, node_name)
node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0))
node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1))
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
node.out_port(0).get_connection().set_source(embedding_bag.out_port(0))
if len(node.in_ports()) == 4 and not node.in_port(3).disconnected():
if is_packed:
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2))
else:
# connect per_sample_weights
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4))

weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node()

weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node()
last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1)
weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node)

weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0])

zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])},
{'name': node_name + '/Broadcast'})
zero_col_node.in_port(1).connect(weights_last_dim.out_port(0))

default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
{'name': node_name + '/Unsqueeze'})
default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0))

# expand embedding table with zeros
weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2,
'name': node_name + '/Concat'}).create_node()
embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0))
weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0))
weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0))
weights_concat.in_port(1).connect(default_embeddings_node.out_port(0))
weights_concat.out_port(0).connect(embedding_bag.in_port(0))

# point default index to expanded part of embedding table
weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
164 changes: 132 additions & 32 deletions model-optimizer/extensions/front/ATenToEmbeddingBag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,147 @@

import unittest

import numpy as np

from extensions.front.ATenToEmbeddingBag import AtenToEmbeddingBag
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph

nodes_attributes = {
'weights_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'indices_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'offsets_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'aten': {'type': None, 'kind': 'op', 'op': 'ATen', 'mode': 0, 'operator': 'embedding_bag', 'name': 'my_aten',
'scale_grad_by_freq': 0},
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},

# new EmbeddingBag layer
'emb_bag': {'type': None, 'kind': 'op', 'op': 'EmbeddingBag', 'mode': 0, 'scale_grad_by_freq': 0},
}
from mo.utils.unittest.graph import build_graph, result, \
regular_op, const


class AtenToEmbeddingBagTest(unittest.TestCase):
def test(self):
graph = build_graph(nodes_attributes,
[('weights_inp', 'aten', {'in': 0, 'out': 0}),
('indices_inp', 'aten', {'in': 1, 'out': 0}),
('offsets_inp', 'aten', {'in': 2, 'out': 0}),
('aten', 'result', {'in': 0, 'out': 0}),
],
{}, nodes_with_edges_only=True)

graph_ref = build_graph(nodes_attributes,
[('weights_inp', 'emb_bag', {'in': 0, 'out': 0}),
('indices_inp', 'emb_bag', {'in': 1, 'out': 0}),
('offsets_inp', 'emb_bag', {'in': 2, 'out': 0}),
('emb_bag', 'result', {'in': 0, 'out': 0}),
],
{}, nodes_with_edges_only=True)
nodes = {
**const('weights_inp', np.random.randn(100, 2)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('offsets_inp', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', 'op': 'EmbeddingBagOffsetsSum'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('offsets_inp', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

edges_ref = [('weights_inp', 'emb_bag'),
('indices_inp', 'emb_bag'),
('offsets_inp', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_packed(self):
nodes = {
**const('weights_inp', np.random.randn(100, 4)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagPackedSum', 'kind': 'op',
'op': 'EmbeddingBagPackedSum'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

replacer = AtenToEmbeddingBag()
replacer.find_and_replace_pattern(graph)
edges_ref = [('weights_inp', 'emb_bag'),
('indices_inp', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)

def test_per_sample_weights(self):
nodes = {
**const('weights_inp', np.random.randn(100, 2)),
**regular_op('indices_inp', {'type': 'Parameter'}),
**regular_op('offsets_inp', {'type': 'Parameter'}),
**regular_op('per_sample_weights', {'type': 'Parameter'}),
**regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0,
'name': 'my_aten'}),

**regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op',
'op': 'EmbeddingBagOffsetsSum'}),
**regular_op('WeightsRank', {'type': None, 'kind': 'op', 'op': 'Rank'}),
**regular_op('WeightsRank/axis', {'type': 'Add', 'kind': 'op', 'op': 'Add'}),
**regular_op('gather1', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}),
**regular_op('gather2', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}),
**regular_op('WeightsShape', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}),
**regular_op('Broadcast', {'type': 'Broadcast', 'kind': 'op', 'op': 'Broadcast'}),
**regular_op('Unsqueeze', {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'}),
**const('WeightsShape/Axis', int64_array(0)),
**const('zero1', int64_array(0)),
**const('zero2', int64_array(0)),
**const('Unsqueeze/value', int64_array(0)),
**const('Broadcast/value', int64_array(0)),
**const('neg', int64_array(-1)),
**regular_op('Concat', {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}),
**result('result'),
}
edges = [('weights_inp', 'aten'),
('indices_inp', 'aten'),
('offsets_inp', 'aten'),
('per_sample_weights', 'aten'),
('aten', 'result'),
]
graph = build_graph(nodes, edges, nodes_with_edges_only=True)

graph.graph['layout'] = 'NCHW'
graph.stage = 'front'

edges_ref = [('weights_inp', 'Concat', {'in': 0, 'out': 0}),
('weights_inp', 'WeightsShape', {'in': 0, 'out': 0}),
('weights_inp', 'WeightsRank', {'in': 0, 'out': 0}),
('WeightsRank', 'WeightsRank/axis'),
('neg', 'WeightsRank/axis'),
('WeightsShape', 'gather1', {'in': 0, 'out': 0}),
('WeightsRank/axis', 'gather1'),
('WeightsShape/Axis', 'gather1'),
('WeightsShape', 'gather2', {'in': 0, 'out': 0}),
('zero1', 'gather2'),
('zero2', 'gather2'),
('Broadcast/value', 'Broadcast'),
('gather1', 'Broadcast'),
('Broadcast', 'Unsqueeze'),
('Unsqueeze/value', 'Unsqueeze'),
('Unsqueeze', 'Concat'),
('Concat', 'emb_bag'),
('indices_inp', 'emb_bag'),
('offsets_inp', 'emb_bag'),
('gather2', 'emb_bag'),
('per_sample_weights', 'emb_bag'),
('emb_bag', 'result'),
]

graph_ref = build_graph(nodes, edges_ref, nodes_with_edges_only=True)

AtenToEmbeddingBag().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'result')
self.assertTrue(flag, resp)
self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='EmbeddingBag')[0]]['name'] == 'my_aten')
3 changes: 1 addition & 2 deletions model-optimizer/extensions/front/onnx/aten_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class ATenFrontExtractor(FrontExtractorOp):
def extract(cls, node):
mode = onnx_attr(node, 'mode', 'i', default=1)
operator = onnx_attr(node, 'operator', 's').decode()
scale_grad_by_freq = onnx_attr(node, 'scale_grad_by_freq', 'i', default=0)
mvafin marked this conversation as resolved.
Show resolved Hide resolved

ATen.update_node_stat(node, {'operator': operator, 'mode': mode, 'scale_grad_by_freq': scale_grad_by_freq})
ATen.update_node_stat(node, {'operator': operator, 'mode': mode})
return cls.enabled
4 changes: 2 additions & 2 deletions model-optimizer/extensions/front/tf/WhereDecomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class WhereDecomposition(FrontReplacementOp):
enabled = True

def run_after(self):
from extensions.front.tf.sparse_weighted_sum import ExperimentalSparseWeightedSumFrontReplacer
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer
return [ExperimentalSparseWeightedSumFrontReplacer, TransposeOrderNormalizer]
return [EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2, TransposeOrderNormalizer]

def replace_op(self, graph: Graph, node: Node):
node_name = node.soft_get('name', node.id)
Expand Down
Loading