From f1811ad0602fc283d7750a4910d84f14824ee449 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Mon, 8 Jun 2020 18:06:40 +0300 Subject: [PATCH] Implement support for opset3 EmbeddingBag ops (#546) * [MO] Implement EmbeddingBag_3 * Transform dynamic sub-graph of Wide and Deep into EmbeddingSegmentsSum - Expressed SparseWeightedSum sub-graph through EmbeddingSegmentsSum - Removed experimental SparseWeightedSum layer - Implemented tests for the transformation Signed-off-by: Roman Kazantsev * Fix EmbeddingBag shape infer * Fix EmbeddingSegmentsSum transformation for Wide and Deep Signed-off-by: Roman Kazantsev * Fix EmbeddingSegmentSum replacer after ports swap Signed-off-by: Roman Kazantsev * Update package_BOM.txt Signed-off-by: Roman Kazantsev * Add unit tests for EmbeddingXXX shape infer * Fix ATen resolver * Remove deleted files from BOM * Add opset version to embedding_bag * Use base class for EmbeddingBag * Fix per_sample_weights case * Fix EmbeddingSegmentsSum transformation Signed-off-by: Roman Kazantsev * Fix EmbeddingBag checks * Fix ATen front transformation and merge conflicts * Fix BOM * Work around limitation for I64 input of W&D model Signed-off-by: Roman Kazantsev * Cleanup where operation to fix affect of WhereDecomposition transform Signed-off-by: Roman Kazantsev * Fix BOM * Correct EmbeddingSegmentSum transform for Wide and Deep Add casting segment ids to i32 and remove ConstToResult sub-graph. Signed-off-by: Roman Kazantsev * Update BOM with RemoveConstToResult transform Signed-off-by: Roman Kazantsev * Add more comments for RemoveConstToResult transformation Signed-off-by: Roman Kazantsev * Remove useless logging in EmbeddingSegmentsSum transformation Signed-off-by: Roman Kazantsev * Small fixes * Move EmbeddingBag resolving back to front phase * Improve error messages * Fix typo in unittests * Reimplement sparse_reshape middle transform Avoid deprecated API. Signed-off-by: Roman Kazantsev * Clean-up graph after sparse_reshape and ConstToResult transformation Signed-off-by: Roman Kazantsev * Fix clean-up for transformations Signed-off-by: Roman Kazantsev * Fix clean-up for transformation #2 Signed-off-by: Roman Kazantsev Co-authored-by: Roman Kazantsev --- model-optimizer/automation/package_BOM.txt | 4 +- .../back/SpecialNodesFinalization.py | 37 +++ .../extensions/front/ATenToEmbeddingBag.py | 62 ++++- .../front/ATenToEmbeddingBag_test.py | 164 ++++++++++--- .../extensions/front/onnx/aten_ext.py | 3 +- .../extensions/front/tf/WhereDecomposition.py | 4 +- ...ghted_sum.py => embedding_segments_sum.py} | 118 +++++++--- .../front/tf/embedding_segments_sum_test.py | 218 ++++++++++++++++++ .../front/tf/sparse_weighted_sum_test.py | 173 -------------- .../extensions/middle/EmbeddingBagResolver.py | 111 --------- .../middle/EmbeddingBagResolver_test.py | 127 ---------- .../extensions/middle/sparse_reshape.py | 44 ++-- model-optimizer/extensions/ops/aten.py | 2 +- .../extensions/ops/embedding_bag.py | 104 ++++++--- .../extensions/ops/embedding_bag_test.py | 89 +++++++ .../extensions/ops/sparse_reshape.py | 5 +- .../extensions/ops/sparse_weighted_sum.py | 57 ----- .../ops/sparse_weighted_sum_test.py | 67 ------ 18 files changed, 735 insertions(+), 654 deletions(-) rename model-optimizer/extensions/front/tf/{sparse_weighted_sum.py => embedding_segments_sum.py} (56%) create mode 100644 model-optimizer/extensions/front/tf/embedding_segments_sum_test.py delete mode 100644 model-optimizer/extensions/front/tf/sparse_weighted_sum_test.py delete mode 100644 model-optimizer/extensions/middle/EmbeddingBagResolver.py delete mode 100644 model-optimizer/extensions/middle/EmbeddingBagResolver_test.py create mode 100644 model-optimizer/extensions/ops/embedding_bag_test.py delete mode 100644 model-optimizer/extensions/ops/sparse_weighted_sum.py delete mode 100644 model-optimizer/extensions/ops/sparse_weighted_sum_test.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 969713807628de..94602d6cc51950 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -367,6 +367,7 @@ extensions/front/tf/cumsum_ext.py extensions/front/tf/deconv_ext.py extensions/front/tf/depth_to_space.py extensions/front/tf/elementwise_ext.py +extensions/front/tf/embedding_segments_sum.py extensions/front/tf/expand_dims_ext.py extensions/front/tf/extract_image_patches_ext.py extensions/front/tf/fake_const_ext.py @@ -441,7 +442,6 @@ extensions/front/tf/sparse_segment_mean_ext.py extensions/front/tf/sparse_segment_sqrtn_ext.py extensions/front/tf/sparse_segment_sum_ext.py extensions/front/tf/sparse_to_dense_ext.py -extensions/front/tf/sparse_weighted_sum.py extensions/front/tf/split_ext.py extensions/front/tf/ssd_support.json extensions/front/tf/ssd_support_api_v1.14.json @@ -522,7 +522,6 @@ extensions/middle/DilatedConvolution.py extensions/middle/EltwiseChecker.py extensions/middle/EltwiseInputNormalization.py extensions/middle/EltwiseInputReshape.py -extensions/middle/EmbeddingBagResolver.py extensions/middle/FakeSplitOutputs.py extensions/middle/FusedBatchNormNonConstant.py extensions/middle/FusedBatchNormTraining.py @@ -686,7 +685,6 @@ extensions/ops/sparse_segment_mean.py extensions/ops/sparse_segment_sqrtn.py extensions/ops/sparse_segment_sum.py extensions/ops/sparse_to_dense.py -extensions/ops/sparse_weighted_sum.py extensions/ops/spatial_transformer.py extensions/ops/splice.py extensions/ops/split.py diff --git a/model-optimizer/extensions/back/SpecialNodesFinalization.py b/model-optimizer/extensions/back/SpecialNodesFinalization.py index 09422d3e8a9a73..020ef5f28865d4 100644 --- a/model-optimizer/extensions/back/SpecialNodesFinalization.py +++ b/model-optimizer/extensions/back/SpecialNodesFinalization.py @@ -130,6 +130,43 @@ def find_and_replace_pattern(self, graph: Graph): graph.remove_node(node.id) +class RemoveConstToResult(BackReplacementPattern): + """ + Transformation looks for a sub-graph "Const->Result" and removes Result node. + Currently IE is unable to handle such graph so this transformation removes to work around this case. + For instance, this case appears for Wide and Deep model. + """ + enabled = True + + @staticmethod + def pattern(): + return dict( + nodes=[ + ('const_node', {'type': 'Const', 'kind': 'op'}), + ('const_data', {'kind': 'data'}), + ('result_node', {'type': 'Result', 'kind': 'op'}), + ], + edges=[ + ('const_node', 'const_data'), + ('const_data', 'result_node') + ] + ) + + @staticmethod + def replace_pattern(graph: Graph, match: dict): + const_node = match['const_node'] + const_data_node = match['const_data'] + result_node = match['result_node'] + nodes_to_remove = [result_node.id] + + # in case only const data consumer that is the result node, remove the whole sub-graph + if len(const_node.out_port(0).get_destinations()) == 1: + nodes_to_remove.append(const_node.id) + nodes_to_remove.append(const_data_node.id) + + graph.remove_node(nodes_to_remove) + + class NormalizeTI(BackReplacementPattern): """ This transformation is used while generating IR of lower than 10 version diff --git a/model-optimizer/extensions/front/ATenToEmbeddingBag.py b/model-optimizer/extensions/front/ATenToEmbeddingBag.py index 5c0cc434add802..8da70dc5c0d164 100644 --- a/model-optimizer/extensions/front/ATenToEmbeddingBag.py +++ b/model-optimizer/extensions/front/ATenToEmbeddingBag.py @@ -14,9 +14,18 @@ limitations under the License. """ -from extensions.ops.embedding_bag import EmbeddingBag +from extensions.ops.embedding_bag import EmbeddingBagOffsetsSum, EmbeddingBagPackedSum +from extensions.ops.rank import Rank +from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementPattern +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph, rename_node +from mo.ops.broadcast import Broadcast +from mo.ops.concat import Concat +from mo.ops.shape import Shape +from mo.ops.unsqueeze import Unsqueeze +from mo.utils.shape import node_to_get_shape_value_of_indices, get_canonical_axis_index_node, \ + get_shape_values_by_indices_node class AtenToEmbeddingBag(FrontReplacementPattern): @@ -27,11 +36,52 @@ class AtenToEmbeddingBag(FrontReplacementPattern): def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'): - node_name = node.name - rename_node(node, node_name + '/Old') - embedding_bag = EmbeddingBag(graph, {'name': node_name, 'mode': node.mode, - 'scale_grad_by_freq': node.scale_grad_by_freq}).create_node() + assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ + 'mode is supported for node {}.'.format(node.id) + node_name = node.soft_get('name', node.id) + rename_node(node, node_name + '/TBR') + is_packed = False + if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): + is_packed = True + embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node() + else: + embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node() + node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2)) + rename_node(embedding_bag, node_name) node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0)) node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1)) - node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2)) node.out_port(0).get_connection().set_source(embedding_bag.out_port(0)) + if len(node.in_ports()) == 4 and not node.in_port(3).disconnected(): + if is_packed: + node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2)) + else: + # connect per_sample_weights + node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4)) + + weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node() + + weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node() + last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1) + weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node) + + weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0]) + + zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])}, + {'name': node_name + '/Broadcast'}) + zero_col_node.in_port(1).connect(weights_last_dim.out_port(0)) + + default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)}, + {'name': node_name + '/Unsqueeze'}) + default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0)) + + # expand embedding table with zeros + weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2, + 'name': node_name + '/Concat'}).create_node() + embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0)) + weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0)) + weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0)) + weights_concat.in_port(1).connect(default_embeddings_node.out_port(0)) + weights_concat.out_port(0).connect(embedding_bag.in_port(0)) + + # point default index to expanded part of embedding table + weights_first_dim.out_port(0).connect(embedding_bag.in_port(3)) diff --git a/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py b/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py index 3873eee29f3f07..bade2da38dbc9d 100644 --- a/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py +++ b/model-optimizer/extensions/front/ATenToEmbeddingBag_test.py @@ -16,47 +16,147 @@ import unittest +import numpy as np + from extensions.front.ATenToEmbeddingBag import AtenToEmbeddingBag +from mo.front.common.partial_infer.utils import int64_array from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph - -nodes_attributes = { - 'weights_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'indices_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'offsets_inp': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'aten': {'type': None, 'kind': 'op', 'op': 'ATen', 'mode': 0, 'operator': 'embedding_bag', 'name': 'my_aten', - 'scale_grad_by_freq': 0}, - 'result': {'type': 'Result', 'value': None, 'kind': 'op', 'op': 'Result'}, - - # new EmbeddingBag layer - 'emb_bag': {'type': None, 'kind': 'op', 'op': 'EmbeddingBag', 'mode': 0, 'scale_grad_by_freq': 0}, -} +from mo.utils.unittest.graph import build_graph, result, \ + regular_op, const class AtenToEmbeddingBagTest(unittest.TestCase): def test(self): - graph = build_graph(nodes_attributes, - [('weights_inp', 'aten', {'in': 0, 'out': 0}), - ('indices_inp', 'aten', {'in': 1, 'out': 0}), - ('offsets_inp', 'aten', {'in': 2, 'out': 0}), - ('aten', 'result', {'in': 0, 'out': 0}), - ], - {}, nodes_with_edges_only=True) - - graph_ref = build_graph(nodes_attributes, - [('weights_inp', 'emb_bag', {'in': 0, 'out': 0}), - ('indices_inp', 'emb_bag', {'in': 1, 'out': 0}), - ('offsets_inp', 'emb_bag', {'in': 2, 'out': 0}), - ('emb_bag', 'result', {'in': 0, 'out': 0}), - ], - {}, nodes_with_edges_only=True) + nodes = { + **const('weights_inp', np.random.randn(100, 2)), + **regular_op('indices_inp', {'type': 'Parameter'}), + **regular_op('offsets_inp', {'type': 'Parameter'}), + **regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0, + 'name': 'my_aten'}), + + **regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', 'op': 'EmbeddingBagOffsetsSum'}), + **result('result'), + } + edges = [('weights_inp', 'aten'), + ('indices_inp', 'aten'), + ('offsets_inp', 'aten'), + ('aten', 'result'), + ] + graph = build_graph(nodes, edges) + + graph.graph['layout'] = 'NCHW' + graph.stage = 'front' + + edges_ref = [('weights_inp', 'emb_bag'), + ('indices_inp', 'emb_bag'), + ('offsets_inp', 'emb_bag'), + ('emb_bag', 'result'), + ] + + graph_ref = build_graph(nodes, edges_ref) + + AtenToEmbeddingBag().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + def test_packed(self): + nodes = { + **const('weights_inp', np.random.randn(100, 4)), + **regular_op('indices_inp', {'type': 'Parameter'}), + **regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0, + 'name': 'my_aten'}), + + **regular_op('emb_bag', {'type': 'EmbeddingBagPackedSum', 'kind': 'op', + 'op': 'EmbeddingBagPackedSum'}), + **result('result'), + } + edges = [('weights_inp', 'aten'), + ('indices_inp', 'aten'), + ('aten', 'result'), + ] + graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.stage = 'front' - replacer = AtenToEmbeddingBag() - replacer.find_and_replace_pattern(graph) + edges_ref = [('weights_inp', 'emb_bag'), + ('indices_inp', 'emb_bag'), + ('emb_bag', 'result'), + ] + + graph_ref = build_graph(nodes, edges_ref) + + AtenToEmbeddingBag().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'result') + self.assertTrue(flag, resp) + + def test_per_sample_weights(self): + nodes = { + **const('weights_inp', np.random.randn(100, 2)), + **regular_op('indices_inp', {'type': 'Parameter'}), + **regular_op('offsets_inp', {'type': 'Parameter'}), + **regular_op('per_sample_weights', {'type': 'Parameter'}), + **regular_op('aten', {'type': None, 'kind': 'op', 'op': 'ATen', 'operator': 'embedding_bag', 'mode': 0, + 'name': 'my_aten'}), + + **regular_op('emb_bag', {'type': 'EmbeddingBagOffsetsSum', 'kind': 'op', + 'op': 'EmbeddingBagOffsetsSum'}), + **regular_op('WeightsRank', {'type': None, 'kind': 'op', 'op': 'Rank'}), + **regular_op('WeightsRank/axis', {'type': 'Add', 'kind': 'op', 'op': 'Add'}), + **regular_op('gather1', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}), + **regular_op('gather2', {'type': 'Gather', 'kind': 'op', 'op': 'Gather'}), + **regular_op('WeightsShape', {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'}), + **regular_op('Broadcast', {'type': 'Broadcast', 'kind': 'op', 'op': 'Broadcast'}), + **regular_op('Unsqueeze', {'type': 'Unsqueeze', 'kind': 'op', 'op': 'Unsqueeze'}), + **const('WeightsShape/Axis', int64_array(0)), + **const('zero1', int64_array(0)), + **const('zero2', int64_array(0)), + **const('Unsqueeze/value', int64_array(0)), + **const('Broadcast/value', int64_array(0)), + **const('neg', int64_array(-1)), + **regular_op('Concat', {'type': 'Concat', 'kind': 'op', 'op': 'Concat'}), + **result('result'), + } + edges = [('weights_inp', 'aten'), + ('indices_inp', 'aten'), + ('offsets_inp', 'aten'), + ('per_sample_weights', 'aten'), + ('aten', 'result'), + ] + graph = build_graph(nodes, edges, nodes_with_edges_only=True) + + graph.graph['layout'] = 'NCHW' + graph.stage = 'front' + + edges_ref = [('weights_inp', 'Concat', {'in': 0, 'out': 0}), + ('weights_inp', 'WeightsShape', {'in': 0, 'out': 0}), + ('weights_inp', 'WeightsRank', {'in': 0, 'out': 0}), + ('WeightsRank', 'WeightsRank/axis'), + ('neg', 'WeightsRank/axis'), + ('WeightsShape', 'gather1', {'in': 0, 'out': 0}), + ('WeightsRank/axis', 'gather1'), + ('WeightsShape/Axis', 'gather1'), + ('WeightsShape', 'gather2', {'in': 0, 'out': 0}), + ('zero1', 'gather2'), + ('zero2', 'gather2'), + ('Broadcast/value', 'Broadcast'), + ('gather1', 'Broadcast'), + ('Broadcast', 'Unsqueeze'), + ('Unsqueeze/value', 'Unsqueeze'), + ('Unsqueeze', 'Concat'), + ('Concat', 'emb_bag'), + ('indices_inp', 'emb_bag'), + ('offsets_inp', 'emb_bag'), + ('gather2', 'emb_bag'), + ('per_sample_weights', 'emb_bag'), + ('emb_bag', 'result'), + ] + + graph_ref = build_graph(nodes, edges_ref, nodes_with_edges_only=True) + + AtenToEmbeddingBag().find_and_replace_pattern(graph) - (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) + (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) - self.assertTrue(graph.node[graph.get_nodes_with_attributes(op='EmbeddingBag')[0]]['name'] == 'my_aten') diff --git a/model-optimizer/extensions/front/onnx/aten_ext.py b/model-optimizer/extensions/front/onnx/aten_ext.py index f55b341a16e0a1..502830f406b68a 100644 --- a/model-optimizer/extensions/front/onnx/aten_ext.py +++ b/model-optimizer/extensions/front/onnx/aten_ext.py @@ -26,7 +26,6 @@ class ATenFrontExtractor(FrontExtractorOp): def extract(cls, node): mode = onnx_attr(node, 'mode', 'i', default=1) operator = onnx_attr(node, 'operator', 's').decode() - scale_grad_by_freq = onnx_attr(node, 'scale_grad_by_freq', 'i', default=0) - ATen.update_node_stat(node, {'operator': operator, 'mode': mode, 'scale_grad_by_freq': scale_grad_by_freq}) + ATen.update_node_stat(node, {'operator': operator, 'mode': mode}) return cls.enabled diff --git a/model-optimizer/extensions/front/tf/WhereDecomposition.py b/model-optimizer/extensions/front/tf/WhereDecomposition.py index 656c015b0cf386..a8e513be399f5c 100644 --- a/model-optimizer/extensions/front/tf/WhereDecomposition.py +++ b/model-optimizer/extensions/front/tf/WhereDecomposition.py @@ -33,9 +33,9 @@ class WhereDecomposition(FrontReplacementOp): enabled = True def run_after(self): - from extensions.front.tf.sparse_weighted_sum import ExperimentalSparseWeightedSumFrontReplacer + from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2 from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer - return [ExperimentalSparseWeightedSumFrontReplacer, TransposeOrderNormalizer] + return [EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2, TransposeOrderNormalizer] def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) diff --git a/model-optimizer/extensions/front/tf/sparse_weighted_sum.py b/model-optimizer/extensions/front/tf/embedding_segments_sum.py similarity index 56% rename from model-optimizer/extensions/front/tf/sparse_weighted_sum.py rename to model-optimizer/extensions/front/tf/embedding_segments_sum.py index 7ea0bb0d28c2cc..b033ae7cd2cbbe 100644 --- a/model-optimizer/extensions/front/tf/sparse_weighted_sum.py +++ b/model-optimizer/extensions/front/tf/embedding_segments_sum.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2020 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,13 +15,19 @@ """ import logging as log +import numpy as np -from extensions.ops.sparse_weighted_sum import ExperimentalSparseWeightedSum +from extensions.ops.Cast import Cast +from extensions.ops.embedding_bag import EmbeddingSegmentsSum +from extensions.ops.split import Split +from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementSubgraph -from mo.graph.graph import Graph +from mo.front.tf.graph_utils import create_op_with_const_inputs +from mo.graph.graph import Graph, rename_nodes +from mo.ops.squeeze import Squeeze -class ExperimentalSparseWeightedSumFrontReplacer(FrontReplacementSubgraph): +class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph): """ The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table for object feature values and sum up these embedding vectors for every object. @@ -30,7 +36,6 @@ class ExperimentalSparseWeightedSumFrontReplacer(FrontReplacementSubgraph): enabled = True def pattern(self): - log.debug('Enabled ExperimentalSparseWeightedSum replacement') return dict( nodes=[ ('identity_spw', dict(op='Identity')), @@ -76,18 +81,45 @@ def replace_sub_graph(self, graph: Graph, match: dict): gather0_2 = match['gather0_2'] greaterequal0 = match['greaterequal0'] sparse_fill_empty_rows = match['sparse_fill_empty_rows'] - where0 = match['where0'] gather = match['gather'] select = match['select'] - - log.debug('Found ExperimentalSparseWeightedSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, sparse_fill_empty_rows.name)) - - sparse_weighted_sum = ExperimentalSparseWeightedSum(graph, {'name': sparse_fill_empty_rows.name + '/ExperimentalSparseWeightedSum_'}).create_node() - gather0_1.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(0)) - greaterequal0.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(1)) - identity_spw.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(2)) - gather.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(3)) - sparse_fill_empty_rows.in_port(3).get_connection().set_destination(sparse_weighted_sum.in_port(4)) + where0 = match['where0'] + output_node_name = select.soft_get('name', select.id) + + log.debug('Found EmbeddingSegmentsSum pattern after {} with name {}'.format(sparse_fill_empty_rows.op, + sparse_fill_empty_rows.name)) + + split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2}) + squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])}) + split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) + squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])}) + cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node() + cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node() + cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node() + embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node() + rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)]) + + # connect parameters table + gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0)) + # connect indices values + greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1)) + # split and connect segment ids + gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0)) + squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0)) + # TODO: remove casting once we start to support I64 model input + cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0)) + embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0)) + # split and connect number of segments + identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0)) + squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0)) + # TODO: remove casting once we start to support I64 model input + cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0)) + embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0)) + # connect default value + # TODO: remove casting once we start to support I64 model input + sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0)) + embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0)) + # no input port for per_sample_weight identity_spw.in_port(0).disconnect() gather0_1.in_port(0).disconnect() @@ -96,11 +128,11 @@ def replace_sub_graph(self, graph: Graph, match: dict): sparse_fill_empty_rows.in_port(2).disconnect() gather.in_port(0).disconnect() - select.out_port(0).get_connection().set_source(sparse_weighted_sum.out_port(0)) + select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0)) graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id]) -class ExperimentalSparseWeightedSumFrontReplacer2(FrontReplacementSubgraph): +class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph): """ The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table for object feature values and sum up these embedding vectors for every object. @@ -109,7 +141,6 @@ class ExperimentalSparseWeightedSumFrontReplacer2(FrontReplacementSubgraph): enabled = True def pattern(self): - log.debug('Enabled ExperimentalSparseWeightedSum2 replacement') return dict( nodes=[ ('identity_spw', dict(op='Identity')), @@ -162,15 +193,46 @@ def replace_sub_graph(self, graph: Graph, match: dict): gather = match['gather'] select = match['select'] where0 = match['where0'] - - log.debug('Found ExperimentalSparseWeightedSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, sparse_fill_empty_rows.name)) - - sparse_weighted_sum = ExperimentalSparseWeightedSum(graph, {'name': sparse_fill_empty_rows.name + '/ExperimentalSparseWeightedSum_'}).create_node() - gather0_1.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(0)) - greaterequal0.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(1)) - identity_spw.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(2)) - gather.in_port(0).get_connection().set_destination(sparse_weighted_sum.in_port(3)) - sparse_fill_empty_rows.in_port(3).get_connection().set_destination(sparse_weighted_sum.in_port(4)) + output_node_name = select.soft_get('name', select.id) + + log.debug('Found EmbeddingSegmentsSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, + sparse_fill_empty_rows.name)) + + split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, + {'num_splits': 2, + 'name': output_node_name + '/SplitForIndices'}) + squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])}) + split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, + {'num_splits': 2, + 'name': output_node_name + '/SplitForDenseShape'}) + squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])}) + cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node() + cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node() + cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node() + embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node() + rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)]) + + # connect parameters table + gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0)) + # connect indices values + greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1)) + # split and connect segment ids + gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0)) + squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0)) + # TODO: remove casting once we start to support I64 model input + cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0)) + embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0)) + # split and connect number of segments + identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0)) + squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0)) + # TODO: remove casting once we start to support I64 model input + cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0)) + embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0)) + # connect default value + # TODO: remove casting once we start to support I64 model input + sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0)) + embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0)) + # no input port for per_sample_weight identity_spw.in_port(0).disconnect() gather0_1.in_port(0).disconnect() @@ -179,5 +241,5 @@ def replace_sub_graph(self, graph: Graph, match: dict): sparse_fill_empty_rows.in_port(2).disconnect() gather.in_port(0).disconnect() - select.out_port(0).get_connection().set_source(sparse_weighted_sum.out_port(0)) + select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0)) graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id]) diff --git a/model-optimizer/extensions/front/tf/embedding_segments_sum_test.py b/model-optimizer/extensions/front/tf/embedding_segments_sum_test.py new file mode 100644 index 00000000000000..daed2c76b9e8a0 --- /dev/null +++ b/model-optimizer/extensions/front/tf/embedding_segments_sum_test.py @@ -0,0 +1,218 @@ +""" + Copyright (C) 2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2 +from mo.front.common.partial_infer.utils import int64_array +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph, const + + +class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase): + def test1(self): + nodes_attributes = { + 'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + + 'identity_spw': {'kind': 'op', 'op': 'Identity'}, + 'gather0_1': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, + 'gather0_2': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, + 'reshape0': {'kind': 'op', 'op': 'Reshape'}, + 'where0': {'kind': 'op', 'op': 'Where'}, + 'greaterequal0': {'kind': 'op', 'op': 'GreaterEqual'}, + 'sparse_fill_empty_rows': {'kind': 'op', 'op': 'SparseFillEmptyRows'}, + 'unique': {'kind': 'op', 'op': 'Unique'}, + 'strided_slice': {'kind': 'op', 'op': 'StridedSlice'}, + 'cast': {'kind': 'op', 'op': 'Cast'}, + 'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, + 'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'}, + 'reshape': {'kind': 'op', 'op': 'Reshape'}, + 'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'}, + 'select': {'kind': 'op', 'op': 'Select'}, + + 'split_for_indices': {'kind': 'op', 'op': 'Split'}, + 'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'}, + 'split_for_dense_shape': {'kind': 'op', 'op': 'Split'}, + 'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'}, + 'cast_segment_ids': {'kind': 'op', 'op': 'Cast'}, + 'cast_default_value': {'kind': 'op', 'op': 'Cast'}, + 'cast_number_segments': {'kind': 'op', 'op': 'Cast'}, + 'embedding_segments_sum': {'kind': 'op', 'op': 'EmbeddingSegmentsSum'}, + + **const('split_for_indices_axis', int64_array(1)), + **const('split_for_dense_shape_axis', int64_array(0)), + **const('squeeze_axis', int64_array([0])), + **const('squeeze_for_indices_axis', int64_array([1])), + + 'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'}, + } + + graph = build_graph(nodes_attributes, + [('input_indices', 'gather0_1', {'out': 0, 'in': 0}), + ('input_dense_shape', 'identity_spw', {'out': 0, 'in': 0}), + ('input_values', 'greaterequal0', {'out': 0, 'in': 0}), + ('input_values', 'gather0_2', {'out': 0, 'in': 0}), + ('input_params_table', 'gather', {'out': 0, 'in': 0}), + ('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}), + + ('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}), + ('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}), + ('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}), + ('reshape0', 'gather0_1', {'out': 0, 'in': 1}), + ('reshape0', 'gather0_2', {'out': 0, 'in': 1}), + ('where0', 'reshape0', {'out': 0, 'in': 0}), + ('greaterequal0', 'where0', {'out': 0, 'in': 0}), + ('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}), + ('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}), + ('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}), + ('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}), + ('unique', 'gather', {'out': 0, 'in': 1}), + ('strided_slice', 'cast', {'out': 0, 'in': 0}), + ('gather', 'sparse_segment_sum', {'out': 0, 'in': 0}), + ('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}), + ('sparse_segment_sum', 'select', {'out': 0, 'in': 2}), + ('reshape', 'tile', {'out': 0, 'in': 0}), + ('tile', 'select', {'out': 0, 'in': 0}), + ('select', 'last', {'out': 0, 'in': 0}), + ], nodes_with_edges_only=True) + graph.stage = 'front' + EmbeddingSegmentsSumFrontReplacer().find_and_replace_pattern(graph) + + graph_ref = build_graph(nodes_attributes, + [('input_indices', 'split_for_indices', {'in': 0}), + ('split_for_indices_axis', 'split_for_indices', {'in': 1}), + ('split_for_indices', 'squeeze_for_indices', {'in': 0}), + ('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}), + ('squeeze_for_indices', 'cast_segment_ids', {'in': 0}), + ('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}), + ('input_values', 'embedding_segments_sum', {'in': 1}), + ('input_dense_shape', 'split_for_dense_shape', {'in': 0}), + ('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}), + ('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}), + ('squeeze_axis', 'squeeze_to_scalar', {'in': 1}), + ('squeeze_to_scalar', 'cast_number_segments', {'in': 0}), + ('cast_number_segments', 'embedding_segments_sum', {'in': 3, 'out': 0}), + ('input_params_table', 'embedding_segments_sum', {'in': 0}), + ('input_default_value', 'cast_default_value', {'in': 0}), + ('cast_default_value', 'embedding_segments_sum', {'in': 4}), + ('embedding_segments_sum', 'last', {'in': 0}),], + nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test2(self): + nodes_attributes = { + 'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + 'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, + + 'identity_spw': {'kind': 'op', 'op': 'Identity'}, + 'gather0_1': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, + 'gather0_2': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, + 'reshape0': {'kind': 'op', 'op': 'Reshape'}, + 'where0': {'kind': 'op', 'op': 'Where'}, + 'greaterequal0': {'kind': 'op', 'op': 'GreaterEqual'}, + 'sparse_fill_empty_rows': {'kind': 'op', 'op': 'SparseFillEmptyRows'}, + 'unique': {'kind': 'op', 'op': 'Unique'}, + 'strided_slice': {'kind': 'op', 'op': 'StridedSlice'}, + 'cast': {'kind': 'op', 'op': 'Cast'}, + 'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, + 'identity': {'kind': 'op', 'op': 'Identity'}, + 'identity_1': {'kind': 'op', 'op': 'Identity'}, + 'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'}, + 'reshape': {'kind': 'op', 'op': 'Reshape'}, + 'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'}, + 'select': {'kind': 'op', 'op': 'Select'}, + + 'split_for_indices': {'kind': 'op', 'op': 'Split'}, + 'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'}, + 'split_for_dense_shape': {'kind': 'op', 'op': 'Split'}, + 'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'}, + 'cast_segment_ids': {'kind': 'op', 'op': 'Cast'}, + 'cast_default_value': {'kind': 'op', 'op': 'Cast'}, + 'cast_number_segments': {'kind': 'op', 'op': 'Cast'}, + 'embedding_segments_sum': {'kind': 'op', 'op': 'EmbeddingSegmentsSum'}, + + **const('split_for_indices_axis', int64_array(1)), + **const('split_for_dense_shape_axis', int64_array(0)), + **const('squeeze_axis', int64_array([0])), + **const('squeeze_for_indices_axis', int64_array([1])), + + 'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'}, + } + + graph = build_graph(nodes_attributes, + [('input_indices', 'gather0_1', {'out': 0, 'in': 0}), + ('input_dense_shape', 'identity_spw', {'out': 0, 'in': 0}), + ('input_values', 'greaterequal0', {'out': 0, 'in': 0}), + ('input_values', 'gather0_2', {'out': 0, 'in': 0}), + ('input_params_table', 'gather', {'out': 0, 'in': 0}), + ('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}), + + ('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}), + ('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}), + ('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}), + ('reshape0', 'gather0_1', {'out': 0, 'in': 1}), + ('reshape0', 'gather0_2', {'out': 0, 'in': 1}), + ('where0', 'reshape0', {'out': 0, 'in': 0}), + ('greaterequal0', 'where0', {'out': 0, 'in': 0}), + ('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}), + ('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}), + ('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}), + ('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}), + ('unique', 'gather', {'out': 0, 'in': 1}), + ('strided_slice', 'cast', {'out': 0, 'in': 0}), + ('gather', 'identity', {'out': 0, 'in': 0}), + ('identity', 'identity_1', {'out': 0, 'in': 0}), + ('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}), + ('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}), + ('sparse_segment_sum', 'select', {'out': 0, 'in': 2}), + ('reshape', 'tile', {'out': 0, 'in': 0}), + ('tile', 'select', {'out': 0, 'in': 0}), + ('select', 'last', {'out': 0, 'in': 0})], + nodes_with_edges_only=True) + graph.stage = 'front' + EmbeddingSegmentsSumFrontReplacer2().find_and_replace_pattern(graph) + + graph_ref = build_graph(nodes_attributes, + [('input_indices', 'split_for_indices', {'in': 0}), + ('split_for_indices_axis', 'split_for_indices', {'in': 1}), + ('split_for_indices', 'squeeze_for_indices', {'in': 0}), + ('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}), + ('squeeze_for_indices', 'cast_segment_ids', {'in': 0}), + ('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}), + ('input_values', 'embedding_segments_sum', {'in': 1}), + ('input_dense_shape', 'split_for_dense_shape', {'in': 0}), + ('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}), + ('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}), + ('squeeze_axis', 'squeeze_to_scalar', {'in': 1}), + ('squeeze_to_scalar', 'cast_number_segments', {'in': 0}), + ('cast_number_segments', 'embedding_segments_sum', {'in': 3, 'out': 0}), + ('input_params_table', 'embedding_segments_sum', {'in': 0}), + ('input_default_value', 'cast_default_value', {'in': 0}), + ('cast_default_value', 'embedding_segments_sum', {'in': 4}), + ('embedding_segments_sum', 'last', {'in': 0}),], + nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/front/tf/sparse_weighted_sum_test.py b/model-optimizer/extensions/front/tf/sparse_weighted_sum_test.py deleted file mode 100644 index becbac324f93ca..00000000000000 --- a/model-optimizer/extensions/front/tf/sparse_weighted_sum_test.py +++ /dev/null @@ -1,173 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import unittest - -from extensions.front.tf.sparse_weighted_sum import ExperimentalSparseWeightedSumFrontReplacer, \ - ExperimentalSparseWeightedSumFrontReplacer2 -from mo.front.common.partial_infer.utils import int64_array -from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph - - -class ExperimentalSparseWeightedSumFrontReplacersTest(unittest.TestCase): - def test1(self): - nodes_attributes = { - 'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - - 'identity_spw': {'kind': 'op', 'op': 'Identity'}, - 'gather0_1': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, - 'gather0_2': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, - 'reshape0': {'kind': 'op', 'op': 'Reshape'}, - 'where0': {'kind': 'op', 'op': 'Where'}, - 'greaterequal0': {'kind': 'op', 'op': 'GreaterEqual'}, - 'sparse_fill_empty_rows': {'kind': 'op', 'op': 'SparseFillEmptyRows'}, - 'unique': {'kind': 'op', 'op': 'Unique'}, - 'strided_slice': {'kind': 'op', 'op': 'StridedSlice'}, - 'cast': {'kind': 'op', 'op': 'Cast'}, - 'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, - 'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'}, - 'reshape': {'kind': 'op', 'op': 'Reshape'}, - 'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'}, - 'select': {'kind': 'op', 'op': 'Select'}, - - 'sparse_weighted_sum': {'kind': 'op', 'op': 'ExperimentalSparseWeightedSum'}, - - 'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'}, - } - - graph = build_graph(nodes_attributes, - [('input_indices', 'gather0_1', {'out': 0, 'in': 0}), - ('input_dense_shape', 'identity_spw', {'out': 0, 'in': 0}), - ('input_values', 'greaterequal0', {'out': 0, 'in': 0}), - ('input_values', 'gather0_2', {'out': 0, 'in': 0}), - ('input_params_table', 'gather', {'out': 0, 'in': 0}), - ('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}), - - ('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}), - ('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}), - ('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}), - ('reshape0', 'gather0_1', {'out': 0, 'in': 1}), - ('reshape0', 'gather0_2', {'out': 0, 'in': 1}), - ('where0', 'reshape0', {'out': 0, 'in': 0}), - ('greaterequal0', 'where0', {'out': 0, 'in': 0}), - ('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}), - ('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}), - ('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}), - ('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}), - ('unique', 'gather', {'out': 0, 'in': 1}), - ('strided_slice', 'cast', {'out': 0, 'in': 0}), - ('gather', 'sparse_segment_sum', {'out': 0, 'in': 0}), - ('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}), - ('sparse_segment_sum', 'select', {'out': 0, 'in': 2}), - ('reshape', 'tile', {'out': 0, 'in': 0}), - ('tile', 'select', {'out': 0, 'in': 0}), - ('select', 'last', {'out': 0, 'in': 0}), - ], nodes_with_edges_only=True) - graph.stage = 'front' - ExperimentalSparseWeightedSumFrontReplacer().find_and_replace_pattern(graph) - - graph_ref = build_graph(nodes_attributes, - [('input_indices', 'sparse_weighted_sum', {'in': 0}), - ('input_values', 'sparse_weighted_sum', {'in': 1}), - ('input_dense_shape', 'sparse_weighted_sum', {'in': 2}), - ('input_params_table', 'sparse_weighted_sum', {'in': 3}), - ('input_default_value', 'sparse_weighted_sum', {'in': 4}), - ('sparse_weighted_sum', 'last', {'in': 0}),], - nodes_with_edges_only=True) - - (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) - self.assertTrue(flag, resp) - - def test2(self): - nodes_attributes = { - 'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - - 'identity_spw': {'kind': 'op', 'op': 'Identity'}, - 'gather0_1': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, - 'gather0_2': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, - 'reshape0': {'kind': 'op', 'op': 'Reshape'}, - 'where0': {'kind': 'op', 'op': 'Where'}, - 'greaterequal0': {'kind': 'op', 'op': 'GreaterEqual'}, - 'sparse_fill_empty_rows': {'kind': 'op', 'op': 'SparseFillEmptyRows'}, - 'unique': {'kind': 'op', 'op': 'Unique'}, - 'strided_slice': {'kind': 'op', 'op': 'StridedSlice'}, - 'cast': {'kind': 'op', 'op': 'Cast'}, - 'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'}, - 'identity': {'kind': 'op', 'op': 'Identity'}, - 'identity_1': {'kind': 'op', 'op': 'Identity'}, - 'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'}, - 'reshape': {'kind': 'op', 'op': 'Reshape'}, - 'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'}, - 'select': {'kind': 'op', 'op': 'Select'}, - - 'sparse_weighted_sum': {'kind': 'op', 'op': 'ExperimentalSparseWeightedSum'}, - - 'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'}, - } - - graph = build_graph(nodes_attributes, - [('input_indices', 'gather0_1', {'out': 0, 'in': 0}), - ('input_dense_shape', 'identity_spw', {'out': 0, 'in': 0}), - ('input_values', 'greaterequal0', {'out': 0, 'in': 0}), - ('input_values', 'gather0_2', {'out': 0, 'in': 0}), - ('input_params_table', 'gather', {'out': 0, 'in': 0}), - ('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}), - - ('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}), - ('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}), - ('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}), - ('reshape0', 'gather0_1', {'out': 0, 'in': 1}), - ('reshape0', 'gather0_2', {'out': 0, 'in': 1}), - ('where0', 'reshape0', {'out': 0, 'in': 0}), - ('greaterequal0', 'where0', {'out': 0, 'in': 0}), - ('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}), - ('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}), - ('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}), - ('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}), - ('unique', 'gather', {'out': 0, 'in': 1}), - ('strided_slice', 'cast', {'out': 0, 'in': 0}), - ('gather', 'identity', {'out': 0, 'in': 0}), - ('identity', 'identity_1', {'out': 0, 'in': 0}), - ('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}), - ('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}), - ('sparse_segment_sum', 'select', {'out': 0, 'in': 2}), - ('reshape', 'tile', {'out': 0, 'in': 0}), - ('tile', 'select', {'out': 0, 'in': 0}), - ('select', 'last', {'out': 0, 'in': 0})], - nodes_with_edges_only=True) - graph.stage = 'front' - ExperimentalSparseWeightedSumFrontReplacer2().find_and_replace_pattern(graph) - - graph_ref = build_graph(nodes_attributes, - [('input_indices', 'sparse_weighted_sum', {'in': 0}), - ('input_values', 'sparse_weighted_sum', {'in': 1}), - ('input_dense_shape', 'sparse_weighted_sum', {'in': 2}), - ('input_params_table', 'sparse_weighted_sum', {'in': 3}), - ('input_default_value', 'sparse_weighted_sum', {'in': 4}), - ('sparse_weighted_sum', 'last', {'in': 0}),], - nodes_with_edges_only=True) - - (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) - self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/middle/EmbeddingBagResolver.py b/model-optimizer/extensions/middle/EmbeddingBagResolver.py deleted file mode 100644 index 3582677ff2a967..00000000000000 --- a/model-optimizer/extensions/middle/EmbeddingBagResolver.py +++ /dev/null @@ -1,111 +0,0 @@ -""" - Copyright (c) 2018-2019 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -import logging as log - -import numpy as np - -from extensions.ops.gather import Gather -from extensions.ops.parameter import Parameter -from extensions.ops.sparse_weighted_sum import ExperimentalSparseWeightedSum -from mo.front.common.partial_infer.utils import int64_array -from mo.front.tf.graph_utils import create_op_with_const_inputs -from mo.graph.graph import Graph -from mo.middle.replacement import MiddleReplacementPattern -from mo.ops.concat import Concat - - -class EmbeddingBagResolver(MiddleReplacementPattern): - ''' - Replace EmbeddingBag with Gather or SparseWeightedSum. - If shape of offsets is equal to shape of indices it means that offsets are obsolete because they have to define - "bags" of shape 1 and we can remove offsets and replace EmbeddingBag with just Gather. In another case offsets must - be used and EmbeddingBag can be replaced by SparseWeightedSum, but offsets must be pre-processed. - ''' - enabled = True - force_clean_up = True - - def find_and_replace_pattern(self, graph: Graph): - weighted_sum_nodes = list() - index_shape = None - merge_offsets = True - - for node in graph.get_op_nodes(op='EmbeddingBag'): - weights_shape = node.in_port(0).data.get_shape() - indices_shape = node.in_port(1).data.get_shape() - offsets_shape = node.in_port(2).data.get_shape() - - assert node.scale_grad_by_freq == 0 - - if indices_shape[0] == offsets_shape[0]: - # The simple case when we can replace EmbeddingBag with just Gather and not use offsets node at all - gather = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, - {'name': node.name + '/Emb_Bag/Gather_'}) - - node.in_port(0).get_connection().set_destination(gather.in_port(0)) - node.in_port(1).get_connection().set_destination(gather.in_port(1)) - node.out_port(0).get_connection().set_source(gather.out_port(0)) - else: - assert node.mode == 0 - - dense_shape = int64_array([offsets_shape[0], indices_shape[0]]) - default_index = int64_array(weights_shape[0]) - sweightedsum = create_op_with_const_inputs(graph, ExperimentalSparseWeightedSum, - {2: dense_shape, 4: default_index}, - {'name': node.name + '/WeightedSum'}) - if index_shape is None: - index_shape = indices_shape[-1] - else: - merge_offsets = merge_offsets and index_shape == indices_shape[-1] - weighted_sum_nodes.append((sweightedsum, indices_shape[-1])) - - default_embeddings = np.zeros([1, weights_shape[-1]]) - weights_concat = create_op_with_const_inputs(graph, Concat, {1: default_embeddings}, - {'axis': 0, 'in_ports_count': 2}) - node.in_port(0).get_connection().set_destination(weights_concat.in_port(0)) - node.in_port(1).get_connection().set_destination(sweightedsum.in_port(1)) - weights_concat.out_port(0).connect(sweightedsum.in_port(3)) - - node.out_port(0).get_connection().set_source(sweightedsum.out_port(0)) - self.create_offsets_for_weighted_sum(graph, weighted_sum_nodes, merge_offsets, index_shape) - - def create_offsets_for_weighted_sum(self, graph, weighted_sum_nodes, merge_offsets, index_shape): - new_offsets = None - for i, (node, ind_shape) in enumerate(weighted_sum_nodes): - if merge_offsets and len(weighted_sum_nodes) > 1: - # generate single offsets input if possible - if new_offsets is None: - shape = int64_array([len(weighted_sum_nodes), index_shape, 2]) - new_offsets = Parameter(graph, {'name': 'Emb_Bag/offsets', - 'shape': shape, - 'data_type': np.int32}).create_node() - log.error( - 'Pre-process of offsets is needed for generated input "Emb_Bag/offsets" of shape: {}. ' - 'Refer to the documentation on how to convert the ONNX* DLRM model'.format(shape), - extra={'is_warning': True}) - gather = create_op_with_const_inputs(graph, Gather, {1: int64_array(i), 2: int64_array(0)}, - {'name': node.name + '/Gather_'}) - new_offsets.out_port(0).connect(gather.in_port(0)) - gather.out_port(0).connect(node.in_port(0)) - else: - shape = int64_array([ind_shape, 2]) - new_offsets = Parameter(graph, {'name': 'Emb_Bag/offsets{}'.format(i), - 'shape': shape, - 'data_type': np.int32}).create_node() - new_offsets.out_port(0).connect(node.in_port(0)) - log.error( - 'Pre-process of offsets is needed for generated input "Emb_Bag/offsets{}" of shape: {}. ' - 'Refer to the documentation on how to convert the ONNX* DLRM model'.format(i, shape), - extra={'is_warning': True}) diff --git a/model-optimizer/extensions/middle/EmbeddingBagResolver_test.py b/model-optimizer/extensions/middle/EmbeddingBagResolver_test.py deleted file mode 100644 index c3627d076103d6..00000000000000 --- a/model-optimizer/extensions/middle/EmbeddingBagResolver_test.py +++ /dev/null @@ -1,127 +0,0 @@ -""" - Copyright (C) 2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" -import unittest - -import numpy as np - -from extensions.middle.EmbeddingBagResolver import EmbeddingBagResolver -from mo.utils.ir_engine.compare_graphs import compare_graphs -from mo.utils.unittest.graph import build_graph - -nodes_attributes = {'node_1': {'value': None, 'kind': 'op', 'op': 'EmbeddingBag', 'scale_grad_by_freq': 0, 'mode': 0}, - 'node_1_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'node_2': {'value': None, 'kind': 'op', 'op': 'EmbeddingBag', 'scale_grad_by_freq': 0, 'mode': 0}, - 'node_2_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'gather_1': {'type': 'Gather', 'value': None, 'kind': 'op'}, - 'gather_1_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'ws_1': {'type': 'ExperimentalSparseWeightedSum', 'value': None, 'kind': 'op'}, - 'ws_1_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'const_axis': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': None}, - 'axis_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'const_default': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': None}, - 'default_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'const_dense_shape': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': None}, - 'dense_shape_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': None}, - 'const_data': {'value': None, 'kind': 'data', 'data_type': None}, - 'concat': {'type': 'Concat', 'value': None, 'kind': 'op'}, - 'concat_data': {'value': None, 'kind': 'data', 'data_type': None}, - # Placeholders - 'indices': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'indices_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'offsets': {'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}, - 'offsets_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - 'const_weights': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None, 'shape': None}, - 'weights_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None}, - - 'op_output': {'kind': 'op', 'op': 'Result', 'infer': lambda x: None} - } - - -class EmbeddingBagResolverTests(unittest.TestCase): - def test_embedding_bag_to_gather(self): - graph = build_graph(nodes_attributes, - [('const_weights', 'weights_data'), - ('weights_data', 'node_1'), - ('indices', 'indices_data'), - ('indices_data', 'node_1'), - ('offsets', 'offsets_data'), - ('offsets_data', 'node_1'), - ('node_1', 'node_1_data'), - ('node_1_data', 'op_output') - ], - {'indices_data': {'shape': np.array([128])}, - 'offsets_data': {'shape': np.array([128])}}, - nodes_with_edges_only=True) - - graph_ref = build_graph(nodes_attributes, - [('const_weights', 'weights_data'), - ('weights_data', 'gather_1'), - ('indices', 'indices_data'), - ('indices_data', 'gather_1'), - ('const_axis', 'axis_data'), - ('axis_data', 'gather_1'), - ('gather_1', 'gather_1_data'), - ('gather_1_data', 'op_output') - ], - {'indices_data': {'shape': np.array([128])}}, - nodes_with_edges_only=True) - graph.graph['layout'] = 'NCHW' - EmbeddingBagResolver().find_and_replace_pattern(graph) - (flag, resp) = compare_graphs(graph, graph_ref, 'op_output') - self.assertTrue(flag, resp) - - def test_embedding_bag_to_single_weighted_sum(self): - graph = build_graph(nodes_attributes, - [('const_weights', 'weights_data'), - ('weights_data', 'node_1'), - ('indices', 'indices_data'), - ('indices_data', 'node_1'), - ('offsets', 'offsets_data'), - ('offsets_data', 'node_1'), - ('node_1', 'node_1_data'), - ('node_1_data', 'op_output') - ], - {'indices_data': {'shape': np.array([128])}, - 'offsets_data': {'shape': np.array([64])}, - 'weights_data': {'shape': np.array([1024, 16])}}, - nodes_with_edges_only=True) - - graph_ref = build_graph(nodes_attributes, - [('const_weights', 'weights_data'), - ('const', 'const_data'), - ('weights_data', 'concat'), - ('const_data', 'concat'), - ('concat', 'concat_data'), - ('offsets', 'offsets_data'), - ('indices', 'indices_data'), - ('const_default', 'default_data'), - ('const_dense_shape', 'dense_shape_data'), - ('offsets_data', 'ws_1'), - ('indices_data', 'ws_1'), - ('dense_shape_data', 'ws_1'), - ('concat_data', 'ws_1'), - ('default_data', 'ws_1'), - ('ws_1', 'ws_1_data'), - ('ws_1_data', 'op_output') - ], - {'indices_data': {'shape': np.array([128])}, - 'offsets': {'shape': np.array([128, 2])}}, - nodes_with_edges_only=True) - graph.graph['layout'] = 'NCHW' - EmbeddingBagResolver().find_and_replace_pattern(graph) - (flag, resp) = compare_graphs(graph, graph_ref, 'op_output') - self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/middle/sparse_reshape.py b/model-optimizer/extensions/middle/sparse_reshape.py index eeaa381555398a..5db8b162dcd261 100644 --- a/model-optimizer/extensions/middle/sparse_reshape.py +++ b/model-optimizer/extensions/middle/sparse_reshape.py @@ -46,26 +46,34 @@ def replace_pattern(self, graph: Graph, match: dict): input_shape_value = sparse_reshape.in_port(1).data.get_value() output_shape_value = sparse_reshape.out_port(1).data.get_value() + # establish output shape if value of new shape is given as input + new_shape_value = sparse_reshape.in_port(2).data.get_value() + if output_shape_value is None and new_shape_value is not None: + output_shape_value = new_shape_value + if np.count_nonzero(output_shape_value == -1) == 1: + elem = np.prod(input_shape_value) // np.prod(new_shape_value[new_shape_value != -1]) + output_shape_value[output_shape_value == -1] = elem + if input_shape_value is None or output_shape_value is None: raise Error("Input shape and output shape values must be defined for node {}".format(sparse_reshape.id)) if not np.array_equal(input_shape_value, output_shape_value): raise Error("Input shape and output shape values must be equal for node {}".format(sparse_reshape.id)) - input_data_node1 = sparse_reshape.in_node(0) - input_data_node2 = sparse_reshape.in_node(1) - output_data_node1 = sparse_reshape.out_node(0) - output_data_node2 = sparse_reshape.out_node(1) - graph.remove_edge(input_data_node1.id, sparse_reshape.id) - graph.remove_edge(sparse_reshape.id, output_data_node1.id) - graph.remove_edge(input_data_node2.id, sparse_reshape.id) - graph.remove_edge(sparse_reshape.id, output_data_node2.id) - merge_data_nodes(graph, output_data_node1, input_data_node1) - merge_data_nodes(graph, output_data_node2, input_data_node2) - graph.remove_nodes_from([sparse_reshape.id, input_data_node1.id, input_data_node2.id]) - - # TODO: investigate why this second way does not work - #sparse_reshape.out_port(0).get_connection().set_source(sparse_reshape.in_port(0).get_source()) - #sparse_reshape.out_port(1).get_connection().set_source(sparse_reshape.in_port(1).get_source()) - #sparse_reshape.in_port(0).get_connection().set_destination(sparse_reshape.in_port(0) - #sparse_reshape.in_port(2).disconnect() - #graph.remove_nodes_from([sparse_reshape.id]) + nodes_to_remove = [sparse_reshape.id] + if sparse_reshape.is_out_port_connected(0): + sparse_reshape.out_port(0).get_connection().set_source(sparse_reshape.in_port(0).get_source()) + output_data_node = sparse_reshape.out_node(0) + nodes_to_remove.append(output_data_node.id) + else: + input_data_node = sparse_reshape.in_node(0) + nodes_to_remove.append(input_data_node.id) + + if sparse_reshape.is_out_port_connected(1): + sparse_reshape.out_port(1).get_connection().set_source(sparse_reshape.in_port(1).get_source()) + output_data_node = sparse_reshape.out_node(1) + nodes_to_remove.append(output_data_node.id) + else: + input_data_node = sparse_reshape.in_node(1) + nodes_to_remove.append(input_data_node.id) + + graph.remove_nodes_from(nodes_to_remove) diff --git a/model-optimizer/extensions/ops/aten.py b/model-optimizer/extensions/ops/aten.py index f1c68f9aec7efd..92a045d2340cc4 100644 --- a/model-optimizer/extensions/ops/aten.py +++ b/model-optimizer/extensions/ops/aten.py @@ -31,4 +31,4 @@ def __init__(self, graph: Graph, attrs: dict): }, attrs) def supported_attrs(self): - return ['mode', 'operator', 'scale_grad_by_freq'] + return ['mode', 'operator'] diff --git a/model-optimizer/extensions/ops/embedding_bag.py b/model-optimizer/extensions/ops/embedding_bag.py index 28d44fa776ffd7..1d83db0a6375a3 100644 --- a/model-optimizer/extensions/ops/embedding_bag.py +++ b/model-optimizer/extensions/ops/embedding_bag.py @@ -14,53 +14,105 @@ limitations under the License. """ -from mo.front.common.partial_infer.utils import int64_array +import numpy as np + from mo.graph.graph import Node, Graph from mo.ops.op import Op -class EmbeddingBag(Op): - ''' - This is nn.EmbeddingBag from Pytorch. It is a simple lookup table that stores embeddings of a fixed dictionary and - size and computes sums or means of "bags" of embeddings, without instantiating the intermediate embeddings. - Inputs: - 0: Weights (num_embeddings, embedding_dim) - the lookup table - 1: Indices (N,) - indices to get from lookup table - 2: Offsets (B,) - index in indices tensor on which each bag starts - Output: - 0: Embeddings (B, embedding_dim) - ''' - op = 'EmbeddingBag' +class EmbeddingBagBase(Op): enabled = False + op = op_type = None + version = None + in_ports_count = None + def __init__(self, graph: Graph, attrs: dict): super().__init__(graph, { 'op': self.op, - 'type': None, + 'type': self.op_type, + 'version': self.version, 'infer': self.infer, - 'in_ports_count': 3, + 'in_ports_count': self.in_ports_count, 'out_ports_count': 1, }, attrs) - def supported_attrs(self): - return ['mode', 'scale_grad_by_freq'] + @staticmethod + def infer(node: Node): + raise NotImplementedError('Please use specialized EmbeddingBag operation class, EmbeddingBagBase is base class') + + +class EmbeddingBagOffsetsSum(EmbeddingBagBase): + op = op_type = 'EmbeddingBagOffsetsSum' + version = 'opset3' + in_ports_count = 5 @staticmethod def infer(node: Node): name = node.soft_get('name', node.id) connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} - assert len(connected_in_ports) == 3 and 0 in connected_in_ports and 1 in connected_in_ports and \ - 2 in connected_in_ports, "EmbeddingBag should have 3 connected input port, but it doesn't for " \ - "node: `{}`. Ports: {}".format(name, connected_in_ports) + assert len(connected_in_ports) >= 3 and all(p in connected_in_ports for p in [0, 1, 2]), \ + "EmbeddingBagOffsetsSum should have at least 3 connected input port, but it doesn't " \ + "for node: `{}`. Ports: {}".format(name, connected_in_ports) - weights = node.in_port(0).data.get_value() - assert weights is not None and len(weights.shape) == 2 - input_shape = node.in_port(1).data.get_shape() - assert input_shape is not None + weights_shape = node.in_port(0).data.get_shape() + assert len(weights_shape) >= 2,\ + "EmbeddingBagOffsetsSum should have at least 2D weights for node: `{}`".format(name) offsets_shape = node.in_port(2).data.get_shape() - assert offsets_shape is not None and len(offsets_shape) == 1 + assert offsets_shape is not None and len(offsets_shape) == 1,\ + "Rank of the offsets in EmbeddingBagOffsetsSum should be equal to 1 for node: `{}`".format(name) + + node.out_port(0).data.set_shape(np.concatenate((offsets_shape[:1], weights_shape[1:]))) + + +class EmbeddingBagPackedSum(EmbeddingBagBase): + op = op_type = 'EmbeddingBagPackedSum' + version = 'opset3' + in_ports_count = 3 + + @staticmethod + def infer(node: Node): + name = node.soft_get('name', node.id) + + connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} + assert len(connected_in_ports) >= 2 and all(p in connected_in_ports for p in [0, 1]), \ + "EmbeddingBagPackedSum should have at least 2 connected input port, but it doesn't for node: `{}`. " \ + "Ports: {}".format(name, connected_in_ports) + + weights_shape = node.in_port(0).data.get_shape() + assert len(weights_shape) >= 2, \ + "EmbeddingBagPackedSum should have at least 2D weights for node: `{}`".format(name) + input_shape = node.in_port(1).data.get_shape() + + node.out_port(0).data.set_shape(np.concatenate((input_shape[:1], weights_shape[1:]))) + + +class EmbeddingSegmentsSum(EmbeddingBagBase): + op = op_type = 'EmbeddingSegmentsSum' + version = 'opset3' + in_ports_count = 6 + + @staticmethod + def infer(node: Node): + name = node.soft_get('name', node.id) + + connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()} + assert len(connected_in_ports) >= 4 and all(p in connected_in_ports for p in [0, 1, 2, 3]), \ + "EmbeddingSegmentsSum should have at least 4 connected input port, but it doesn't for node: `{}`. " \ + "Ports: {}".format(name, connected_in_ports) - node.out_port(0).data.set_shape(int64_array([offsets_shape[0], weights.shape[1]])) + weights_shape = node.in_port(0).data.get_shape() + assert len(weights_shape) >= 2,\ + "EmbeddingSegmentsSum should have at least 2D weights for node: `{}`".format(name) + indices_shape = node.in_port(1).data.get_shape() + segment_ids = node.in_port(2).data.get_shape() + assert len(indices_shape) == 1 and len(segment_ids) == 1 and indices_shape == segment_ids,\ + "Both indices and segment_ids should have the same shape for node: `{}`".format(name) + num_segments = node.in_port(3).data.get_value() + assert num_segments is not None, "EmbeddingSegmentsSum should have a constant num_segments provided, but it " \ + "doesn't for node: `{}`.".format(name) + output_shape = np.concatenate(([num_segments], weights_shape[1:])) + node.out_port(0).data.set_shape(output_shape) diff --git a/model-optimizer/extensions/ops/embedding_bag_test.py b/model-optimizer/extensions/ops/embedding_bag_test.py new file mode 100644 index 00000000000000..beda49b9c88d74 --- /dev/null +++ b/model-optimizer/extensions/ops/embedding_bag_test.py @@ -0,0 +1,89 @@ +""" + Copyright (C) 2018-2020 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import unittest + +import numpy as np + +from extensions.ops.embedding_bag import EmbeddingBagOffsetsSum, EmbeddingBagPackedSum, EmbeddingSegmentsSum +from mo.front.common.partial_infer.utils import int64_array +from mo.graph.graph import Node +from mo.utils.unittest.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, result, \ + connect, FakeAttr + +nodes = { + **valued_const_with_data('data', np.random.randn(3000, 8)), + **regular_op_with_shaped_data('indices1d', [100], {'type': 'Parameter', 'value': None, + '_out_port_data_type': {0: np.int32}}), + **regular_op_with_shaped_data('indices2d', [30, 3], {'type': 'Parameter', 'value': None, + '_out_port_data_type': {0: np.int32}}), + **regular_op_with_shaped_data('offsets', [30], {'type': 'Parameter', 'value': None, + '_out_port_data_type': {0: np.int32}}), + **regular_op_with_shaped_data('segment_ids', [100], {'type': 'Parameter', 'value': None, + '_out_port_data_type': {0: np.int32}}), + **valued_const_with_data('num_segments', np.array(30, dtype=np.int32)), + **regular_op_with_shaped_data('embedding_bag_offsets', None, + {'op': 'EmbeddingBagOffsetsSum', 'type': 'EmbeddingBagOffsetsSum', + 'name': 'embedding_bag_offsets'}), + **regular_op_with_shaped_data('embedding_bag_packed', None, + {'op': 'EmbeddingBagPackedSum', 'type': 'EmbeddingBagPackedSum', + 'name': 'embedding_bag_packed'}), + **regular_op_with_shaped_data('embedding_segments', None, + {'op': 'EmbeddingSegmentsSum', 'type': 'EmbeddingSegmentsSum', + 'name': 'embedding_bag_packed'}), + **result('output'), +} + + +class TestEmbeddingInfer(unittest.TestCase): + def test_embedding_bag_offsets_sum(self): + graph = build_graph(nodes, [ + *connect('data', '0:embedding_bag_offsets'), + *connect('indices1d', '1:embedding_bag_offsets'), + *connect('offsets', '2:embedding_bag_offsets'), + ('embedding_bag_offsets', 'embedding_bag_offsets_d', {'out': 0}), + ('embedding_bag_offsets_d', 'output'), + ], nodes_with_edges_only=True) + eb_node = Node(graph, 'embedding_bag_offsets') + EmbeddingBagOffsetsSum.infer(eb_node) + + self.assertTrue(np.array_equal(eb_node.out_port(0).data.get_shape(), int64_array([30, 8]))) + + def test_embedding_bag_packed_sum(self): + graph = build_graph(nodes, [ + *connect('data', '0:embedding_bag_packed'), + *connect('indices2d', '1:embedding_bag_packed'), + ('embedding_bag_packed', 'embedding_bag_packed_d', {'out': 0}), + ('embedding_bag_packed_d', 'output'), + ], nodes_with_edges_only=True) + eb_node = Node(graph, 'embedding_bag_packed') + EmbeddingBagPackedSum.infer(eb_node) + + self.assertTrue(np.array_equal(eb_node.out_port(0).data.get_shape(), int64_array([30, 8]))) + + def test_embedding_segments_sum(self): + graph = build_graph(nodes, [ + *connect('data', '0:embedding_segments'), + *connect('indices1d', '1:embedding_segments'), + *connect('segment_ids', '2:embedding_segments'), + *connect('num_segments', '3:embedding_segments'), + ('embedding_segments', 'embedding_segments_d', {'out': 0}), + ('embedding_segments_d', 'output'), + ], nodes_with_edges_only=True) + eb_node = Node(graph, 'embedding_segments') + EmbeddingSegmentsSum.infer(eb_node) + + self.assertTrue(np.array_equal(eb_node.out_port(0).data.get_shape(), int64_array([30, 8]))) diff --git a/model-optimizer/extensions/ops/sparse_reshape.py b/model-optimizer/extensions/ops/sparse_reshape.py index 2a2c75ce6c2d02..18c1d062af425f 100644 --- a/model-optimizer/extensions/ops/sparse_reshape.py +++ b/model-optimizer/extensions/ops/sparse_reshape.py @@ -42,6 +42,7 @@ def supported_attrs(self): @staticmethod def infer(node: Node): input_indices_shape = node.in_port(0).data.get_shape() + input_indices_value = node.in_port(0).data.get_value() input_shape_value = node.in_port(1).data.get_value() new_shape_value = node.in_port(2).data.get_value() new_shape_shape = node.in_port(2).data.get_shape() @@ -60,4 +61,6 @@ def infer(node: Node): output_indices_shape = np.concatenate((input_indices_shape[0:1], new_shape_shape)) node.out_port(0).data.set_shape(output_indices_shape) - # TODO: implement for constant input indices value + # TODO: implement constant value propogation for common case + if np.array_equal(input_shape_value, output_shape_value) and input_indices_value is not None: + node.out_port(0).data.set_value(input_indices_value) diff --git a/model-optimizer/extensions/ops/sparse_weighted_sum.py b/model-optimizer/extensions/ops/sparse_weighted_sum.py deleted file mode 100644 index ff76b211f1c3de..00000000000000 --- a/model-optimizer/extensions/ops/sparse_weighted_sum.py +++ /dev/null @@ -1,57 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from mo.front.common.partial_infer.utils import int64_array -from mo.graph.graph import Node, Graph -from mo.ops.op import Op - - -class ExperimentalSparseWeightedSum(Op): - op = 'ExperimentalSparseWeightedSum' - enabled = True - - def __init__(self, graph: Graph, attrs: dict): - super().__init__(graph, { - 'type': __class__.op, - 'op': __class__.op, - 'version': 'experimental', - 'reduce_op': None, - 'type_infer': self.type_infer, - 'infer': self.infer, - 'in_ports_count': 6, - 'out_ports_count': 1, - }, attrs) - - def supported_attrs(self): - return [] - - @staticmethod - def type_infer(node): - # the output type must be the same as the parameters table type - params_table_type = node.in_port(3).get_data_type() - node.out_port(0).set_data_type(params_table_type) - - @staticmethod - def infer(node: Node): - assert len(node.in_nodes()) == 5 or len(node.in_nodes()) == 6, \ - "Incorrect number of inputs for {} node".format(node.id) - - batch_size = node.in_port(2).data.get_value()[0] - num_features = node.in_port(3).data.get_shape()[1:] - output_shape = int64_array([batch_size] + num_features.tolist()) - node.out_port(0).data.set_shape(output_shape) - - # TODO: implement output value computation if all input is constant diff --git a/model-optimizer/extensions/ops/sparse_weighted_sum_test.py b/model-optimizer/extensions/ops/sparse_weighted_sum_test.py deleted file mode 100644 index 0f1b40028d5af2..00000000000000 --- a/model-optimizer/extensions/ops/sparse_weighted_sum_test.py +++ /dev/null @@ -1,67 +0,0 @@ -""" - Copyright (C) 2018-2020 Intel Corporation - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -import unittest - -import numpy as np - -from extensions.ops.sparse_weighted_sum import ExperimentalSparseWeightedSum -from mo.front.common.partial_infer.utils import int64_array -from mo.graph.graph import Node -from mo.utils.unittest.graph import build_graph - -nodes_attributes = {'input_indices': {'shape': None, 'value': None, 'kind': 'data'}, - 'input_values': {'shape': None, 'value': None, 'kind': 'data'}, - 'input_dense_shape': {'shape': None, 'value': None, 'kind': 'data'}, - 'input_params_table': {'shape': None, 'value': None, 'kind': 'data'}, - 'input_default_value': {'shape': None, 'value': None, 'kind': 'data'}, - 'input_weights': {'shape': None, 'value': None, 'kind': 'data'}, - 'sparse_weighted_sum_node': {'op': 'ExperimentalSparseWeightedSum', 'kind': 'op'}, - 'output': {'shape': None, 'value': None, 'kind': 'data'}} - -# graph 1 -edges1 = [('input_indices', 'sparse_weighted_sum_node', {'in': 0}), - ('input_values', 'sparse_weighted_sum_node', {'in': 1}), - ('input_dense_shape', 'sparse_weighted_sum_node', {'in': 2}), - ('input_params_table', 'sparse_weighted_sum_node', {'in': 3}), - ('input_default_value', 'sparse_weighted_sum_node', {'in': 4}), - ('input_weights', 'sparse_weighted_sum_node', {'in': 5}), - ('sparse_weighted_sum_node', 'output', {'out': 0})] - -inputs1 = {'input_indices': {'shape': int64_array([5, 2]), 'value': None}, - 'input_values': {'shape': int64_array([5]), 'value': None}, - 'input_dense_shape': {'shape': int64_array([2]), 'value': int64_array([4, 3])}, - 'input_params_table': {'shape': int64_array([100, 4, 5]), 'value': None}, - 'input_default_value': {'shape': int64_array([]), 'value': 100.0}, - 'input_weights': {'shape': int64_array([5]), 'value': None}} - - -class TestExperimentalSparseWeightedSum(unittest.TestCase): - def test_partial_infer1(self): - graph = build_graph(nodes_attributes, edges1, inputs1) - sparse_weighted_sum_node = Node(graph, 'sparse_weighted_sum_node') - ExperimentalSparseWeightedSum.infer(sparse_weighted_sum_node) - - # prepare reference results - ref_output_shape = np.array([4, 4, 5], dtype=np.int32) - - # get the result - res_output_shape = graph.node['output']['shape'] - - self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), - 'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) - -