Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Export for Model Conversion and Saving #7934

Merged
merged 23 commits into from
Aug 23, 2024
Merged
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a769c84
Modify _export with Saver and onnx_export
Han123su Jul 20, 2024
c70f24e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2024
0c984d9
./runtests.sh --autofix
Han123su Jul 20, 2024
995aea0
./runtests.sh --autofix
Han123su Jul 20, 2024
b62c4d5
modify
Han123su Jul 20, 2024
707c498
modify
Han123su Jul 20, 2024
bf5b04d
./runtests.sh --autofix
Han123su Jul 20, 2024
0e15c00
Merge remote-tracking branch 'origin/Fix-issue-6375' into Fix-issue-6375
Han123su Jul 20, 2024
5798d72
modify save_onnx
Han123su Jul 20, 2024
cd5c958
Merge branch 'dev' into Fix-issue-6375
ericspod Jul 23, 2024
d94e1ff
Merge branch 'dev' into Fix-issue-6375
Han123su Aug 2, 2024
84355a2
saver include argument meta_values
Han123su Aug 7, 2024
50bda8c
Merge remote-tracking branch 'upstream/dev' into Fix-issue-6375
Han123su Aug 7, 2024
cd72d2a
Merge branch 'Fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Aug 11, 2024
e84dbb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2024
0959b08
./runtests.sh --autofix
Han123su Aug 11, 2024
77f2fd3
Merge branch 'Fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Aug 11, 2024
4708007
Merge branch 'dev' into Fix-issue-6375
Han123su Aug 11, 2024
379728b
restore parser args
Han123su Aug 11, 2024
af4eb0f
Merge branch 'Fix-issue-6375' of https://github.com/Han123su/MONAI in…
Han123su Aug 11, 2024
1fceaf7
Merge branch 'dev' into Fix-issue-6375
Han123su Aug 20, 2024
8d8d250
Merge branch 'dev' into Fix-issue-6375
Han123su Aug 22, 2024
089a27d
Merge branch 'dev' into Fix-issue-6375
KumoLiu Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import warnings
import zipfile
from collections.abc import Mapping, Sequence
from functools import partial
from pathlib import Path
from pydoc import locate
from shutil import copyfile
from textwrap import dedent
from typing import Any, Callable
from typing import IO, Any, Callable

import torch
from torch.cuda import is_available
Expand Down Expand Up @@ -1159,6 +1160,7 @@ def verify_net_in_out(

def _export(
converter: Callable,
saver: Callable,
parser: ConfigParser,
net_id: str,
filepath: str,
Expand All @@ -1173,6 +1175,8 @@ def _export(
Args:
converter: a callable object that takes a torch.nn.module and kwargs as input and
converts the module to another type.
saver: a callable object that takes the converted model and a filepath as input and
Han123su marked this conversation as resolved.
Show resolved Hide resolved
saves the model to the specified location.
parser: a ConfigParser of the bundle to be converted.
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
Expand Down Expand Up @@ -1212,14 +1216,8 @@ def _export(
# add .json extension to all extra files which are always encoded as JSON
extra_files = {k + ".json": v for k, v in extra_files.items()}
Han123su marked this conversation as resolved.
Show resolved Hide resolved

save_net_with_metadata(
jit_obj=net,
filename_prefix_or_stream=filepath,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
more_extra_files=extra_files,
)
saver(net, filepath, more_extra_files=extra_files)
Han123su marked this conversation as resolved.
Show resolved Hide resolved

logger.info(f"exported to file: {filepath}.")


Expand Down Expand Up @@ -1318,17 +1316,23 @@ def onnx_export(
input_shape_ = _get_fake_input_shape(parser=parser)

inputs_ = [torch.rand(input_shape_)]
net = parser.get_parsed_content(net_id_)
if has_ignite:
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
else:
ckpt = torch.load(ckpt_file_)
copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])

converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
onnx_model = convert_to_onnx(model=net, **converter_kwargs_)
onnx.save(onnx_model, filepath_)

def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str | IO[Any], **kwargs: Any) -> None:
Han123su marked this conversation as resolved.
Show resolved Hide resolved
onnx.save(onnx_obj, filename_prefix_or_stream)

_export(
convert_to_onnx,
save_onnx,
parser,
net_id=net_id_,
filepath=filepath_,
ckpt_file=ckpt_file_,
config_file=config_file_,
key_in_ckpt=key_in_ckpt_,
**converter_kwargs_,
)


def ckpt_export(
Expand Down Expand Up @@ -1449,8 +1453,17 @@ def ckpt_export(

converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
# Use the given converter to convert a model and save with metadata, config content

save_ts = partial(
save_net_with_metadata,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
)

_export(
convert_to_torchscript,
save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
Expand Down Expand Up @@ -1620,8 +1633,16 @@ def trt_export(
}
converter_kwargs_.update(trt_api_parameters)

save_ts = partial(
save_net_with_metadata,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
)

_export(
convert_to_trt,
save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
Expand Down
Loading