diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 2b7727d97ef..5247fbf6d66 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -10,7 +10,6 @@ # limitations under the License. from copy import copy -from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -660,22 +659,37 @@ def _get_node_inputs(node: torch.fx.Node, model: torch.fx.GraphModule) -> Option return tuple(args) +def _get_value( + arg: Optional[Union[torch.fx.Node, float, int]], model: torch.fx.GraphModule +) -> Union[torch.nn.Parameter, float, int]: + """ + Retrieves value from the given argument. It can be either torch.fx.Node or float/int value. + + :param arg: Given arg to retrieve value. + :param model: torch.fx.GraphModule instance. + :return: value from the given argument. + """ + if isinstance(arg, torch.fx.Node): + return get_tensor_constant_from_node(arg, model) + return arg + + def _compress_qdq_constant_transformation(model: torch.fx.GraphModule, matches) -> None: """ Change the FP32 weight value to Int8 and also reshape the scale for per_channel_quantization. :param: model: Model to apply transformations to. """ + for match in matches: mul_node = match.replacements[0] sub_node = match.replacements[1] - weight_node, scale_node, zp_node, axis = None, None, None, None nodes_map = {node.name: match.nodes_map[node] for node in match.nodes_map} - get_const = partial(get_tensor_constant_from_node, model=model) - weight_node = get_const(nodes_map["weight"]) - scale_node = get_const(nodes_map["scale"]) - zp_node = get_const(nodes_map["zero_point"]) - axis = nodes_map["axis"] + + weight_node = _get_value(nodes_map["weight"], model) + scale_node = _get_value(nodes_map["scale"], model) + zp_node = _get_value(nodes_map["zero_point"], model) + axis = _get_value(nodes_map.get("axis"), model) port_id = 0 if axis is not None: result = torch.ops.quantized_decomposed.quantize_per_channel.default( diff --git a/tests/post_training/pipelines/base.py b/tests/post_training/pipelines/base.py index 8fe6581640d..41e9bcf17e4 100644 --- a/tests/post_training/pipelines/base.py +++ b/tests/post_training/pipelines/base.py @@ -418,6 +418,9 @@ def save_compressed_model(self) -> None: ov.serialize(ov_model, self.path_compressed_ir) elif self.backend in OV_BACKENDS: self.path_compressed_ir = self.output_model_dir / "model.xml" + from openvino._offline_transformations import apply_moc_transformations + + apply_moc_transformations(self.compressed_model, cf=True) ov.serialize(self.compressed_model, str(self.path_compressed_ir)) def get_num_compressed(self) -> None: diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py index 36cdddd4439..b675bd6e814 100644 --- a/tests/post_training/pipelines/image_classification_torchvision.py +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -37,6 +37,7 @@ def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch class VisionModelParams: weights: models.WeightsEnum export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule] + export_torch_before_ov_convert: bool = False class ImageClassificationTorchvision(ImageClassificationBase): @@ -47,8 +48,12 @@ class ImageClassificationTorchvision(ImageClassificationBase): models.mobilenet_v3_small: VisionModelParams( models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module ), - models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module), - models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module), + models.vit_b_16: VisionModelParams( + models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True + ), + models.swin_v2_s: VisionModelParams( + models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True + ), } def __init__(self, *args, **kwargs): @@ -92,9 +97,10 @@ def prepare_model(self) -> None: elif self.backend in [BackendType.OV, BackendType.FP32]: with torch.no_grad(): - with disable_patching(): - m = torch.export.export(model, args=(self.dummy_tensor,)) - self.model = ov.convert_model(m, example_input=self.dummy_tensor, input=self.input_size) + if self.model_params.export_torch_before_ov_convert: + with disable_patching(): + model = torch.export.export(model, (self.dummy_tensor,)) + self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size) self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0] self._dump_model_fp32() diff --git a/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_channel_invalid.dot b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_channel_invalid.dot new file mode 100644 index 00000000000..5db5924962f --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_channel_invalid.dot @@ -0,0 +1,50 @@ +strict digraph { +"0 arg0_1" [id=0, type=input]; +"1 _param_constant0" [id=1, type=get_attr]; +"2 _param_constant1" [id=2, type=get_attr]; +"3 scale_node0" [id=3, type=get_attr]; +"4 weight_node0" [id=4, type=get_attr]; +"5 quantize_per_channel_default" [id=5, type=quantize_per_channel]; +"6 add_tensor_2" [id=6, type=add]; +"7 dequantize_per_channel_default" [id=7, type=dequantize_per_channel]; +"8 conv2d" [id=8, type=conv2d]; +"9 _param_constant2" [id=9, type=get_attr]; +"10 _param_constant3" [id=10, type=get_attr]; +"11 conv2d_1" [id=11, type=conv2d]; +"12 _tensor_constant0" [id=12, type=get_attr]; +"13 add_" [id=13, type=add_]; +"14 _tensor_constant0_1" [id=14, type=get_attr]; +"15 add__1" [id=15, type=add_]; +"16 add" [id=16, type=add]; +"17 _param_constant4" [id=17, type=get_attr]; +"18 _param_constant5" [id=18, type=get_attr]; +"19 conv2d_2" [id=19, type=conv2d]; +"20 _tensor_constant0_2" [id=20, type=get_attr]; +"21 add_1" [id=21, type=add]; +"22 output" [id=22, type=output]; +"0 arg0_1" -> "8 conv2d" [label="(1, 3, 224, 224)", style=solid]; +"1 _param_constant0" -> "5 quantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid]; +"2 _param_constant1" -> "8 conv2d" [label="(3,)", style=solid]; +"3 scale_node0" -> "5 quantize_per_channel_default" [label="(3,)", style=solid]; +"3 scale_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid]; +"4 weight_node0" -> "5 quantize_per_channel_default" [label="(3,)", style=solid]; +"4 weight_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid]; +"5 quantize_per_channel_default" -> "6 add_tensor_2" [label="(3, 3, 1, 1)", style=solid]; +"6 add_tensor_2" -> "7 dequantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid]; +"7 dequantize_per_channel_default" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid]; +"8 conv2d" -> "11 conv2d_1" [label="(1, 3, 224, 224)", style=solid]; +"8 conv2d" -> "13 add_" [label="(1, 3, 224, 224)", style=solid]; +"9 _param_constant2" -> "11 conv2d_1" [label="(3, 3, 1, 1)", style=solid]; +"10 _param_constant3" -> "11 conv2d_1" [label="(3,)", style=solid]; +"11 conv2d_1" -> "15 add__1" [label="(1, 3, 224, 224)", style=solid]; +"12 _tensor_constant0" -> "13 add_" [label="(1,)", style=solid]; +"13 add_" -> "16 add" [label="(1, 3, 224, 224)", style=solid]; +"14 _tensor_constant0_1" -> "15 add__1" [label="(1,)", style=solid]; +"15 add__1" -> "16 add" [label="(1, 3, 224, 224)", style=solid]; +"16 add" -> "19 conv2d_2" [label="(1, 3, 224, 224)", style=solid]; +"17 _param_constant4" -> "19 conv2d_2" [label="(3, 3, 1, 1)", style=solid]; +"18 _param_constant5" -> "19 conv2d_2" [label="(3,)", style=solid]; +"19 conv2d_2" -> "21 add_1" [label="(1, 3, 224, 224)", style=solid]; +"20 _tensor_constant0_2" -> "21 add_1" [label="(1,)", style=solid]; +"21 add_1" -> "22 output" [label="(1, 3, 224, 224)", style=solid]; +} diff --git a/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_channel_valid.dot b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_channel_valid.dot new file mode 100644 index 00000000000..232b2ba544a --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_channel_valid.dot @@ -0,0 +1,46 @@ +strict digraph { +"0 arg0_1" [id=0, type=input]; +"1 _param_constant1" [id=1, type=get_attr]; +"2 scale_updated_constant0" [id=2, type=get_attr]; +"3 compressed_weight_updated_constant0" [id=3, type=get_attr]; +"4 mul_tensor" [id=4, type=mul]; +"5 zero_point_updated_constant0" [id=5, type=get_attr]; +"6 sub_tensor" [id=6, type=sub]; +"7 conv2d" [id=7, type=conv2d]; +"8 _param_constant2" [id=8, type=get_attr]; +"9 _param_constant3" [id=9, type=get_attr]; +"10 conv2d_1" [id=10, type=conv2d]; +"11 _tensor_constant0" [id=11, type=get_attr]; +"12 add_" [id=12, type=add_]; +"13 _tensor_constant0_1" [id=13, type=get_attr]; +"14 add__1" [id=14, type=add_]; +"15 add" [id=15, type=add]; +"16 _param_constant4" [id=16, type=get_attr]; +"17 _param_constant5" [id=17, type=get_attr]; +"18 conv2d_2" [id=18, type=conv2d]; +"19 _tensor_constant0_2" [id=19, type=get_attr]; +"20 add_1" [id=20, type=add]; +"21 output" [id=21, type=output]; +"0 arg0_1" -> "7 conv2d" [label="(1, 3, 224, 224)", style=solid]; +"1 _param_constant1" -> "7 conv2d" [label="(3,)", style=solid]; +"2 scale_updated_constant0" -> "4 mul_tensor" [label="(3, 1, 1, 1)", style=solid]; +"3 compressed_weight_updated_constant0" -> "4 mul_tensor" [label="(3, 3, 1, 1)", style=solid]; +"4 mul_tensor" -> "6 sub_tensor" [label="(3, 3, 1, 1)", style=solid]; +"5 zero_point_updated_constant0" -> "6 sub_tensor" [label="(3, 1, 1, 1)", style=solid]; +"6 sub_tensor" -> "7 conv2d" [label="(3, 3, 1, 1)", style=solid]; +"7 conv2d" -> "10 conv2d_1" [label="(1, 3, 224, 224)", style=solid]; +"7 conv2d" -> "12 add_" [label="(1, 3, 224, 224)", style=solid]; +"8 _param_constant2" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid]; +"9 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid]; +"10 conv2d_1" -> "14 add__1" [label="(1, 3, 224, 224)", style=solid]; +"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid]; +"12 add_" -> "15 add" [label="(1, 3, 224, 224)", style=solid]; +"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid]; +"14 add__1" -> "15 add" [label="(1, 3, 224, 224)", style=solid]; +"15 add" -> "18 conv2d_2" [label="(1, 3, 224, 224)", style=solid]; +"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid]; +"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid]; +"18 conv2d_2" -> "20 add_1" [label="(1, 3, 224, 224)", style=solid]; +"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid]; +"20 add_1" -> "21 output" [label="(1, 3, 224, 224)", style=solid]; +} diff --git a/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_tensor_invalid.dot b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_tensor_invalid.dot new file mode 100644 index 00000000000..c603317a2d1 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_tensor_invalid.dot @@ -0,0 +1,44 @@ +strict digraph { +"0 arg0_1" [id=0, type=input]; +"1 _param_constant0" [id=1, type=get_attr]; +"2 _param_constant1" [id=2, type=get_attr]; +"3 quantize_per_tensor_default" [id=3, type=quantize_per_tensor]; +"4 add_tensor_2" [id=4, type=add]; +"5 dequantize_per_tensor_default" [id=5, type=dequantize_per_tensor]; +"6 conv2d" [id=6, type=conv2d]; +"7 _param_constant2" [id=7, type=get_attr]; +"8 _param_constant3" [id=8, type=get_attr]; +"9 conv2d_1" [id=9, type=conv2d]; +"10 _tensor_constant0" [id=10, type=get_attr]; +"11 add_" [id=11, type=add_]; +"12 _tensor_constant0_1" [id=12, type=get_attr]; +"13 add__1" [id=13, type=add_]; +"14 add" [id=14, type=add]; +"15 _param_constant4" [id=15, type=get_attr]; +"16 _param_constant5" [id=16, type=get_attr]; +"17 conv2d_2" [id=17, type=conv2d]; +"18 _tensor_constant0_2" [id=18, type=get_attr]; +"19 add_1" [id=19, type=add]; +"20 output" [id=20, type=output]; +"0 arg0_1" -> "6 conv2d" [label="(1, 3, 224, 224)", style=solid]; +"1 _param_constant0" -> "3 quantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid]; +"2 _param_constant1" -> "6 conv2d" [label="(3,)", style=solid]; +"3 quantize_per_tensor_default" -> "4 add_tensor_2" [label="(3, 3, 1, 1)", style=solid]; +"4 add_tensor_2" -> "5 dequantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid]; +"5 dequantize_per_tensor_default" -> "6 conv2d" [label="(3, 3, 1, 1)", style=solid]; +"6 conv2d" -> "9 conv2d_1" [label="(1, 3, 224, 224)", style=solid]; +"6 conv2d" -> "11 add_" [label="(1, 3, 224, 224)", style=solid]; +"7 _param_constant2" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid]; +"8 _param_constant3" -> "9 conv2d_1" [label="(3,)", style=solid]; +"9 conv2d_1" -> "13 add__1" [label="(1, 3, 224, 224)", style=solid]; +"10 _tensor_constant0" -> "11 add_" [label="(1,)", style=solid]; +"11 add_" -> "14 add" [label="(1, 3, 224, 224)", style=solid]; +"12 _tensor_constant0_1" -> "13 add__1" [label="(1,)", style=solid]; +"13 add__1" -> "14 add" [label="(1, 3, 224, 224)", style=solid]; +"14 add" -> "17 conv2d_2" [label="(1, 3, 224, 224)", style=solid]; +"15 _param_constant4" -> "17 conv2d_2" [label="(3, 3, 1, 1)", style=solid]; +"16 _param_constant5" -> "17 conv2d_2" [label="(3,)", style=solid]; +"17 conv2d_2" -> "19 add_1" [label="(1, 3, 224, 224)", style=solid]; +"18 _tensor_constant0_2" -> "19 add_1" [label="(1,)", style=solid]; +"19 add_1" -> "20 output" [label="(1, 3, 224, 224)", style=solid]; +} diff --git a/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_tensor_valid.dot b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_tensor_valid.dot new file mode 100644 index 00000000000..d4b52ec0840 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/transformed/compress_post_quantize_per_tensor_valid.dot @@ -0,0 +1,42 @@ +strict digraph { +"0 arg0_1" [id=0, type=input]; +"1 _param_constant1" [id=1, type=get_attr]; +"2 compressed_weight_updated_constant0" [id=2, type=get_attr]; +"3 mul_tensor" [id=3, type=mul]; +"4 sub_tensor" [id=4, type=sub]; +"5 conv2d" [id=5, type=conv2d]; +"6 _param_constant2" [id=6, type=get_attr]; +"7 _param_constant3" [id=7, type=get_attr]; +"8 conv2d_1" [id=8, type=conv2d]; +"9 _tensor_constant0" [id=9, type=get_attr]; +"10 add_" [id=10, type=add_]; +"11 _tensor_constant0_1" [id=11, type=get_attr]; +"12 add__1" [id=12, type=add_]; +"13 add" [id=13, type=add]; +"14 _param_constant4" [id=14, type=get_attr]; +"15 _param_constant5" [id=15, type=get_attr]; +"16 conv2d_2" [id=16, type=conv2d]; +"17 _tensor_constant0_2" [id=17, type=get_attr]; +"18 add_1" [id=18, type=add]; +"19 output" [id=19, type=output]; +"0 arg0_1" -> "5 conv2d" [label="(1, 3, 224, 224)", style=solid]; +"1 _param_constant1" -> "5 conv2d" [label="(3,)", style=solid]; +"2 compressed_weight_updated_constant0" -> "3 mul_tensor" [label="(3, 3, 1, 1)", style=solid]; +"3 mul_tensor" -> "4 sub_tensor" [label="(3, 3, 1, 1)", style=solid]; +"4 sub_tensor" -> "5 conv2d" [label="(3, 3, 1, 1)", style=solid]; +"5 conv2d" -> "8 conv2d_1" [label="(1, 3, 224, 224)", style=solid]; +"5 conv2d" -> "10 add_" [label="(1, 3, 224, 224)", style=solid]; +"6 _param_constant2" -> "8 conv2d_1" [label="(3, 3, 1, 1)", style=solid]; +"7 _param_constant3" -> "8 conv2d_1" [label="(3,)", style=solid]; +"8 conv2d_1" -> "12 add__1" [label="(1, 3, 224, 224)", style=solid]; +"9 _tensor_constant0" -> "10 add_" [label="(1,)", style=solid]; +"10 add_" -> "13 add" [label="(1, 3, 224, 224)", style=solid]; +"11 _tensor_constant0_1" -> "12 add__1" [label="(1,)", style=solid]; +"12 add__1" -> "13 add" [label="(1, 3, 224, 224)", style=solid]; +"13 add" -> "16 conv2d_2" [label="(1, 3, 224, 224)", style=solid]; +"14 _param_constant4" -> "16 conv2d_2" [label="(3, 3, 1, 1)", style=solid]; +"15 _param_constant5" -> "16 conv2d_2" [label="(3,)", style=solid]; +"16 conv2d_2" -> "18 add_1" [label="(1, 3, 224, 224)", style=solid]; +"17 _tensor_constant0_2" -> "18 add_1" [label="(1,)", style=solid]; +"18 add_1" -> "19 output" [label="(1, 3, 224, 224)", style=solid]; +} diff --git a/tests/torch/fx/test_model_transformer.py b/tests/torch/fx/test_model_transformer.py index 2d13ba758e5..d0799a0d5ed 100644 --- a/tests/torch/fx/test_model_transformer.py +++ b/tests/torch/fx/test_model_transformer.py @@ -35,6 +35,7 @@ from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node from nncf.experimental.torch.fx.transformations import _set_new_node_meta from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder +from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder from nncf.experimental.torch.fx.transformations import module_insertion_transformation_builder @@ -46,7 +47,6 @@ from nncf.torch.graph.operator_metatypes import CONST_NOOP_METATYPES from nncf.torch.graph.transformations.commands import PTModelExtractionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint -from tests.torch.fx.test_sanity import count_q_dq from tests.torch.test_compressed_graph import check_graph from tests.torch.test_models.synthetic import ConvolutionWithAllConstantInputsModel from tests.torch.test_models.synthetic import ConvolutionWithNotTensorBiasModel @@ -406,56 +406,82 @@ def get_shared_constant_nodes(nncf_graph: NNCFGraph): return shared_const_node_consumer_node -def insert_qdq_add_nodes(model: torch.fx.GraphModule): - const_node = get_graph_node_by_name(model.graph, "_param_constant0") - quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default - dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default - add_op = torch.add - conv_node = get_graph_node_by_name(model.graph, "conv2d") - with model.graph.inserting_before(conv_node): - scale_node = create_getattr_from_value( - model, - model.graph, - "scale_node", - torch.ones( - [ - 3, - ] - ), - ) - zp_node = create_getattr_from_value( - model, - model.graph, - "weight_node", - torch.ones( - [ - 3, - ] - ), - ) +def insert_qdq_nodes( + model: torch.fx.GraphModule, + correct_pattern: bool, + per_channel: bool, + node_name: str = "conv2d", + w_const_node_name: str = "_param_constant0", +): + const_node = get_graph_node_by_name(model.graph, w_const_node_name) + if per_channel: + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + + conv_node = get_graph_node_by_name(model.graph, node_name) + if per_channel: + with model.graph.inserting_before(conv_node): + scale_node = create_getattr_from_value( + model, + model.graph, + "scale_node", + torch.ones([3]), + ) + zp_node = create_getattr_from_value( + model, + model.graph, + "weight_node", + torch.ones([3]), + ) qdq_args = (scale_node, zp_node, 0, -128, 127, torch.int8) + else: + qdq_args = (1.0, 1, -128, 127, torch.int8) + with model.graph.inserting_before(conv_node): q_node = model.graph.create_node("call_function", quantize_op, (const_node,) + qdq_args, {}) - add_node = model.graph.create_node("call_function", add_op, (q_node, 0), {}) - dq_node = model.graph.create_node("call_function", dequantize_op, (add_node,) + qdq_args, {}) - _set_new_node_meta(q_node, (const_node,) + qdq_args, quantize_op, model) - _set_new_node_meta(add_node, (q_node, 0), add_op, model) - _set_new_node_meta(dq_node, (add_node,) + qdq_args, dequantize_op, model) + if not correct_pattern: + add_op = torch.ops.aten.add.Tensor + add_node = model.graph.create_node("call_function", add_op, (q_node, 0), {}) + dq_node = model.graph.create_node("call_function", dequantize_op, (add_node,) + qdq_args, {}) + _set_new_node_meta(q_node, (const_node,) + qdq_args, quantize_op, model) + _set_new_node_meta(add_node, (q_node, 0), add_op, model) + _set_new_node_meta(dq_node, (add_node,) + qdq_args, dequantize_op, model) + else: + dq_node = model.graph.create_node("call_function", dequantize_op, (q_node,) + qdq_args, {}) + _set_new_node_meta(q_node, (const_node,) + qdq_args, quantize_op, model) + _set_new_node_meta(dq_node, (q_node,) + qdq_args, dequantize_op, model) conv_node.replace_input_with(const_node, dq_node) + model.graph.eliminate_dead_code() + model.recompile() -def test_different_qdq_pattern(): +def test_compress_post_quantize_transformation(is_per_channel: bool): model = MultiBranchesConnectedModel() ex_input = torch.ones(1, 3, 224, 224) - captured_model = _capture_model(model, ex_input) - quantized_before_insertion = nncf.quantize(captured_model, nncf.Dataset([ex_input])) - q_before, dq_before = count_q_dq(quantized_before_insertion) - insert_qdq_add_nodes(captured_model) - quantized_after_insertion = nncf.quantize(captured_model, nncf.Dataset([ex_input])) - q_after, dq_after = count_q_dq(quantized_after_insertion) - assert q_before == 5 - assert dq_before == 6 - assert q_after == 6 - assert dq_after == 7 + + model_with_correct_pattern = _capture_model(model, ex_input) + insert_qdq_nodes(model_with_correct_pattern, correct_pattern=True, per_channel=is_per_channel) + compress_post_quantize_transformation(model_with_correct_pattern) + graph_name = f"compress_post_quantize_{'per_channel' if is_per_channel else 'per_tensor'}_valid.dot" + check_graph( + NNCFGraphFactory.create(model_with_correct_pattern), + graph_name, + TRANSFORMED_GRAPH_DIR_NAME, + extended=True, + ) + + model_with_incorrect_pattern = _capture_model(model, ex_input) + insert_qdq_nodes(model_with_incorrect_pattern, correct_pattern=False, per_channel=is_per_channel) + compress_post_quantize_transformation(model_with_incorrect_pattern) + graph_name = f"compress_post_quantize_{'per_channel' if is_per_channel else 'per_tensor'}_invalid.dot" + check_graph( + NNCFGraphFactory.create(model_with_incorrect_pattern), + graph_name, + TRANSFORMED_GRAPH_DIR_NAME, + extended=True, + ) def test_update_shared_constant():