Skip to content

Commit

Permalink
Fixed deleting Transpose layers after and before Interpolate layers. (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#1071)

* Fixed deleting Transpose layers after and before Interpolate layers.

* Added run_after() for the transformation InterpolateTranspose.

* Some checks were moved from the replacement function to the pattern.

* Added a check of the attribute 'axes' into the pattern.
  • Loading branch information
vgavrilo authored Jun 29, 2020
1 parent 182499c commit b9d6792
Showing 1 changed file with 47 additions and 32 deletions.
79 changes: 47 additions & 32 deletions model-optimizer/extensions/front/tf/InterpolateTransposes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,64 @@
import numpy as np

from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.replacement import FrontReplacementFromConfigFileGeneral
from mo.graph.graph import Graph, Node
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
from mo.front.tf.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph


class InterpolateTranspose(FrontReplacementFromConfigFileGeneral):
class InterpolateTranspose(FrontReplacementSubgraph):
"""
Delete useless transposes around ResizeNearestNeighbor op. In TF this op is working in NHWC layout,
Resample in OpenVINO working in NCHW layout. If all graph has NCHW layout we should delete transposes around
Resample: (NCHW->NHWC) -> Resample -> (NHWC -> NCHW) to run this op in NCHW without changes of layout.
"""
enabled = True
replacement_id = 'InterpolateTranspose'
graph_condition = [lambda graph: graph.graph['layout'] == 'NCHW']

pattern_nodes = [
('interpolate', {'kind': 'op', 'op': 'Interpolate'}),
('transpose_1', {'kind': 'op', 'op': 'Transpose'}),
('transpose_1_order', {'kind': 'op', 'op': 'Const',
'value': lambda x: x is not None and np.array_equal(x, int64_array([0, 2, 3, 1]))}),
('transpose_2', {'kind': 'op', 'op': 'Transpose'}),
('transpose_2_order', {'kind': 'op', 'op': 'Const',
'value': lambda x: x is not None and np.array_equal(x, int64_array([0, 3, 1, 2]))}),
]
pattern_edges = [
('transpose_1', 'interpolate', {'in': 0, 'out': 0}),
('transpose_1_order', 'transpose_1', {'in': 1, 'out': 0}),
('interpolate', 'transpose_2', {'in': 0, 'out': 0}),
('transpose_2_order', 'transpose_2', {'in': 1, 'out': 0}),
]
def pattern(self):
return dict(
nodes=[
('interpolate',
{
'kind': 'op',
'op': 'Interpolate',
'axes': lambda axes: axes is not None and np.array_equal(axes, int64_array([1, 2]))
}),
('transpose_1', {'kind': 'op', 'op': 'Transpose'}),
('transpose_1_order',
{
'kind': 'op',
'op': 'Const',
'value': lambda value: value is not None and np.array_equal(value, int64_array([0, 2, 3, 1]))
}),
('transpose_2', {'kind': 'op', 'op': 'Transpose'}),
('transpose_2_order',
{
'kind': 'op',
'op': 'Const',
'value': lambda value: value is not None and np.array_equal(value, int64_array([0, 3, 1, 2]))
}),
],
edges=[
('transpose_1', 'interpolate', {'in': 0, 'out': 0}),
('transpose_1_order', 'transpose_1', {'in': 1, 'out': 0}),
('interpolate', 'transpose_2', {'in': 0, 'out': 0}),
('transpose_2_order', 'transpose_2', {'in': 1, 'out': 0}),
]
)

def transform_graph(self, graph: Graph, replacement_descriptions: dict):
matches = find_pattern_matches(graph, self.pattern_nodes, self.pattern_edges)
for match in list(matches):
inverse_match = inverse_dict(match)
interpolate = Node(graph, inverse_match['interpolate'])
transpose_1 = Node(graph, inverse_match['transpose_1'])
transpose_2 = Node(graph, inverse_match['transpose_2'])
def run_after(self):
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
return [InterpolateNormalizer]

# because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout
interpolate.axes = int64_array([2, 3])
def replace_sub_graph(self, graph: Graph, match: dict):
interpolate = match['interpolate']
transpose_1 = match['transpose_1']
transpose_2 = match['transpose_2']

transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0))
transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0))
# because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout
interpolate.axes = int64_array([2, 3])

graph.remove_nodes_from([transpose_1.id, transpose_2.id])
transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0))
transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0))

graph.remove_nodes_from([transpose_1.id, transpose_2.id])

0 comments on commit b9d6792

Please sign in to comment.