Skip to content

Commit

Permalink
onnx export to support older pytorch with example_outputs argument (P…
Browse files Browse the repository at this point in the history
…roject-MONAI#6309)

Fixes # .
Project-MONAI#6297

### Description

PyTorch onnx exporter API has been changed since 1.10 to remove
example_outputs as a required input argument. Special handling is added
in this PR to support PyTorch version older than 1.10. Unit test is also
extended to covert this case.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu authored Apr 6, 2023
1 parent 1832f95 commit 06defb7
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 23 deletions.
32 changes: 24 additions & 8 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,10 @@ def convert_to_onnx(
inputs: input sample data used by pytorch.onnx.export. It is also used in ONNX model verification.
input_names: optional input names of the ONNX model.
output_names: optional output names of the ONNX model.
opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and <= 16, for more
details: https://github.com/onnx/onnx/blob/main/docs/Operators.md.
opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and not exceed
the latest opset version supported by PyTorch, for more details:
https://github.com/onnx/onnx/blob/main/docs/Operators.md and
https://github.com/pytorch/pytorch/blob/master/torch/onnx/_constants.py
dynamic_axes: specifies axes of tensors as dynamic (i.e. known only at run-time). If set to None,
the exported model will have the shapes of all input and output tensors set to match given
ones, for more details: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export.
Expand All @@ -603,31 +605,45 @@ def convert_to_onnx(
"""
model.eval()
with torch.no_grad():
torch_versioned_kwargs = {}
if use_trace:
script_module = torch.jit.trace(model, example_inputs=inputs)
# let torch.onnx.export to trace the model.
mode_to_export = model
else:
script_module = torch.jit.script(model, **kwargs)
if not pytorch_after(1, 10):
if "example_outputs" not in kwargs:
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
raise TypeError(
"example_outputs is required in scripting mode before PyTorch 1.10."
"Please provide example outputs or use trace mode to export onnx model."
)
torch_versioned_kwargs["example_outputs"] = kwargs["example_outputs"]
del kwargs["example_outputs"]
mode_to_export = torch.jit.script(model, **kwargs)

if filename is None:
f = io.BytesIO()
torch.onnx.export(
script_module,
inputs,
mode_to_export,
tuple(inputs),
f=f,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
**torch_versioned_kwargs,
)
onnx_model = onnx.load_model_from_string(f.getvalue())
else:
torch.onnx.export(
script_module,
inputs,
mode_to_export,
tuple(inputs),
f=filename,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
**torch_versioned_kwargs,
)
onnx_model = onnx.load(filename)

Expand Down
48 changes: 33 additions & 15 deletions tests/test_convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,60 @@

from monai.networks import convert_to_onnx
from monai.networks.nets import SegResNet, UNet
from monai.utils.module import pytorch_after
from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule, optional_import

if torch.cuda.is_available():
TORCH_DEVICE_OPTIONS = ["cpu", "cuda"]
else:
TORCH_DEVICE_OPTIONS = ["cpu"]
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False]))
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))

onnx, _ = optional_import("onnx")


@SkipIfNoModule("onnx")
@SkipIfBeforePyTorchVersion((1, 10))
@SkipIfBeforePyTorchVersion((1, 9))
class TestConvertToOnnx(unittest.TestCase):
@parameterized.expand(TESTS)
def test_unet(self, device, use_ort):
def test_unet(self, device, use_trace, use_ort):
if use_ort:
_, has_onnxruntime = optional_import("onnxruntime")
if not has_onnxruntime:
self.skipTest("onnxruntime is not installed probably due to python version >= 3.11.")
model = UNet(
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
)
onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
verify=True,
device=device,
use_ort=use_ort,
use_trace=True,
rtol=1e-3,
atol=1e-4,
)
if pytorch_after(1, 10) or use_trace:
onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
verify=True,
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
)
else:
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
# example_outputs is required in scripting mode before PyTorch 3.10
onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
example_outputs=[torch.randn((16, 3, 32, 32), requires_grad=False)],
verify=True,
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

@parameterized.expand(TESTS_ORT)
Expand Down

0 comments on commit 06defb7

Please sign in to comment.