Skip to content

Commit

Permalink
[PT] Support scaled_dot_product_attention (#2761)
Browse files Browse the repository at this point in the history
### Changes

Add PTScaledDotProductAttentionMetatype

### Related tickets

130925
  • Loading branch information
AlexanderDokuchaev authored Jun 25, 2024
1 parent 2871993 commit 4924b7e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 8 deletions.
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def group_conv_metatypes(self) -> List[OperatorMetatype]:

@property
def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]:
return []
return [om.PTScaledDotProductAttentionMetatype]

@property
def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]:
Expand Down
10 changes: 10 additions & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,16 @@ class PTReduceL2(PTOperatorMetatype):
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
name = "ScaledDotProductAttentionOp"
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["scaled_dot_product_attention"],
}
hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION]
target_input_ports = [0, 1]


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
"""
Returns a list of the operator metatypes.
Expand Down
8 changes: 8 additions & 0 deletions tests/post_training/test_templates/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,3 +424,11 @@ def forward(self, x):
x = self.conv(x)
x = self.linear(x)
return x


class ScaledDotProductAttentionModel(nn.Module):
def __init__(self):
super().__init__()

def forward(self, query, key, value):
return nn.functional.scaled_dot_product_attention(query, key, value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
strict digraph {
"0 /nncf_model_input_0" [id=0, type=nncf_model_input];
"1 SymmetricQuantizer/symmetric_quantize_0" [id=1, type=symmetric_quantize];
"2 /nncf_model_input_1" [id=2, type=nncf_model_input];
"3 SymmetricQuantizer/symmetric_quantize_1" [id=3, type=symmetric_quantize];
"4 /nncf_model_input_2" [id=4, type=nncf_model_input];
"5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0" [id=5, type=scaled_dot_product_attention];
"6 /nncf_model_output_0" [id=6, type=nncf_model_output];
"0 /nncf_model_input_0" -> "1 SymmetricQuantizer/symmetric_quantize_0";
"1 SymmetricQuantizer/symmetric_quantize_0" -> "5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0";
"2 /nncf_model_input_1" -> "3 SymmetricQuantizer/symmetric_quantize_1";
"3 SymmetricQuantizer/symmetric_quantize_1" -> "5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0";
"4 /nncf_model_input_2" -> "5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0";
"5 ScaledDotProductAttentionModel/scaled_dot_product_attention_0" -> "6 /nncf_model_output_0";
}
26 changes: 19 additions & 7 deletions tests/torch/ptq/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
import pytest
import torch

from nncf import Dataset
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.torch import wrap_model
from nncf.torch.layers import NNCF_RNN
from nncf.torch.layers import LSTMCellNNCF
from tests.post_training.test_templates.helpers import EmbeddingModel
from tests.post_training.test_templates.helpers import get_static_dataset
from tests.post_training.test_templates.helpers import ScaledDotProductAttentionModel
from tests.torch import test_models
from tests.torch.quantization.test_algo_quantization import SharedLayersModel
from tests.torch.test_compressed_graph import ModelDesc
Expand All @@ -49,6 +50,14 @@ def get_model_name(description):

TEST_MODELS_DESC = [
(ModelDesc("embedding_model", EmbeddingModel, [1, 10]), {}),
(
ModelDesc(
"scaled_dot_product_attention_model",
ScaledDotProductAttentionModel,
{"query": [1, 8, 16], "key": [1, 8, 16], "value": [1, 8, 16]},
),
{},
),
(ModelDesc("shared_model", SharedLayersModel, [1, 1, 5, 6]), {}),
(ModelDesc("alexnet", test_models.AlexNet, [1, 3, 32, 32]), {}),
(ModelDesc("lenet", test_models.LeNet, [1, 3, 32, 32]), {}),
Expand Down Expand Up @@ -96,18 +105,21 @@ def get_model_name(description):
def test_min_max_classification_quantized_graphs(desc: ModelDesc, quantization_parameters, graph_dir, mocker):
model = desc.model_builder()

nncf_network = wrap_model(model, torch.ones(desc.input_sample_sizes), trace_parameters=True)
if isinstance(desc.input_sample_sizes, dict):
example_input = {}
for name, size in desc.input_sample_sizes.items():
example_input[name] = torch.ones(size)
else:
example_input = torch.ones(desc.input_sample_sizes)

nncf_network = wrap_model(model, example_input, trace_parameters=True)
quantization_parameters["advanced_parameters"] = AdvancedQuantizationParameters(disable_bias_correction=True)
quantization_parameters["subset_size"] = 1
quantization_algorithm = PostTrainingQuantization(**quantization_parameters)

def transform_fn(input_) -> torch.Tensor:
return torch.tensor(input_[0])

quantized_model = quantization_algorithm.apply(
nncf_network,
nncf_network.nncf.get_graph(),
dataset=get_static_dataset(desc.input_sample_sizes, transform_fn, None),
dataset=Dataset([example_input]),
)

check_graph(quantized_model.nncf.get_graph(), desc.dot_filename(), graph_dir)

0 comments on commit 4924b7e

Please sign in to comment.