Skip to content

Commit

Permalink
Export and mock transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Oct 4, 2024
1 parent bca8501 commit 61e48e7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
5 changes: 5 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,11 @@ def save_compressed_model(self) -> None:
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend in OV_BACKENDS:
self.path_compressed_ir = self.output_model_dir / "model.xml"
from openvino._offline_transformations import (
apply_moc_transformations, # pylint: disable=import-error,no-name-in-module
)

apply_moc_transformations(self.compressed_model, cf=True)
ov.serialize(self.compressed_model, str(self.path_compressed_ir))

def get_num_compressed(self) -> None:
Expand Down
12 changes: 10 additions & 2 deletions tests/post_training/pipelines/image_classification_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch
class VisionModelParams:
weights: models.WeightsEnum
export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule]
export_torch_before_ov_convert: bool = False


class ImageClassificationTorchvision(ImageClassificationBase):
Expand All @@ -48,8 +49,12 @@ class ImageClassificationTorchvision(ImageClassificationBase):
models.mobilenet_v3_small: VisionModelParams(
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
),
models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module),
models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module),
models.vit_b_16: VisionModelParams(
models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
models.swin_v2_s: VisionModelParams(
models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -94,6 +99,9 @@ def prepare_model(self) -> None:

elif self.backend in [BackendType.OV, BackendType.FP32]:
with torch.no_grad():
if self.model_params.export_torch_before_ov_convert:
with disable_patching():
model = torch.export.export(model, (self.dummy_tensor,))
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

Expand Down

0 comments on commit 61e48e7

Please sign in to comment.