From 8de452659bd31d05a2d783819d2ea3153dec3461 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 25 Sep 2024 16:13:07 +0200 Subject: [PATCH] [WIP][TorchFX] Migrate from capture_pre_autograd_graph to torch.export Metrics update --- .../torch/fx/nncf_graph_builder.py | 2 +- .../data/ptq_reference_data.yaml | 2 +- .../image_classification_torchvision.py | 41 ++++++++++++++----- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index 946ac27ce84..1a91acde109 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -167,7 +167,7 @@ def get_edge_params( output_port_id = 0 tensor_shape = None if source_node.op in ("get_attr",): - tensor_shape = tuple(getattr(model, source_node.target).shape) + tensor_shape = tuple(get_tensor_constant_from_node(source_node, model).shape) elif "val" in source_node.meta: if source_nncf_node.metatype is om.PTBatchNormMetatype: tensor = source_node.meta["val"][0] diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index b23f446c7ac..d6e8428bc6a 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -57,7 +57,7 @@ torchvision/swin_v2_s_backend_FP32: torchvision/swin_v2_s_backend_OV: metric_value: 0.83638 torchvision/swin_v2_s_backend_FX_TORCH: - metric_value: 0.8296 + metric_value: 0.8360 timm/crossvit_9_240_backend_CUDA_TORCH: metric_value: 0.689 timm/crossvit_9_240_backend_FP32: diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py index c42aa9ab1bb..4fae4787b90 100644 --- a/tests/post_training/pipelines/image_classification_torchvision.py +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -9,6 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass +from typing import Any, Callable, Tuple + import numpy as np import onnx import openvino as ov @@ -22,29 +25,45 @@ from tests.post_training.pipelines.image_classification_base import ImageClassificationBase +def _capture_pre_autograd_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule: + return capture_pre_autograd_graph(model, args) + + +def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule: + return torch.export.export(model, args).module() + + +@dataclass +class VisionModelParams: + weights: models.WeightsEnum + export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule] + + class ImageClassificationTorchvision(ImageClassificationBase): """Pipeline for Image Classification model from torchvision repository""" - models_vs_imagenet_weights = { - models.resnet18: models.ResNet18_Weights.DEFAULT, - models.mobilenet_v3_small: models.MobileNet_V3_Small_Weights.DEFAULT, - models.vit_b_16: models.ViT_B_16_Weights.DEFAULT, - models.swin_v2_s: models.Swin_V2_S_Weights.DEFAULT, + models_vs_model_params = { + models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _capture_pre_autograd_module), + models.mobilenet_v3_small: VisionModelParams( + models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module + ), + models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module), + models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module), } def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.model_weights: models.WeightsEnum = None + self.model_params: VisionModelParams self.input_name: str = None def prepare_model(self) -> None: model_cls = models.__dict__.get(self.model_id) - self.model_weights = self.models_vs_imagenet_weights[model_cls] - model = model_cls(weights=self.model_weights) + self.model_params = self.models_vs_model_params[model_cls] + model = model_cls(weights=self.model_params.weights) model.eval() default_input_size = [self.batch_size, 3, 224, 224] - self.dummy_tensor = self.model_weights.transforms()(torch.rand(default_input_size)) + self.dummy_tensor = self.model_params.weights.transforms()(torch.rand(default_input_size)) self.static_input_size = list(self.dummy_tensor.shape) self.input_size = self.static_input_size.copy() @@ -54,7 +73,7 @@ def prepare_model(self) -> None: if self.backend == BackendType.FX_TORCH: with torch.no_grad(): with disable_patching(): - self.model = capture_pre_autograd_graph(model, (self.dummy_tensor,)) + self.model = self.model_params.export_fn(model, (self.dummy_tensor,)) elif self.backend in PT_BACKENDS: self.model = model @@ -103,7 +122,7 @@ def _dump_model_fp32(self) -> None: ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml") def prepare_preprocessor(self) -> None: - self.transform = self.model_weights.transforms() + self.transform = self.model_params.weights.transforms() def get_transform_calibration_fn(self): if self.backend in [BackendType.FX_TORCH] + PT_BACKENDS: