From 0c7a8d5efbc2a65e6d3576851d41a07a591fd937 Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Fri, 17 Nov 2023 13:33:58 +0100 Subject: [PATCH] Update Torch to ONNX export in conformance (#2269) ### Changes Do constant folding while exporting to ONNX from Torch ### Reason for changes Conformance test regressgion of ONNX after updating torch to 2.1 Model graphs are updated and contain BatchNorm. Therefore bias locates no more as Conv attribute but in BatchNorm layer. It leads to not applying FBC and BC algorithms to these biases. ### Related tickets 125203 ### Tests N/A --- .../post_training/pipelines/image_classification_timm.py | 9 +-------- tests/post_training/reference_data.yaml | 4 ++-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/post_training/pipelines/image_classification_timm.py b/tests/post_training/pipelines/image_classification_timm.py index 00171d701b7..bd02f6635d7 100644 --- a/tests/post_training/pipelines/image_classification_timm.py +++ b/tests/post_training/pipelines/image_classification_timm.py @@ -54,14 +54,7 @@ def prepare_model(self) -> None: if self.backend == BackendType.ONNX: onnx_path = self.output_model_dir / "model_fp32.onnx" - torch.onnx.export( - timm_model, - self.dummy_tensor, - onnx_path, - export_params=True, - opset_version=13, - do_constant_folding=False, - ) + torch.onnx.export(timm_model, self.dummy_tensor, onnx_path, export_params=True, opset_version=13) self.model = onnx.load(onnx_path) self.input_name = self.model.graph.input[0].name diff --git a/tests/post_training/reference_data.yaml b/tests/post_training/reference_data.yaml index 9309b93b7be..1ef8b61cc0b 100644 --- a/tests/post_training/reference_data.yaml +++ b/tests/post_training/reference_data.yaml @@ -175,7 +175,7 @@ timm/levit_128_backend_TORCH: metric_value: 0.73346 metric_value_fp32: 0.7405 timm/levit_128_backend_ONNX: - metric_value: 0.73184 + metric_value: 0.73286 metric_value_fp32: 0.7405 timm/levit_128_backend_OV: metric_value: 0.7334 @@ -321,7 +321,7 @@ timm/visformer_small_backend_TORCH: metric_value: 0.77728 metric_value_fp32: 0.77902 timm/visformer_small_backend_ONNX: - metric_value: 0.77432 + metric_value: 0.77678 metric_value_fp32: 0.77902 timm/visformer_small_backend_OV: metric_value: 0.77686