Skip to content

Commit

Permalink
[WIP][TorchFX] Migrate from capture_pre_autograd_graph to torch.export
Browse files Browse the repository at this point in the history
Metrics update
  • Loading branch information
daniil-lyakhov committed Sep 27, 2024
1 parent b4135c8 commit 8de4526
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 13 deletions.
2 changes: 1 addition & 1 deletion nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_edge_params(
output_port_id = 0
tensor_shape = None
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
tensor_shape = tuple(get_tensor_constant_from_node(source_node, model).shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
Expand Down
2 changes: 1 addition & 1 deletion tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ torchvision/swin_v2_s_backend_FP32:
torchvision/swin_v2_s_backend_OV:
metric_value: 0.83638
torchvision/swin_v2_s_backend_FX_TORCH:
metric_value: 0.8296
metric_value: 0.8360
timm/crossvit_9_240_backend_CUDA_TORCH:
metric_value: 0.689
timm/crossvit_9_240_backend_FP32:
Expand Down
41 changes: 30 additions & 11 deletions tests/post_training/pipelines/image_classification_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Callable, Tuple

import numpy as np
import onnx
import openvino as ov
Expand All @@ -22,29 +25,45 @@
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase


def _capture_pre_autograd_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
return capture_pre_autograd_graph(model, args)


def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
return torch.export.export(model, args).module()


@dataclass
class VisionModelParams:
weights: models.WeightsEnum
export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule]


class ImageClassificationTorchvision(ImageClassificationBase):
"""Pipeline for Image Classification model from torchvision repository"""

models_vs_imagenet_weights = {
models.resnet18: models.ResNet18_Weights.DEFAULT,
models.mobilenet_v3_small: models.MobileNet_V3_Small_Weights.DEFAULT,
models.vit_b_16: models.ViT_B_16_Weights.DEFAULT,
models.swin_v2_s: models.Swin_V2_S_Weights.DEFAULT,
models_vs_model_params = {
models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _capture_pre_autograd_module),
models.mobilenet_v3_small: VisionModelParams(
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
),
models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module),
models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module),
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_weights: models.WeightsEnum = None
self.model_params: VisionModelParams
self.input_name: str = None

def prepare_model(self) -> None:
model_cls = models.__dict__.get(self.model_id)
self.model_weights = self.models_vs_imagenet_weights[model_cls]
model = model_cls(weights=self.model_weights)
self.model_params = self.models_vs_model_params[model_cls]
model = model_cls(weights=self.model_params.weights)
model.eval()

default_input_size = [self.batch_size, 3, 224, 224]
self.dummy_tensor = self.model_weights.transforms()(torch.rand(default_input_size))
self.dummy_tensor = self.model_params.weights.transforms()(torch.rand(default_input_size))
self.static_input_size = list(self.dummy_tensor.shape)

self.input_size = self.static_input_size.copy()
Expand All @@ -54,7 +73,7 @@ def prepare_model(self) -> None:
if self.backend == BackendType.FX_TORCH:
with torch.no_grad():
with disable_patching():
self.model = capture_pre_autograd_graph(model, (self.dummy_tensor,))
self.model = self.model_params.export_fn(model, (self.dummy_tensor,))

elif self.backend in PT_BACKENDS:
self.model = model
Expand Down Expand Up @@ -103,7 +122,7 @@ def _dump_model_fp32(self) -> None:
ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml")

def prepare_preprocessor(self) -> None:
self.transform = self.model_weights.transforms()
self.transform = self.model_params.weights.transforms()

def get_transform_calibration_fn(self):
if self.backend in [BackendType.FX_TORCH] + PT_BACKENDS:
Expand Down

0 comments on commit 8de4526

Please sign in to comment.