Skip to content

Commit

Permalink
[ MO ] Complete weights layout permutation (#2299)
Browse files Browse the repository at this point in the history
* MO TF: FQPerChannel extractor

* [ MO ] Complete weights layout permutation

* removed deleted file out of BOM

* Bring back stashed changes

* Skip if no weights permutation

* Conditional permutation

* Comments
  • Loading branch information
Evgenya Stepyreva authored Sep 18, 2020
1 parent 8dcff4a commit cd39138
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 204 deletions.
1 change: 0 additions & 1 deletion model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ extensions/middle/UnsqueezeTileReshapeBlockToInterpolate.py
extensions/middle/UpsampleToResample.py
extensions/middle/UselessMerge.py
extensions/middle/UselessSplitEraser.py
extensions/middle/wights_permute_normalizer.py
extensions/ops/__init__.py
extensions/ops/accum.py
extensions/ops/activation_ops.py
Expand Down
18 changes: 18 additions & 0 deletions model-optimizer/extensions/front/tf/FakeQuantWithMinMaxVars_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,21 @@ def extract(cls, node):
'narrow_range': narrow_range, 'num_bits': num_bits})

return cls.enabled


class FakeQuantWithMinMaxVarsPerChannelExtractor(FrontExtractorOp):
op = 'FakeQuantWithMinMaxVarsPerChannel'
enabled = True

@classmethod
def extract(cls, node):
narrow_range = node.pb.attr['narrow_range'].b
num_bits = node.pb.attr['num_bits'].i
levels = 2 ** num_bits - int(narrow_range)

# we prepare this operation to be converted to FakeQuantize op,
# but input reconnection is needed, so we don't set infer function and type attribute
Op.update_node_stat(node, {'op': 'FakeQuantWithMinMaxVars', 'levels': levels,
'narrow_range': narrow_range, 'num_bits': num_bits})

return cls.enabled
14 changes: 9 additions & 5 deletions model-optimizer/extensions/middle/ApplyPermutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from extensions.middle.pass_separator import PostMiddleStart
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import get_node_with_permutation
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern
from mo.utils.error import Error
Expand All @@ -47,6 +48,7 @@ def find_and_replace_pattern(self, graph: Graph):
self.permute_op_nodes_attrs(graph)
self.shape_of_sub_graph_reinference(graph)
self.permute_input_data(graph)
graph.graph['layout'] = 'NCHW'

@staticmethod
def merge_nodes_permutations(graph: Graph):
Expand Down Expand Up @@ -94,7 +96,8 @@ def merge_nodes_permutations(graph: Graph):
def permute_data_nodes_attrs(graph: Graph):
# Iterate over all data nodes and apply permutation if exists
for node in graph.get_data_nodes():
if not node.has_valid('permutation'):
if not node.has_valid('permutation') or \
all([attrs.get('input_permutation', False) for u, v, attrs in graph.out_edges(node.id, data=True)]):
continue

if len(
Expand Down Expand Up @@ -126,8 +129,6 @@ def permute_op_nodes_attrs(graph: Graph):

@staticmethod
def permute_input_data(graph: Graph):
if graph.graph['layout'] != 'NHWC':
return
for node in graph.get_op_nodes():
input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in
node.in_edges().items() if edge_attrs.get('input_permutation') is not None]
Expand All @@ -136,9 +137,12 @@ def permute_input_data(graph: Graph):
direction, port = port_info.split(':')
port = int(port)
port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port)
if not is_input_data_in_correct_layout(node, in_port) and len(port_to_check.data.get_shape()) >= 4:
permutation_data_node = get_node_with_permutation(node, port_info)

if permutation_data_node.has_and_set('permutation') and \
not is_input_data_in_correct_layout(node, in_port) and \
len(port_to_check.data.get_shape()) >= 4:
permutation(node, port_info, in_port)
graph.graph['layout'] = 'NCHW'

@staticmethod
def shape_of_sub_graph_reinference(graph: Graph):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,6 @@ def get_next_in_ports(in_port: Port) -> Set[Port]:
next_in_ports.update(out_port.get_destinations())
return next_in_ports

def mark_node_as_in_correct_layout_by_in_port(self, in_port):
next_in_ports = self.get_next_in_ports(in_port)
in_port.__setattr__('input_permutation', None)
mark_input_as_in_correct_layout(in_port.node, in_port.idx)
for port in next_in_ports:
mark_output_as_in_correct_layout(port.get_source().node, port.get_source().idx)

def find_shape_subgraph_endpoints(self, out_ports: List[Port], visited: set = None,
action: callable = None) -> Set[Port]:
"""
Expand Down Expand Up @@ -108,8 +101,7 @@ def find_and_replace_pattern(self, graph: Graph):
shape.out_port(0).get_connection().insert_node(gather)

# 2. Inserting Gather/Transpose to NC* format
shape_sub_graph_end_points = self.find_shape_subgraph_endpoints(
[shape.out_port(0) for shape in shape_ops], None, self.mark_node_as_in_correct_layout_by_in_port)
shape_sub_graph_end_points = self.find_shape_subgraph_endpoints([shape.out_port(0) for shape in shape_ops])
for in_port in shape_sub_graph_end_points:
name = in_port.node.soft_get('name', in_port.node.id)
shape = in_port.data.get_shape()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
import logging as log
from collections import deque

from typing import Set

from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
mark_as_correct_data_layout
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
from extensions.middle.pass_separator import PostMiddleStart
from mo.graph.graph import Graph, Node
from mo.graph.port import Port
from mo.middle.replacement import MiddleReplacementPattern


Expand Down Expand Up @@ -121,3 +125,88 @@ def find_and_replace_pattern(self, graph: Graph):
for visited_node in marked_nodes:
mark_as_correct_data_layout(visited_node)
visited_node['nchw_layout'] = True

for node in self.get_ports_and_nodes_on_weights(graph)[1]:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
node.out_node()['nchw_layout'] = True

for node in self.get_ports_and_nodes_on_shape_subgraphs(graph)[1]:
mark_as_correct_data_layout(node)
node['nchw_layout'] = True
if node.soft_get('type') == 'Const': # WA for Const op deletion during clean_up
node.out_node()['nchw_layout'] = True

@staticmethod
def walk_up_from_in_ports_to_out_ports(in_ports: Set[Port], out_ports: Set[Port],
visited_ports: Set[Port] = None, visited_nodes: Set[Node] = None):
""""
Returns all intermediate ports and nodes of such a sub-graph:
out_ports
| |
\/ \/
. . .
| |
\/ \/
in_ports
"""
if visited_ports is None:
visited_ports = set()
if visited_nodes is None:
visited_nodes = set()

deque_of_in_ports = deque(in_ports)
while len(deque_of_in_ports):
in_port = deque_of_in_ports.popleft()
source_node = in_port.get_source().node
if in_port in visited_ports: # do not check visited_nodes as search is based on ports
continue
visited_ports.update({in_port, in_port.get_source()})
if in_port.get_source() in out_ports: # reached source marked to stop the search
if not len(in_port.get_source().node.in_ports()): # for Constants and Parameters to be visited
visited_nodes.add(in_port.get_source().node)
continue
deque_of_in_ports.extend([port for port in source_node.in_ports().values() if not port.disconnected()])
visited_nodes.add(source_node)
return visited_ports, visited_nodes

@staticmethod
def get_ports_and_nodes_on_weights(graph):
get_weights_port_index = lambda node: node.weights_index if node.has_valid('weights_index') else 1
weighted_layer_type_to_in_weights_port = {
'Convolution': get_weights_port_index,
'DeformableConvolution': get_weights_port_index,
'Deconvolution': get_weights_port_index,
'BinaryConvolution': get_weights_port_index,
}
nodes = graph.get_op_nodes()
weighted_types = list(weighted_layer_type_to_in_weights_port.keys())

# collect all input ports with weights
weight_ports = set()
start_ports = set()
for node in nodes:
node_type = node.soft_get('type', 'unknown')
if node_type not in weighted_types:
if node_type in ['Const', 'Parameter', 'ShapeOf']:
start_ports.add(node.out_port(0))
continue
weight_port_idx = weighted_layer_type_to_in_weights_port[node_type](node)
assert node.is_in_port_connected(weight_port_idx), \
'Unexpected port configuration of {} node with name=`{}`'.format(node_type,
node.soft_get('name', node.id))
weight_ports.add(node.in_port(weight_port_idx))

# collect all sub-graphs that start with Constant/Parameter/ShapeOf and end at in_port as weights
ports, nodes = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(weight_ports, start_ports)
return ports, nodes

@staticmethod
def get_ports_and_nodes_on_shape_subgraphs(graph):
shape_sources = {shape_of.out_port(0) for shape_of in graph.get_op_nodes(type='ShapeOf')}
end_points = LayoutChangeForConstantShapePaths().find_shape_subgraph_endpoints(
[shape.out_port(0) for shape in graph.get_op_nodes(type='ShapeOf')])
ports, nodes = MarkSubGraphsWithCorrectLayout.walk_up_from_in_ports_to_out_ports(end_points, shape_sources)
return ports, nodes
15 changes: 1 addition & 14 deletions model-optimizer/extensions/middle/quantize_fuses.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,4 @@ def find_and_replace_pattern(self, graph: Graph):
port.get_source().connect(fuse_node_duplicate.in_port(idx))
fuse_node_duplicate.infer(fuse_node_duplicate)

first_port_fusion = False

if 'permutation' in quantize_node.in_edge(0):
permutation = quantize_node.in_edge(0)['permutation']
if permutation is None:
continue

perm_rank = permutation.perm.size

if not all([quantize_node.in_port(i).data.get_shape().size == perm_rank for i in range(1, 5)]):
continue

for i in range(1, 5):
quantize_node.in_edge(i)['permutation'] = permutation
first_port_fusion = False
118 changes: 0 additions & 118 deletions model-optimizer/extensions/middle/weights_permute_normalizer_test.py

This file was deleted.

51 changes: 0 additions & 51 deletions model-optimizer/extensions/middle/wights_permute_normalizer.py

This file was deleted.

Loading

0 comments on commit cd39138

Please sign in to comment.