-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exporting SwinUNETR to ONNX does not work #5125
Comments
Could you please elaborate the float exception? Indeed we have a recent patch #4913 not sure if it’s related. |
@wyli I tried the change in that patch and that resulted in the same problem. The problem exception happens in the following:
|
it seems by using it's possible to export with these parameter changes: model = SwinUNETR(img_size=(96, 96, 96),
in_channels=1,
out_channels=5,
feature_size=48,
+ norm_name=("instance", {"affine": True}),
drop_rate=0.0,
attn_drop_rate=0.0,
dropout_path_rate=0.0,
- use_checkpoint=True,
+ use_checkpoint=False,
) I don't think |
@wyli I tried the changes you suggested and indeed, the model was successfully exported to onnx. Unfortunately, during inference I get the following error:
|
Hi @wyli , if you happen to see this in this closed issue. Do you have any idea on whether we need to make Swin UNETR model scriptable by Thank you. |
I looked into that some time ago, and |
Thanks, I agree, I was trying to rewrite this model these days to support torch.jit.script, but stuck with the |
Sure, the point is whether we need |
I think the Thanks. |
also, if it's for inference only, the option of using gradient checkpointing is not needed |
Hello @tangy5 @wyli and others, I have tried and failed to use the above mentioned workarounds to convert SwinUNETR to TensorRT format. I'm using monai version 1.2.0rc3. I am attempting two routes to TensorRT: (1) through
The
The
|
Describe the bug
When trying to export the SwinUNETR model from MONAI, I get the error:
In a different issue, I read that this issue might get fixed by changing
x_shape = x.size()
tox_shape = [int(s) for s in x.size()]
in the problematic code -- I found out that problem manifests atproj_out()
. Doing this, though, results in a different error:Making this change in all places where I find
x_shape = x.size()
results in a floating point exception!To Reproduce
Here is a minimal example demonstrating the issue:
Environment
================================
MONAI version: 0.9.1
Numpy version: 1.23.2
Pytorch version: 1.12.1.post200
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 356d2d2
MONAI file: .../envs/temp_env/lib/python3.10/site-packages/monai/init.py
Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: NOT INSTALLED or UNKNOWN VERSION.
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.4.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
================================
Printing system config...
psutil
required forprint_system_info
================================
Printing GPU config...
Num GPUs: 1
Has CUDA: True
CUDA version: 11.2
cuDNN enabled: True
cuDNN version: 8401
Current device: 0
Library compiled for CUDA architectures: ['sm_35', 'sm_50', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'compute_86']
GPU 0 Name: Tesla T4
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 40
GPU 0 Total memory (GB): 14.6
GPU 0 CUDA capability (maj.min): 7.5
Additional context
I have also filed an issue with Pytorch as I'm not certain on which side should the bug be resolved.
The text was updated successfully, but these errors were encountered: