Skip to content

Commit

Permalink
Implement LookupTableInsert shape inference (openvinotoolkit#2348)
Browse files Browse the repository at this point in the history
* Implement LookupTableInsertV2 shape inference

It is needed if other nodes not beeing pruned in the graph
have a conditional dependence on LookupTableInsertV2 node.

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix after core-review #1

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix the code after review #2

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix after code review #3
  • Loading branch information
rkazants authored Oct 20, 2020
1 parent 347e92c commit c239450
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
2 changes: 2 additions & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ extensions/front/tf/identity_ext.py
extensions/front/tf/identityN_to_identity.py
extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/LookupTableInsert_ext.py
extensions/front/tf/LoopCond_ext.py
extensions/front/tf/lrn_ext.py
extensions/front/tf/mask_rcnn_support.json
Expand Down Expand Up @@ -630,6 +631,7 @@ extensions/ops/identity.py
extensions/ops/instance_normalization.py
extensions/ops/interp.py
extensions/ops/interpolate.py
extensions/ops/LookupTableInsert.py
extensions/ops/LSTM.py
extensions/ops/lstm_cell.py
extensions/ops/lstm_sequence.py
Expand Down
38 changes: 38 additions & 0 deletions model-optimizer/extensions/front/tf/LookupTableInsert_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from extensions.ops.LookupTableInsert import LookupTableInsert
from mo.front.extractor import FrontExtractorOp


class LookupTableInsertFrontExtractor(FrontExtractorOp):
op = 'LookupTableInsert'
enabled = True

@classmethod
def extract(cls, node):
LookupTableInsert.update_node_stat(node, {})
return cls.enabled


class LookupTableInsertV2FrontExtractor(FrontExtractorOp):
op = 'LookupTableInsertV2'
enabled = True

@classmethod
def extract(cls, node):
LookupTableInsert.update_node_stat(node, {})
return cls.enabled
58 changes: 58 additions & 0 deletions model-optimizer/extensions/ops/LookupTableInsert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op


class LookupTableInsert(Op):
'''
This operation has only output control flow edges and no output data edges in some models.
And for these cases implementation of the shape inference is needed since the shape inference is executed
before control flow edges resolving. This operation has non-tensor output so the output shape is empty.
'''
enabled = False
op = 'LookupTableInsert'

def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': None,
'op': self.op,
'infer': self.infer,
'in_ports_count': 3,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)

@staticmethod
def infer(node: Node):
node_name = node.soft_get('name', node.id)
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
assert len(connected_in_ports) == 3, \
"Incorrect number of inputs for {} node".format(node_name)

# check shapes of input tensors
keys_shape = node.in_port(1).data.get_shape()
values_shape = node.in_port(2).data.get_shape()
assert np.array_equal(keys_shape, values_shape), \
'Shapes of tensors with keys and values must be equal for {} node'.format(node_name)

# set output shape that must be empty
# since output is not a tensor
node.out_port(0).data.set_shape(int64_array([]))
72 changes: 72 additions & 0 deletions model-optimizer/extensions/ops/LookupTableInsert_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

import numpy as np

from extensions.ops.LookupTableInsert import LookupTableInsert
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph

nodes_attributes = {'table': {'kind': 'op'},
'table_data': {'shape': None, 'value': None, 'kind': 'data'},
'keys': {'kind': 'op'},
'keys_data': {'shape': None, 'value': None, 'kind': 'data'},
'values': {'kind': 'op'},
'values_data': {'shape': None, 'value': None, 'kind': 'data'},
'lookuptableinsert_node': {'op': 'LookupTableInsert', 'kind': 'op'},
'output': {'shape': None, 'value': None, 'kind': 'data'}}

# graph 1
edges1 = [('table', 'table_data'),
('keys', 'keys_data'),
('values', 'values_data'),
('table_data', 'lookuptableinsert_node', {'in': 0}),
('keys_data', 'lookuptableinsert_node', {'in': 1}),
('values_data', 'lookuptableinsert_node', {'in': 2}),
('lookuptableinsert_node', 'output')]

# valid test case
inputs1 = {'table_data': {},
'keys_data': {'shape': int64_array([4])},
'values_data': {'shape': int64_array([4])}}

# invalid test case
inputs2 = {'table_data': {},
'keys_data': {'shape': int64_array([5, 2])},
'values_data': {'shape': int64_array([4])}}

class TestLookupTableInsert(unittest.TestCase):
def test_infer1(self):
graph = build_graph(nodes_attributes, edges1, inputs1)
lookuptableinsert_node = Node(graph, 'lookuptableinsert_node')
LookupTableInsert.infer(lookuptableinsert_node)

# prepare reference results
ref_output_shape = int64_array([])

# get the result
res_output_shape = graph.node['output']['shape']

self.assertTrue(np.array_equal(ref_output_shape, res_output_shape),
'shapes do not match expected: {} and given: {}'.format(ref_output_shape, res_output_shape))

def test_infer_invalid1(self):
graph = build_graph(nodes_attributes, edges1, inputs2)
lookuptableinsert_node = Node(graph, 'lookuptableinsert_node')
self.assertRaises(AssertionError, LookupTableInsert.infer, lookuptableinsert_node)

0 comments on commit c239450

Please sign in to comment.