diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index e4dc75cb2142e2..4e1df9034340bd 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -438,11 +438,11 @@ extensions/front/tf/identity_ext.py extensions/front/tf/identityN_to_identity.py extensions/front/tf/InterpolateTransposes.py extensions/front/tf/IteratorGetNext_ext.py -extensions/front/tf/KerasRNNTransformation.py extensions/front/tf/log_softmax_ext.py extensions/front/tf/LookupTableInsert_ext.py extensions/front/tf/LoopCond_ext.py extensions/front/tf/lrn_ext.py +extensions/front/tf/MapFNTransformation.py extensions/front/tf/mask_rcnn_support.json extensions/front/tf/mask_rcnn_support_api_v1.11.json extensions/front/tf/mask_rcnn_support_api_v1.13.json diff --git a/model-optimizer/extensions/front/tf/KerasRNNTransformation.py b/model-optimizer/extensions/front/tf/MapFNTransformation.py similarity index 57% rename from model-optimizer/extensions/front/tf/KerasRNNTransformation.py rename to model-optimizer/extensions/front/tf/MapFNTransformation.py index 70c853181a9624..167989cdfebc64 100644 --- a/model-optimizer/extensions/front/tf/KerasRNNTransformation.py +++ b/model-optimizer/extensions/front/tf/MapFNTransformation.py @@ -1,12 +1,15 @@ # Copyright (C) 2018-2021 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import logging as log + import numpy as np from extensions.front.tf.WhileNormalize import WhileNormalize from extensions.ops.loop import Loop from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.tf.custom_subgraph_call import skip_nodes_by_condition from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph, Node, rename_nodes from mo.middle.pattern_match import find_pattern_matches, inverse_dict @@ -15,19 +18,6 @@ from mo.ops.unsqueeze import Unsqueeze -def compute_input_port_idx(req_node: Node, loop_node: Node): - """ - Computes input port index by which requested node is passed to Loop node - :param req_node: a node for which to find input port index is requested - :param loop_node: a node that can receive input data from requested node by some input port - :return: input port index - """ - for destination in req_node.out_port(0).get_destinations(): - if loop_node.id == destination.node.id: - return destination.idx - return None - - def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict): """ Finds sub-graph matches corresponding pattern in graph @@ -45,26 +35,18 @@ def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict): return matches -class KerasRNNInputSlicing(FrontReplacementSubgraph): +class MapFNInputSlicing(FrontReplacementSubgraph): """ - The transformation detects TensorFlow 2 pattern that corresponds to subsequent slicing of input. - It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original sub-graph - by adding axis attribute for corresponding input port of Loop node. - The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers. + The transformation handles inputs slicing in While loop created by TensorFlow 2 Map Function primitive + (see tf.map_fn). It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original + sub-graph by adding axis attribute in Loop node for slicing inputs. + The transformation is also applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers. """ enabled = True def run_before(self): return [WhileNormalize] - @staticmethod - def pattern(**kwargs): - return dict( - nodes=[('unstack', dict(op='TensorListFromTensor')), - ('while', dict(op='Loop'))], - edges=[('unstack', 'while')] - ) - @staticmethod def get_body_pattern(): return dict( @@ -84,7 +66,7 @@ def get_body_pattern(): ) @staticmethod - def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict): + def transform_map_fn_input_slicing(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node :param external_match: a match used for handling a part of the main graph responsible for input slicing @@ -115,51 +97,48 @@ def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict # remove TensorListFromTensor and pass a tensor to Loop as is unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source()) - def replace_sub_graph(self, graph: Graph, external_match: dict): - loop_node = external_match['while'] - body_graph = loop_node['body'] - body_pattern = KerasRNNInputSlicing.get_body_pattern() - internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) - - # a case of multiple matches is not handled since it is not clear how to select corresponding match - if len(internal_matches) == 1: - internal_match = internal_matches[0] - loop_node = external_match['while'] - unstack_port_idx = compute_input_port_idx(external_match['unstack'], loop_node) - # check that back edges connect correct Parameter and Result nodes in the body - # check connections between body input ports and external inputs ports of Loop node - if Loop.back_edge_exists(loop_node.back_edges, - internal_match['increment_iteration_result'].internal_layer_id, - internal_match['current_iteration'].internal_layer_id) and \ - Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx, - internal_match['tensor_list'].internal_layer_id): - # only if inter-graph match passed it starts to process the sub-graph - KerasRNNInputSlicing.transform_keras_rnn_input_slicing(external_match, internal_match) - - -class KerasRNNOutputConcatenation(FrontReplacementSubgraph): + def find_and_replace_pattern(self, graph: Graph): + for loop_node in graph.get_op_nodes(op='Loop'): + loop_name = loop_node.soft_get('name', loop_node.id) + body_graph = loop_node['body'] + body_pattern = MapFNInputSlicing.get_body_pattern() + internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) + + for internal_match in internal_matches: + # check if TensorListGetItem from the body graph is connected with TensorListFromTensor + # from the main graph. If yes, the transformation detects input slicing by this port + # and can use Loop axis attribute + unstack_node = Loop.get_external_nodes_by_internal_id(loop_node, + internal_match['tensor_list'].internal_layer_id) + unstack_node = unstack_node[0] if (len(unstack_node) == 1 + and unstack_node[0].op == 'TensorListFromTensor') else None + if unstack_node is None: + log.info("A sub-graph around the loop node {} does not match " + "TensorFlow 2 MapFN pattern for input slicing".format(loop_name)) + continue + + external_match = {'while': loop_node, + 'unstack': unstack_node} + # check that back edges connect correct Parameter and Result nodes in the body + # check connections between body input ports and external inputs ports of Loop node + if Loop.back_edge_exists(loop_node.back_edges, + internal_match['increment_iteration_result'].internal_layer_id, + internal_match['current_iteration'].internal_layer_id): + MapFNInputSlicing.transform_map_fn_input_slicing(external_match, internal_match) + + +class MapFNOutputConcatenation(FrontReplacementSubgraph): """ - The transformation detects TensorFlow 2 pattern that corresponds to concatenation of intermediate results - generated in each iteration of While operation. - It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces the original sub-graph - by adding axis attribute for corresponding output port of Loop node. - The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers. + The transformation handles inputs slicing in While loop created by TensorFlow 2 Map Function primitive + (see tf.map_fn). It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces + the original sub-graph by adding axis attribute in Loop node for concatenation of intermediate output results. + The transformation is also applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers. """ enabled = True def run_before(self): return [WhileNormalize] - @staticmethod - def pattern(**kwargs): - return dict( - nodes=[('reserve', dict(op='TensorListReserve')), - ('while', dict(op='Loop')), - ('stack', dict(op='TensorListStack'))], - edges=[('reserve', 'while'), - ('while', 'stack')] - ) - @staticmethod def get_body_pattern(): return dict( @@ -184,7 +163,7 @@ def get_body_pattern(): ) @staticmethod - def transform_keras_rnn_output_concatenation(external_match: dict, internal_match: dict): + def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node :param external_match: a match used for handling a part of the main graph responsible for output concatenation @@ -229,27 +208,50 @@ def transform_keras_rnn_output_concatenation(external_match: dict, internal_matc const_true = Const(body_graph, {'value': np.array(True, dtype=np.bool)}).create_node() exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0)) - def replace_sub_graph(self, graph: Graph, external_match: dict): - loop_node = external_match['while'] - body_graph = loop_node['body'] - body_pattern = KerasRNNOutputConcatenation.get_body_pattern() - - internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) - - if len(internal_matches) == 1: - internal_match = internal_matches[0] - reserve_port_idx = compute_input_port_idx(external_match['reserve'], loop_node) - stack_port_idx = external_match['stack'].in_port(0).get_source().idx - # check that back edges connect correct Parameter and Result nodes in the body - # check connections between body input ports and external inputs ports of Loop node - # check connections between body output ports and external output ports of Loop node - if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, - internal_match['container'].internal_layer_id) and \ - Loop.back_edge_exists(loop_node.back_edges, - internal_match['increment_iteration_result'].internal_layer_id, - internal_match['current_iteration'].internal_layer_id) and \ - Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx, - internal_match['container'].internal_layer_id) and \ - Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx, - internal_match['concatenation_result'].internal_layer_id): - KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(external_match, internal_match) + def find_and_replace_pattern(self, graph: Graph): + for loop_node in graph.get_op_nodes(op='Loop'): + loop_name = loop_node.soft_get('name', loop_node.id) + body_graph = loop_node['body'] + body_pattern = MapFNOutputConcatenation.get_body_pattern() + internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern) + + for internal_match in internal_matches: + # check if TensorListReserve from the main graph is connected with Parameter node from the body graph + # that is assigned for storing intermediate output results of While Loop. If yes, the transformation + # detects intermediate outputs concatentation by this port and can use Loop axis attribute + reserve_node = Loop.get_external_nodes_by_internal_id(loop_node, + internal_match['container'].internal_layer_id) + reserve_node = reserve_node[0] if (len(reserve_node) == 1 and + reserve_node[0].op == 'TensorListReserve') else None + if reserve_node is None: + log.info("A sub-graph around the loop node {} does not match " + "TensorFlow 2 MapFN pattern for intermediate outputs concatenation".format(loop_name)) + continue + stack_node = Loop.get_external_nodes_by_internal_id( + loop_node, internal_match['concatenation_result'].internal_layer_id) + stack_node = stack_node[0] if len(stack_node) == 1 else None + + if stack_node is None: + log.info("A sub-graph around the loop node {} does not match " + "TensorFlow 2 MapFN pattern for intermediate outputs concatenation".format(loop_name)) + continue + + # skip StopGradient node if it exists between While loop output port and TensorListStack operation + stack_node = skip_nodes_by_condition(stack_node, lambda x: x.has_and_set('identity'), True) + stack_node = stack_node if stack_node.op == 'TensorListStack' else None + if stack_node is None: + log.info("A sub-graph around the loop node {} does not match " + "TensorFlow 2 MapFN pattern for intermediate outputs concatenation".format(loop_name)) + continue + + external_match = {'while': loop_node, + 'reserve': reserve_node, + 'stack': stack_node} + # check that back edges connect Parameter node (or container with intermediate output results) + # and concatenation result produced by TensorListSetItem node + if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id, + internal_match['container'].internal_layer_id) and \ + Loop.back_edge_exists(loop_node.back_edges, + internal_match['increment_iteration_result'].internal_layer_id, + internal_match['current_iteration'].internal_layer_id): + MapFNOutputConcatenation.transform_map_fn_output_concatenation(external_match, internal_match) diff --git a/model-optimizer/extensions/front/tf/ObjectDetectionAPI.py b/model-optimizer/extensions/front/tf/ObjectDetectionAPI.py index 736457284c6d41..87fce4fa6e7009 100644 --- a/model-optimizer/extensions/front/tf/ObjectDetectionAPI.py +++ b/model-optimizer/extensions/front/tf/ObjectDetectionAPI.py @@ -11,7 +11,7 @@ from extensions.front.split_normalizer import SqueezeAxis from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement from extensions.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize -from extensions.front.tf.KerasRNNTransformation import KerasRNNInputSlicing, KerasRNNOutputConcatenation +from extensions.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer from extensions.front.tf.pad_tf_to_pad import PadTFToPad from extensions.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \ @@ -31,6 +31,7 @@ from mo.front.common.replacement import FrontReplacementPattern from mo.front.extractor import output_user_data_repack, add_output_ops from mo.front.subgraph_matcher import SubgraphMatch +from mo.front.tf.custom_subgraph_call import skip_nodes_by_condition from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \ mark_squeeze_reshape_concat_before_detection_output, add_fake_background_loc, create_op_node_with_second_input, \ create_op_with_const_inputs @@ -346,12 +347,6 @@ def swap_weights_xy(graph: Graph, nodes: list): insert_weights_swap_xy_sub_graph(graph, m.in_port(1).get_connection()) -def skip_nodes_by_condition(current_node: Node, condition: callable): - while condition(current_node): - current_node = current_node.in_node() - return current_node - - def calculate_shape_keeping_aspect_ratio(height: int, width: int, min_size: int, max_size: int, pad_to_max_dimension: bool = False): """ @@ -529,7 +524,7 @@ def run_before(self): # is removed during removing nodes from the DO sub-graph so the first input to Transpose is missing which # results in TransposeOrderNormalizer transformation failure. return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, TFSliceToSliceReplacer, - KerasRNNOutputConcatenation, KerasRNNInputSlicing] + MapFNOutputConcatenation, MapFNInputSlicing] def find_and_replace_pattern(self, graph: Graph): pass diff --git a/model-optimizer/extensions/ops/loop.py b/model-optimizer/extensions/ops/loop.py index 6aca93bebc8e2e..4089983bc3b478 100644 --- a/model-optimizer/extensions/ops/loop.py +++ b/model-optimizer/extensions/ops/loop.py @@ -61,6 +61,27 @@ def get_body_node_by_internal_id(loop_node: Node, internal_id: int): 'Expected 0 or 1 node with `internal_layer_id`={}, {} found'.format(internal_id, len(suitable_nodes)) return suitable_nodes[0] if len(suitable_nodes) == 1 else None + @staticmethod + def get_external_nodes_by_internal_id(loop_node: Node, internal_layer_id: int) -> list: + """ + Get a list of nodes from the main graph that are connected with a node with internal_layer_id + from the body graph + + :param loop_node: The Loop node + :param internal_layer_id: Internal layer ID of the node in the body graph + :return: A list of external nodes (from the main graph) that are connected with a node with + internal_layer_id from the body graph + """ + for map_item in loop_node.input_port_map: + if map_item['internal_layer_id'] == internal_layer_id \ + and loop_node.is_in_port_connected(map_item['external_port_id']): + return [loop_node.in_port(map_item['external_port_id']).get_source().node] + for map_item in loop_node.output_port_map: + if map_item['internal_layer_id'] == internal_layer_id \ + and loop_node.is_out_port_connected(map_item['external_port_id']): + return [dest.node for dest in loop_node.out_port(map_item['external_port_id']).get_destinations()] + return [] + @staticmethod def updated_body_parameters_shape(loop_node: Node): """ diff --git a/model-optimizer/mo/front/tf/custom_subgraph_call.py b/model-optimizer/mo/front/tf/custom_subgraph_call.py index 53ec45a7b0cd04..ab2062f5a29f22 100644 --- a/model-optimizer/mo/front/tf/custom_subgraph_call.py +++ b/model-optimizer/mo/front/tf/custom_subgraph_call.py @@ -149,3 +149,13 @@ def set_tf_custom_call_node_attrs(node_attrs: dict): node_attrs['op'] = 'TFCustomSubgraphCall' node_attrs['infer'] = tf_subgraph_infer node_attrs['kind'] = 'op' + + +def skip_nodes_by_condition(current_node: Node, condition: callable, forward: bool = False): + if forward: + while condition(current_node): + current_node = current_node.out_node() + else: + while condition(current_node): + current_node = current_node.in_node() + return current_node