-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement LookupTableInsert shape inference (#2348)
* Implement LookupTableInsertV2 shape inference It is needed if other nodes not beeing pruned in the graph have a conditional dependence on LookupTableInsertV2 node. Signed-off-by: Roman Kazantsev <[email protected]> * Fix after core-review #1 Signed-off-by: Roman Kazantsev <[email protected]> * Fix the code after review #2 Signed-off-by: Roman Kazantsev <[email protected]> * Fix after code review #3
- Loading branch information
Showing
4 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 38 additions & 0 deletions
38
model-optimizer/extensions/front/tf/LookupTableInsert_ext.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Copyright (C) 2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
from extensions.ops.LookupTableInsert import LookupTableInsert | ||
from mo.front.extractor import FrontExtractorOp | ||
|
||
|
||
class LookupTableInsertFrontExtractor(FrontExtractorOp): | ||
op = 'LookupTableInsert' | ||
enabled = True | ||
|
||
@classmethod | ||
def extract(cls, node): | ||
LookupTableInsert.update_node_stat(node, {}) | ||
return cls.enabled | ||
|
||
|
||
class LookupTableInsertV2FrontExtractor(FrontExtractorOp): | ||
op = 'LookupTableInsertV2' | ||
enabled = True | ||
|
||
@classmethod | ||
def extract(cls, node): | ||
LookupTableInsert.update_node_stat(node, {}) | ||
return cls.enabled |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
Copyright (C) 2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from mo.front.common.partial_infer.utils import int64_array | ||
from mo.graph.graph import Node, Graph | ||
from mo.ops.op import Op | ||
|
||
|
||
class LookupTableInsert(Op): | ||
''' | ||
This operation has only output control flow edges and no output data edges in some models. | ||
And for these cases implementation of the shape inference is needed since the shape inference is executed | ||
before control flow edges resolving. This operation has non-tensor output so the output shape is empty. | ||
''' | ||
enabled = False | ||
op = 'LookupTableInsert' | ||
|
||
def __init__(self, graph: Graph, attrs: dict): | ||
mandatory_props = { | ||
'type': None, | ||
'op': self.op, | ||
'infer': self.infer, | ||
'in_ports_count': 3, | ||
'out_ports_count': 1, | ||
} | ||
super().__init__(graph, mandatory_props, attrs) | ||
|
||
@staticmethod | ||
def infer(node: Node): | ||
node_name = node.soft_get('name', node.id) | ||
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()] | ||
assert len(connected_in_ports) == 3, \ | ||
"Incorrect number of inputs for {} node".format(node_name) | ||
|
||
# check shapes of input tensors | ||
keys_shape = node.in_port(1).data.get_shape() | ||
values_shape = node.in_port(2).data.get_shape() | ||
assert np.array_equal(keys_shape, values_shape), \ | ||
'Shapes of tensors with keys and values must be equal for {} node'.format(node_name) | ||
|
||
# set output shape that must be empty | ||
# since output is not a tensor | ||
node.out_port(0).data.set_shape(int64_array([])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
Copyright (C) 2020 Intel Corporation | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
|
||
from extensions.ops.LookupTableInsert import LookupTableInsert | ||
from mo.front.common.partial_infer.utils import int64_array | ||
from mo.graph.graph import Node | ||
from mo.utils.unittest.graph import build_graph | ||
|
||
nodes_attributes = {'table': {'kind': 'op'}, | ||
'table_data': {'shape': None, 'value': None, 'kind': 'data'}, | ||
'keys': {'kind': 'op'}, | ||
'keys_data': {'shape': None, 'value': None, 'kind': 'data'}, | ||
'values': {'kind': 'op'}, | ||
'values_data': {'shape': None, 'value': None, 'kind': 'data'}, | ||
'lookuptableinsert_node': {'op': 'LookupTableInsert', 'kind': 'op'}, | ||
'output': {'shape': None, 'value': None, 'kind': 'data'}} | ||
|
||
# graph 1 | ||
edges1 = [('table', 'table_data'), | ||
('keys', 'keys_data'), | ||
('values', 'values_data'), | ||
('table_data', 'lookuptableinsert_node', {'in': 0}), | ||
('keys_data', 'lookuptableinsert_node', {'in': 1}), | ||
('values_data', 'lookuptableinsert_node', {'in': 2}), | ||
('lookuptableinsert_node', 'output')] | ||
|
||
# valid test case | ||
inputs1 = {'table_data': {}, | ||
'keys_data': {'shape': int64_array([4])}, | ||
'values_data': {'shape': int64_array([4])}} | ||
|
||
# invalid test case | ||
inputs2 = {'table_data': {}, | ||
'keys_data': {'shape': int64_array([5, 2])}, | ||
'values_data': {'shape': int64_array([4])}} | ||
|
||
class TestLookupTableInsert(unittest.TestCase): | ||
def test_infer1(self): | ||
graph = build_graph(nodes_attributes, edges1, inputs1) | ||
lookuptableinsert_node = Node(graph, 'lookuptableinsert_node') | ||
LookupTableInsert.infer(lookuptableinsert_node) | ||
|
||
# prepare reference results | ||
ref_output_shape = int64_array([]) | ||
|
||
# get the result | ||
res_output_shape = graph.node['output']['shape'] | ||
|
||
self.assertTrue(np.array_equal(ref_output_shape, res_output_shape), | ||
'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape)) | ||
|
||
def test_infer_invalid1(self): | ||
graph = build_graph(nodes_attributes, edges1, inputs2) | ||
lookuptableinsert_node = Node(graph, 'lookuptableinsert_node') | ||
self.assertRaises(AssertionError, LookupTableInsert.infer, lookuptableinsert_node) |