From bacb8420f032824b083cfcfba7db2ce346ab22b2 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Thu, 21 Jan 2021 17:39:57 +0300 Subject: [PATCH] [MO] Implement TensorFlow 2 While and Keras RNN support in MO (#3573) * [MO] Implement TensorFlow 2 While support in MO Signed-off-by: Roman Kazantsev * Add extractors for both While and StatelessWhile and do minor changes Signed-off-by: Roman Kazantsev * Improve update_body_graph function and manage graph names properly Signed-off-by: Roman Kazantsev * Fix a map for original name of parameters from body and cond Signed-off-by: Roman Kazantsev * Implement draft version of support of TF2 Keras RNN Signed-off-by: Roman Kazantsev * Implement Keras LSTM and GRU support in MO Signed-off-by: Roman Kazantsev * Improve code for Keras RNN support Signed-off-by: Roman Kazantsev * Finalize implementation of TF2 Keras RNN support in MO Signed-off-by: Roman Kazantsev * Apply the first part of the comments after review #1 Signed-off-by: Roman Kazantsev * Avoid use of explicit values of port indices in the transformation Signed-off-by: Roman Kazantsev * Finalize code after the first-round review Signed-off-by: Roman Kazantsev * Apply comments after the second-round review Signed-off-by: Roman Kazantsev --- model-optimizer/automation/package_BOM.txt | 3 + .../back/SpecialNodesFinalization.py | 8 +- .../front/standalone_const_eraser.py | 5 +- .../front/tf/KerasRNNTransformation.py | 268 ++++++++++++++++++ .../extensions/front/tf/WhileNormalize.py | 53 ++++ .../extensions/front/tf/while_ext.py | 207 ++++++++++++++ model-optimizer/extensions/load/tf/loader.py | 3 +- model-optimizer/extensions/ops/loop.py | 60 +++- model-optimizer/mo/front/tf/loader.py | 13 +- model-optimizer/mo/graph/graph.py | 4 +- model-optimizer/mo/ops/op.py | 7 +- 11 files changed, 618 insertions(+), 13 deletions(-) create mode 100644 model-optimizer/extensions/front/tf/KerasRNNTransformation.py create mode 100644 model-optimizer/extensions/front/tf/WhileNormalize.py create mode 100644 model-optimizer/extensions/front/tf/while_ext.py diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 7eaa40186867e0..e4080d168e1274 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -404,6 +404,7 @@ extensions/front/tf/identity_ext.py extensions/front/tf/identityN_to_identity.py extensions/front/tf/InterpolateTransposes.py extensions/front/tf/IteratorGetNext_ext.py +extensions/front/tf/KerasRNNTransformation.py extensions/front/tf/log_softmax_ext.py extensions/front/tf/LookupTableInsert_ext.py extensions/front/tf/LoopCond_ext.py @@ -483,6 +484,8 @@ extensions/front/tf/UnpackPackReverseInputChannels.py extensions/front/tf/variable_ext.py extensions/front/tf/variables_values_freezing.py extensions/front/tf/WhereDecomposition.py +extensions/front/tf/while_ext.py +extensions/front/tf/WhileNormalize.py extensions/front/tf/yolo_v1.json extensions/front/tf/yolo_v1_tiny.json extensions/front/tf/yolo_v2.json diff --git a/model-optimizer/extensions/back/SpecialNodesFinalization.py b/model-optimizer/extensions/back/SpecialNodesFinalization.py index cd1e599485da3b..e19d413262de5e 100644 --- a/model-optimizer/extensions/back/SpecialNodesFinalization.py +++ b/model-optimizer/extensions/back/SpecialNodesFinalization.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ """ import logging as log from collections import defaultdict -from copy import copy import numpy as np @@ -125,6 +124,11 @@ class RemoveConstToResult(BackReplacementPattern): """ enabled = True force_clean_up = True + # TODO: remove this transformation once all plugins support constant value network. + # Do not run recursively since Const->Result sub-graph can be encountered in a body graph of Loop node + # and this sub-graph is needed to avoid dynamism created by Loop node + # in case using axis in output port map + run_not_recursively = True @staticmethod def pattern(): diff --git a/model-optimizer/extensions/front/standalone_const_eraser.py b/model-optimizer/extensions/front/standalone_const_eraser.py index c0290968964b50..6cbae5c77757be 100644 --- a/model-optimizer/extensions/front/standalone_const_eraser.py +++ b/model-optimizer/extensions/front/standalone_const_eraser.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -22,6 +22,9 @@ class StandaloneConstEraser(FrontReplacementSubgraph): enabled = True + # TODO: remove this transformation once all plugins support constant value network. + # Now it avoids to be run recursively since Const->Result sub-graph can be encountered in a body graph of Loop node + run_not_recursively = True @staticmethod def pattern(): diff --git a/model-optimizer/extensions/front/tf/KerasRNNTransformation.py b/model-optimizer/extensions/front/tf/KerasRNNTransformation.py new file mode 100644 index 00000000000000..7ba9fd41783630 --- /dev/null +++ b/model-optimizer/extensions/front/tf/KerasRNNTransformation.py @@ -0,0 +1,268 @@ +""" + Copyright (C) 2017-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +from extensions.front.tf.WhileNormalize import WhileNormalize +from extensions.ops.loop import Loop +from mo.front.common.partial_infer.utils import int64_array +from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.tf.graph_utils import create_op_with_const_inputs +from mo.graph.graph import Graph, Node, rename_nodes +from mo.middle.pattern_match import find_pattern_matches, inverse_dict +from mo.ops.const import Const +from mo.ops.squeeze import Squeeze +from mo.ops.unsqueeze import Unsqueeze + + +def compute_input_port_idx(req_node: Node, loop_node: Node): + """ + Computes input port index by which requested node is passed to Loop node + :param req_node: a node for which to find input port index is requested + :param loop_node: a node that can receive input data from requested node by some input port + :return: input port index + """ + for destination in req_node.out_port(0).get_destinations(): + if loop_node.id == destination.node.id: + return destination.idx + return None + + +def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict): + """ + Finds sub-graph matches corresponding pattern in graph + :param graph: a graph where to search for matched sub-graph + :param body_pattern: a pattern + :return: a list of sub-graph matches + """ + matches = [] + for match in find_pattern_matches(graph, **body_pattern): + match = inverse_dict(match) + for k in match: + match[k] = Node(graph, match[k]) + matches.append(match) + + return matches + + +class KerasRNNInputSlicing(FrontReplacementSubgraph): + """ + The transformation detects TensorFlow 2 pattern that corresponds to subsequent slicing of input. + It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original sub-graph + by adding axis attribute for corresponding input port of Loop node. + The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers. + """ + enabled = True + + def run_before(self): + return [WhileNormalize] + + @staticmethod + def pattern(**kwargs): + return dict( + nodes=[('unstack', dict(op='TensorListFromTensor')), + ('while', dict(op='Loop'))], + edges=[('unstack', 'while')] + ) + + @staticmethod + def get_body_pattern(): + return dict( + nodes=[('tensor_list', dict(op='Parameter')), + ('current_iteration', dict(op='Parameter')), + ('slicing', dict(op='TensorListGetItem')), + ('const_increment', dict(op='Const')), + ('increment_iteration', dict(op='Add')), + ('increment_iteration_identity', dict(op='Identity')), + ('increment_iteration_result', dict(op='Result'))], + edges=[('tensor_list', 'slicing', {'in': 0}), + ('current_iteration', 'slicing', {'in': 1}), + ('const_increment', 'increment_iteration', {'in': 1}), + ('current_iteration', 'increment_iteration', {'in': 0}), + ('increment_iteration', 'increment_iteration_identity', {'in': 0}), + ('increment_iteration_identity', 'increment_iteration_result', {'in': 0})] + ) + + @staticmethod + def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict): + """ + Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node + :param external_match: a match used for handling a part of the main graph responsible for input slicing + :param internal_match: a match used for handling a part of the body graph responsible for input slicing + """ + loop_node = external_match['while'] + unstack_node = external_match['unstack'] + body_graph = loop_node['body'] + + tensor_list_get_item_node = internal_match['slicing'] + unstack_placeholder = internal_match['tensor_list'] + tensor_list_get_item_node_name = tensor_list_get_item_node.soft_get('name', tensor_list_get_item_node.id) + + # 1. process the body graph to avoid unsupported operations: TensorListGetItem and TensorListSetItem + # replace TensorListGetItem with Squeeze node and iterate through slices using axis for input port + squeeze_list_element = create_op_with_const_inputs(body_graph, Squeeze, {1: int64_array(0)}, + {'name': 'TensorListGetItemSqueeze'}) + tensor_list_get_item_node.in_port(0).get_connection().set_destination(squeeze_list_element.in_port(0)) + tensor_list_get_item_node.out_port(0).get_connection().set_source(squeeze_list_element.out_port(0)) + rename_nodes([(tensor_list_get_item_node, tensor_list_get_item_node_name + '/AbandonedName'), + (squeeze_list_element, tensor_list_get_item_node_name)]) + unstack_placeholder_layer_id = unstack_placeholder.internal_layer_id + Loop.update_port_map_value_ext(loop_node.input_port_map, 'internal_layer_id', unstack_placeholder_layer_id, + 'axis', 0) + + # 2. process locality of Loop node in the main graph to avoid unsupported operations: + # TensorListFromTensor, TensorListReserve, and TensorListStack + # remove TensorListFromTensor and pass a tensor to Loop as is + unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source()) + + def replace_sub_graph(self, graph: Graph, external_match: dict): + loop_node = external_match['while'] + body_graph = loop_node['body'] + body_pattern = KerasRNNInputSlicing.get_body_pattern() + internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) + + # a case of multiple matches is not handled since it is not clear how to select corresponding match + if len(internal_matches) == 1: + internal_match = internal_matches[0] + loop_node = external_match['while'] + unstack_port_idx = compute_input_port_idx(external_match['unstack'], loop_node) + # check that back edges connect correct Parameter and Result nodes in the body + # check connections between body input ports and external inputs ports of Loop node + if Loop.back_edge_exists(loop_node.back_edges, + internal_match['increment_iteration_result'].internal_layer_id, + internal_match['current_iteration'].internal_layer_id) and \ + Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx, + internal_match['tensor_list'].internal_layer_id): + # only if inter-graph match passed it starts to process the sub-graph + KerasRNNInputSlicing.transform_keras_rnn_input_slicing(external_match, internal_match) + + +class KerasRNNOutputConcatenation(FrontReplacementSubgraph): + """ + The transformation detects TensorFlow 2 pattern that corresponds to concatenation of intermediate results + generated in each iteration of While operation. + It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces the original sub-graph + by adding axis attribute for corresponding output port of Loop node. + The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers. + """ + enabled = True + + def run_before(self): + return [WhileNormalize] + + @staticmethod + def pattern(**kwargs): + return dict( + nodes=[('reserve', dict(op='TensorListReserve')), + ('while', dict(op='Loop')), + ('stack', dict(op='TensorListStack'))], + edges=[('reserve', 'while'), + ('while', 'stack')] + ) + + @staticmethod + def get_body_pattern(): + return dict( + nodes=[('container', dict(op='Parameter')), + ('current_iteration', dict(op='Parameter')), + ('const_increment', dict(op='Const')), + ('increment_iteration', dict(op='Add')), + ('increment_iteration_identity', dict(op='Identity')), + ('increment_iteration_result', dict(op='Result')), + ('concatenation', dict(op='TensorListSetItem')), + ('concatenation_identity', dict(op='Identity')), + ('concatenation_result', dict(op='Result')), + ], + edges=[('const_increment', 'increment_iteration', {'in': 1}), + ('current_iteration', 'increment_iteration', {'in': 0}), + ('container', 'concatenation', {'in': 0}), + ('current_iteration', 'concatenation', {'in': 1}), + ('concatenation', 'concatenation_identity', {'in': 0}), + ('concatenation_identity', 'concatenation_result', {'in': 0}), + ('increment_iteration', 'increment_iteration_identity', {'in': 0}), + ('increment_iteration_identity', 'increment_iteration_result', {'in': 0})] + ) + + @staticmethod + def transform_keras_rnn_output_concatenation(external_match: dict, internal_match: dict): + """ + Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node + :param external_match: a match used for handling a part of the main graph responsible for output concatenation + :param internal_match: a match used for handling a part of the body graph responsible for output concatenation + """ + loop_node = external_match['while'] + stack_node = external_match['stack'] + list_reserve_node = external_match['reserve'] + body_graph = loop_node['body'] + + tensor_list_set_item_node = internal_match['concatenation'] + tensor_list_set_item_node_name = tensor_list_set_item_node.soft_get('name', tensor_list_set_item_node.id) + list_result_node = internal_match['concatenation_result'] + + # replace TensorListSetItem with Unsqueeze and use axis attribute for corresponding Result node + # to concatenate results from different iterations + unsqueeze_list_element = create_op_with_const_inputs(body_graph, Unsqueeze, {1: int64_array(0)}, + {'name': 'TensorListSetItemUnsqueeze'}) + tensor_list_set_item_node.in_port(2).get_connection().set_destination(unsqueeze_list_element.in_port(0)) + tensor_list_set_item_node.out_port(0).get_connection().set_source(unsqueeze_list_element.out_port(0)) + rename_nodes([(tensor_list_set_item_node, tensor_list_set_item_node_name + '/AbandonedName'), + (unsqueeze_list_element, tensor_list_set_item_node_name)]) + list_result_node_layer_id = list_result_node.internal_layer_id + Loop.update_port_map_value_ext(loop_node.output_port_map, 'internal_layer_id', list_result_node_layer_id, + 'axis', 0) + + # remove TensorListStack to by-pass the node since the result from the Loop node is already concatenated + stack_node.out_port(0).get_connection().set_source(stack_node.in_port(0).get_connection().get_source()) + + # disconnect ListReserve node because it is no longer needed for Loop + list_reserve_node.out_port(0).disconnect() + + # connect a number of iterations with trip count that can be received from the second input of ListReserve + # create a constant network with True value for execution_condition so that IE can ignore execution condition + # and perform trip_counts iterations. This approach with known trip count value allows to avoid dynamism. + loop_node.in_port(1).disconnect() + list_reserve_node.in_port(1).get_source().connect(loop_node.in_port(1)) + for record in loop_node.output_port_map: + if 'purpose' in record and record['purpose'] == 'execution_condition': + exec_cond_layer_id = record['internal_layer_id'] + exec_cond_node = Loop.get_body_node_by_internal_id(loop_node, exec_cond_layer_id) + const_true = Const(body_graph, {'value': np.array(True, dtype=np.bool)}).create_node() + exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0)) + + def replace_sub_graph(self, graph: Graph, external_match: dict): + loop_node = external_match['while'] + body_graph = loop_node['body'] + body_pattern = KerasRNNOutputConcatenation.get_body_pattern() + + internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) + + if len(internal_matches) == 1: + internal_match = internal_matches[0] + reserve_port_idx = compute_input_port_idx(external_match['reserve'], loop_node) + stack_port_idx = external_match['stack'].in_port(0).get_source().idx + # check that back edges connect correct Parameter and Result nodes in the body + # check connections between body input ports and external inputs ports of Loop node + # check connections between body output ports and external output ports of Loop node + if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, + internal_match['container'].internal_layer_id) and \ + Loop.back_edge_exists(loop_node.back_edges, + internal_match['increment_iteration_result'].internal_layer_id, + internal_match['current_iteration'].internal_layer_id) and \ + Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx, + internal_match['container'].internal_layer_id) and \ + Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx, + internal_match['concatenation_result'].internal_layer_id): + KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(external_match, internal_match) diff --git a/model-optimizer/extensions/front/tf/WhileNormalize.py b/model-optimizer/extensions/front/tf/WhileNormalize.py new file mode 100644 index 00000000000000..a5d37216b16fec --- /dev/null +++ b/model-optimizer/extensions/front/tf/WhileNormalize.py @@ -0,0 +1,53 @@ +""" + Copyright (C) 2017-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import numpy as np + +from extensions.ops.loop import Loop +from mo.front.common.replacement import FrontReplacementSubgraph +from mo.graph.graph import Graph, Node +from mo.ops.const import Const + + +class WhileNormalize(FrontReplacementSubgraph): + """ + Normalize inputs for Loop replacing TensorFlow 2 While operation: + 1) Remove external input port for current iteration + 2) Move trip count from port #1 to port #0 + 3) Occupy port #1 for execution condition + """ + enabled = True + + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(op='Loop'): + self.normalize_loop_node(graph, node) + + @staticmethod + def normalize_loop_node(graph: Graph, loop_node: Node): + loop_name = loop_node.soft_get('name', loop_node.id) + + # disconnect current iteration from external port #0 and move trip count to this port + loop_node.in_port(0).disconnect() + loop_node.in_port(1).get_connection().add_destination(loop_node.in_port(0)) + Loop.update_port_map_value(loop_node.input_port_map, 'external_port_id', 1, 0) + + # connect execution condition port + exec_cond_node = Const(graph, {'name': loop_name + '/ExecutionConditionValue', + 'value': np.array(True, dtype=np.bool)}).create_node() + loop_node.in_port(1).get_connection().set_source(exec_cond_node.out_port(0)) + + loop_node.body.clean_up() + Loop.normalize_input_output_ports(loop_node) diff --git a/model-optimizer/extensions/front/tf/while_ext.py b/model-optimizer/extensions/front/tf/while_ext.py new file mode 100644 index 00000000000000..52f7defad3d152 --- /dev/null +++ b/model-optimizer/extensions/front/tf/while_ext.py @@ -0,0 +1,207 @@ +""" + Copyright (C) 2017-2021 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import copy + +from extensions.front.onnx.loop_ext import connect_body_input, connect_body_output +from extensions.ops.loop import Loop +from extensions.ops.parameter import Parameter +from mo.front.common.register_custom_ops import check_for_duplicates +from mo.front.extractor import extract_node_attrs, FrontExtractorOp +from mo.front.tf.extractor import tf_op_extractor, tf_op_extractors +from mo.front.tf.extractors.utils import tf_dtype_extractor +from mo.graph.graph import add_opoutput, Graph, Node +from mo.ops.op import PermuteAttrs + + +def update_body_graph(body_graph: Graph, subgraph_proto: dict, + body_parameter_names: list, body_results: list): + """ + Updates the loop body graph with a sub-graph (for body or condition functions) + :param body_graph: a loop body graph to be updated + :param subgraph_proto: a sub-graph in a protobuf format to be added into the loop body graph + :param body_parameter_names: a (unchanged) list of parameters in the loop body graph + :param body_results: a list of Result nodes that is extended with a list from a sub-graph + """ + # create a map from a node name in original model to a name in a loop body graph assuming + # that names in the original model are unique + # initially, the map contains names for parameters that are common for the body and condition graphs + map_original_name = {} + for idx, pb_node in enumerate(subgraph_proto['input_arg']): + map_original_name[pb_node.name] = body_parameter_names[idx] + + # walk through all nodes (non-parameter and non-result nodes) and add into the loop body graph + for pb_node in subgraph_proto['node_def']: + # create an NX node + id = body_graph.unique_id(pb_node.name) + map_original_name[pb_node.name] = id + body_graph.add_node(id, pb=pb_node, kind='op') + + # add incoming edges based on data_nodes_map + for dst_port, inp in enumerate(pb_node.input): + orig_src_id = inp.split(":")[0] + src_id = map_original_name[orig_src_id] + src_port = 0 if len(inp.split(":")) == 1 else int(inp.split(":")[-1]) + assert (body_graph.has_node(src_id)) + edge_attrs = { + 'out': src_port, + 'in': dst_port, + 'name': src_id, + 'fw_tensor_debug_info': [(src_id, src_port)], + 'in_attrs': ['in', 'name'], + 'out_attrs': ['out', 'name'], + 'data_attrs': ['fw_tensor_debug_info'] + } + body_graph.add_edge(src_id, id, **edge_attrs) + + # create Result nodes in the loop body graph + for output in subgraph_proto['output_arg']: + output_name = subgraph_proto['ret'][output.name] + orig_src_id = output_name.split(":")[0] + src_id = map_original_name[orig_src_id] + src_port = 0 if len(output_name.split(":")) == 1\ + else int(output_name.split(":")[-1]) + assert body_graph.has_node(src_id), 'The body graph does not contain output with name "{}"'.format( + src_id) + body_results.append(Node(body_graph, add_opoutput(body_graph, src_id, src_port, False))) + + +class WhileExtractor(FrontExtractorOp): + """ + The While operation is a variation of the while_loop primitive from TensorFlow 2 Python API. + While can have stateful operations in the body and condition graphs that does not influence on inference so + the logic for handling While and StatelessWhile (see below) is the same. + """ + op = 'While' + enabled = True + + @classmethod + def extract(cls, loop_node): + Loop.update_node_stat(loop_node, {}) + loop_name = loop_node.soft_get('name', loop_node.id) + + # check that required body and condition functions exist in the graph library + main_graph = loop_node.graph + body_graph_name = loop_node.pb.attr['body'].func.name + cond_graph_name = loop_node.pb.attr['cond'].func.name + assert 'library' in main_graph.graph, 'The graph does not contain a library that is required ' \ + 'by node with name "{}".'.format(loop_name) + library_graph = main_graph.graph['library'] + + assert body_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \ + 'that is required by node ' \ + 'with name "{}".'.format(body_graph_name, loop_name) + body_graph_proto = library_graph[body_graph_name] + + assert cond_graph_name in library_graph, 'The library does not contain a function with name "{}" ' \ + 'that is required by node ' \ + 'with name "{}".'.format(cond_graph_name, loop_name) + cond_graph_proto = library_graph[cond_graph_name] + + body_graph = Graph() + # fill the body graph + for attr_key in main_graph.graph.keys(): + if attr_key != 'library': + body_graph.graph[attr_key] = copy.deepcopy(main_graph.graph[attr_key]) + else: + # it is sufficient to have a link to the library + body_graph.graph['library'] = main_graph.graph['library'] + loop_node['body'] = body_graph + + # create Parameter nodes for the body graph + body_parameters = [] + body_parameter_names = [] + for idx, pb_node in enumerate(body_graph_proto['input_arg']): + param_id = body_graph.unique_id(pb_node.name) + body_graph.add_node(param_id, name=param_id, kind='op', op='Parameter', pb=None, shape=None) + parameter_node = Node(body_graph, pb_node.name) + Parameter.update_node_stat(parameter_node, + {'data_type': tf_dtype_extractor(pb_node.type), + 'permute_attrs': PermuteAttrs().update_attrs(attrs=[('shape', 'output:0')])} + ) + body_parameters.append(parameter_node) + body_parameter_names.append(param_id) + + # update the loop body graph with the body function graph + body_results = [] + update_body_graph(body_graph, body_graph_proto, body_parameter_names, body_results) + + # update the loop body graph with the condition function graph + update_body_graph(body_graph, cond_graph_proto, body_parameter_names, body_results) + + # add 'internal_layer_id' attribute which is a must have attribute for the loop body node + for idx, body_node in enumerate(body_graph.get_op_nodes()): + body_node['internal_layer_id'] = idx + + body_graph.stage = 'front' + + # Currently, + # Loop Inputs Order: + # 0 - current iteration + # 1 - trip count + # 2.. - "loop carried" dependencies variables + # + # Body Inputs Order: + # 0 - current iteration + # 1 - trip count + # 2.. - "loop carried" dependencies variables + # + # Body Outputs Order: + # 0 - current iteration + # 1 - trip count + # 2.. - "loop carried" dependencies variables + # + # Loop Outputs Order: + # 0 - current iteration + # 1 - trip count + # 2.. - "loop carried" dependencies variables + # + # so inputs must be reordered and execution condition must be created in the front transformation + # to be aligned with the specification + + # connect external input ports with body parameter nodes except current iteration + # since it must be disconnected from external port + for idx in range(1, len(body_parameters)): + connect_body_input(loop_node, idx, body_parameters[idx]) + + # mark current iteration input Parameter node and execution condition Result node + Loop.mark_current_iteration_parameter_node(loop_node, body_parameters[0]) + Loop.mark_execution_condition_result_node(loop_node, body_results[-1]) + + # connect back edges in the body except current iteration + for idx in range(1, len(body_parameters)): + Loop.add_back_edge(loop_node, body_parameters[idx], body_results[idx]) + + # connect body outputs with Loop operation output ports except the execution condition result + for idx in range(len(body_results)-1): + connect_body_output(loop_node, idx, body_results[idx]) + + # run function to parse body nodes attributes similar to the main graph + extract_node_attrs(body_graph, lambda node: tf_op_extractor(node, check_for_duplicates(tf_op_extractors))) + return cls.enabled + + +class StatelessWhileExtractor(FrontExtractorOp): + """ + The StatelessWhile operation is a variation of the while_loop primitive from TensorFlow 2 Python API. + StatelessWhile does not have stateful operations in the body and condition graphs. + """ + op = 'StatelessWhile' + enabled = True + + @classmethod + def extract(cls, loop_node): + WhileExtractor.extract(loop_node) + return cls.enabled diff --git a/model-optimizer/extensions/load/tf/loader.py b/model-optimizer/extensions/load/tf/loader.py index deb6929d6f35fd..b1e219f6429044 100644 --- a/model-optimizer/extensions/load/tf/loader.py +++ b/model-optimizer/extensions/load/tf/loader.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2020 Intel Corporation + Copyright (C) 2020-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -42,6 +42,7 @@ class TFLoader(Loader): enabled = True + run_not_recursively = True def load(self, graph: Graph): argv = graph.graph['cmd_params'] diff --git a/model-optimizer/extensions/ops/loop.py b/model-optimizer/extensions/ops/loop.py index 94aff3be31a3a6..39c2fb6ac18284 100644 --- a/model-optimizer/extensions/ops/loop.py +++ b/model-optimizer/extensions/ops/loop.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2017-2020 Intel Corporation + Copyright (C) 2017-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -94,6 +94,9 @@ def updated_body_parameters_shape(loop_node: Node): loop_port_idx = record['external_port_id'] if loop_port_idx != -1: input_shape = loop_node.in_port(loop_port_idx).get_connection().get_source().data.get_shape() + slice_axis = record['axis'] + if slice_axis is not None: + input_shape[slice_axis] = 1 body_node.shape = input_shape log.debug('Updated shape for the body node with internal_id "{}" with value {}' ''.format(record['internal_layer_id'], body_node.shape)) @@ -155,6 +158,8 @@ def iterations_count(loop_node: Node): num_iterations = loop_node.in_port(0).data.get_value() if num_iterations is not None: num_iterations = num_iterations.item(0) + if num_iterations < 0: + return None return num_iterations @staticmethod @@ -317,9 +322,57 @@ def update_port_map_value(port_map: dict, attr: str, original_value: int, new_va if record[attr] == original_value: record[attr] = new_value matched += 1 - assert matched == 1, 'More than one record in the portmap for attr "{}" wil original value "{}"' \ + assert matched == 1, 'More than one record in the portmap for attr "{}" with original value "{}"' \ ''.format(attr, original_value) + @staticmethod + def update_port_map_value_ext(port_map: dict, layer_id_attr: str, layer_id_value: int, + updated_attr: str, new_attr_value: int): + """ + Updates a value of requested attribute for a certain layer id in a port map + :param port_map: a map of external ports to internal layer ids + :param layer_id_attr: layer id attribute for which to update attribute + :param layer_id_value: layer id value for which to update attribute + :param updated_attr: a name of attribute which to update + :param new_attr_value: new value of attribute + """ + matched = 0 + for record in port_map: + if record.get(layer_id_attr) == layer_id_value: + record[updated_attr] = new_attr_value + matched += 1 + assert matched == 1, 'More than one record in the portmap for attr "{}" with original value "{}"' \ + ''.format(layer_id_attr, layer_id_value) + + @staticmethod + def back_edge_exists(back_edges_map: dict, from_layer: int, to_layer: int): + """ + Checks if a back edge exists in the back_edges_map connecting specific nodes + :param back_edges_map: a map where to search for specified back edge + :param from_layer: id of Result node that belongs a back edge + :param to_layer: id of Parameter node that belongs a back edge + :return: True or False + """ + for back_edge in back_edges_map: + if back_edge['from_layer'] == from_layer and back_edge['to_layer'] == to_layer: + return True + return False + + @staticmethod + def inter_edge_exists(port_map: dict, external_port_id: int, internal_layer_id: int): + """ + Check if inter-graph edge (i.e. an edge between the main graph and body graph) exists + :param port_map: a port map where to search for inter-graph edge + :param external_port_id: port index from/to which edge goes + :param internal_layer_id: layer id from/to which edge goes + :return: True or False + """ + for i_port in port_map: + if i_port['external_port_id'] == external_port_id and \ + i_port['internal_layer_id'] == internal_layer_id: + return True + return False + @staticmethod def re_numerate_input_ports(loop_node: Node): """ @@ -372,7 +425,8 @@ def re_number_output_port(loop_node: Node, old_port_id: int, new_port_id: int): new_port_id += 1 for port_idx_to_remove in reversed(range(new_port_id, max_port_id + 1)): - loop_node.delete_output_port(port_idx_to_remove) + if port_idx_to_remove in loop_node.out_ports().keys(): + loop_node.delete_output_port(port_idx_to_remove) @staticmethod def remove_unused_ops_from_port_map(loop_node: Node, port_map: dict, port_map_attr: str, dir: [None, str] = None): diff --git a/model-optimizer/mo/front/tf/loader.py b/model-optimizer/mo/front/tf/loader.py index f78eea18ca1ffe..7e3813a8059409 100644 --- a/model-optimizer/mo/front/tf/loader.py +++ b/model-optimizer/mo/front/tf/loader.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -272,6 +272,17 @@ def protobuf_attrs(pb:tf_v1.NodeDef): def protobuf2nx(graph, pb: tf_v1.GraphDef): fill_graph_with_nodes(graph, pb.node, get_id=lambda pb: pb.name, get_attrs=protobuf_attrs) + + # Create a library with auxiliary functions used in TensorFlow 2 operations + if hasattr(pb, 'library') and hasattr(pb.library, 'function'): + graph.graph['library'] = {} + for library_function in pb.library.function: + function_name = library_function.signature.name + graph.graph['library'][function_name] = {} + graph.graph['library'][function_name]['input_arg'] = library_function.signature.input_arg + graph.graph['library'][function_name]['output_arg'] = library_function.signature.output_arg + graph.graph['library'][function_name]['node_def'] = library_function.node_def + graph.graph['library'][function_name]['ret'] = library_function.ret # initial order of nodes in the GraphDef. It is used to specify order in # which merged nodes are added to the generated sub-graph GraphDef for the TensorFlow offload feature. graph.graph['initial_nodes_order'] = [node.name for node in pb.node] diff --git a/model-optimizer/mo/graph/graph.py b/model-optimizer/mo/graph/graph.py index a27cbe45773aaf..1b2ca67dfb2e5f 100644 --- a/model-optimizer/mo/graph/graph.py +++ b/model-optimizer/mo/graph/graph.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -120,7 +120,7 @@ def delete_output_port(self, idx, skip_if_absent=False): # no handling of control flow edges -- TODO control_flow = False if not skip_if_absent and idx not in self.out_ports(control_flow=control_flow): - raise Error("Input port with index {} doesn't exist in node {}.".format(idx, self.soft_get('name'))) + raise Error("Output port with index {} doesn't exist in node {}.".format(idx, self.soft_get('name'))) if not self.out_port(idx).disconnected(): self.out_port(idx).disconnect() del self._out_ports[idx] diff --git a/model-optimizer/mo/ops/op.py b/model-optimizer/mo/ops/op.py index 4f881c8dfc32d6..cf4c44aa4204fe 100644 --- a/model-optimizer/mo/ops/op.py +++ b/model-optimizer/mo/ops/op.py @@ -1,5 +1,5 @@ """ - Copyright (C) 2018-2020 Intel Corporation + Copyright (C) 2018-2021 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import networkx as nx import numpy as np +from mo.front.common.partial_infer.utils import int64_array from mo.front.extractor import add_attrs_props, update_ie_fields from mo.graph.graph import Node, Graph from mo.utils import class_registration @@ -445,7 +446,7 @@ def get_nhwc_to_nchw_permutation(dims_number: int): # Exclude 3D shapes from permutation process: identity permutation perm = list(range(0, dims_number)) inv = PermuteAttrs.get_inverse_permutation(perm) - return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv)) + return PermuteAttrs.Permutation(perm=int64_array(perm), inv=int64_array(inv)) @staticmethod def get_nchw_to_nhwc_permutation(dims_number: int): @@ -456,4 +457,4 @@ def get_nchw_to_nhwc_permutation(dims_number: int): # Exclude 3D shapes from permutation process: identity permutation perm = list(range(0, dims_number)) inv = PermuteAttrs.get_inverse_permutation(perm) - return PermuteAttrs.Permutation(perm=np.array(perm), inv=np.array(inv)) + return PermuteAttrs.Permutation(perm=int64_array(perm), inv=int64_array(inv))