Skip to content

Commit

Permalink
[MO] Support TF2 Keras ConvLSTM2D operation (#4197)
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants authored Feb 9, 2021
1 parent 9df1381 commit 636f5c4
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 10 deletions.
1 change: 1 addition & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ extensions/middle/LayoutChangeForConstantShapePaths.py
extensions/middle/LeakyReluPattern.py
extensions/middle/LSTMRNNSequenceToTensorIterator.py
extensions/middle/MarkSubgraphsWithCorrectLayout.py
extensions/middle/MoveConstToLoopBody.py
extensions/middle/MulFakeQuantizeFuse.py
extensions/middle/MXNetRNNSequenceNormalize.py
extensions/middle/MXNetSplitMultiLayers.py
Expand Down
58 changes: 58 additions & 0 deletions model-optimizer/extensions/middle/MoveConstToLoopBody.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Copyright (C) 2017-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from extensions.ops.loop import Loop
from mo.middle.replacement import MiddleReplacementPattern
from mo.graph.graph import Graph


class MoveConstToLoopBody(MiddleReplacementPattern):
"""
It moves constant producers for Loop node into the body graph and removes input ports for them.
This transformations helps to continue constant folding inside the body graph if possible.
The constant folding serves as optimization path and helps to avoid issue connecting with constants
lying on weights path to Convolution node.
"""
enabled = True
force_shape_inference = True

def run_after(self):
from extensions.middle.pass_separator import PostMiddleStart
return [PostMiddleStart]

def run_before(self):
from extensions.middle.ApplyPermutations import ApplyPermutation
return [ApplyPermutation]

def find_and_replace_pattern(self, graph: Graph):
cleanup_called_once = False

# walk through all Loop nodes and find Const inputs
for loop_node in graph.get_op_nodes(op='Loop'):
# call clean-up only once that performs constant folding
if not cleanup_called_once:
graph.clean_up()
cleanup_called_once = True

# move constant node into the body graph and removes body parameter nodes corresponding to them
Loop.pull_constant_inputs_into_body(loop_node)

# since some input ports can be removed after the pulling constants, normalization of Loop node is required
Loop.normalize_input_output_ports(loop_node)

# perform shape inference for the Loop node again since new constant can be appeared
# and constant folding can be helpful for weights path to Convolution node inside the body graph
loop_node['need_shape_inference'] = True
4 changes: 2 additions & 2 deletions model-optimizer/extensions/middle/fusings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -96,7 +96,7 @@ def find_and_replace_pattern(self, graph: Graph):
for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

normalize_eltwise_inputs(graph)
for_graph_and_each_sub_graph_recursively(graph, normalize_eltwise_inputs)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph)
Expand Down
5 changes: 2 additions & 3 deletions model-optimizer/extensions/ops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ def add_back_edge(loop_node: Node, internal_parameter: Node, internal_result: No
@staticmethod
def pull_constant_inputs_into_body(loop_node: Node):
for port_idx, in_port in reversed(loop_node.in_ports().items()):
# TODO add a check that the input does not correspond to execution_condition
if not in_port.disconnected() and in_port.get_source().node.soft_get('type') == 'Const':
if port_idx > 1 and not in_port.disconnected() and in_port.get_source().node.soft_get('type') == 'Const':
original_const_node = in_port.get_source().node
new_const_node = Const(loop_node.body, original_const_node.attrs()).create_node()

Expand Down Expand Up @@ -463,7 +462,7 @@ def remove_unused_ops_from_port_map(loop_node: Node, port_map: dict, port_map_at
port_to_remove = port_map[record_id_to_remove]['external_port_id']
if port_to_remove != -1:
if dir == 'in':
if port_to_remove not in [0, 1]: # input port 0 and 1 are mandatory for the Loop node
if port_to_remove not in [0, 1] and port_to_remove in loop_node.in_ports().keys(): # input port 0 and 1 are mandatory for the Loop node
loop_node.delete_input_port(port_to_remove)
elif dir == 'out' and port_to_remove in loop_node.out_ports():
loop_node.delete_output_port(port_to_remove)
Expand Down
10 changes: 5 additions & 5 deletions model-optimizer/mo/graph/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def check_and_remove_edge():
"Broken Connection object! Destination (node:{}) is not connected to source.".format(
destination.node.name))
destination.disconnect()
return edge_attrs
return None
return edge_attrs, key
return {}, None

if self.destinations and len(self.destinations) > 1:
raise Error("set_destination applicable only for connections that has exactly one destination or \
Expand All @@ -228,7 +228,7 @@ def check_and_remove_edge():
if self.graph.stage == 'front':
if self.source is not None:
node = self.source.node
source_attrs = check_and_remove_edge() or {}
source_attrs, _ = check_and_remove_edge()
dest_attrs = port.get_in_edge_attrs() or {}

edge_attrs = {}
Expand All @@ -243,7 +243,7 @@ def check_and_remove_edge():
# in case if data node exists just use it as is
if self.source is not None:
data_node = self.source._create_data_if_necessary()
edge_attrs = check_and_remove_edge() or {}
edge_attrs, key = check_and_remove_edge()
edge_attrs.update({'in': port.idx})

dest_attrs = {}
Expand All @@ -253,7 +253,7 @@ def check_and_remove_edge():
new_tensor_info = self._get_new_tensor_debug_info(attributes_save_mode, data_node.attrs(), dest_attrs)
self._update_tensor_debug_info(data_node.attrs(), new_tensor_info)

self.graph.add_edge(data_node.id, port.node.id, **edge_attrs)
self.graph.add_edge(data_node.id, port.node.id, key=key, **edge_attrs)
self.destinations = [port]

def add_destination(self, port):
Expand Down

0 comments on commit 636f5c4

Please sign in to comment.