Skip to content

Commit

Permalink
[Conformance] TorchFX/OV backends Alignment (#2996)
Browse files Browse the repository at this point in the history
### Changes

* ~~Constant folding is applied to all TorchFX models before the
quantization~~
* Some torchvision models (swin_v2_s, vit_16_b) are exported by
`torch.export.export` before ov conversation
* Moc transformations are applied to openvino compressed models after
the compression

After the #2984 
* Fixed `_compress_qdq_constant_transformation` for per tensor case

### Reason for changes

* To align TorchFX/OV quantized models

### Related tickets

#2766

### Tests

post_training_quantization/504/ is finished successfully
  • Loading branch information
daniil-lyakhov authored Oct 30, 2024
1 parent c01a50f commit b280eb7
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 56 deletions.
28 changes: 21 additions & 7 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

from copy import copy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -660,22 +659,37 @@ def _get_node_inputs(node: torch.fx.Node, model: torch.fx.GraphModule) -> Option
return tuple(args)


def _get_value(
arg: Optional[Union[torch.fx.Node, float, int]], model: torch.fx.GraphModule
) -> Union[torch.nn.Parameter, float, int]:
"""
Retrieves value from the given argument. It can be either torch.fx.Node or float/int value.
:param arg: Given arg to retrieve value.
:param model: torch.fx.GraphModule instance.
:return: value from the given argument.
"""
if isinstance(arg, torch.fx.Node):
return get_tensor_constant_from_node(arg, model)
return arg


def _compress_qdq_constant_transformation(model: torch.fx.GraphModule, matches) -> None:
"""
Change the FP32 weight value to Int8 and also reshape the scale for per_channel_quantization.
:param: model: Model to apply transformations to.
"""

for match in matches:
mul_node = match.replacements[0]
sub_node = match.replacements[1]
weight_node, scale_node, zp_node, axis = None, None, None, None
nodes_map = {node.name: match.nodes_map[node] for node in match.nodes_map}
get_const = partial(get_tensor_constant_from_node, model=model)
weight_node = get_const(nodes_map["weight"])
scale_node = get_const(nodes_map["scale"])
zp_node = get_const(nodes_map["zero_point"])
axis = nodes_map["axis"]

weight_node = _get_value(nodes_map["weight"], model)
scale_node = _get_value(nodes_map["scale"], model)
zp_node = _get_value(nodes_map["zero_point"], model)
axis = _get_value(nodes_map.get("axis"), model)
port_id = 0
if axis is not None:
result = torch.ops.quantized_decomposed.quantize_per_channel.default(
Expand Down
3 changes: 3 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def save_compressed_model(self) -> None:
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend in OV_BACKENDS:
self.path_compressed_ir = self.output_model_dir / "model.xml"
from openvino._offline_transformations import apply_moc_transformations

apply_moc_transformations(self.compressed_model, cf=True)
ov.serialize(self.compressed_model, str(self.path_compressed_ir))

def get_num_compressed(self) -> None:
Expand Down
16 changes: 11 additions & 5 deletions tests/post_training/pipelines/image_classification_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch
class VisionModelParams:
weights: models.WeightsEnum
export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule]
export_torch_before_ov_convert: bool = False


class ImageClassificationTorchvision(ImageClassificationBase):
Expand All @@ -47,8 +48,12 @@ class ImageClassificationTorchvision(ImageClassificationBase):
models.mobilenet_v3_small: VisionModelParams(
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
),
models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module),
models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module),
models.vit_b_16: VisionModelParams(
models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
models.swin_v2_s: VisionModelParams(
models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -92,9 +97,10 @@ def prepare_model(self) -> None:

elif self.backend in [BackendType.OV, BackendType.FP32]:
with torch.no_grad():
with disable_patching():
m = torch.export.export(model, args=(self.dummy_tensor,))
self.model = ov.convert_model(m, example_input=self.dummy_tensor, input=self.input_size)
if self.model_params.export_torch_before_ov_convert:
with disable_patching():
model = torch.export.export(model, (self.dummy_tensor,))
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

self._dump_model_fp32()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 scale_node0" [id=3, type=get_attr];
"4 weight_node0" [id=4, type=get_attr];
"5 quantize_per_channel_default" [id=5, type=quantize_per_channel];
"6 add_tensor_2" [id=6, type=add];
"7 dequantize_per_channel_default" [id=7, type=dequantize_per_channel];
"8 conv2d" [id=8, type=conv2d];
"9 _param_constant2" [id=9, type=get_attr];
"10 _param_constant3" [id=10, type=get_attr];
"11 conv2d_1" [id=11, type=conv2d];
"12 _tensor_constant0" [id=12, type=get_attr];
"13 add_" [id=13, type=add_];
"14 _tensor_constant0_1" [id=14, type=get_attr];
"15 add__1" [id=15, type=add_];
"16 add" [id=16, type=add];
"17 _param_constant4" [id=17, type=get_attr];
"18 _param_constant5" [id=18, type=get_attr];
"19 conv2d_2" [id=19, type=conv2d];
"20 _tensor_constant0_2" [id=20, type=get_attr];
"21 add_1" [id=21, type=add];
"22 output" [id=22, type=output];
"0 arg0_1" -> "8 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant0" -> "5 quantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "8 conv2d" [label="(3,)", style=solid];
"3 scale_node0" -> "5 quantize_per_channel_default" [label="(3,)", style=solid];
"3 scale_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid];
"4 weight_node0" -> "5 quantize_per_channel_default" [label="(3,)", style=solid];
"4 weight_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid];
"5 quantize_per_channel_default" -> "6 add_tensor_2" [label="(3, 3, 1, 1)", style=solid];
"6 add_tensor_2" -> "7 dequantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid];
"7 dequantize_per_channel_default" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid];
"8 conv2d" -> "11 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"8 conv2d" -> "13 add_" [label="(1, 3, 224, 224)", style=solid];
"9 _param_constant2" -> "11 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"10 _param_constant3" -> "11 conv2d_1" [label="(3,)", style=solid];
"11 conv2d_1" -> "15 add__1" [label="(1, 3, 224, 224)", style=solid];
"12 _tensor_constant0" -> "13 add_" [label="(1,)", style=solid];
"13 add_" -> "16 add" [label="(1, 3, 224, 224)", style=solid];
"14 _tensor_constant0_1" -> "15 add__1" [label="(1,)", style=solid];
"15 add__1" -> "16 add" [label="(1, 3, 224, 224)", style=solid];
"16 add" -> "19 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"17 _param_constant4" -> "19 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"18 _param_constant5" -> "19 conv2d_2" [label="(3,)", style=solid];
"19 conv2d_2" -> "21 add_1" [label="(1, 3, 224, 224)", style=solid];
"20 _tensor_constant0_2" -> "21 add_1" [label="(1,)", style=solid];
"21 add_1" -> "22 output" [label="(1, 3, 224, 224)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant1" [id=1, type=get_attr];
"2 scale_updated_constant0" [id=2, type=get_attr];
"3 compressed_weight_updated_constant0" [id=3, type=get_attr];
"4 mul_tensor" [id=4, type=mul];
"5 zero_point_updated_constant0" [id=5, type=get_attr];
"6 sub_tensor" [id=6, type=sub];
"7 conv2d" [id=7, type=conv2d];
"8 _param_constant2" [id=8, type=get_attr];
"9 _param_constant3" [id=9, type=get_attr];
"10 conv2d_1" [id=10, type=conv2d];
"11 _tensor_constant0" [id=11, type=get_attr];
"12 add_" [id=12, type=add_];
"13 _tensor_constant0_1" [id=13, type=get_attr];
"14 add__1" [id=14, type=add_];
"15 add" [id=15, type=add];
"16 _param_constant4" [id=16, type=get_attr];
"17 _param_constant5" [id=17, type=get_attr];
"18 conv2d_2" [id=18, type=conv2d];
"19 _tensor_constant0_2" [id=19, type=get_attr];
"20 add_1" [id=20, type=add];
"21 output" [id=21, type=output];
"0 arg0_1" -> "7 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant1" -> "7 conv2d" [label="(3,)", style=solid];
"2 scale_updated_constant0" -> "4 mul_tensor" [label="(3, 1, 1, 1)", style=solid];
"3 compressed_weight_updated_constant0" -> "4 mul_tensor" [label="(3, 3, 1, 1)", style=solid];
"4 mul_tensor" -> "6 sub_tensor" [label="(3, 3, 1, 1)", style=solid];
"5 zero_point_updated_constant0" -> "6 sub_tensor" [label="(3, 1, 1, 1)", style=solid];
"6 sub_tensor" -> "7 conv2d" [label="(3, 3, 1, 1)", style=solid];
"7 conv2d" -> "10 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"7 conv2d" -> "12 add_" [label="(1, 3, 224, 224)", style=solid];
"8 _param_constant2" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"9 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid];
"10 conv2d_1" -> "14 add__1" [label="(1, 3, 224, 224)", style=solid];
"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid];
"12 add_" -> "15 add" [label="(1, 3, 224, 224)", style=solid];
"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid];
"14 add__1" -> "15 add" [label="(1, 3, 224, 224)", style=solid];
"15 add" -> "18 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid];
"18 conv2d_2" -> "20 add_1" [label="(1, 3, 224, 224)", style=solid];
"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid];
"20 add_1" -> "21 output" [label="(1, 3, 224, 224)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 quantize_per_tensor_default" [id=3, type=quantize_per_tensor];
"4 add_tensor_2" [id=4, type=add];
"5 dequantize_per_tensor_default" [id=5, type=dequantize_per_tensor];
"6 conv2d" [id=6, type=conv2d];
"7 _param_constant2" [id=7, type=get_attr];
"8 _param_constant3" [id=8, type=get_attr];
"9 conv2d_1" [id=9, type=conv2d];
"10 _tensor_constant0" [id=10, type=get_attr];
"11 add_" [id=11, type=add_];
"12 _tensor_constant0_1" [id=12, type=get_attr];
"13 add__1" [id=13, type=add_];
"14 add" [id=14, type=add];
"15 _param_constant4" [id=15, type=get_attr];
"16 _param_constant5" [id=16, type=get_attr];
"17 conv2d_2" [id=17, type=conv2d];
"18 _tensor_constant0_2" [id=18, type=get_attr];
"19 add_1" [id=19, type=add];
"20 output" [id=20, type=output];
"0 arg0_1" -> "6 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant0" -> "3 quantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "6 conv2d" [label="(3,)", style=solid];
"3 quantize_per_tensor_default" -> "4 add_tensor_2" [label="(3, 3, 1, 1)", style=solid];
"4 add_tensor_2" -> "5 dequantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid];
"5 dequantize_per_tensor_default" -> "6 conv2d" [label="(3, 3, 1, 1)", style=solid];
"6 conv2d" -> "9 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"6 conv2d" -> "11 add_" [label="(1, 3, 224, 224)", style=solid];
"7 _param_constant2" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"8 _param_constant3" -> "9 conv2d_1" [label="(3,)", style=solid];
"9 conv2d_1" -> "13 add__1" [label="(1, 3, 224, 224)", style=solid];
"10 _tensor_constant0" -> "11 add_" [label="(1,)", style=solid];
"11 add_" -> "14 add" [label="(1, 3, 224, 224)", style=solid];
"12 _tensor_constant0_1" -> "13 add__1" [label="(1,)", style=solid];
"13 add__1" -> "14 add" [label="(1, 3, 224, 224)", style=solid];
"14 add" -> "17 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"15 _param_constant4" -> "17 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"16 _param_constant5" -> "17 conv2d_2" [label="(3,)", style=solid];
"17 conv2d_2" -> "19 add_1" [label="(1, 3, 224, 224)", style=solid];
"18 _tensor_constant0_2" -> "19 add_1" [label="(1,)", style=solid];
"19 add_1" -> "20 output" [label="(1, 3, 224, 224)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant1" [id=1, type=get_attr];
"2 compressed_weight_updated_constant0" [id=2, type=get_attr];
"3 mul_tensor" [id=3, type=mul];
"4 sub_tensor" [id=4, type=sub];
"5 conv2d" [id=5, type=conv2d];
"6 _param_constant2" [id=6, type=get_attr];
"7 _param_constant3" [id=7, type=get_attr];
"8 conv2d_1" [id=8, type=conv2d];
"9 _tensor_constant0" [id=9, type=get_attr];
"10 add_" [id=10, type=add_];
"11 _tensor_constant0_1" [id=11, type=get_attr];
"12 add__1" [id=12, type=add_];
"13 add" [id=13, type=add];
"14 _param_constant4" [id=14, type=get_attr];
"15 _param_constant5" [id=15, type=get_attr];
"16 conv2d_2" [id=16, type=conv2d];
"17 _tensor_constant0_2" [id=17, type=get_attr];
"18 add_1" [id=18, type=add];
"19 output" [id=19, type=output];
"0 arg0_1" -> "5 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant1" -> "5 conv2d" [label="(3,)", style=solid];
"2 compressed_weight_updated_constant0" -> "3 mul_tensor" [label="(3, 3, 1, 1)", style=solid];
"3 mul_tensor" -> "4 sub_tensor" [label="(3, 3, 1, 1)", style=solid];
"4 sub_tensor" -> "5 conv2d" [label="(3, 3, 1, 1)", style=solid];
"5 conv2d" -> "8 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"5 conv2d" -> "10 add_" [label="(1, 3, 224, 224)", style=solid];
"6 _param_constant2" -> "8 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"7 _param_constant3" -> "8 conv2d_1" [label="(3,)", style=solid];
"8 conv2d_1" -> "12 add__1" [label="(1, 3, 224, 224)", style=solid];
"9 _tensor_constant0" -> "10 add_" [label="(1,)", style=solid];
"10 add_" -> "13 add" [label="(1, 3, 224, 224)", style=solid];
"11 _tensor_constant0_1" -> "12 add__1" [label="(1,)", style=solid];
"12 add__1" -> "13 add" [label="(1, 3, 224, 224)", style=solid];
"13 add" -> "16 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"14 _param_constant4" -> "16 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"15 _param_constant5" -> "16 conv2d_2" [label="(3,)", style=solid];
"16 conv2d_2" -> "18 add_1" [label="(1, 3, 224, 224)", style=solid];
"17 _tensor_constant0_2" -> "18 add_1" [label="(1,)", style=solid];
"18 add_1" -> "19 output" [label="(1, 3, 224, 224)", style=solid];
}
Loading

0 comments on commit b280eb7

Please sign in to comment.