Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Conformance] TorchFX/OV backends Alignment #2996

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

from copy import copy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -660,22 +659,37 @@ def _get_node_inputs(node: torch.fx.Node, model: torch.fx.GraphModule) -> Option
return tuple(args)


def _get_value(
arg: Optional[Union[torch.fx.Node, float, int]], model: torch.fx.GraphModule
) -> Union[torch.nn.Parameter, float, int]:
"""
Retrieves value from the given argument. It can be either torch.fx.Node or float/int value.

:param arg: Given arg to retrieve value.
:param model: torch.fx.GraphModule instance.
:return: value from the given argument.
"""
if isinstance(arg, torch.fx.Node):
return get_tensor_constant_from_node(arg, model)
return arg


def _compress_qdq_constant_transformation(model: torch.fx.GraphModule, matches) -> None:
"""
Change the FP32 weight value to Int8 and also reshape the scale for per_channel_quantization.

:param: model: Model to apply transformations to.
"""

for match in matches:
mul_node = match.replacements[0]
sub_node = match.replacements[1]
weight_node, scale_node, zp_node, axis = None, None, None, None
nodes_map = {node.name: match.nodes_map[node] for node in match.nodes_map}
get_const = partial(get_tensor_constant_from_node, model=model)
weight_node = get_const(nodes_map["weight"])
scale_node = get_const(nodes_map["scale"])
zp_node = get_const(nodes_map["zero_point"])
axis = nodes_map["axis"]

weight_node = _get_value(nodes_map["weight"], model)
scale_node = _get_value(nodes_map["scale"], model)
zp_node = _get_value(nodes_map["zero_point"], model)
axis = _get_value(nodes_map.get("axis"), model)
port_id = 0
if axis is not None:
result = torch.ops.quantized_decomposed.quantize_per_channel.default(
Expand Down
3 changes: 3 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def save_compressed_model(self) -> None:
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend in OV_BACKENDS:
self.path_compressed_ir = self.output_model_dir / "model.xml"
from openvino._offline_transformations import apply_moc_transformations

apply_moc_transformations(self.compressed_model, cf=True)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
ov.serialize(self.compressed_model, str(self.path_compressed_ir))

def get_num_compressed(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch
class VisionModelParams:
weights: models.WeightsEnum
export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule]
export_torch_before_ov_convert: bool = False


class ImageClassificationTorchvision(ImageClassificationBase):
Expand All @@ -47,8 +48,12 @@ class ImageClassificationTorchvision(ImageClassificationBase):
models.mobilenet_v3_small: VisionModelParams(
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
),
models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module),
models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module),
models.vit_b_16: VisionModelParams(
models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
models.swin_v2_s: VisionModelParams(
models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -92,9 +97,10 @@ def prepare_model(self) -> None:

elif self.backend in [BackendType.OV, BackendType.FP32]:
with torch.no_grad():
with disable_patching():
m = torch.export.export(model, args=(self.dummy_tensor,))
self.model = ov.convert_model(m, example_input=self.dummy_tensor, input=self.input_size)
if self.model_params.export_torch_before_ov_convert:
with disable_patching():
model = torch.export.export(model, (self.dummy_tensor,))
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

self._dump_model_fp32()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 scale_node0" [id=3, type=get_attr];
"4 weight_node0" [id=4, type=get_attr];
"5 quantize_per_channel_default" [id=5, type=quantize_per_channel];
"6 add_tensor_2" [id=6, type=add];
"7 dequantize_per_channel_default" [id=7, type=dequantize_per_channel];
"8 conv2d" [id=8, type=conv2d];
"9 _param_constant2" [id=9, type=get_attr];
"10 _param_constant3" [id=10, type=get_attr];
"11 conv2d_1" [id=11, type=conv2d];
"12 _tensor_constant0" [id=12, type=get_attr];
"13 add_" [id=13, type=add_];
"14 _tensor_constant0_1" [id=14, type=get_attr];
"15 add__1" [id=15, type=add_];
"16 add" [id=16, type=add];
"17 _param_constant4" [id=17, type=get_attr];
"18 _param_constant5" [id=18, type=get_attr];
"19 conv2d_2" [id=19, type=conv2d];
"20 _tensor_constant0_2" [id=20, type=get_attr];
"21 add_1" [id=21, type=add];
"22 output" [id=22, type=output];
"0 arg0_1" -> "8 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant0" -> "5 quantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "8 conv2d" [label="(3,)", style=solid];
"3 scale_node0" -> "5 quantize_per_channel_default" [label="(3,)", style=solid];
"3 scale_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid];
"4 weight_node0" -> "5 quantize_per_channel_default" [label="(3,)", style=solid];
"4 weight_node0" -> "7 dequantize_per_channel_default" [label="(3,)", style=solid];
"5 quantize_per_channel_default" -> "6 add_tensor_2" [label="(3, 3, 1, 1)", style=solid];
"6 add_tensor_2" -> "7 dequantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid];
"7 dequantize_per_channel_default" -> "8 conv2d" [label="(3, 3, 1, 1)", style=solid];
"8 conv2d" -> "11 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"8 conv2d" -> "13 add_" [label="(1, 3, 224, 224)", style=solid];
"9 _param_constant2" -> "11 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"10 _param_constant3" -> "11 conv2d_1" [label="(3,)", style=solid];
"11 conv2d_1" -> "15 add__1" [label="(1, 3, 224, 224)", style=solid];
"12 _tensor_constant0" -> "13 add_" [label="(1,)", style=solid];
"13 add_" -> "16 add" [label="(1, 3, 224, 224)", style=solid];
"14 _tensor_constant0_1" -> "15 add__1" [label="(1,)", style=solid];
"15 add__1" -> "16 add" [label="(1, 3, 224, 224)", style=solid];
"16 add" -> "19 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"17 _param_constant4" -> "19 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"18 _param_constant5" -> "19 conv2d_2" [label="(3,)", style=solid];
"19 conv2d_2" -> "21 add_1" [label="(1, 3, 224, 224)", style=solid];
"20 _tensor_constant0_2" -> "21 add_1" [label="(1,)", style=solid];
"21 add_1" -> "22 output" [label="(1, 3, 224, 224)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant1" [id=1, type=get_attr];
"2 scale_updated_constant0" [id=2, type=get_attr];
"3 compressed_weight_updated_constant0" [id=3, type=get_attr];
"4 mul_tensor" [id=4, type=mul];
"5 zero_point_updated_constant0" [id=5, type=get_attr];
"6 sub_tensor" [id=6, type=sub];
"7 conv2d" [id=7, type=conv2d];
"8 _param_constant2" [id=8, type=get_attr];
"9 _param_constant3" [id=9, type=get_attr];
"10 conv2d_1" [id=10, type=conv2d];
"11 _tensor_constant0" [id=11, type=get_attr];
"12 add_" [id=12, type=add_];
"13 _tensor_constant0_1" [id=13, type=get_attr];
"14 add__1" [id=14, type=add_];
"15 add" [id=15, type=add];
"16 _param_constant4" [id=16, type=get_attr];
"17 _param_constant5" [id=17, type=get_attr];
"18 conv2d_2" [id=18, type=conv2d];
"19 _tensor_constant0_2" [id=19, type=get_attr];
"20 add_1" [id=20, type=add];
"21 output" [id=21, type=output];
"0 arg0_1" -> "7 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant1" -> "7 conv2d" [label="(3,)", style=solid];
"2 scale_updated_constant0" -> "4 mul_tensor" [label="(3, 1, 1, 1)", style=solid];
"3 compressed_weight_updated_constant0" -> "4 mul_tensor" [label="(3, 3, 1, 1)", style=solid];
"4 mul_tensor" -> "6 sub_tensor" [label="(3, 3, 1, 1)", style=solid];
"5 zero_point_updated_constant0" -> "6 sub_tensor" [label="(3, 1, 1, 1)", style=solid];
"6 sub_tensor" -> "7 conv2d" [label="(3, 3, 1, 1)", style=solid];
"7 conv2d" -> "10 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"7 conv2d" -> "12 add_" [label="(1, 3, 224, 224)", style=solid];
"8 _param_constant2" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"9 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid];
"10 conv2d_1" -> "14 add__1" [label="(1, 3, 224, 224)", style=solid];
"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid];
"12 add_" -> "15 add" [label="(1, 3, 224, 224)", style=solid];
"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid];
"14 add__1" -> "15 add" [label="(1, 3, 224, 224)", style=solid];
"15 add" -> "18 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid];
"18 conv2d_2" -> "20 add_1" [label="(1, 3, 224, 224)", style=solid];
"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid];
"20 add_1" -> "21 output" [label="(1, 3, 224, 224)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 quantize_per_tensor_default" [id=3, type=quantize_per_tensor];
"4 add_tensor_2" [id=4, type=add];
"5 dequantize_per_tensor_default" [id=5, type=dequantize_per_tensor];
"6 conv2d" [id=6, type=conv2d];
"7 _param_constant2" [id=7, type=get_attr];
"8 _param_constant3" [id=8, type=get_attr];
"9 conv2d_1" [id=9, type=conv2d];
"10 _tensor_constant0" [id=10, type=get_attr];
"11 add_" [id=11, type=add_];
"12 _tensor_constant0_1" [id=12, type=get_attr];
"13 add__1" [id=13, type=add_];
"14 add" [id=14, type=add];
"15 _param_constant4" [id=15, type=get_attr];
"16 _param_constant5" [id=16, type=get_attr];
"17 conv2d_2" [id=17, type=conv2d];
"18 _tensor_constant0_2" [id=18, type=get_attr];
"19 add_1" [id=19, type=add];
"20 output" [id=20, type=output];
"0 arg0_1" -> "6 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant0" -> "3 quantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "6 conv2d" [label="(3,)", style=solid];
"3 quantize_per_tensor_default" -> "4 add_tensor_2" [label="(3, 3, 1, 1)", style=solid];
"4 add_tensor_2" -> "5 dequantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid];
"5 dequantize_per_tensor_default" -> "6 conv2d" [label="(3, 3, 1, 1)", style=solid];
"6 conv2d" -> "9 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"6 conv2d" -> "11 add_" [label="(1, 3, 224, 224)", style=solid];
"7 _param_constant2" -> "9 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"8 _param_constant3" -> "9 conv2d_1" [label="(3,)", style=solid];
"9 conv2d_1" -> "13 add__1" [label="(1, 3, 224, 224)", style=solid];
"10 _tensor_constant0" -> "11 add_" [label="(1,)", style=solid];
"11 add_" -> "14 add" [label="(1, 3, 224, 224)", style=solid];
"12 _tensor_constant0_1" -> "13 add__1" [label="(1,)", style=solid];
"13 add__1" -> "14 add" [label="(1, 3, 224, 224)", style=solid];
"14 add" -> "17 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"15 _param_constant4" -> "17 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"16 _param_constant5" -> "17 conv2d_2" [label="(3,)", style=solid];
"17 conv2d_2" -> "19 add_1" [label="(1, 3, 224, 224)", style=solid];
"18 _tensor_constant0_2" -> "19 add_1" [label="(1,)", style=solid];
"19 add_1" -> "20 output" [label="(1, 3, 224, 224)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant1" [id=1, type=get_attr];
"2 compressed_weight_updated_constant0" [id=2, type=get_attr];
"3 mul_tensor" [id=3, type=mul];
"4 sub_tensor" [id=4, type=sub];
"5 conv2d" [id=5, type=conv2d];
"6 _param_constant2" [id=6, type=get_attr];
"7 _param_constant3" [id=7, type=get_attr];
"8 conv2d_1" [id=8, type=conv2d];
"9 _tensor_constant0" [id=9, type=get_attr];
"10 add_" [id=10, type=add_];
"11 _tensor_constant0_1" [id=11, type=get_attr];
"12 add__1" [id=12, type=add_];
"13 add" [id=13, type=add];
"14 _param_constant4" [id=14, type=get_attr];
"15 _param_constant5" [id=15, type=get_attr];
"16 conv2d_2" [id=16, type=conv2d];
"17 _tensor_constant0_2" [id=17, type=get_attr];
"18 add_1" [id=18, type=add];
"19 output" [id=19, type=output];
"0 arg0_1" -> "5 conv2d" [label="(1, 3, 224, 224)", style=solid];
"1 _param_constant1" -> "5 conv2d" [label="(3,)", style=solid];
"2 compressed_weight_updated_constant0" -> "3 mul_tensor" [label="(3, 3, 1, 1)", style=solid];
"3 mul_tensor" -> "4 sub_tensor" [label="(3, 3, 1, 1)", style=solid];
"4 sub_tensor" -> "5 conv2d" [label="(3, 3, 1, 1)", style=solid];
"5 conv2d" -> "8 conv2d_1" [label="(1, 3, 224, 224)", style=solid];
"5 conv2d" -> "10 add_" [label="(1, 3, 224, 224)", style=solid];
"6 _param_constant2" -> "8 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"7 _param_constant3" -> "8 conv2d_1" [label="(3,)", style=solid];
"8 conv2d_1" -> "12 add__1" [label="(1, 3, 224, 224)", style=solid];
"9 _tensor_constant0" -> "10 add_" [label="(1,)", style=solid];
"10 add_" -> "13 add" [label="(1, 3, 224, 224)", style=solid];
"11 _tensor_constant0_1" -> "12 add__1" [label="(1,)", style=solid];
"12 add__1" -> "13 add" [label="(1, 3, 224, 224)", style=solid];
"13 add" -> "16 conv2d_2" [label="(1, 3, 224, 224)", style=solid];
"14 _param_constant4" -> "16 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"15 _param_constant5" -> "16 conv2d_2" [label="(3,)", style=solid];
"16 conv2d_2" -> "18 add_1" [label="(1, 3, 224, 224)", style=solid];
"17 _tensor_constant0_2" -> "18 add_1" [label="(1,)", style=solid];
"18 add_1" -> "19 output" [label="(1, 3, 224, 224)", style=solid];
}
Loading
Loading