Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement transformation for TensorFlow 2 Map Function (aka tf.map_fn) #6836

2 changes: 1 addition & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,11 @@ extensions/front/tf/identity_ext.py
extensions/front/tf/identityN_to_identity.py
extensions/front/tf/InterpolateTransposes.py
extensions/front/tf/IteratorGetNext_ext.py
extensions/front/tf/KerasRNNTransformation.py
extensions/front/tf/log_softmax_ext.py
extensions/front/tf/LookupTableInsert_ext.py
extensions/front/tf/LoopCond_ext.py
extensions/front/tf/lrn_ext.py
extensions/front/tf/MapFNTransformation.py
extensions/front/tf/mask_rcnn_support.json
extensions/front/tf/mask_rcnn_support_api_v1.11.json
extensions/front/tf/mask_rcnn_support_api_v1.13.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,6 @@
from mo.ops.unsqueeze import Unsqueeze


def compute_input_port_idx(req_node: Node, loop_node: Node):
"""
Computes input port index by which requested node is passed to Loop node
:param req_node: a node for which to find input port index is requested
:param loop_node: a node that can receive input data from requested node by some input port
:return: input port index
"""
for destination in req_node.out_port(0).get_destinations():
if loop_node.id == destination.node.id:
return destination.idx
return None


def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict):
"""
Finds sub-graph matches corresponding pattern in graph
Expand All @@ -45,26 +32,18 @@ def find_subgraph_match_to_pattern(graph: Graph, body_pattern: dict):
return matches


class KerasRNNInputSlicing(FrontReplacementSubgraph):
class MapFNInputSlicing(FrontReplacementSubgraph):
"""
The transformation detects TensorFlow 2 pattern that corresponds to subsequent slicing of input.
It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original sub-graph
by adding axis attribute for corresponding input port of Loop node.
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
The transformation handles inputs slicing in While loop created by TensorFlow 2 Map Function primitive
(see tf.map_fn). It avoids TensorListFromTensor and TensorFlowGetItem operations and replaces the original
sub-graph by adding axis attribute in Loop node for slicing inputs.
The transformation is also applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
"""
enabled = True

def run_before(self):
return [WhileNormalize]

@staticmethod
def pattern(**kwargs):
return dict(
nodes=[('unstack', dict(op='TensorListFromTensor')),
('while', dict(op='Loop'))],
edges=[('unstack', 'while')]
)

@staticmethod
def get_body_pattern():
return dict(
Expand All @@ -84,7 +63,7 @@ def get_body_pattern():
)

@staticmethod
def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict):
def transform_map_fn_input_slicing(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for input slicing
Expand Down Expand Up @@ -115,51 +94,45 @@ def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict
# remove TensorListFromTensor and pass a tensor to Loop as is
unstack_node.out_port(0).get_connection().set_source(unstack_node.in_port(0).get_connection().get_source())

def replace_sub_graph(self, graph: Graph, external_match: dict):
loop_node = external_match['while']
body_graph = loop_node['body']
body_pattern = KerasRNNInputSlicing.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

# a case of multiple matches is not handled since it is not clear how to select corresponding match
if len(internal_matches) == 1:
internal_match = internal_matches[0]
loop_node = external_match['while']
unstack_port_idx = compute_input_port_idx(external_match['unstack'], loop_node)
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.input_port_map, unstack_port_idx,
internal_match['tensor_list'].internal_layer_id):
# only if inter-graph match passed it starts to process the sub-graph
KerasRNNInputSlicing.transform_keras_rnn_input_slicing(external_match, internal_match)


class KerasRNNOutputConcatenation(FrontReplacementSubgraph):
def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'):
body_graph = loop_node['body']
body_pattern = MapFNInputSlicing.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

for internal_match in internal_matches:
# check if TensorListGetItem from the body graph is connected with TensorListFromTensor
# from the main graph. If yes, the transformation detects input slicing by this port
# and can use Loop axis attribute
unstack_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['tensor_list'].internal_layer_id)
rkazants marked this conversation as resolved.
Show resolved Hide resolved
unstack_node = unstack_node[0] if (len(unstack_node) == 1
and unstack_node[0].op == 'TensorListFromTensor') else None
if unstack_node is None:
continue
rkazants marked this conversation as resolved.
Show resolved Hide resolved

external_match = {'while': loop_node,
'unstack': unstack_node}
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id):
MapFNInputSlicing.transform_map_fn_input_slicing(external_match, internal_match)


class MapFNOutputConcatenation(FrontReplacementSubgraph):
"""
The transformation detects TensorFlow 2 pattern that corresponds to concatenation of intermediate results
generated in each iteration of While operation.
It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces the original sub-graph
by adding axis attribute for corresponding output port of Loop node.
The transformation is applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
The transformation handles inputs slicing in While loop created by TensorFlow 2 Map Function primitive
(see tf.map_fn). It avoids TensorListReserve, TensorListStack, and TensorListSetItem operations and replaces
the original sub-graph by adding axis attribute in Loop node for concatenation of intermediate output results.
The transformation is also applicable to TensorFlow 2 Keras Simple RNN, GRU, and LSTM layers.
"""
enabled = True

def run_before(self):
return [WhileNormalize]

@staticmethod
def pattern(**kwargs):
return dict(
nodes=[('reserve', dict(op='TensorListReserve')),
('while', dict(op='Loop')),
('stack', dict(op='TensorListStack'))],
edges=[('reserve', 'while'),
('while', 'stack')]
)

@staticmethod
def get_body_pattern():
return dict(
Expand All @@ -184,7 +157,7 @@ def get_body_pattern():
)

@staticmethod
def transform_keras_rnn_output_concatenation(external_match: dict, internal_match: dict):
def transform_map_fn_output_concatenation(external_match: dict, internal_match: dict):
"""
Transforms TensorFlow 2 output concatenation into use of axis attribute for output port of Loop node
:param external_match: a match used for handling a part of the main graph responsible for output concatenation
Expand Down Expand Up @@ -229,27 +202,46 @@ def transform_keras_rnn_output_concatenation(external_match: dict, internal_matc
const_true = Const(body_graph, {'value': np.array(True, dtype=np.bool)}).create_node()
exec_cond_node.in_port(0).get_connection().set_source(const_true.out_port(0))

def replace_sub_graph(self, graph: Graph, external_match: dict):
loop_node = external_match['while']
body_graph = loop_node['body']
body_pattern = KerasRNNOutputConcatenation.get_body_pattern()

internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

if len(internal_matches) == 1:
internal_match = internal_matches[0]
reserve_port_idx = compute_input_port_idx(external_match['reserve'], loop_node)
stack_port_idx = external_match['stack'].in_port(0).get_source().idx
# check that back edges connect correct Parameter and Result nodes in the body
# check connections between body input ports and external inputs ports of Loop node
# check connections between body output ports and external output ports of Loop node
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
internal_match['container'].internal_layer_id) and \
Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.input_port_map, reserve_port_idx,
internal_match['container'].internal_layer_id) and \
Loop.inter_edge_exists(loop_node.output_port_map, stack_port_idx,
internal_match['concatenation_result'].internal_layer_id):
KerasRNNOutputConcatenation.transform_keras_rnn_output_concatenation(external_match, internal_match)
def find_and_replace_pattern(self, graph: Graph):
for loop_node in graph.get_op_nodes(op='Loop'):
body_graph = loop_node['body']
body_pattern = MapFNOutputConcatenation.get_body_pattern()
internal_matches = find_subgraph_match_to_pattern(body_graph, body_pattern)

for internal_match in internal_matches:
# check if TensorListReserve from the main graph is connected with Parameter node from the body graph
# that is assigned for storing intermediate output results of While Loop. If yes, the transformation
# detects intermediate outputs concatentation by this port and can use Loop axis attribute
reserve_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match['container'].internal_layer_id)
rkazants marked this conversation as resolved.
Show resolved Hide resolved
reserve_node = reserve_node[0] if (len(reserve_node) == 1 and
reserve_node[0].op == 'TensorListReserve') else None
if reserve_node is None:
continue
rkazants marked this conversation as resolved.
Show resolved Hide resolved
stack_node = Loop.get_external_nodes_by_internal_id(loop_node,
internal_match[
'concatenation_result'].internal_layer_id)
rkazants marked this conversation as resolved.
Show resolved Hide resolved
stack_node = stack_node if len(stack_node) == 1 else None

# skip StopGradient node if it exists between While loop output port and TensorListStack operation
if stack_node is None:
continue
if stack_node[0].op == 'StopGradient':
stack_node = [dest.node for dest in stack_node[0].out_port(0).get_destinations()]
rkazants marked this conversation as resolved.
Show resolved Hide resolved

stack_node = stack_node[0] if (len(stack_node) == 1 and
stack_node[0].op == 'TensorListStack') else None
if stack_node is None:
continue

external_match = {'while': loop_node,
'reserve': reserve_node,
'stack': stack_node}
# check that back edges connect Parameter node (or container with intermediate output results)
# and concatenation result produced by TensorListSetItem node
if Loop.back_edge_exists(loop_node.back_edges, internal_match['concatenation_result'].internal_layer_id,
internal_match['container'].internal_layer_id) and \
Loop.back_edge_exists(loop_node.back_edges,
internal_match['increment_iteration_result'].internal_layer_id,
internal_match['current_iteration'].internal_layer_id):
MapFNOutputConcatenation.transform_map_fn_output_concatenation(external_match, internal_match)
4 changes: 2 additions & 2 deletions model-optimizer/extensions/front/tf/ObjectDetectionAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from extensions.front.split_normalizer import SqueezeAxis
from extensions.front.tf.CropAndResizeReplacement import CropAndResizeReplacement
from extensions.front.tf.FakeQuantWithMinMaxVars import FakeQuantWithMinMaxVarsToQuantize
from extensions.front.tf.KerasRNNTransformation import KerasRNNInputSlicing, KerasRNNOutputConcatenation
from extensions.front.tf.MapFNTransformation import MapFNInputSlicing, MapFNOutputConcatenation
from extensions.front.tf.TFSliceToSlice import TFSliceToSliceReplacer
from extensions.front.tf.pad_tf_to_pad import PadTFToPad
from extensions.middle.InsertLayoutPropagationTransposes import mark_as_correct_data_layout, \
Expand Down Expand Up @@ -529,7 +529,7 @@ def run_before(self):
# is removed during removing nodes from the DO sub-graph so the first input to Transpose is missing which
# results in TransposeOrderNormalizer transformation failure.
return [Pack, TransposeOrderNormalizer, PadTFToPad, SqueezeAxis, TFSliceToSliceReplacer,
KerasRNNOutputConcatenation, KerasRNNInputSlicing]
MapFNOutputConcatenation, MapFNInputSlicing]

def find_and_replace_pattern(self, graph: Graph):
pass
Expand Down
21 changes: 21 additions & 0 deletions model-optimizer/extensions/ops/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,27 @@ def get_body_node_by_internal_id(loop_node: Node, internal_id: int):
'Expected 0 or 1 node with `internal_layer_id`={}, {} found'.format(internal_id, len(suitable_nodes))
return suitable_nodes[0] if len(suitable_nodes) == 1 else None

@staticmethod
def get_external_nodes_by_internal_id(loop_node: Node, internal_layer_id: int) -> list:
"""
Get a list of nodes from the main graph that are connected with a node with internal_layer_id
from the body graph

:param loop_node: The Loop node
:param internal_layer_id: Internal layer ID of the node in the body graph
:return: A list of external nodes (from the main graph) that are connected with a node with
internal_layer_id from the body graph
"""
for map_item in loop_node.input_port_map:
if map_item['internal_layer_id'] == internal_layer_id \
and not loop_node.in_port(map_item['external_port_id']).disconnected():
rkazants marked this conversation as resolved.
Show resolved Hide resolved
return [loop_node.in_port(map_item['external_port_id']).get_source().node]
for map_item in loop_node.output_port_map:
if map_item['internal_layer_id'] == internal_layer_id \
and not loop_node.out_port(map_item['external_port_id']).disconnected():
return [dest.node for dest in loop_node.out_port(map_item['external_port_id']).get_destinations()]
return []

@staticmethod
def updated_body_parameters_shape(loop_node: Node):
"""
Expand Down