Skip to content

Commit

Permalink
Implement transformation for TensorFlow 2 Map Function (aka tf.map_fn) (
Browse files Browse the repository at this point in the history
openvinotoolkit#6836)

* Implement transformation for TensorFlow 2 Map Function primitive

Signed-off-by: Roman Kazantsev <[email protected]>

* Add a description for get_external_node_by_internal_id function

Signed-off-by: Roman Kazantsev <[email protected]>

* Correct a name for get_external_nodes_by_internal_id function

Signed-off-by: Roman Kazantsev <[email protected]>

* Fix a description for get_external_nodes_by_internal_id function

Signed-off-by: Roman Kazantsev <[email protected]>

* Add logging and fix indentation

Signed-off-by: Roman Kazantsev <[email protected]>

* Use skip_nodes_by_condition to by-pass StopGradient nodes for tf.map_fn

Signed-off-by: Roman Kazantsev <[email protected]>
  • Loading branch information
rkazants authored and rnugmanx committed Aug 26, 2021
1 parent f6d4a5b commit 92d6d16
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 99 deletions.
2 changes: 1 addition & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,11 @@ extensions/front/tf/identity_ext.py
extensions/front/tf/identityN_to_identity.py
extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/KerasRNNTransformation.py
extensions/front/tf/log_softmax_ext.py
extensions/front/tf/LookupTableInsert_ext.py
extensions/front/tf/LoopCond_ext.py
extensions/front/tf/lrn_ext.py
extensions/front/tf/MapFNTransformation.py
extensions/front/tf/mask_rcnn_support.json
extensions/front/tf/mask_rcnn_support_api_v1.11.json
extensions/front/tf/mask_rcnn_support_api_v1.13.json
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging as log

import numpy as np

from extensions.front.tf.WhileNormalize import WhileNormalize
from extensions.ops.loop import Loop
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.custom_subgraph_call import skip_nodes_by_condition
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
Expand All @@ -15,19 +18,6 @@
from mo.ops.unsqueeze import Unsqueeze


def compute_input_port_idx(req_node: Node, loop_node: Node):
"""
Computes input port index by which requested node is passed to Loop node
:param req_node: a node for which to find input port index is requested
:param loop_node: a node that can receive input data from requested node by some input port
:return: input port index
"""
for destination in req_node.out_port(0).get_destinations():
if loop_node.id == destination.node.id:
return destination.idx
return None


def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict):
"""
Finds sub-graph matches corresponding pattern in graph
Expand All @@ -45,26 +35,18 @@ def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict):
return matches


class KerasRNNInputSlicing(FrontReplacementSubgraph):
class MapFNInputSlicing(FrontReplacementSubgraph):
"""
The transformation detects TensorFlow 2 pattern that corresponds to subsequent slicing of input.
It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original sub-graph
by adding axis attribute for corresponding input port of Loop node.
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
The transformation handles inputs slicing in While loop created by TensorFlow 2 Map Function primitive
(see tf.map_fn). It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original
sub-graph by adding axis attribute in Loop node for slicing inputs.
The transformation is also applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
"""
enabled = True

def run_before(self):
return [WhileNormalize]

@staticmethod
def pattern(**kwargs):
return dict(
nodes=[('unstack', dict(op='TensorListFromTensor')),
('while', dict(op='Loop'))],
edges=[('unstack', 'while')]
)

@staticmethod
def get_body_pattern():
return dict(
Expand All @@ -84,7 +66,7 @@ def get_body_pattern():
)

@staticmethod
def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict):
def transform_map_fn_input_slicing(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for input slicing
Expand Down Expand Up @@ -115,51 +97,48 @@ def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict
# remove TensorListFromTensor and pass a tensor to Loop as is
unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source())

def replace_sub_graph(self, graph: Graph, external_match: dict):
loop_node = external_match['while']
body_graph = loop_node['body']
body_pattern = KerasRNNInputSlicing.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

# a case of multiple matches is not handled since it is not clear how to select corresponding match
if len(internal_matches) == 1:
internal_match = internal_matches[0]
loop_node = external_match['while']
unstack_port_idx = compute_input_port_idx(external_match['unstack'], loop_node)
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx,
internal_match['tensor_list'].internal_layer_id):
# only if inter-graph match passed it starts to process the sub-graph
KerasRNNInputSlicing.transform_keras_rnn_input_slicing(external_match, internal_match)


class KerasRNNOutputConcatenation(FrontReplacementSubgraph):
def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'):
loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body']
body_pattern = MapFNInputSlicing.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

for internal_match in internal_matches:
# check if TensorListGetItem from the body graph is connected with TensorListFromTensor
# from the main graph. If yes, the transformation detects input slicing by this port
# and can use Loop axis attribute
unstack_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['tensor_list'].internal_layer_id)
unstack_node = unstack_node[0] if (len(unstack_node) == 1
and unstack_node[0].op == 'TensorListFromTensor') else None
if unstack_node is None:
log.info("A sub-graph around the loop node {} does not match "
"TensorFlow 2 MapFN pattern for input slicing".format(loop_name))
continue

external_match = {'while': loop_node,
'unstack': unstack_node}
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id):
MapFNInputSlicing.transform_map_fn_input_slicing(external_match, internal_match)


class MapFNOutputConcatenation(FrontReplacementSubgraph):
"""
The transformation detects TensorFlow 2 pattern that corresponds to concatenation of intermediate results
generated in each iteration of While operation.
It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces the original sub-graph
by adding axis attribute for corresponding output port of Loop node.
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
The transformation handles inputs slicing in While loop created by TensorFlow 2 Map Function primitive
(see tf.map_fn). It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces
the original sub-graph by adding axis attribute in Loop node for concatenation of intermediate output results.
The transformation is also applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
"""
enabled = True

def run_before(self):
return [WhileNormalize]

@staticmethod
def pattern(**kwargs):
return dict(
nodes=[('reserve', dict(op='TensorListReserve')),
('while', dict(op='Loop')),
('stack', dict(op='TensorListStack'))],
edges=[('reserve', 'while'),
('while', 'stack')]
)

@staticmethod
def get_body_pattern():
return dict(
Expand All @@ -184,7 +163,7 @@ def get_body_pattern():
)

@staticmethod
def transform_keras_rnn_output_concatenation(external_match: dict, internal_match: dict):
def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for output concatenation
Expand Down Expand Up @@ -229,27 +208,50 @@ def transform_keras_rnn_output_concatenation(external_match: dict, internal_matc
const_true = Const(body_graph, {'value': np.array(True, dtype=np.bool)}).create_node()
exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0))

def replace_sub_graph(self, graph: Graph, external_match: dict):
loop_node = external_match['while']
body_graph = loop_node['body']
body_pattern = KerasRNNOutputConcatenation.get_body_pattern()

internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

if len(internal_matches) == 1:
internal_match = internal_matches[0]
reserve_port_idx = compute_input_port_idx(external_match['reserve'], loop_node)
stack_port_idx = external_match['stack'].in_port(0).get_source().idx
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
# check connections between body output ports and external output ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
internal_match['container'].internal_layer_id) and \
Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx,
internal_match['container'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx,
internal_match['concatenation_result'].internal_layer_id):
KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(external_match, internal_match)
def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'):
loop_name = loop_node.soft_get('name', loop_node.id)
body_graph = loop_node['body']
body_pattern = MapFNOutputConcatenation.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

for internal_match in internal_matches:
# check if TensorListReserve from the main graph is connected with Parameter node from the body graph
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation
# detects intermediate outputs concatentation by this port and can use Loop axis attribute
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['container'].internal_layer_id)
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
reserve_node[0].op == 'TensorListReserve') else None
if reserve_node is None:
log.info("A sub-graph around the loop node {} does not match "
"TensorFlow 2 MapFN pattern for intermediate outputs concatenation".format(loop_name))
continue
stack_node = Loop.get_external_nodes_by_internal_id(
loop_node, internal_match['concatenation_result'].internal_layer_id)
stack_node = stack_node[0] if len(stack_node) == 1 else None

if stack_node is None:
log.info("A sub-graph around the loop node {} does not match "
"TensorFlow 2 MapFN pattern for intermediate outputs concatenation".format(loop_name))
continue

# skip StopGradient node if it exists between While loop output port and TensorListStack operation
stack_node = skip_nodes_by_condition(stack_node, lambda x: x.has_and_set('identity'), True)
stack_node = stack_node if stack_node.op == 'TensorListStack' else None
if stack_node is None:
log.info("A sub-graph around the loop node {} does not match "
"TensorFlow 2 MapFN pattern for intermediate outputs concatenation".format(loop_name))
continue

external_match = {'while': loop_node,
'reserve': reserve_node,
'stack': stack_node}
# check that back edges connect Parameter node (or container with intermediate output results)
# and concatenation result produced by TensorListSetItem node
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
internal_match['container'].internal_layer_id) and \
Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id):
MapFNOutputConcatenation.transform_map_fn_output_concatenation(external_match, internal_match)
11 changes: 3 additions & 8 deletions model-optimizer/extensions/front/tf/ObjectDetectionAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from extensions.front.split_normalizer import SqueezeAxis
from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
from extensions.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
from extensions.front.tf.KerasRNNTransformation import KerasRNNInputSlicing, KerasRNNOutputConcatenation
from extensions.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation
from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
from extensions.front.tf.pad_tf_to_pad import PadTFToPad
from extensions.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
Expand All @@ -31,6 +31,7 @@
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.extractor import output_user_data_repack, add_output_ops
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.custom_subgraph_call import skip_nodes_by_condition
from mo.front.tf.graph_utils import add_activation_function_after_node, add_convolution_to_swap_xy_coordinates, \
mark_squeeze_reshape_concat_before_detection_output, add_fake_background_loc, create_op_node_with_second_input, \
create_op_with_const_inputs
Expand Down Expand Up @@ -346,12 +347,6 @@ def swap_weights_xy(graph: Graph, nodes: list):
insert_weights_swap_xy_sub_graph(graph, m.in_port(1).get_connection())


def skip_nodes_by_condition(current_node: Node, condition: callable):
while condition(current_node):
current_node = current_node.in_node()
return current_node


def calculate_shape_keeping_aspect_ratio(height: int, width: int, min_size: int, max_size: int,
pad_to_max_dimension: bool = False):
"""
Expand Down Expand Up @@ -529,7 +524,7 @@ def run_before(self):
# is removed during removing nodes from the DO sub-graph so the first input to Transpose is missing which
# results in TransposeOrderNormalizer transformation failure.
return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, TFSliceToSliceReplacer,
KerasRNNOutputConcatenation, KerasRNNInputSlicing]
MapFNOutputConcatenation, MapFNInputSlicing]

def find_and_replace_pattern(self, graph: Graph):
pass
Expand Down
21 changes: 21 additions & 0 deletions model-optimizer/extensions/ops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,27 @@ def get_body_node_by_internal_id(loop_node: Node, internal_id: int):
'Expected 0 or 1 node with `internal_layer_id`={}, {} found'.format(internal_id, len(suitable_nodes))
return suitable_nodes[0] if len(suitable_nodes) == 1 else None

@staticmethod
def get_external_nodes_by_internal_id(loop_node: Node, internal_layer_id: int) -> list:
"""
Get a list of nodes from the main graph that are connected with a node with internal_layer_id
from the body graph
:param loop_node: The Loop node
:param internal_layer_id: Internal layer ID of the node in the body graph
:return: A list of external nodes (from the main graph) that are connected with a node with
internal_layer_id from the body graph
"""
for map_item in loop_node.input_port_map:
if map_item['internal_layer_id'] == internal_layer_id \
and loop_node.is_in_port_connected(map_item['external_port_id']):
return [loop_node.in_port(map_item['external_port_id']).get_source().node]
for map_item in loop_node.output_port_map:
if map_item['internal_layer_id'] == internal_layer_id \
and loop_node.is_out_port_connected(map_item['external_port_id']):
return [dest.node for dest in loop_node.out_port(map_item['external_port_id']).get_destinations()]
return []

@staticmethod
def updated_body_parameters_shape(loop_node: Node):
"""
Expand Down
10 changes: 10 additions & 0 deletions model-optimizer/mo/front/tf/custom_subgraph_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,13 @@ def set_tf_custom_call_node_attrs(node_attrs: dict):
node_attrs['op'] = 'TFCustomSubgraphCall'
node_attrs['infer'] = tf_subgraph_infer
node_attrs['kind'] = 'op'


def skip_nodes_by_condition(current_node: Node, condition: callable, forward: bool = False):
if forward:
while condition(current_node):
current_node = current_node.out_node()
else:
while condition(current_node):
current_node = current_node.in_node()
return current_node

0 comments on commit 92d6d16

Please sign in to comment.