diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index a92b25669ab..e372b32d640 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -105,7 +105,7 @@ def group_conv_metatypes(self) -> List[OperatorMetatype]: @property def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: - return [] + return [om.PTScaledDotProductAttentionMetatype] @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index db09f1edab5..0e0bca6e770 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -1065,6 +1065,16 @@ class PTReduceL2(PTOperatorMetatype): num_expected_input_edges = 1 +@PT_OPERATOR_METATYPES.register() +class PTScaledDotProductAttentionMetatype(PTOperatorMetatype): + name = "ScaledDotProductAttentionOp" + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["scaled_dot_product_attention"], + } + hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION] + target_input_ports = [0, 1] + + def get_operator_metatypes() -> List[Type[OperatorMetatype]]: """ Returns a list of the operator metatypes. diff --git a/tests/post_training/test_templates/helpers.py b/tests/post_training/test_templates/helpers.py index 054a15a00b2..da969214914 100644 --- a/tests/post_training/test_templates/helpers.py +++ b/tests/post_training/test_templates/helpers.py @@ -424,3 +424,11 @@ def forward(self, x): x = self.conv(x) x = self.linear(x) return x + + +class ScaledDotProductAttentionModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, query, key, value): + return nn.functional.scaled_dot_product_attention(query, key, value) diff --git a/tests/torch/data/reference_graphs/quantized/ptq/symmetric/scaled_dot_product_attention_model.dot b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/scaled_dot_product_attention_model.dot new file mode 100644 index 00000000000..63ac9291dcf --- /dev/null +++ b/tests/torch/data/reference_graphs/quantized/ptq/symmetric/scaled_dot_product_attention_model.dot @@ -0,0 +1,15 @@ +strict digraph { +"0 /nncf_model_input_0" [id=0, type=nncf_model_input]; +"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize]; +"2 /nncf_model_input_1" [id=2, type=nncf_model_input]; +"3 SymmetricQuantizer/symmetric_quantize_1" [id=3, type=symmetric_quantize]; +"4 /nncf_model_input_2" [id=4, type=nncf_model_input]; +"5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0" [id=5, type=scaled_dot_product_attention]; +"6 /nncf_model_output_0" [id=6, type=nncf_model_output]; +"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0"; +"1 SymmetricQuantizer/symmetric_quantize_0" -> "5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0"; +"2 /nncf_model_input_1" -> "3 SymmetricQuantizer/symmetric_quantize_1"; +"3 SymmetricQuantizer/symmetric_quantize_1" -> "5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0"; +"4 /nncf_model_input_2" -> "5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0"; +"5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0" -> "6 /nncf_model_output_0"; +} diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index 93281435104..aa427735b53 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -15,6 +15,7 @@ import pytest import torch +from nncf import Dataset from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization @@ -22,7 +23,7 @@ from nncf.torch.layers import NNCF_RNN from nncf.torch.layers import LSTMCellNNCF from tests.post_training.test_templates.helpers import EmbeddingModel -from tests.post_training.test_templates.helpers import get_static_dataset +from tests.post_training.test_templates.helpers import ScaledDotProductAttentionModel from tests.torch import test_models from tests.torch.quantization.test_algo_quantization import SharedLayersModel from tests.torch.test_compressed_graph import ModelDesc @@ -49,6 +50,14 @@ def get_model_name(description): TEST_MODELS_DESC = [ (ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}), + ( + ModelDesc( + "scaled_dot_product_attention_model", + ScaledDotProductAttentionModel, + {"query": [1, 8, 16], "key": [1, 8, 16], "value": [1, 8, 16]}, + ), + {}, + ), (ModelDesc("shared_model", SharedLayersModel, [1, 1, 5, 6]), {}), (ModelDesc("alexnet", test_models.AlexNet, [1, 3, 32, 32]), {}), (ModelDesc("lenet", test_models.LeNet, [1, 3, 32, 32]), {}), @@ -96,18 +105,21 @@ def get_model_name(description): def test_min_max_classification_quantized_graphs(desc: ModelDesc, quantization_parameters, graph_dir, mocker): model = desc.model_builder() - nncf_network = wrap_model(model, torch.ones(desc.input_sample_sizes), trace_parameters=True) + if isinstance(desc.input_sample_sizes, dict): + example_input = {} + for name, size in desc.input_sample_sizes.items(): + example_input[name] = torch.ones(size) + else: + example_input = torch.ones(desc.input_sample_sizes) + + nncf_network = wrap_model(model, example_input, trace_parameters=True) quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(disable_bias_correction=True) quantization_parameters["subset_size"] = 1 quantization_algorithm = PostTrainingQuantization(**quantization_parameters) - def transform_fn(input_) -> torch.Tensor: - return torch.tensor(input_[0]) - quantized_model = quantization_algorithm.apply( nncf_network, nncf_network.nncf.get_graph(), - dataset=get_static_dataset(desc.input_sample_sizes, transform_fn, None), + dataset=Dataset([example_input]), ) - check_graph(quantized_model.nncf.get_graph(), desc.dot_filename(), graph_dir)