From b9d67927fd797975373cbfabbfe3d616f138b387 Mon Sep 17 00:00:00 2001 From: Vladimir Gavrilov Date: Mon, 29 Jun 2020 12:49:29 +0300 Subject: [PATCH] Fixed deleting Transpose layers after and before Interpolate layers. (#1071) * Fixed deleting Transpose layers after and before Interpolate layers. * Added run_after() for the transformation InterpolateTranspose. * Some checks were moved from the replacement function to the pattern. * Added a check of the attribute 'axes' into the pattern. --- .../front/tf/InterpolateTransposes.py | 79 +++++++++++-------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/model-optimizer/extensions/front/tf/InterpolateTransposes.py b/model-optimizer/extensions/front/tf/InterpolateTransposes.py index 91616c0508d912..6c19a05d103004 100644 --- a/model-optimizer/extensions/front/tf/InterpolateTransposes.py +++ b/model-optimizer/extensions/front/tf/InterpolateTransposes.py @@ -16,49 +16,64 @@ import numpy as np from mo.front.common.partial_infer.utils import int64_array -from mo.front.tf.replacement import FrontReplacementFromConfigFileGeneral -from mo.graph.graph import Graph, Node -from mo.middle.pattern_match import find_pattern_matches, inverse_dict +from mo.front.tf.replacement import FrontReplacementSubgraph +from mo.graph.graph import Graph -class InterpolateTranspose(FrontReplacementFromConfigFileGeneral): +class InterpolateTranspose(FrontReplacementSubgraph): """ Delete useless transposes around ResizeNearestNeighbor op. In TF this op is working in NHWC layout, Resample in OpenVINO working in NCHW layout. If all graph has NCHW layout we should delete transposes around Resample: (NCHW->NHWC) -> Resample -> (NHWC -> NCHW) to run this op in NCHW without changes of layout. """ enabled = True - replacement_id = 'InterpolateTranspose' graph_condition = [lambda graph: graph.graph['layout'] == 'NCHW'] - pattern_nodes = [ - ('interpolate', {'kind': 'op', 'op': 'Interpolate'}), - ('transpose_1', {'kind': 'op', 'op': 'Transpose'}), - ('transpose_1_order', {'kind': 'op', 'op': 'Const', - 'value': lambda x: x is not None and np.array_equal(x, int64_array([0, 2, 3, 1]))}), - ('transpose_2', {'kind': 'op', 'op': 'Transpose'}), - ('transpose_2_order', {'kind': 'op', 'op': 'Const', - 'value': lambda x: x is not None and np.array_equal(x, int64_array([0, 3, 1, 2]))}), - ] - pattern_edges = [ - ('transpose_1', 'interpolate', {'in': 0, 'out': 0}), - ('transpose_1_order', 'transpose_1', {'in': 1, 'out': 0}), - ('interpolate', 'transpose_2', {'in': 0, 'out': 0}), - ('transpose_2_order', 'transpose_2', {'in': 1, 'out': 0}), - ] + def pattern(self): + return dict( + nodes=[ + ('interpolate', + { + 'kind': 'op', + 'op': 'Interpolate', + 'axes': lambda axes: axes is not None and np.array_equal(axes, int64_array([1, 2])) + }), + ('transpose_1', {'kind': 'op', 'op': 'Transpose'}), + ('transpose_1_order', + { + 'kind': 'op', + 'op': 'Const', + 'value': lambda value: value is not None and np.array_equal(value, int64_array([0, 2, 3, 1])) + }), + ('transpose_2', {'kind': 'op', 'op': 'Transpose'}), + ('transpose_2_order', + { + 'kind': 'op', + 'op': 'Const', + 'value': lambda value: value is not None and np.array_equal(value, int64_array([0, 3, 1, 2])) + }), + ], + edges=[ + ('transpose_1', 'interpolate', {'in': 0, 'out': 0}), + ('transpose_1_order', 'transpose_1', {'in': 1, 'out': 0}), + ('interpolate', 'transpose_2', {'in': 0, 'out': 0}), + ('transpose_2_order', 'transpose_2', {'in': 1, 'out': 0}), + ] + ) - def transform_graph(self, graph: Graph, replacement_descriptions: dict): - matches = find_pattern_matches(graph, self.pattern_nodes, self.pattern_edges) - for match in list(matches): - inverse_match = inverse_dict(match) - interpolate = Node(graph, inverse_match['interpolate']) - transpose_1 = Node(graph, inverse_match['transpose_1']) - transpose_2 = Node(graph, inverse_match['transpose_2']) + def run_after(self): + from extensions.front.InterpolateNormalizer import InterpolateNormalizer + return [InterpolateNormalizer] - # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout - interpolate.axes = int64_array([2, 3]) + def replace_sub_graph(self, graph: Graph, match: dict): + interpolate = match['interpolate'] + transpose_1 = match['transpose_1'] + transpose_2 = match['transpose_2'] - transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0)) - transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0)) + # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout + interpolate.axes = int64_array([2, 3]) - graph.remove_nodes_from([transpose_1.id, transpose_2.id]) + transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0)) + transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0)) + + graph.remove_nodes_from([transpose_1.id, transpose_2.id])