Skip to content

Commit

Permalink
FP32 torchvision models support
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jul 26, 2024
1 parent 93422f9 commit 5daacf3
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions tests/post_training/pipelines/image_classification_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class ImageClassificationTorchvision(PTQTestPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_weights: models.WeightsEnum = None
self.input_name: str = None

def prepare_model(self) -> None:
if self.backend != BackendType.TORCH_FX:
if self.backend not in [BackendType.FP32, BackendType.TORCH_FX]:
raise RuntimeError("Torchvision classification models supports only torch fx quantization.")

model_cls = models.__dict__.get(self.model_id)
Expand All @@ -59,6 +60,11 @@ def prepare_model(self) -> None:
with disable_patching():
self.model = capture_pre_autograd_graph(model, (torch.ones(self.input_size),))

elif self.backend == BackendType.FP32:
with torch.no_grad():
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

self._dump_model_fp32()

def _dump_model_fp32(self) -> None:
Expand All @@ -68,15 +74,25 @@ def _dump_model_fp32(self) -> None:
ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size)
ov.serialize(ov_model, self.fp32_model_dir / "fx_model_fp32.xml")

if self.backend == BackendType.FP32:
ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml")

def prepare_preprocessor(self) -> None:
self.transform = self.model_weights.transforms()

def get_transform_calibration_fn(self):
device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu")
if self.backend == BackendType.TORCH_FX:
device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu")

def transform_fn(data_item):
images, _ = data_item
return images.to(device)

else:

def transform_fn(data_item):
images, _ = data_item
return images.to(device)
def transform_fn(data_item):
images, _ = data_item
return {self.input_name: np.array(images, dtype=np.float32)}

return transform_fn

Expand Down

0 comments on commit 5daacf3

Please sign in to comment.