From d84ee36e983e92d6cc0d19abb646a16c7ff4f532 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 3 Oct 2023 19:42:42 +0200 Subject: [PATCH] test: fix compatibility with `onnxruntime` 0.16+ (#18692) --- requirements/pytorch/test.txt | 2 +- tests/tests_pytorch/models/test_onnx.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 09a11afba2cb0..8c2ade615915c 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -9,7 +9,7 @@ pytest-random-order ==1.1.0 cloudpickle >=1.3, <2.3.0 scikit-learn >0.22.1, <1.4.0 onnx >=0.14.0, <1.15.0 -onnxruntime >=0.15.0, <1.16.0 +onnxruntime >=0.15.0, <1.17.0 psutil <5.9.6 # for `DeviceStatsMonitor` pandas >1.0, <2.1.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index b1ca0b4fd1789..cc0776ddcc40d 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import operator import os from unittest.mock import patch @@ -21,6 +22,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel +from lightning_utilities import compare_version import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.runif import RunIf @@ -150,7 +152,8 @@ def test_if_inference_output_is_valid(tmpdir): file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, model.example_input_array, export_params=True) - ort_session = onnxruntime.InferenceSession(file_path) + ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {} + ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs) def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()