Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for opset3 EmbeddingBag ops #546

Merged
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
90bf6c6
[MO] Implement EmbeddingBag_3
mvafin Apr 24, 2020
bb897c4
Transform dynamic sub-graph of Wide and Deep into EmbeddingSegmentsSum
rkazants Apr 28, 2020
3a86aff
Fix EmbeddingBag shape infer
mvafin Apr 30, 2020
40295be
Fix EmbeddingSegmentsSum transformation for Wide and Deep
rkazants May 7, 2020
725a1bb
Fix EmbeddingSegmentSum replacer after ports swap
rkazants May 13, 2020
f4bcd88
Update package_BOM.txt
rkazants May 13, 2020
05b1faf
Add unit tests for EmbeddingXXX shape infer
mvafin May 15, 2020
973e9b5
Fix ATen resolver
mvafin May 15, 2020
4bf92b2
Remove deleted files from BOM
mvafin May 19, 2020
75b5ec4
Add opset version to embedding_bag
mvafin May 20, 2020
e9cf0fa
Use base class for EmbeddingBag
mvafin May 20, 2020
adab635
Fix per_sample_weights case
mvafin May 21, 2020
af14345
Fix EmbeddingSegmentsSum transformation
rkazants May 21, 2020
0f202f9
Fix EmbeddingBag checks
mvafin May 22, 2020
f71fc27
Fix ATen front transformation and merge conflicts
mvafin May 25, 2020
c6ceb77
Fix BOM
mvafin May 26, 2020
ab67e0e
Work around limitation for I64 input of W&D model
rkazants May 31, 2020
3b9ae65
Cleanup where operation to fix affect of WhereDecomposition transform
rkazants Jun 1, 2020
80960f9
Fix BOM
mvafin Jun 1, 2020
2d6de5d
Correct EmbeddingSegmentSum transform for Wide and Deep
rkazants Jun 3, 2020
c03fc6b
Update BOM with RemoveConstToResult transform
rkazants Jun 3, 2020
3e6c42b
Merge remote-tracking branch 'remotes/upstream/master' into feature/e…
mvafin Jun 3, 2020
284fde6
Merge remote-tracking branch 'remotes/upstream/master' into feature/e…
mvafin Jun 3, 2020
fbf745f
Add more comments for RemoveConstToResult transformation
rkazants Jun 4, 2020
1ba3ccc
Remove useless logging in EmbeddingSegmentsSum transformation
rkazants Jun 4, 2020
47fc039
Small fixes
mvafin Jun 5, 2020
9dcee86
Merge remote-tracking branch 'remotes/upstream/master' into feature/e…
mvafin Jun 5, 2020
4da1a30
Move EmbeddingBag resolving back to front phase
mvafin Jun 5, 2020
a296137
Improve error messages
mvafin Jun 5, 2020
8df85b4
Fix typo in unittests
mvafin Jun 5, 2020
71773aa
Reimplement sparse_reshape middle transform
rkazants Jun 7, 2020
42b6441
Clean-up graph after sparse_reshape and ConstToResult transformation
rkazants Jun 8, 2020
9b5e8de
Fix clean-up for transformations
rkazants Jun 8, 2020
954bc6f
Fix clean-up for transformation #2
rkazants Jun 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ extensions/back/ReadValueAssignToMemory.py
extensions/back/ReduceToPooling.py
extensions/back/ReduceTransposeDimensions.py
extensions/back/remove_last_softmax_pattern.py
extensions/back/RemoveConstToResult.py
extensions/back/RemoveUselessConvert.py
extensions/back/Reshape0DToSqueeze.py
extensions/back/ReshapeMutation.py
Expand Down Expand Up @@ -364,6 +365,7 @@ extensions/front/tf/cumsum_ext.py
extensions/front/tf/deconv_ext.py
extensions/front/tf/depth_to_space.py
extensions/front/tf/elementwise_ext.py
extensions/front/tf/embedding_segments_sum.py
extensions/front/tf/expand_dims_ext.py
extensions/front/tf/extract_image_patches_ext.py
extensions/front/tf/fake_const_ext.py
Expand Down Expand Up @@ -438,7 +440,6 @@ extensions/front/tf/sparse_segment_mean_ext.py
extensions/front/tf/sparse_segment_sqrtn_ext.py
extensions/front/tf/sparse_segment_sum_ext.py
extensions/front/tf/sparse_to_dense_ext.py
extensions/front/tf/sparse_weighted_sum.py
extensions/front/tf/split_ext.py
extensions/front/tf/ssd_support.json
extensions/front/tf/ssd_support_api_v1.14.json
Expand Down Expand Up @@ -683,7 +684,6 @@ extensions/ops/sparse_segment_mean.py
extensions/ops/sparse_segment_sqrtn.py
extensions/ops/sparse_segment_sum.py
extensions/ops/sparse_to_dense.py
extensions/ops/sparse_weighted_sum.py
extensions/ops/spatial_transformer.py
extensions/ops/splice.py
extensions/ops/split.py
Expand Down
45 changes: 45 additions & 0 deletions model-optimizer/extensions/back/RemoveConstToResult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
rkazants marked this conversation as resolved.
Show resolved Hide resolved
Copyright (C) 2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph


class RemoveConstToResult(BackReplacementPattern):
rkazants marked this conversation as resolved.
Show resolved Hide resolved
"""
Transformation looks for a sub-graph "Const->Result" and removes Result node.
Currently IE is unable to handle such graph so this transformation removes to work around this case.
For instance, this case appears for Wide and Deep model.
"""
enabled = True

@staticmethod
def pattern():
return dict(
nodes=[
('const_node', {'type': 'Const', 'kind': 'op'}),
('const_data', {'kind': 'data'}),
('result_node', {'type': 'Result', 'kind': 'op'}),
],
edges=[
('const_node', 'const_data'),
('const_data', 'result_node')
]
)

@staticmethod
def replace_pattern(graph: Graph, match: dict):
result_node = match['result_node']
graph.remove_node(result_node.id)
15 changes: 9 additions & 6 deletions model-optimizer/extensions/front/ATenToEmbeddingBag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from extensions.ops.embedding_bag import EmbeddingBag
from extensions.ops.embedding_bag import ATenEmbeddingBag
from mo.front.common.replacement import FrontReplacementPattern
from mo.graph.graph import Graph, rename_node

Expand All @@ -27,11 +27,14 @@ class AtenToEmbeddingBag(FrontReplacementPattern):

def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'):
node_name = node.name
rename_node(node, node_name + '/Old')
embedding_bag = EmbeddingBag(graph, {'name': node_name, 'mode': node.mode,
'scale_grad_by_freq': node.scale_grad_by_freq}).create_node()
node_name = node.soft_get('name', node.id)
rename_node(node, node_name + '/TBR')
embedding_bag = ATenEmbeddingBag(graph, {'name': node_name, 'mode': node.soft_get('mode', 1)}).create_node()
rename_node(embedding_bag, node_name)
node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0))
node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1))
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
if node.is_in_port_connected(2):
node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2))
if node.is_in_port_connected(3):
node.in_port(3).get_connection().set_destination(embedding_bag.in_port(3))
node.out_port(0).get_connection().set_source(embedding_bag.out_port(0))
7 changes: 3 additions & 4 deletions model-optimizer/extensions/front/ATenToEmbeddingBag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
'weights_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'indices_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'offsets_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'aten': {'type': None, 'kind': 'op', 'op': 'ATen', 'mode': 0, 'operator': 'embedding_bag', 'name': 'my_aten',
'scale_grad_by_freq': 0},
'aten': {'type': None, 'kind': 'op', 'op': 'ATen', 'mode': 0, 'operator': 'embedding_bag', 'name': 'my_aten'},
'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'},

# new EmbeddingBag layer
'emb_bag': {'type': None, 'kind': 'op', 'op': 'EmbeddingBag', 'mode': 0, 'scale_grad_by_freq': 0},
'emb_bag': {'type': None, 'kind': 'op', 'op': 'ATenEmbeddingBag', 'mode': 0},
}


Expand Down Expand Up @@ -59,4 +58,4 @@ def test(self):

(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)
self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='EmbeddingBag')[0]]['name'] == 'my_aten')
self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='ATenEmbeddingBag')[0]]['name'] == 'my_aten')
3 changes: 1 addition & 2 deletions model-optimizer/extensions/front/onnx/aten_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class ATenFrontExtractor(FrontExtractorOp):
def extract(cls, node):
mode = onnx_attr(node, 'mode', 'i', default=1)
operator = onnx_attr(node, 'operator', 's').decode()
scale_grad_by_freq = onnx_attr(node, 'scale_grad_by_freq', 'i', default=0)
mvafin marked this conversation as resolved.
Show resolved Hide resolved

ATen.update_node_stat(node, {'operator': operator, 'mode': mode, 'scale_grad_by_freq': scale_grad_by_freq})
ATen.update_node_stat(node, {'operator': operator, 'mode': mode})
return cls.enabled
4 changes: 2 additions & 2 deletions model-optimizer/extensions/front/tf/WhereDecomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class WhereDecomposition(FrontReplacementOp):
enabled = True

def run_after(self):
from extensions.front.tf.sparse_weighted_sum import ExperimentalSparseWeightedSumFrontReplacer
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer
return [ExperimentalSparseWeightedSumFrontReplacer, TransposeOrderNormalizer]
return [EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2, TransposeOrderNormalizer]

def replace_op(self, graph: Graph, node: Node):
node_name = node.soft_get('name', node.id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2020 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -15,13 +15,19 @@
"""

import logging as log
import numpy as np

from extensions.ops.sparse_weighted_sum import ExperimentalSparseWeightedSum
from extensions.ops.Cast import Cast
from extensions.ops.embedding_bag import EmbeddingSegmentsSum
from extensions.ops.split import Split
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.ops.squeeze import Squeeze


class ExperimentalSparseWeightedSumFrontReplacer(FrontReplacementSubgraph):
class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
"""
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table
for object feature values and sum up these embedding vectors for every object.
Expand All @@ -30,7 +36,6 @@ class ExperimentalSparseWeightedSumFrontReplacer(FrontReplacementSubgraph):
enabled = True

def pattern(self):
log.debug('Enabled ExperimentalSparseWeightedSum replacement')
return dict(
nodes=[
('identity_spw', dict(op='Identity')),
Expand Down Expand Up @@ -76,18 +81,45 @@ def replace_sub_graph(self, graph: Graph, match: dict):
gather0_2 = match['gather0_2']
greaterequal0 = match['greaterequal0']
sparse_fill_empty_rows = match['sparse_fill_empty_rows']
where0 = match['where0']
gather = match['gather']
select = match['select']

log.debug('Found ExperimentalSparseWeightedSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, sparse_fill_empty_rows.name))

sparse_weighted_sum = ExperimentalSparseWeightedSum(graph, {'name': sparse_fill_empty_rows.name + '/ExperimentalSparseWeightedSum_'}).create_node()
gather0_1.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(0))
greaterequal0.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(1))
identity_spw.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(2))
gather.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(3))
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(sparse_weighted_sum.in_port(4))
where0 = match['where0']
output_node_name = select.soft_get('name', select.id)

log.debug('Found EmbeddingSegmentsSum pattern after {} with name {}'.format(sparse_fill_empty_rows.op,
sparse_fill_empty_rows.name))

split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2})
squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node()
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node()
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)])

# connect parameters table
gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0))
# connect indices values
greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1))
# split and connect segment ids
gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
# split and connect number of segments
identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0))
# connect default value
# TODO: remove casting once we start to support I64 model input
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0))
# no input port for per_sample_weight

identity_spw.in_port(0).disconnect()
gather0_1.in_port(0).disconnect()
Expand All @@ -96,11 +128,11 @@ def replace_sub_graph(self, graph: Graph, match: dict):
sparse_fill_empty_rows.in_port(2).disconnect()
gather.in_port(0).disconnect()

select.out_port(0).get_connection().set_source(sparse_weighted_sum.out_port(0))
select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0))
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])


class ExperimentalSparseWeightedSumFrontReplacer2(FrontReplacementSubgraph):
class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
"""
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table
for object feature values and sum up these embedding vectors for every object.
Expand All @@ -109,7 +141,6 @@ class ExperimentalSparseWeightedSumFrontReplacer2(FrontReplacementSubgraph):
enabled = True

def pattern(self):
log.debug('Enabled ExperimentalSparseWeightedSum2 replacement')
return dict(
nodes=[
('identity_spw', dict(op='Identity')),
Expand Down Expand Up @@ -162,15 +193,46 @@ def replace_sub_graph(self, graph: Graph, match: dict):
gather = match['gather']
select = match['select']
where0 = match['where0']

log.debug('Found ExperimentalSparseWeightedSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, sparse_fill_empty_rows.name))

sparse_weighted_sum = ExperimentalSparseWeightedSum(graph, {'name': sparse_fill_empty_rows.name + '/ExperimentalSparseWeightedSum_'}).create_node()
gather0_1.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(0))
greaterequal0.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(1))
identity_spw.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(2))
gather.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(3))
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(sparse_weighted_sum.in_port(4))
output_node_name = select.soft_get('name', select.id)

log.debug('Found EmbeddingSegmentsSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op,
sparse_fill_empty_rows.name))

split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)},
{'num_splits': 2,
'name': output_node_name + '/SplitForIndices'})
squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)},
{'num_splits': 2,
'name': output_node_name + '/SplitForDenseShape'})
squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node()
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node()
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)])

# connect parameters table
gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0))
# connect indices values
greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1))
# split and connect segment ids
gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
# split and connect number of segments
identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0))
# connect default value
# TODO: remove casting once we start to support I64 model input
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0))
# no input port for per_sample_weight

identity_spw.in_port(0).disconnect()
gather0_1.in_port(0).disconnect()
Expand All @@ -179,5 +241,5 @@ def replace_sub_graph(self, graph: Graph, match: dict):
sparse_fill_empty_rows.in_port(2).disconnect()
gather.in_port(0).disconnect()

select.out_port(0).get_connection().set_source(sparse_weighted_sum.out_port(0))
select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0))
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
Loading