Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MO] Support TF2 Keras ConvLSTM2D operation #4197

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ extensions/middle/LayoutChangeForConstantShapePaths.py
extensions/middle/LeakyReluPattern.py
extensions/middle/LSTMRNNSequenceToTensorIterator.py
extensions/middle/MarkSubgraphsWithCorrectLayout.py
extensions/middle/MoveConstToLoopBody.py
extensions/middle/MulFakeQuantizeFuse.py
extensions/middle/MXNetRNNSequenceNormalize.py
extensions/middle/MXNetSplitMultiLayers.py
Expand Down
58 changes: 58 additions & 0 deletions model-optimizer/extensions/middle/MoveConstToLoopBody.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Copyright (C) 2017-2021 Intel Corporation
rkazants marked this conversation as resolved.
Show resolved Hide resolved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from extensions.ops.loop import Loop
from mo.middle.replacement import MiddleReplacementPattern
from mo.graph.graph import Graph


class MoveConstToLoopBody(MiddleReplacementPattern):
"""
It moves constant producers for Loop node into the body graph and removes input ports for them.
This transformations helps to continue constant folding inside the body graph if possible.
The constant folding serves as optimization path and helps to avoid issue connecting with constants
lying on weights path to Convolution node.
"""
enabled = True
force_shape_inference = True

def run_after(self):
from extensions.middle.pass_separator import PostMiddleStart
return [PostMiddleStart]

def run_before(self):
from extensions.middle.ApplyPermutations import ApplyPermutation
return [ApplyPermutation]

def find_and_replace_pattern(self, graph: Graph):
cleanup_called_once = False

# walk through all Loop nodes and find Const inputs
for loop_node in graph.get_op_nodes(op='Loop'):
# call clean-up only once that performs constant folding
if not cleanup_called_once:
graph.clean_up()
rkazants marked this conversation as resolved.
Show resolved Hide resolved
cleanup_called_once = True

# move constant node into the body graph and removes body parameter nodes corresponding to them
Loop.pull_constant_inputs_into_body(loop_node)

# since some input ports can be removed after the pulling constants, normalization of Loop node is required
Loop.normalize_input_output_ports(loop_node)

# perform shape inference for the Loop node again since new constant can be appeared
# and constant folding can be helpful for weights path to Convolution node inside the body graph
loop_node['need_shape_inference'] = True
4 changes: 2 additions & 2 deletions model-optimizer/extensions/middle/fusings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Copyright (C) 2018-2021 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -96,7 +96,7 @@ def find_and_replace_pattern(self, graph: Graph):
for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

rkazants marked this conversation as resolved.
Show resolved Hide resolved
normalize_eltwise_inputs(graph)
for_graph_and_each_sub_graph_recursively(graph, normalize_eltwise_inputs)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())

MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph)
Expand Down
5 changes: 2 additions & 3 deletions model-optimizer/extensions/ops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ def add_back_edge(loop_node: Node, internal_parameter: Node, internal_result: No
@staticmethod
def pull_constant_inputs_into_body(loop_node: Node):
for port_idx, in_port in reversed(loop_node.in_ports().items()):
# TODO add a check that the input does not correspond to execution_condition
if not in_port.disconnected() and in_port.get_source().node.soft_get('type') == 'Const':
if port_idx > 1 and not in_port.disconnected() and in_port.get_source().node.soft_get('type') == 'Const':
original_const_node = in_port.get_source().node
new_const_node = Const(loop_node.body, original_const_node.attrs()).create_node()

Expand Down Expand Up @@ -463,7 +462,7 @@ def remove_unused_ops_from_port_map(loop_node: Node, port_map: dict, port_map_at
port_to_remove = port_map[record_id_to_remove]['external_port_id']
if port_to_remove != -1:
if dir == 'in':
if port_to_remove not in [0, 1]: # input port 0 and 1 are mandatory for the Loop node
if port_to_remove not in [0, 1] and port_to_remove in loop_node.in_ports().keys(): # input port 0 and 1 are mandatory for the Loop node
loop_node.delete_input_port(port_to_remove)
elif dir == 'out' and port_to_remove in loop_node.out_ports():
loop_node.delete_output_port(port_to_remove)
Expand Down
10 changes: 5 additions & 5 deletions model-optimizer/mo/graph/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def check_and_remove_edge():
"Broken Connection object! Destination (node:{}) is not connected to source.".format(
destination.node.name))
destination.disconnect()
return edge_attrs
return None
return edge_attrs, key
return {}, None

if self.destinations and len(self.destinations) > 1:
raise Error("set_destination applicable only for connections that has exactly one destination or \
Expand All @@ -228,7 +228,7 @@ def check_and_remove_edge():
if self.graph.stage == 'front':
if self.source is not None:
node = self.source.node
source_attrs = check_and_remove_edge() or {}
source_attrs, _ = check_and_remove_edge()
dest_attrs = port.get_in_edge_attrs() or {}

edge_attrs = {}
Expand All @@ -243,7 +243,7 @@ def check_and_remove_edge():
# in case if data node exists just use it as is
if self.source is not None:
data_node = self.source._create_data_if_necessary()
edge_attrs = check_and_remove_edge() or {}
edge_attrs, key = check_and_remove_edge()
edge_attrs.update({'in': port.idx})

dest_attrs = {}
Expand All @@ -253,7 +253,7 @@ def check_and_remove_edge():
new_tensor_info = self._get_new_tensor_debug_info(attributes_save_mode, data_node.attrs(), dest_attrs)
self._update_tensor_debug_info(data_node.attrs(), new_tensor_info)

self.graph.add_edge(data_node.id, port.node.id, **edge_attrs)
self.graph.add_edge(data_node.id, port.node.id, key=key, **edge_attrs)
self.destinations = [port]

def add_destination(self, port):
Expand Down