Skip to content

Commit

Permalink
WIP OV and Torch support
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jul 29, 2024
1 parent 237771d commit 9c8daec
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
1 change: 1 addition & 0 deletions nncf/torch/quantization/default_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
operator_metatypes.PTMaxMetatype,
operator_metatypes.PTMinMetatype,
operator_metatypes.PTTransposeMetatype,
operator_metatypes.PTGatherMetatype,
operator_metatypes.PTScatterMetatype,
operator_metatypes.PTReshapeMetatype,
operator_metatypes.PTSqueezeMetatype,
Expand Down
4 changes: 4 additions & 0 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ torchvision/vit_b_16_backend_FP32:
metric_value: 0.81072
torchvision/vit_b_16_backend_TORCH_FX:
metric_value: 0.79432
torchvision/vit_b_16_backend_OV:
metric_value: 0.79432
torchvision/vit_b_16_backend_TORCH:
metric_value: 0.79432
timm/crossvit_9_240_backend_CUDA_TORCH:
metric_value: 0.689
timm/crossvit_9_240_backend_FP32:
Expand Down
2 changes: 1 addition & 1 deletion tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"model_type": ModelType.TRANSFORMER,
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=-1.0),
},
"backends": [BackendType.TORCH_FX],
"backends": [BackendType.TORCH_FX, BackendType.OV, BackendType.TORCH],
"batch_size": 1,
},
# Timm models
Expand Down
31 changes: 27 additions & 4 deletions tests/post_training/pipelines/image_classification_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs):
self.input_name: str = None

def prepare_model(self) -> None:
if self.backend not in [BackendType.FP32, BackendType.TORCH_FX]:
if self.backend not in [BackendType.FP32, BackendType.TORCH_FX, BackendType.OV, BackendType.TORCH]:
raise RuntimeError("Torchvision classification models supports only torch fx quantization.")

model_cls = models.__dict__.get(self.model_id)
Expand All @@ -60,7 +60,20 @@ def prepare_model(self) -> None:
with disable_patching():
self.model = capture_pre_autograd_graph(model, (torch.ones(self.input_size),))

elif self.backend == BackendType.FP32:
elif self.backend == BackendType.TORCH:
self.model = model

elif self.backend in [BackendType.OV]:
with torch.no_grad():
with disable_patching():
self.model = capture_pre_autograd_graph(model, (torch.ones(self.input_size),))
exported_model = torch.export.export(model, (self.dummy_tensor,))
# exported_model = torch.export.export(self.model, (self.dummy_tensor,))
# exported_model = self.model
self.model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

elif self.backend in [BackendType.FP32]:
with torch.no_grad():
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]
Expand All @@ -69,19 +82,29 @@ def prepare_model(self) -> None:

def _dump_model_fp32(self) -> None:
"""Dump IRs of fp32 models, to help debugging."""
if self.backend == BackendType.TORCH:
breakpoint()
with disable_patching():
ov_model = ov.convert_model(
torch.export.export(self.model, args=(self.dummy_tensor,)),
example_input=self.dummy_tensor,
input=self.input_size,
)
ov.serialize(ov_model, self.fp32_model_dir / "model_fp32.xml")

if self.backend == BackendType.TORCH_FX:
exported_model = torch.export.export(self.model, (self.dummy_tensor,))
ov_model = ov.convert_model(exported_model, example_input=self.dummy_tensor, input=self.input_size)
ov.serialize(ov_model, self.fp32_model_dir / "fx_model_fp32.xml")

if self.backend == BackendType.FP32:
if self.backend in [BackendType.FP32, BackendType.OV]:
ov.serialize(self.model, self.fp32_model_dir / "model_fp32.xml")

def prepare_preprocessor(self) -> None:
self.transform = self.model_weights.transforms()

def get_transform_calibration_fn(self):
if self.backend == BackendType.TORCH_FX:
if self.backend in [BackendType.TORCH_FX, BackendType.TORCH]:
device = torch.device("cuda" if self.backend == BackendType.CUDA_TORCH else "cpu")

def transform_fn(data_item):
Expand Down

0 comments on commit 9c8daec

Please sign in to comment.