Skip to content

Commit

Permalink
[OV] Support for ScaledDotProductAttention operation (#2268)
Browse files Browse the repository at this point in the history
### Changes

- Introduced `OVScaledDotProductAttentionMetatype`
- target_input_ports attribute in `OperatorMetatype`
- Implement handling OVScaledDotProductAttentionMetatype into
`QuantizerPropagationSolver`

### Reason for changes

Support quantization of ScaledDotProductAttention operation from
OpenVINO opset13

### Related tickets

124573

### Tests

test_scaled_dot_product_attention_placement

---------

Co-authored-by: Nikita Malinin <[email protected]>
  • Loading branch information
alexsu52 and KodiaqQ authored Nov 17, 2023
1 parent 0c7a8d5 commit 9898876
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 5 deletions.
2 changes: 1 addition & 1 deletion nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def get_graph_for_structure_analysis(self, extended: bool = False) -> nx.DiGraph
"""
out_graph = nx.DiGraph()
for node_name, node in self._nx_graph.nodes.items():
attrs_node = {"id": node[NNCFNode.ID_NODE_ATTR], "type": node[NNCFNode.NODE_TYPE_ATTR]}
attrs_node = {"id": str(node[NNCFNode.ID_NODE_ATTR]), "type": node[NNCFNode.NODE_TYPE_ATTR]}
for attr in ["color", "label", "style"]:
if attr in node:
attrs_node[attr] = node[attr]
Expand Down
1 change: 1 addition & 0 deletions nncf/common/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class OperatorMetatype:
hw_config_names: List[str] = []
output_channel_axis: Optional[int] = None
ignored_input_ports: List[int] = []
target_input_ports: Optional[List[int]] = None

@classmethod
def get_all_aliases(cls) -> List[str]:
Expand Down
15 changes: 14 additions & 1 deletion nncf/common/hardware/configs/cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
"target_device": "CPU",
"config": {
"quantization": {
"q8_a_sym": {
"bits": 8,
"mode": [
"symmetric"
],
"granularity": "pertensor"
},
"q8_a": {
"bits": 8,
"mode": [
Expand Down Expand Up @@ -231,7 +238,6 @@
"activations": "q8_a"
}
},
{"type": "Reshape"},
{
"type": "Concat",
"attributes": {
Expand Down Expand Up @@ -264,6 +270,13 @@
"activations": "q8_a"
}
},
{
"type": "ScaledDotProductAttention",
"quantization": {
"activations": "q8_a_sym"
}
},
{"type": "Reshape"},
{"type": "Flatten"},
{"type": "Squeeze"},
{"type": "Unsqueeze"},
Expand Down
1 change: 1 addition & 0 deletions nncf/common/hardware/opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@ class HWConfigOpName:
LSTMSEQUENCE = "LSTMSequence"
GRUSEQUENCE = "GRUSequence"
GROUPNORMALIZATION = "GroupNormalization"
SCALED_DOT_PRODUCT_ATTENTION = "ScaledDotProductAttention"
3 changes: 3 additions & 0 deletions nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,9 @@ def _setup_initial_quantizers_for_operator_node(
if input_port_id in metatype.ignored_input_ports:
continue

if metatype.target_input_ports is not None and input_port_id not in metatype.target_input_ports:
continue

edge = quant_prop_graph.edges[pred_ip_key, operator_node_key]
if not edge[QuantizerPropagationStateGraph.IS_INTEGER_PATH_EDGE_ATTR]:
pred_ip_key_vs_qconf_dict[pred_ip_key] = qconf_list
Expand Down
1 change: 1 addition & 0 deletions nncf/openvino/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
ov_metatypes.OVLSTMSequenceMetatype,
ov_metatypes.OVGRUSequenceMetatype,
ov_metatypes.OVGroupNormalizationMetatype,
ov_metatypes.OVScaledDotProductAttentionMetatype,
]


Expand Down
8 changes: 8 additions & 0 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,14 @@ class OVGroupNormalizationMetatype(OVOpMetatype):
hw_config_names = [HWConfigOpName.GROUPNORMALIZATION]


@OV_OPERATOR_METATYPES.register()
class OVScaledDotProductAttentionMetatype(OVOpMetatype):
name = "ScaledDotProductAttentionOp"
op_names = ["ScaledDotProductAttention"]
hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION]
target_input_ports = [0, 1]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
22 changes: 22 additions & 0 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,25 @@ def _get_ignored_scope_by_algorithm(self, inference_nncf_graph: NNCFGraph) -> Ig
nncf_node_names.append(nncf_node.node_name)
return IgnoredScope(names=nncf_node_names)

def _get_scope_overrides(self, inference_nncf_graph: NNCFGraph) -> Dict:
"""
Returns a dictionary of quantization configuration overrides for inputs to matching operation nodes.
:param inference_nncf_graph: Inference NNCFGraph instance.
:return: A dictionary of quantization configuration overrides for inputs to matching operation nodes.
"""
scaled_dot_product_attention_node_names = [
node.node_name
for node in inference_nncf_graph.get_nodes_by_metatypes(
self._backend_entity.scaled_dot_product_attention_metatypes
)
]

scope_overrides_activations = {}
for node_name in scaled_dot_product_attention_node_names:
scope_overrides_activations[node_name] = {"mode": "symmetric"}
return {"activations": scope_overrides_activations}

def _get_quantizer_setup(
self,
nncf_graph: NNCFGraph,
Expand Down Expand Up @@ -416,6 +435,8 @@ def _get_quantizer_setup(
QuantizableWeightedLayerNode(node, qconf_list) for node, qconf_list in weighted_node_and_qconf_lists.items()
]

scope_overrides = self._get_scope_overrides(inference_nncf_graph)

ip_graph = InsertionPointGraph(inference_nncf_graph)
ip_graph = ip_graph.get_ip_graph_with_merged_hw_optimized_operations(hw_patterns)
post_processing_types = self._backend_entity.post_processing_metatypes
Expand All @@ -434,6 +455,7 @@ def _get_quantizer_setup(
post_processing_marker_metatypes=post_processing_types,
metatypes_to_ignore=metatypes_to_ignore,
scales_unification_map=self._backend_entity.scales_unification_map,
scope_overrides=scope_overrides,
)

quantization_proposal = solver.run_on_ip_graph(ip_graph)
Expand Down
7 changes: 7 additions & 0 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def group_conv_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific Grouped Convolution metatypes.
"""

@property
@abstractmethod
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
"""
Property for the backend-specific Scaled Dot Product Attention metatypes.
"""

@property
@abstractmethod
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.ONNXConcatMetatype: self.overflow_fix_metatypes}
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return [om.OVReadValueMetatype]

@property
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
return [om.OVScaledDotProductAttentionMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.OVConcatMetatype: self.overflow_fix_metatypes}
Expand Down
4 changes: 4 additions & 0 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def add_metatypes(self) -> List[OperatorMetatype]:
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return self.conv_metatypes

@property
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
return []

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
return {om.PTCatMetatype: self.overflow_fix_metatypes}
Expand Down
11 changes: 8 additions & 3 deletions tests/openvino/native/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.
import json
from copy import deepcopy
from typing import Tuple

import numpy as np
import openvino.runtime as ov
Expand Down Expand Up @@ -65,13 +66,17 @@ def dump_to_json(local_path, data):
json.dump(deepcopy(data), file, indent=4, cls=NumpyEncoder)


def get_openvino_version() -> str:
def get_openvino_major_minor_version() -> Tuple[int]:
ov_version = ov.__version__
pos = ov_version.find("-")
if pos != -1:
ov_version = ov_version[:pos]

ov_version = version.parse(ov_version).base_version
version_major, version_minor = ov_version.split(".")[:2]
return tuple(map(int, ov_version.split(".")[:2]))


def get_openvino_version() -> str:
major_verison, minor_version = get_openvino_major_minor_version()

return f"{version_major}.{version_minor}"
return f"{major_verison}.{minor_version}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
strict digraph {
"0 Input_1" [id=0, type=Parameter];
"1 Input_2" [id=1, type=Parameter];
"2 Input_3" [id=2, type=Parameter];
"3 Input_4" [id=3, type=Parameter];
"4 Input_1/fq_output_0" [id=4, type=FakeQuantize];
"5 Input_2/fq_output_0" [id=5, type=FakeQuantize];
"6 ScaledDotProductAttention_5" [id=6, type=ScaledDotProductAttention];
"7 Result" [id=7, type=Result];
"8 Constant_2553" [id=8, type=Constant];
"9 Constant_2552" [id=9, type=Constant];
"10 Constant_2551" [id=10, type=Constant];
"11 Constant_2550" [id=11, type=Constant];
"12 Constant_2548" [id=12, type=Constant];
"13 Constant_2547" [id=13, type=Constant];
"14 Constant_2546" [id=14, type=Constant];
"15 Constant_2545" [id=15, type=Constant];
"0 Input_1" -> "4 Input_1/fq_output_0" [label="[1, 1, 1, 64]", style=solid];
"1 Input_2" -> "5 Input_2/fq_output_0" [label="[1, 1, 1, 64]", style=solid];
"2 Input_3" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"3 Input_4" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 1]", style=solid];
"4 Input_1/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"5 Input_2/fq_output_0" -> "6 ScaledDotProductAttention_5" [label="[1, 1, 1, 64]", style=solid];
"6 ScaledDotProductAttention_5" -> "7 Result" [label="[1, 1, 1, 64]", style=solid];
"8 Constant_2553" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"9 Constant_2552" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"10 Constant_2551" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"11 Constant_2550" -> "5 Input_2/fq_output_0" [label="[]", style=solid];
"12 Constant_2548" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"13 Constant_2547" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"14 Constant_2546" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
"15 Constant_2545" -> "4 Input_1/fq_output_0" [label="[]", style=solid];
}
15 changes: 15 additions & 0 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import openvino.runtime as ov
from openvino.runtime import opset9 as opset
from openvino.runtime import opset12
from openvino.runtime import opset13

from nncf.common.utils.registry import Registry

Expand Down Expand Up @@ -800,3 +801,17 @@ def _create_ov_model(self):
result = opset.result(gather_2, name="Result")
model = ov.Model([result], [input_1])
return model


class ScaledDotProductAttentionModel(OVReferenceModel):
def _create_ov_model(self):
query = opset.parameter([1, 1, 1, 64], name="Input_1")
key = opset.parameter([1, 1, 1, 64], name="Input_2")
value = opset.parameter([1, 1, 1, 64], name="Input_3")
attn_mask = opset.parameter([1, 1, 1, 1], name="Input_4")

attn = opset13.scaled_dot_product_attention(query, key, value, attn_mask)
result = opset.result(attn, name="Result")
result.get_output_tensor(0).set_names(set(["Result"]))
model = ov.Model([result], [query, key, value, attn_mask])
return model
23 changes: 23 additions & 0 deletions tests/openvino/native/quantization/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tests.openvino.native.common import compare_nncf_graphs
from tests.openvino.native.common import dump_model
from tests.openvino.native.common import get_dataset_for_test
from tests.openvino.native.common import get_openvino_major_minor_version
from tests.openvino.native.common import get_openvino_version
from tests.openvino.native.models import SYNTHETIC_MODELS
from tests.openvino.native.models import DepthwiseConv3DModel
Expand All @@ -36,6 +37,7 @@
from tests.openvino.native.models import GRUSequenceModel
from tests.openvino.native.models import IfModel
from tests.openvino.native.models import MatmulSoftmaxMatmulBlock
from tests.openvino.native.models import ScaledDotProductAttentionModel
from tests.openvino.native.quantization.test_fq_params_calculation import quantize_model
from tests.openvino.omz_helpers import convert_model
from tests.openvino.omz_helpers import download_model
Expand Down Expand Up @@ -196,3 +198,24 @@ def test_if_model_fq_placement():
compare_nncf_graphs(quantized_model, QUANTIZED_REF_GRAPHS_DIR / main_model_path)
compare_nncf_graphs(if_op.get_function(0), QUANTIZED_REF_GRAPHS_DIR / then_body_path)
compare_nncf_graphs(if_op.get_function(1), QUANTIZED_REF_GRAPHS_DIR / else_body_path)


@pytest.mark.parametrize("q_params", [{}, {"model_type": ModelType.TRANSFORMER}], ids=["default", "transformer"])
def test_scaled_dot_product_attention_placement(q_params, tmp_path):
ov_major_version, ov_minor_version = get_openvino_major_minor_version()
if ov_major_version < 2023 or (ov_major_version == 2023 and ov_minor_version < 3):
pytest.xfail("ScaledDotProductAttention is not supported until 2023.3")
model = ScaledDotProductAttentionModel().ov_model
quantized_model = quantize_model(model, q_params)

if q_params:
params_str = "_".join([param.value for param in q_params.values()])
else:
params_str = "default"

path_ref_graph = QUANTIZED_REF_GRAPHS_DIR / "scaled_dot_product_attention.dot"
result_name = f"scaled_dot_product_attention_{params_str}"
xml_path = tmp_path / (result_name + ".xml")
bin_path = tmp_path / (result_name + ".bin")
dump_model(quantized_model, str(xml_path), str(bin_path))
compare_nncf_graphs(quantized_model, path_ref_graph)

0 comments on commit 9898876

Please sign in to comment.