Skip to content

Commit

Permalink
hooks: add hook for onnxruntime
Browse files Browse the repository at this point in the history
Add hook for `onnxruntime` to ensure that provider plugins (located
in `onnxruntime/capi` directory) are collected.
  • Loading branch information
rokm committed Oct 15, 2024
1 parent f9f6e7e commit bf3d572
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
16 changes: 16 additions & 0 deletions _pyinstaller_hooks_contrib/stdhooks/hook-onnxruntime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# ------------------------------------------------------------------
# Copyright (c) 2024 PyInstaller Development Team.
#
# This file is distributed under the terms of the GNU General Public
# License (version 2.0 or later).
#
# The full license is available in LICENSE, distributed with
# this software.
#
# SPDX-License-Identifier: GPL-2.0-or-later
# ------------------------------------------------------------------

from PyInstaller.utils.hooks import collect_dynamic_libs

# Collect provider plugins from onnxruntime/capi.
binaries = collect_dynamic_libs("onnxruntime")
2 changes: 2 additions & 0 deletions news/817.new.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Add hook for ``onnxruntime`` to ensure that provider plugins are
collected.
100 changes: 100 additions & 0 deletions tests/test_deep_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,103 @@ def test_ultralytics_yolo(pyi_builder):
model = YOLO("yolov8n.pt") # Download and load pre-trained model
results = model("https://ultralytics.com/images/bus.jpg")
""")


# Basic inference test with ONNX Runtime (as well as model export with ONNX + Torch).
@importorskip('onnxruntime')
@importorskip('onnx')
@importorskip('torch')
def test_onnxruntime_gpu_inference(pyi_builder, tmpdir):
model_file = tmpdir / "test-model.onnx"

# Build first application: model creation + export (Torch + ONNX)
pyi_builder.test_source("""
import sys
import torch
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <model-filename>", file=sys.stderr)
sys.exit(1)
model_filename = sys.argv[1]
# Create model that performs addition of given inputs
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, y):
return x.add(y)
model = Model()
sample_x = torch.ones(3, dtype=torch.float32)
sample_y = torch.zeros(3, dtype=torch.float32)
# Export model to ONNX graph format
torch.onnx.export(
model,
(sample_x, sample_y),
model_filename,
input_names=["x", "y"],
output_names=["z"],
dynamic_axes={
"x": {0: "array_length_x"},
"y": {0: "array_length_y"},
},
)
""", app_name="model_builder", app_args=[str(model_file)])

assert model_file.isfile(), "Model file was not created!"

# Build second application: inference (ONNX Runtime) on CPU (and optionally with CUDA on Linux).
pyi_builder.test_source("""
import sys
import numpy as np
import onnxruntime
# torch is used for CUDA check; on linux, collecting PyPI-installed torch also ensures that
try:
import torch
cuda_available = torch.cuda.is_available()
except ImportError:
cuda_available = false
if len(sys.argv) != 2:
print(f"Usage: {sys.argv[0]} <model-filename>", file=sys.stderr)
sys.exit(1)
model_filename = sys.argv[1]
test_providers = ['CPUExecutionProvider']
if cuda_available:
test_providers += ['CUDAExecutionProvider']
for test_provider in test_providers:
print(f"Running test with provider: {test_provider}", file=sys.stderr)
# Load model into ONNX Runtime session
session = onnxruntime.InferenceSession(
model_filename,
providers=[test_provider],
)
# Check if the requested provider appears in list returned by session.get_providers(); if requested provider
# failed to initialize, onnxruntime seems to fall back to CPUExecutionProvider.
session_providers = session.get_providers()
print(f"session.get_providers(): {session_providers}", file=sys.stderr)
if test_provider not in session_providers:
raise RuntimeError(f"Provider {test_provider} is missing!")
# Run the model
x = np.float32([1.0, 2.0, 3.0])
y = np.float32([4.0, 5.0, 6.0])
print(f"x = {x}", file=sys.stderr)
print(f"y = {y}", file=sys.stderr)
z = session.run(["z"], {"x": x, "y": y})
z = z[0]
print(f"z = {z}", file=sys.stderr)
assert (z == x + y).all()
""", app_name="model_test", app_args=[str(model_file)])

0 comments on commit bf3d572

Please sign in to comment.