diff --git a/model-optimizer/extensions/front/tf/InterpolateTransposes.py b/model-optimizer/extensions/front/tf/InterpolateTransposes.py index 91616c0508d912..6c19a05d103004 100644 --- a/model-optimizer/extensions/front/tf/InterpolateTransposes.py +++ b/model-optimizer/extensions/front/tf/InterpolateTransposes.py @@ -16,49 +16,64 @@ import numpy as np from mo.front.common.partial_infer.utils import int64_array -from mo.front.tf.replacement import FrontReplacementFromConfigFileGeneral -from mo.graph.graph import Graph, Node -from mo.middle.pattern_match import find_pattern_matches, inverse_dict +from mo.front.tf.replacement import FrontReplacementSubgraph +from mo.graph.graph import Graph -class InterpolateTranspose(FrontReplacementFromConfigFileGeneral): +class InterpolateTranspose(FrontReplacementSubgraph): """ Delete useless transposes around ResizeNearestNeighbor op. In TF this op is working in NHWC layout, Resample in OpenVINO working in NCHW layout. If all graph has NCHW layout we should delete transposes around Resample: (NCHW->NHWC) -> Resample -> (NHWC -> NCHW) to run this op in NCHW without changes of layout. """ enabled = True - replacement_id = 'InterpolateTranspose' graph_condition = [lambda graph: graph.graph['layout'] == 'NCHW'] - pattern_nodes = [ - ('interpolate', {'kind': 'op', 'op': 'Interpolate'}), - ('transpose_1', {'kind': 'op', 'op': 'Transpose'}), - ('transpose_1_order', {'kind': 'op', 'op': 'Const', - 'value': lambda x: x is not None and np.array_equal(x, int64_array([0, 2, 3, 1]))}), - ('transpose_2', {'kind': 'op', 'op': 'Transpose'}), - ('transpose_2_order', {'kind': 'op', 'op': 'Const', - 'value': lambda x: x is not None and np.array_equal(x, int64_array([0, 3, 1, 2]))}), - ] - pattern_edges = [ - ('transpose_1', 'interpolate', {'in': 0, 'out': 0}), - ('transpose_1_order', 'transpose_1', {'in': 1, 'out': 0}), - ('interpolate', 'transpose_2', {'in': 0, 'out': 0}), - ('transpose_2_order', 'transpose_2', {'in': 1, 'out': 0}), - ] + def pattern(self): + return dict( + nodes=[ + ('interpolate', + { + 'kind': 'op', + 'op': 'Interpolate', + 'axes': lambda axes: axes is not None and np.array_equal(axes, int64_array([1, 2])) + }), + ('transpose_1', {'kind': 'op', 'op': 'Transpose'}), + ('transpose_1_order', + { + 'kind': 'op', + 'op': 'Const', + 'value': lambda value: value is not None and np.array_equal(value, int64_array([0, 2, 3, 1])) + }), + ('transpose_2', {'kind': 'op', 'op': 'Transpose'}), + ('transpose_2_order', + { + 'kind': 'op', + 'op': 'Const', + 'value': lambda value: value is not None and np.array_equal(value, int64_array([0, 3, 1, 2])) + }), + ], + edges=[ + ('transpose_1', 'interpolate', {'in': 0, 'out': 0}), + ('transpose_1_order', 'transpose_1', {'in': 1, 'out': 0}), + ('interpolate', 'transpose_2', {'in': 0, 'out': 0}), + ('transpose_2_order', 'transpose_2', {'in': 1, 'out': 0}), + ] + ) - def transform_graph(self, graph: Graph, replacement_descriptions: dict): - matches = find_pattern_matches(graph, self.pattern_nodes, self.pattern_edges) - for match in list(matches): - inverse_match = inverse_dict(match) - interpolate = Node(graph, inverse_match['interpolate']) - transpose_1 = Node(graph, inverse_match['transpose_1']) - transpose_2 = Node(graph, inverse_match['transpose_2']) + def run_after(self): + from extensions.front.InterpolateNormalizer import InterpolateNormalizer + return [InterpolateNormalizer] - # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout - interpolate.axes = int64_array([2, 3]) + def replace_sub_graph(self, graph: Graph, match: dict): + interpolate = match['interpolate'] + transpose_1 = match['transpose_1'] + transpose_2 = match['transpose_2'] - transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0)) - transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0)) + # because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout + interpolate.axes = int64_array([2, 3]) - graph.remove_nodes_from([transpose_1.id, transpose_2.id]) + transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0)) + transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0)) + + graph.remove_nodes_from([transpose_1.id, transpose_2.id])