diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py index 4583558e14c..064b226fef4 100644 --- a/tests/post_training/pipelines/image_classification_torchvision.py +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -41,9 +41,10 @@ class ImageClassificationTorchvision(PTQTestPipeline): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model_weights: models.WeightsEnum = None + self.input_name: str = None def prepare_model(self) -> None: - if self.backend != BackendType.TORCH_FX: + if self.backend not in [BackendType.FP32, BackendType.TORCH_FX]: raise RuntimeError("Torchvision classification models supports only torch fx quantization.") model_cls = models.__dict__.get(self.model_id) @@ -59,6 +60,11 @@ def prepare_model(self) -> None: with disable_patching(): self.model = capture_pre_autograd_graph(model, (torch.ones(self.input_size),)) + elif self.backend == BackendType.FP32: + with torch.no_grad(): + self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size) + self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0] + self._dump_model_fp32() def _dump_model_fp32(self) -> None: @@ -68,15 +74,25 @@ def _dump_model_fp32(self) -> None: ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size) ov.serialize(ov_model, self.fp32_model_dir / "fx_model_fp32.xml") + if self.backend == BackendType.FP32: + ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml") + def prepare_preprocessor(self) -> None: self.transform = self.model_weights.transforms() def get_transform_calibration_fn(self): - device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu") + if self.backend == BackendType.TORCH_FX: + device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu") + + def transform_fn(data_item): + images, _ = data_item + return images.to(device) + + else: - def transform_fn(data_item): - images, _ = data_item - return images.to(device) + def transform_fn(data_item): + images, _ = data_item + return {self.input_name: np.array(images, dtype=np.float32)} return transform_fn