Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][Torch][KQV self attention] Align FQ placement between OV and Torch backend #2166

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4260614
Add softmax -> dropout -> mm <- non pattern pattern / add new names t…
daniil-lyakhov Sep 29, 2023
9698e6c
LayerNorm metatype was added to ignored metatypes for MinMax
daniil-lyakhov Oct 4, 2023
39d785b
Add GroupNorm to ignored MinMax metatypes
daniil-lyakhov Oct 5, 2023
4aa0811
Duplicates are removed from ignored torch patterns
daniil-lyakhov Oct 5, 2023
c6dcc08
Dropout removing pass is added to function
daniil-lyakhov Oct 6, 2023
48da94e
metatypes_to_ignore quantization propagation solver test
daniil-lyakhov Oct 9, 2023
42fa63b
Fix test_passes
daniil-lyakhov Oct 9, 2023
8abc6ea
get_inference_graph method is refactored
daniil-lyakhov Oct 11, 2023
bb90e81
Metrics update
daniil-lyakhov Oct 11, 2023
b4bb243
get_inference_graph fix
daniil-lyakhov Oct 11, 2023
c4225e6
Merge remote-tracking branch 'origin/develop' into dl/torch/patterns_…
daniil-lyakhov Oct 11, 2023
9837d2b
ptq_params test microfix
daniil-lyakhov Oct 11, 2023
d478e2b
Revert transform_to_inferece_graph function
daniil-lyakhov Oct 17, 2023
237ba50
Fix tests
daniil-lyakhov Oct 17, 2023
0638a35
Merge remote-tracking branch 'origin/develop' into dl/torch/patterns_…
daniil-lyakhov Oct 18, 2023
2ecb439
Metrics update
daniil-lyakhov Oct 18, 2023
a662180
Dropout removal original graph reference
daniil-lyakhov Oct 18, 2023
a9d99d0
Clean
daniil-lyakhov Oct 18, 2023
03c2b5e
function remove_dropout_node is refactored
daniil-lyakhov Oct 19, 2023
ce14175
Make passes return values
daniil-lyakhov Oct 19, 2023
5b962ab
Merge remote-tracking branch 'origin/develop' into dl/torch/patterns_…
daniil-lyakhov Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,22 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
attrs_edge = {}
u = u.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
v = v.replace(__RESERVED_DOT_CHARACTER, __CHARACTER_REPLACE_TO)
label = {}
if edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]:
label["parallel_input_port_ids"] = edge[NNCFGraph.PARALLEL_INPUT_PORT_IDS_ATTR]

if extended:
if edge[NNCFGraph.DTYPE_EDGE_ATTR] is Dtype.INTEGER:
attrs_edge["style"] = "dashed"
else:
attrs_edge["style"] = "solid"
attrs_edge["label"] = edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]
label["shape"] = edge[NNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR]

if label:
if "shape" in label and len(label) == 1:
attrs_edge["label"] = label["shape"]
else:
attrs_edge["label"] = ", ".join((f"{k}:{v}" for k, v in label.items()))
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
out_graph.add_edge(u, v, **attrs_edge)
return out_graph

Expand Down
5 changes: 4 additions & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,10 @@ def _get_quantization_target_points(
hw_patterns = PatternsManager.get_full_hw_pattern_graph(backend=backend, device=device, model_type=model_type)

inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph), self._backend_entity.shapeof_metatypes, self._backend_entity.read_variable_metatypes
deepcopy(nncf_graph),
self._backend_entity.shapeof_metatypes,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
self._backend_entity.dropout_metatypes,
self._backend_entity.read_variable_metatypes,
)

quantizer_setup = self._get_quantizer_setup(nncf_graph, inference_nncf_graph, hw_patterns, ignored_patterns)
Expand Down
19 changes: 13 additions & 6 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,23 @@ def post_processing_metatypes(self) -> List[OperatorMetatype]:

@property
@abstractmethod
def shapeof_metatypes(self) -> List[OperatorMetatype]:
def conv_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific ShapeOf metatypes.
Property for the backend-specific Convolution metatypes.
"""

@property
@abstractmethod
def conv_metatypes(self) -> List[OperatorMetatype]:
def shapeof_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific Convolution metatypes.
Property for the backend-specific ShapeOf metatypes.
"""

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
def dropout_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes for which overflow_fix is applicable.
Property for the backend-specific Dropout metatypes.
"""

@property
Expand All @@ -77,6 +77,13 @@ def read_variable_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific metatypes that also can be interpreted as inputs (ReadValue).
"""

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific metatypes for which overflow_fix is applicable.
"""

@property
@abstractmethod
def add_metatypes(self) -> List[OperatorMetatype]:
Expand Down
20 changes: 12 additions & 8 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXTopKMetatype, om.ONNXNonMaxSuppressionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXShapeMetatype]

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype]
Expand All @@ -68,10 +64,6 @@ def conv_metatypes(self) -> List[OperatorMetatype]:
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXConvolutionMetatype, om.ONNXConvolutionTransposeMetatype, *MATMUL_METATYPES]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXAddLayerMetatype]
Expand All @@ -80,6 +72,18 @@ def add_metatypes(self) -> List[OperatorMetatype]:
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return self.conv_metatypes

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXShapeMetatype]

@property
def dropout_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
Expand Down
20 changes: 12 additions & 8 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ def mat_mul_metatypes(self) -> List[OperatorMetatype]:
def post_processing_metatypes(self) -> List[OperatorMetatype]:
return [om.OVTopKMetatype, om.OVNonMaxSuppressionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.OVShapeOfMetatype]

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVConvolutionMetatype]
Expand All @@ -74,10 +70,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
om.OVMatMulMetatype,
]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.OVAddMetatype]
Expand All @@ -86,6 +78,18 @@ def add_metatypes(self) -> List[OperatorMetatype]:
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return [om.OVGroupConvolutionMetatype]

@property
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return [om.OVShapeOfMetatype]

@property
def dropout_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
Expand Down
16 changes: 12 additions & 4 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def post_processing_metatypes(self) -> List[OperatorMetatype]:
def shapeof_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def dropout_metatypes(self) -> List[OperatorMetatype]:
return [om.PTDropoutMetatype]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
Expand All @@ -85,10 +93,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
om.PTModuleConvTranspose3dMetatype,
]

@property
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.PTAddMetatype]
Expand Down Expand Up @@ -307,6 +311,10 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O
om.PTDivMetatype,
om.PTMaxMetatype,
om.PTSqueezeMetatype,
om.PTLayerNormMetatype,
om.PTModuleLayerNormMetatype,
om.PTGroupNormMetatype,
om.PTModuleGroupNormMetatype,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
]
if device != TargetDevice.CPU_SPR:
types.append(om.PTMulMetatype)
Expand Down
65 changes: 59 additions & 6 deletions nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,23 @@
def transform_to_inference_graph(
nncf_graph: NNCFGraph,
shapeof_metatypes: List[OperatorMetatype],
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
dropout_metatypes: List[OperatorMetatype],
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
This method contains pipeline of the passes that uses to provide inference graph without constant flows.
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.

:param nncf_graph: NNCFGraph instance for the transformation.
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
:param dropout_metatypes: List of backend-specific Dropout metatypes.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:return: NNCFGraph in the inference style.
"""
inference_nncf_graph = remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes)
inference_nncf_graph = filter_constant_nodes(nncf_graph, read_variable_metatypes)
return inference_nncf_graph
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, read_variable_metatypes)
remove_nodes_and_reconnect_graph(nncf_graph, dropout_metatypes)
filter_constant_nodes(nncf_graph, read_variable_metatypes)
return nncf_graph


def remove_shapeof_subgraphs(
Expand All @@ -45,7 +48,7 @@ def remove_shapeof_subgraphs(
read_variable_metatypes: Optional[List[OperatorMetatype]] = None,
) -> NNCFGraph:
"""
Removes the ShapeOf subgraphs from the provided NNCFGraph instance.
Removes the ShapeOf subgraphs from the provided NNCFGraph instance inplace.

:param nncf_graph: NNCFGraph instance for the transformation.
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
Expand Down Expand Up @@ -88,11 +91,61 @@ def remove_shapeof_subgraphs(
return nncf_graph


def remove_nodes_and_reconnect_graph(
nncf_graph: NNCFGraph,
metatypes: List[OperatorMetatype],
) -> NNCFGraph:
"""
Removes nodes with metatypes specified by `metatypes` parameter from
the provided NNCFGraph instance and connects previous node of a matched node
with next nodes of a matched node inplace for each matched node.
Matched nodes should have only one input node and only one output port.

:param nncf_graph: NNCFGraph instance for the transformation.
:param metatypes: List of backend-specific metatypes.
:return: Resulting NNCFGraph.
"""
if not metatypes:
return nncf_graph

nodes_to_drop = []
for node in nncf_graph.get_nodes_by_metatypes(metatypes):
if node.metatype in metatypes:
nodes_to_drop.append(node)

prev_nodes = nncf_graph.get_previous_nodes(node)
input_edges = nncf_graph.get_input_edges(node)
assert len(prev_nodes) == len(input_edges) == 1
prev_node = prev_nodes[0]
input_edge = input_edges[0]
assert not input_edge.parallel_input_port_ids

# nncf_graph.get_next_edges is not used to preserve
# parallel_input_port_ids
for output_node in nncf_graph.get_next_nodes(node):
output_edge = nncf_graph.get_edge(node, output_node)
# Connects previous node with all next nodes
# to keep NNCFGraph connected.
assert input_edge.dtype == output_edge.dtype
assert input_edge.tensor_shape == output_edge.tensor_shape
nncf_graph.add_edge_between_nncf_nodes(
from_node_id=prev_node.node_id,
to_node_id=output_edge.to_node.node_id,
tensor_shape=input_edge.tensor_shape,
input_port_id=output_edge.input_port_id,
output_port_id=input_edge.output_port_id,
dtype=input_edge.dtype,
parallel_input_port_ids=output_edge.parallel_input_port_ids,
)
nncf_graph.remove_nodes_from(nodes_to_drop)
return nncf_graph


def filter_constant_nodes(
nncf_graph: NNCFGraph, read_variable_metatypes: Optional[List[OperatorMetatype]] = None
) -> NNCFGraph:
"""
Removes all Constant nodes from NNCFGraph, making it inference graph.
Removes all Constant nodes from NNCFGraph inplace, making it inference graph.
The traversing starts from the input nodes and nodes with weights.

:param nncf_graph: NNCFGraph instance for the transformation.
Expand Down
11 changes: 9 additions & 2 deletions nncf/torch/quantization/ignored_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,15 @@ def _add_softmax_reshape_matmul(

@PT_IGNORED_PATTERNS.register(IgnoredPatternNames.MULTIHEAD_ATTENTION_OUTPUT)
def create_multihead_attention_output() -> GraphPattern:
matmul_aliases = ["linear", "addmm", "matmul", "bmm", "mm", "baddbmm"]
reshape_squeeze_aliases = ["reshape", "view", "flatten", "squeeze", "unsqueeze", "squeeze", "flatten", "unsqueeze"]
matmul_aliases = ["linear", "addmm", "matmul", "bmm", "mm", "baddbmm", "__matmul__"]
reshape_squeeze_aliases = [
"reshape",
"view",
"flatten",
"unsqueeze",
"squeeze",
"unbind",
]
gather_aliases = ["gather", "index_select", "where", "index_select", "__getitem__"]
transpose_aliases = ["transpose", "permute", "transpose_"]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /Split_1_0" [id=1, type=Split_1];
"5 /Output_1_0" [id=5, type=Output_1];
"6 /Output_2_1_0" [id=6, type=Output_2_1];
"7 /Output_2_2_0" [id=7, type=Output_2_2];
"8 /Output_2_3_0" [id=8, type=Output_2_3];
"9 /Output_3_0" [id=9, type=Output_3];
"10 /Output_2_4_0" [id=10, type=output];
"11 /Output_3_1_0" [id=11, type=output];
"0 /Input_1_0" -> "1 /Split_1_0";
"1 /Split_1_0" -> "5 /Output_1_0";
"1 /Split_1_0" -> "6 /Output_2_1_0";
"1 /Split_1_0" -> "7 /Output_2_2_0";
"1 /Split_1_0" -> "8 /Output_2_3_0";
"1 /Split_1_0" -> "9 /Output_3_0";
"1 /Split_1_0" -> "10 /Output_2_4_0";
"1 /Split_1_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]"];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
strict digraph {
"0 /Input_1_0" [id=0, type=Input_1];
"1 /Split_1_0" [id=1, type=Split_1];
"2 /Dropout_1_0" [id=2, type=Dropout_1];
"3 /Dropout_2_0" [id=3, type=Dropout_2];
"4 /Dropout_3_0" [id=4, type=Dropout_3];
"5 /Output_1_0" [id=5, type=Output_1];
"6 /Output_2_1_0" [id=6, type=Output_2_1];
"7 /Output_2_2_0" [id=7, type=Output_2_2];
"8 /Output_2_3_0" [id=8, type=Output_2_3];
"9 /Output_3_0" [id=9, type=Output_3];
"10 /Output_2_4_0" [id=10, type=output];
"11 /Output_3_1_0" [id=11, type=output];
"0 /Input_1_0" -> "1 /Split_1_0";
"1 /Split_1_0" -> "2 /Dropout_1_0";
"1 /Split_1_0" -> "3 /Dropout_2_0";
"1 /Split_1_0" -> "4 /Dropout_3_0";
"2 /Dropout_1_0" -> "5 /Output_1_0";
"3 /Dropout_2_0" -> "6 /Output_2_1_0";
"3 /Dropout_2_0" -> "7 /Output_2_2_0";
"3 /Dropout_2_0" -> "8 /Output_2_3_0";
"3 /Dropout_2_0" -> "10 /Output_2_4_0";
"4 /Dropout_3_0" -> "9 /Output_3_0";
"4 /Dropout_3_0" -> "11 /Output_3_1_0" [label="parallel_input_port_ids:[2, 3, 4, 5, 6, 7, 8, 9]"];
}
Loading