Skip to content

Commit

Permalink
test: fix compatibility with onnxruntime 0.16+ (#18692)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 3, 2023
1 parent b69f3c6 commit d84ee36
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion requirements/pytorch/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pytest-random-order ==1.1.0
cloudpickle >=1.3, <2.3.0
scikit-learn >0.22.1, <1.4.0
onnx >=0.14.0, <1.15.0
onnxruntime >=0.15.0, <1.16.0
onnxruntime >=0.15.0, <1.17.0
psutil <5.9.6 # for `DeviceStatsMonitor`
pandas >1.0, <2.1.0 # needed in benchmarks
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App
Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/models/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
import os
from unittest.mock import patch

Expand All @@ -21,6 +22,7 @@
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning_utilities import compare_version

import tests_pytorch.helpers.pipelines as tpipes
from tests_pytorch.helpers.runif import RunIf
Expand Down Expand Up @@ -150,7 +152,8 @@ def test_if_inference_output_is_valid(tmpdir):
file_path = os.path.join(tmpdir, "model.onnx")
model.to_onnx(file_path, model.example_input_array, export_params=True)

ort_session = onnxruntime.InferenceSession(file_path)
ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {}
ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs)

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
Expand Down

0 comments on commit d84ee36

Please sign in to comment.