Skip to content

Commit

Permalink
[Solver] Check all output branches are quantized before merging quant…
Browse files Browse the repository at this point in the history
…izers (#2854)

### Changes

Quantizer merge logic updated to check that all output branches are
quantized before quantizers merging and propagating up.

### Reason for changes

To prevent merging of quantizers in case of ScaledDotProductAttention
op, which should have quantizers on [0, 1] input ports and shouldn't
have a quantizer on the 3 input port.

### Related tickets

148211
#2766 

### Tests

* Common solver test for ScaleDotProductAttention branch merging and
quantization initialization
* Graph tests for torch/ov backends
  • Loading branch information
daniil-lyakhov authored Aug 5, 2024
1 parent 63fcb15 commit ef49c75
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 52 deletions.
29 changes: 29 additions & 0 deletions nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,35 @@ def recursive_helper(curr_node_key: str, target_node_list: List[str]):
recursive_helper(node_key, ret_node_key_list)
return ret_node_key_list

def all_outputs_are_quantized(self, node_key) -> bool:
"""
Returns True if all pathes from the given node to the first
input quantable nodes have an activation quantizer, False otherwise.
:param node_key: Given node key.
:return: True if all pathes from the given node to the first
input quantable nodes have an activation quantizer, False otherwise.
"""

nodes_keys_stack = deque(self.successors(node_key))
while nodes_keys_stack:
node_key = nodes_keys_stack.popleft()
node = self.nodes[node_key]
node_type = node[QuantizerPropagationStateGraph.NODE_TYPE_NODE_ATTR]
if node_type == QuantizerPropagationStateGraphNodeType.OPERATOR:
trait = node[QuantizerPropagationStateGraph.QUANTIZATION_TRAIT_NODE_ATTR]
if trait != QuantizationTrait.QUANTIZATION_AGNOSTIC:
return False
elif node_type in [
QuantizerPropagationStateGraphNodeType.PRE_HOOK,
QuantizerPropagationStateGraphNodeType.POST_HOOK,
]:
quantizer = node[QuantizerPropagationStateGraph.PROPAGATING_QUANTIZER_NODE_ATTR]
if quantizer:
continue
nodes_keys_stack.extend(self.successors(node_key))
return True

def get_paths_to_immediately_dominating_insertion_points(
self, insertion_point_node_key: str
) -> List[PropagationPath]:
Expand Down
7 changes: 7 additions & 0 deletions nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,13 @@ def check_branching_transition(
dom_op_quantizers = set()
for op_node_key in dom_op_node_keys:
op_node = quant_prop_graph.nodes[op_node_key]

# Check all branches have a quantizer on it before the merge
if op_node["op_meta"].target_input_ports is not None:
all_branches_are_quantized = quant_prop_graph.all_outputs_are_quantized(branching_node_key)
if not all_branches_are_quantized:
return TransitionStatus.SHOULD_NOT_TRANSITION

trait = op_node[QuantizerPropagationStateGraph.QUANTIZATION_TRAIT_NODE_ATTR]
affecting_prop_quantizers = op_node[QuantizerPropagationStateGraph.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]
if affecting_prop_quantizers:
Expand Down
7 changes: 7 additions & 0 deletions tests/common/quantization/metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ class DequantizeTestMetatype(TestMetatype):
name = "dequantize"


@METATYPES_FOR_TEST.register()
class ScaledDotProductAttentionMetatype(TestMetatype):
name = "scaled_dot_product_attention"
target_input_ports = [0, 1]


WEIGHT_LAYER_METATYPES = [LinearTestMetatype, Conv2dTestMetatype, MatMulTestMetatype]


Expand All @@ -189,6 +195,7 @@ class DequantizeTestMetatype(TestMetatype):
GeluTestMetatype,
LinearTestMetatype,
AddTestMetatype,
ScaledDotProductAttentionMetatype,
],
QuantizationTrait.CONCAT: [CatTestMetatype],
}
Expand Down
112 changes: 96 additions & 16 deletions tests/common/quantization/test_quantizer_propagation_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

from collections import Counter
from collections import namedtuple
from dataclasses import dataclass
from itertools import permutations
from typing import Dict, List, Optional, Set, Tuple
from typing import Callable, Dict, List, Optional, Set, Tuple

import networkx as nx
import pytest
Expand Down Expand Up @@ -49,6 +50,7 @@
from tests.common.quantization.metatypes import MatMulTestMetatype
from tests.common.quantization.metatypes import MaxPool2dTestMetatype
from tests.common.quantization.metatypes import MinTestMetatype
from tests.common.quantization.metatypes import ScaledDotProductAttentionMetatype
from tests.common.quantization.metatypes import SoftmaxTestMetatype
from tests.common.quantization.mock_graphs import get_ip_graph_for_test
from tests.common.quantization.mock_graphs import get_mock_nncf_node_attrs
Expand Down Expand Up @@ -146,6 +148,30 @@ def get_branching_model_graph() -> NNCFGraph:
return get_nncf_graph_from_mock_nx_graph(mock_graph)


def get_scaled_dot_product_graph():
mock_graph = nx.DiGraph()

node_keys = ["input", "branch_node", "reshape", "reshape_1", "reshape_2", "scaled_dot_product_attention"]
for node_key in node_keys:
mock_node_attrs = get_mock_nncf_node_attrs(op_name=node_key)
mock_graph.add_node(node_key, **mock_node_attrs)

mock_graph.add_edges_from(
[
("input", "branch_node"),
("branch_node", "reshape"),
("branch_node", "reshape_1"),
("branch_node", "reshape_2"),
("reshape", "scaled_dot_product_attention"),
("reshape_1", "scaled_dot_product_attention"),
("reshape_2", "scaled_dot_product_attention"),
]
)

mark_input_ports_lexicographically_based_on_input_node_key(mock_graph)
return get_nncf_graph_from_mock_nx_graph(mock_graph)


class MultiQPSerializedDataForTest:
def __init__(
self,
Expand Down Expand Up @@ -248,6 +274,38 @@ def test_setup_initial_quantizers_in_quant_prop_graph(self):
edge = qp_graph.edges[pred_ip_key, actual_key]
assert not edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR]

def test_setup_initial_quantizers_sdpa(self):
nncf_graph = get_scaled_dot_product_graph()
ip_graph = get_ip_graph_for_test(nncf_graph)

qp_graph = QPSG(ip_graph)

sdpa_node_key = "5 /scaled_dot_product_attention_0"
quant_prop_solver = QuantizerPropagationSolver(
run_consistency_checks=True,
default_trait_to_metatype_map=DEFAULT_TEST_QUANT_TRAIT_MAP,
)

qp_graph = quant_prop_solver.set_allowed_quantization_types_for_operator_nodes(qp_graph)
qp_graph = quant_prop_solver.setup_initial_quantizers(qp_graph)
qp_graph.run_consistency_check()

for port_id, pred_ip_key in enumerate(qp_graph.predecessors(sdpa_node_key)):
node = qp_graph.nodes[sdpa_node_key]
pred_ip_node = qp_graph.nodes[pred_ip_key]
prop_quant = pred_ip_node[QPSG.PROPAGATING_QUANTIZER_NODE_ATTR]
if port_id in ScaledDotProductAttentionMetatype.target_input_ports:
assert prop_quant is not None
assert node[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR][port_id] == prop_quant

edge = qp_graph.edges[pred_ip_key, sdpa_node_key]
assert edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == [prop_quant]
else:
assert prop_quant is None

edge = qp_graph.edges[pred_ip_key, sdpa_node_key]
assert edge[QPSG.AFFECTING_PROPAGATING_QUANTIZERS_ATTR] == []

MergeQConfigSolution = namedtuple(
"MergeQConfigSolution", ("merge_qconfig_list", "branch_qconfig_lists_after_merge")
)
Expand Down Expand Up @@ -821,23 +879,36 @@ def test_merged_qconfig_list_is_independent_of_branch_qconfig_list_order(

BRANCHING_MODEL_GRAPH = get_branching_model_graph()

BranchTransitionTestStruct = namedtuple(
"BranchTransitionTestStruct",
( # Unspecified nodes are marked as quantization agnostic
"init_node_to_trait_and_configs_dict",
"starting_primary_quantizer_ip_node",
"target_branching_node_for_primary_quantizer",
"expected_status",
),
)

class InitNodeTestStruct:
def __init__(self, quantization_trait, config, op_meta=UnknownMetatype):
self.quantization_trait = quantization_trait
self.config = config
self.op_meta = op_meta

@dataclass
class BranchTransitionTestStruct:
# Unspecified nodes are marked as quantization agnostic
init_node_to_trait_and_configs_dict: Dict[str, "TestQuantizerPropagationSolver.InitNodeTestStruct"]
starting_primary_quantizer_ip_node: str
target_branching_node_for_primary_quantizer: str
expected_status: TransitionStatus
nncf_graph_builder: Callable[[], NNCFGraph] = None

BRANCH_TRANSITION_TEST_CASES = [
# Scaled dot product attention case
BranchTransitionTestStruct(
init_node_to_trait_and_configs_dict=
{
'5 /scaled_dot_product_attention_0': InitNodeTestStruct(QuantizationTrait.INPUTS_QUANTIZABLE,
QuantizerConfig(), ScaledDotProductAttentionMetatype),
},
starting_primary_quantizer_ip_node=
InsertionPointGraph.get_pre_hook_node_key('5 /scaled_dot_product_attention_0'),
target_branching_node_for_primary_quantizer=InsertionPointGraph.get_post_hook_node_key('1 /branch_node_0'),
expected_status=TransitionStatus.SHOULD_NOT_TRANSITION,
nncf_graph_builder=get_scaled_dot_product_graph
),

# Downward branches are quantization-agnostic
BranchTransitionTestStruct(
init_node_to_trait_and_configs_dict=
Expand Down Expand Up @@ -1117,7 +1188,10 @@ def test_check_branching_transition(self, branch_transition_test_struct: BranchT
expected_status = branch_transition_test_struct.expected_status

# Graph preparation
nncf_graph = get_branching_model_graph()
if branch_transition_test_struct.nncf_graph_builder is None:
nncf_graph = get_branching_model_graph()
else:
nncf_graph = branch_transition_test_struct.nncf_graph_builder()
ip_graph = get_ip_graph_for_test(nncf_graph)

# Metatypes must be assigned before QPSG creation, because
Expand All @@ -1137,10 +1211,16 @@ def test_check_branching_transition(self, branch_transition_test_struct: BranchT
trait = init_node_struct.quantization_trait
quant_prop_graph.nodes[node_key][QPSG.QUANTIZATION_TRAIT_NODE_ATTR] = trait
if trait == QuantizationTrait.INPUTS_QUANTIZABLE:
ip_node_key = InsertionPointGraph.get_pre_hook_node_key(node_key)
prop_quant = quant_prop_graph.add_propagating_quantizer(qconfigs, ip_node_key)
if ip_node_key == starting_primary_quantizer_ip_node:
primary_prop_quant = prop_quant
target_input_ports = [0]
metatype = quant_prop_graph.nodes[node_key]["op_meta"]
if metatype.target_input_ports is not None:
target_input_ports = metatype.target_input_ports

for input_port_id in target_input_ports:
ip_node_key = InsertionPointGraph.get_pre_hook_node_key(node_key, input_port_id=input_port_id)
prop_quant = quant_prop_graph.add_propagating_quantizer(qconfigs, ip_node_key)
if ip_node_key == starting_primary_quantizer_ip_node:
primary_prop_quant = prop_quant
elif trait == QuantizationTrait.CONCAT and qconfigs:
# Assuming two-port concat nodes are used in the test graph, adjust as necessary
for input_port_id in [0, 1]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,33 +1,63 @@
strict digraph {
"0 Input_1" [id=0, type=Parameter];
"1 Input_2" [id=1, type=Parameter];
"2 Input_3" [id=2, type=Parameter];
"3 Input_4" [id=3, type=Parameter];
"4 Input_1/fq_output_0" [id=4, type=FakeQuantize];
"5 Input_2/fq_output_0" [id=5, type=FakeQuantize];
"6 ScaledDotProductAttention_5" [id=6, type=ScaledDotProductAttention];
"7 Result" [id=7, type=Result];
"8 Constant_2553" [id=8, type=Constant];
"9 Constant_2552" [id=9, type=Constant];
"10 Constant_2551" [id=10, type=Constant];
"11 Constant_2550" [id=11, type=Constant];
"12 Constant_2548" [id=12, type=Constant];
"13 Constant_2547" [id=13, type=Constant];
"14 Constant_2546" [id=14, type=Constant];
"15 Constant_2545" [id=15, type=Constant];
"0 Input_1" -> "4 Input_1/fq_output_0" [label="[1, 1, 1, 64]", style=solid];
"1 Input_2" -> "5 Input_2/fq_output_0" [label="[1, 1, 1, 64]", style=solid];
"2 Input_3" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"3 Input_4" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 1]", style=solid];
"4 Input_1/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"5 Input_2/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"6 ScaledDotProductAttention_5" -> "7 Result" [label="[1, 1, 1, 64]", style=solid];
"8 Constant_2553" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"9 Constant_2552" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"10 Constant_2551" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"11 Constant_2550" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"12 Constant_2548" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"13 Constant_2547" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"14 Constant_2546" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"15 Constant_2545" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"2 Reshape_3835" [id=2, type=Reshape];
"3 ScaledDotProductAttention_3850" [id=3, type=ScaledDotProductAttention];
"4 Reshape_3837" [id=4, type=Reshape];
"5 Result" [id=5, type=Result];
"6 Reshape_3839/fq_input_0" [id=6, type=FakeQuantize];
"7 Reshape_3843/fq_input_0" [id=7, type=FakeQuantize];
"8 Reshape_3847" [id=8, type=Reshape];
"9 Reshape_3839" [id=9, type=Reshape];
"10 Reshape_3843" [id=10, type=Reshape];
"11 Reshape_3849" [id=11, type=Reshape];
"12 Reshape_3841" [id=12, type=Reshape];
"13 Reshape_3845" [id=13, type=Reshape];
"14 Constant_3848" [id=14, type=Constant];
"15 Constant_3846" [id=15, type=Constant];
"16 Constant_3836" [id=16, type=Constant];
"17 Constant_3834" [id=17, type=Constant];
"18 Constant_3844" [id=18, type=Constant];
"19 Constant_3842" [id=19, type=Constant];
"20 Reshape_3843/fq_input_0/output_high" [id=20, type=Constant];
"21 Reshape_3843/fq_input_0/output_low" [id=21, type=Constant];
"22 Reshape_3843/fq_input_0/input_high" [id=22, type=Constant];
"23 Reshape_3843/fq_input_0/input_low" [id=23, type=Constant];
"24 Constant_3840" [id=24, type=Constant];
"25 Constant_3838" [id=25, type=Constant];
"26 Reshape_3839/fq_input_0/output_high" [id=26, type=Constant];
"27 Reshape_3839/fq_input_0/output_low" [id=27, type=Constant];
"28 Reshape_3839/fq_input_0/input_high" [id=28, type=Constant];
"29 Reshape_3839/fq_input_0/input_low" [id=29, type=Constant];
"0 Input_1" -> "2 Reshape_3835" [label="[1, 1, 1, 64]", style=solid];
"1 Input_2" -> "3 ScaledDotProductAttention_3850" [label="[1, 1, 1, 1]", style=solid];
"2 Reshape_3835" -> "4 Reshape_3837" [label="[64]", style=solid];
"3 ScaledDotProductAttention_3850" -> "5 Result" [label="[1, 1, 1, 64]", style=solid];
"4 Reshape_3837" -> "6 Reshape_3839/fq_input_0" [label="[1, 1, 1, 64]", style=solid];
"4 Reshape_3837" -> "7 Reshape_3843/fq_input_0" [label="[1, 1, 1, 64]", style=solid];
"4 Reshape_3837" -> "8 Reshape_3847" [label="[1, 1, 1, 64]", style=solid];
"6 Reshape_3839/fq_input_0" -> "9 Reshape_3839" [label="[1, 1, 1, 64]", style=solid];
"7 Reshape_3843/fq_input_0" -> "10 Reshape_3843" [label="[1, 1, 1, 64]", style=solid];
"8 Reshape_3847" -> "11 Reshape_3849" [label="[64]", style=solid];
"9 Reshape_3839" -> "12 Reshape_3841" [label="[64]", style=solid];
"10 Reshape_3843" -> "13 Reshape_3845" [label="[64]", style=solid];
"11 Reshape_3849" -> "3 ScaledDotProductAttention_3850" [label="[1, 1, 1, 64]", style=solid];
"12 Reshape_3841" -> "3 ScaledDotProductAttention_3850" [label="[1, 1, 1, 64]", style=solid];
"13 Reshape_3845" -> "3 ScaledDotProductAttention_3850" [label="[1, 1, 1, 64]", style=solid];
"14 Constant_3848" -> "11 Reshape_3849" [label="[4]", style=dashed];
"15 Constant_3846" -> "8 Reshape_3847" [label="[1]", style=dashed];
"16 Constant_3836" -> "4 Reshape_3837" [label="[4]", style=dashed];
"17 Constant_3834" -> "2 Reshape_3835" [label="[1]", style=dashed];
"18 Constant_3844" -> "13 Reshape_3845" [label="[4]", style=dashed];
"19 Constant_3842" -> "10 Reshape_3843" [label="[1]", style=dashed];
"20 Reshape_3843/fq_input_0/output_high" -> "7 Reshape_3843/fq_input_0" [label="[]", style=solid];
"21 Reshape_3843/fq_input_0/output_low" -> "7 Reshape_3843/fq_input_0" [label="[]", style=solid];
"22 Reshape_3843/fq_input_0/input_high" -> "7 Reshape_3843/fq_input_0" [label="[]", style=solid];
"23 Reshape_3843/fq_input_0/input_low" -> "7 Reshape_3843/fq_input_0" [label="[]", style=solid];
"24 Constant_3840" -> "12 Reshape_3841" [label="[4]", style=dashed];
"25 Constant_3838" -> "9 Reshape_3839" [label="[1]", style=dashed];
"26 Reshape_3839/fq_input_0/output_high" -> "6 Reshape_3839/fq_input_0" [label="[]", style=solid];
"27 Reshape_3839/fq_input_0/output_low" -> "6 Reshape_3839/fq_input_0" [label="[]", style=solid];
"28 Reshape_3839/fq_input_0/input_high" -> "6 Reshape_3839/fq_input_0" [label="[]", style=solid];
"29 Reshape_3839/fq_input_0/input_low" -> "6 Reshape_3839/fq_input_0" [label="[]", style=solid];
}
22 changes: 15 additions & 7 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,15 +871,23 @@ def _create_ov_model(self):

class ScaledDotProductAttentionModel(OVReferenceModel):
def _create_ov_model(self):
query = opset.parameter([1, 1, 1, 64], name="Input_1")
key = opset.parameter([1, 1, 1, 64], name="Input_2")
value = opset.parameter([1, 1, 1, 64], name="Input_3")
attn_mask = opset.parameter([1, 1, 1, 1], name="Input_4")

attn = opset.scaled_dot_product_attention(query, key, value, attn_mask)
input_ = opset.parameter([1, 1, 1, 64], name="Input_1")
attn_mask = opset.parameter([1, 1, 1, 1], name="Input_2")
x = opset.reshape(input_, [64], False)
x = opset.reshape(x, [1, 1, 1, 64], False)

# Parallel edges are not supported by PTQ for now.
# Ref 148498
inputs = []
for _ in range(3):
x_ = opset.reshape(x, [64], False)
x_ = opset.reshape(x_, [1, 1, 1, 64], False)
inputs.append(x_)

attn = opset.scaled_dot_product_attention(*inputs, attn_mask)
result = opset.result(attn, name="Result")
result.get_output_tensor(0).set_names(set(["Result"]))
model = ov.Model([result], [query, key, value, attn_mask])
model = ov.Model([result], [input_, attn_mask])
return model


Expand Down
Loading

0 comments on commit ef49c75

Please sign in to comment.