From 74293c54df3c33d785d330e9a020cd9f2468907d Mon Sep 17 00:00:00 2001 From: Evgenya Stepyreva Date: Tue, 25 May 2021 16:42:21 +0300 Subject: [PATCH] Transpose FQ optimization (#5763) * Transpose FQ optimization * Tests added --- .../extensions/back/MatMulNormalizer.py | 9 ++- .../extensions/back/MatMulNormalizer_test.py | 77 ++++++++++++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/model-optimizer/extensions/back/MatMulNormalizer.py b/model-optimizer/extensions/back/MatMulNormalizer.py index 76cf9bfd173b09..fb860f86b5678e 100644 --- a/model-optimizer/extensions/back/MatMulNormalizer.py +++ b/model-optimizer/extensions/back/MatMulNormalizer.py @@ -3,6 +3,7 @@ import numpy as np +from extensions.back.TransposeReduceFusing import TransposeReduce from extensions.ops.transpose import Transpose from mo.back.replacement import BackReplacementPattern from mo.front.caffe.extractors.utils import get_canonical_axis_index @@ -86,7 +87,8 @@ class PullTransposeThroughFQUp(BackReplacementPattern): force_clean_up = True def run_after(self): - return [MatMulConstTransposesExtraction] + # in case FQ->Transpose->Reduce we should first try to optimize out Transpose + return [MatMulConstTransposesExtraction, TransposeReduce] @staticmethod def pattern(): @@ -105,6 +107,11 @@ def pattern(): @staticmethod def replace_pattern(graph: Graph, match: dict): fq = match['fq'] + + if len(fq.out_port(0).get_destinations()) > 1: + # FQ should have only one child -- Transpose for optimization + return + transpose = match['transpose'] name = fq.soft_get('name', fq.id) diff --git a/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py b/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py index 1d019526fcc46e..30805aa19aa84b 100644 --- a/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py +++ b/model-optimizer/unit_tests/extensions/back/MatMulNormalizer_test.py @@ -4,15 +4,18 @@ import unittest from argparse import Namespace +import numpy as np from generator import generate, generator -from extensions.back.MatMulNormalizer import SmartReshape_HC_Reshape_MatMul +from extensions.back.MatMulNormalizer import SmartReshape_HC_Reshape_MatMul, PullTransposeThroughFQUp from extensions.ops.MatMul import MatMul +from extensions.ops.fakequantize import FakeQuantize +from extensions.ops.transpose import Transpose from mo.front.common.partial_infer.utils import int64_array from mo.ops.reshape import Reshape from mo.utils.ir_engine.compare_graphs import compare_graphs from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \ - result, connect + result, connect, connect_data from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data @@ -95,3 +98,73 @@ def test_reshape_on_the_B_input(self, (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) + + +class FQTransposePullerTest(unittest.TestCase): + def nodes(self, input_shape, transpose_shape, fq_shape): + return { + **regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter')), + **valued_const_with_data('il', np.array([[[[0]]]])), + **valued_const_with_data('ih', np.array([[[[255]]]])), + **valued_const_with_data('ol', np.array([[[[0]]]])), + **valued_const_with_data('oh', np.array([[[[255]]]])), + **regular_op_with_shaped_data('FQ', fq_shape, dict(type='FakeQuantize', op='FakeQuantize', infer=FakeQuantize.infer)), + **valued_const_with_data('order', int64_array([0, 2, 3, 1])), + **regular_op_with_shaped_data('transpose', transpose_shape, dict(type='Transpose', op='Transpose', infer=Transpose.infer)), + **regular_op_with_shaped_data('relu', fq_shape, dict(type='Relu', op='Relu')), + + **result(), + } + + def test_positive(self): + nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224]) + edges = [ + *connect('input', '0:FQ'), + *connect('il', '1:FQ'), + *connect('ih', '2:FQ'), + *connect('ol', '3:FQ'), + *connect('oh', '4:FQ'), + *connect('FQ:0', '0:transpose'), + *connect('order:0', '1:transpose'), + *connect('transpose:0', 'output'), + ] + graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True) + PullTransposeThroughFQUp().find_and_replace_pattern(graph) + graph.clean_up() + + nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3]) + edges = [ + *connect('input', '0:transpose'), + *connect('order:0', '1:transpose'), + *connect('transpose', '0:FQ'), + *connect('il', '1:FQ'), + *connect('ih', '2:FQ'), + *connect('ol', '3:FQ'), + *connect('oh', '4:FQ'), + *connect('FQ:0', 'output'), + ] + graph_ref = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_negative(self): + nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224]) + edges = [ + *connect('input', '0:FQ'), + *connect('il', '1:FQ'), + *connect('ih', '2:FQ'), + *connect('ol', '3:FQ'), + *connect('oh', '4:FQ'), + *connect('FQ:0', '0:transpose'), + *connect_data('FQ:0', 'relu'), + *connect('order:0', '1:transpose'), + *connect('transpose:0', 'output'), + ] + graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True) + graph_ref = graph.copy() + PullTransposeThroughFQUp().find_and_replace_pattern(graph) + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) +