Skip to content

Commit

Permalink
Transpose FQ optimization (#5763)
Browse files Browse the repository at this point in the history
* Transpose FQ optimization

* Tests added
  • Loading branch information
Evgenya Stepyreva authored May 25, 2021
1 parent d509fe6 commit 74293c5
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
9 changes: 8 additions & 1 deletion model-optimizer/extensions/back/MatMulNormalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from extensions.back.TransposeReduceFusing import TransposeReduce
from extensions.ops.transpose import Transpose
from mo.back.replacement import BackReplacementPattern
from mo.front.caffe.extractors.utils import get_canonical_axis_index
Expand Down Expand Up @@ -86,7 +87,8 @@ class PullTransposeThroughFQUp(BackReplacementPattern):
force_clean_up = True

def run_after(self):
return [MatMulConstTransposesExtraction]
# in case FQ->Transpose->Reduce we should first try to optimize out Transpose
return [MatMulConstTransposesExtraction, TransposeReduce]

@staticmethod
def pattern():
Expand All @@ -105,6 +107,11 @@ def pattern():
@staticmethod
def replace_pattern(graph: Graph, match: dict):
fq = match['fq']

if len(fq.out_port(0).get_destinations()) > 1:
# FQ should have only one child -- Transpose for optimization
return

transpose = match['transpose']
name = fq.soft_get('name', fq.id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import unittest
from argparse import Namespace

import numpy as np
from generator import generate, generator

from extensions.back.MatMulNormalizer import SmartReshape_HC_Reshape_MatMul
from extensions.back.MatMulNormalizer import SmartReshape_HC_Reshape_MatMul, PullTransposeThroughFQUp
from extensions.ops.MatMul import MatMul
from extensions.ops.fakequantize import FakeQuantize
from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.ops.reshape import Reshape
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
result, connect
result, connect, connect_data
from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data


Expand Down Expand Up @@ -95,3 +98,73 @@ def test_reshape_on_the_B_input(self,

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)


class FQTransposePullerTest(unittest.TestCase):
def nodes(self, input_shape, transpose_shape, fq_shape):
return {
**regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter')),
**valued_const_with_data('il', np.array([[[[0]]]])),
**valued_const_with_data('ih', np.array([[[[255]]]])),
**valued_const_with_data('ol', np.array([[[[0]]]])),
**valued_const_with_data('oh', np.array([[[[255]]]])),
**regular_op_with_shaped_data('FQ', fq_shape, dict(type='FakeQuantize', op='FakeQuantize', infer=FakeQuantize.infer)),
**valued_const_with_data('order', int64_array([0, 2, 3, 1])),
**regular_op_with_shaped_data('transpose', transpose_shape, dict(type='Transpose', op='Transpose', infer=Transpose.infer)),
**regular_op_with_shaped_data('relu', fq_shape, dict(type='Relu', op='Relu')),

**result(),
}

def test_positive(self):
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224])
edges = [
*connect('input', '0:FQ'),
*connect('il', '1:FQ'),
*connect('ih', '2:FQ'),
*connect('ol', '3:FQ'),
*connect('oh', '4:FQ'),
*connect('FQ:0', '0:transpose'),
*connect('order:0', '1:transpose'),
*connect('transpose:0', 'output'),
]
graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True)
PullTransposeThroughFQUp().find_and_replace_pattern(graph)
graph.clean_up()

nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3])
edges = [
*connect('input', '0:transpose'),
*connect('order:0', '1:transpose'),
*connect('transpose', '0:FQ'),
*connect('il', '1:FQ'),
*connect('ih', '2:FQ'),
*connect('ol', '3:FQ'),
*connect('oh', '4:FQ'),
*connect('FQ:0', 'output'),
]
graph_ref = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

def test_negative(self):
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224])
edges = [
*connect('input', '0:FQ'),
*connect('il', '1:FQ'),
*connect('ih', '2:FQ'),
*connect('ol', '3:FQ'),
*connect('oh', '4:FQ'),
*connect('FQ:0', '0:transpose'),
*connect_data('FQ:0', 'relu'),
*connect('order:0', '1:transpose'),
*connect('transpose:0', 'output'),
]
graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True)
graph_ref = graph.copy()
PullTransposeThroughFQUp().find_and_replace_pattern(graph)

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

0 comments on commit 74293c5

Please sign in to comment.