Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exporting SwinUNETR to ONNX does not work #5125

Closed
nicktasios opened this issue Sep 12, 2022 · 12 comments
Closed

Exporting SwinUNETR to ONNX does not work #5125

nicktasios opened this issue Sep 12, 2022 · 12 comments

Comments

@nicktasios
Copy link

nicktasios commented Sep 12, 2022

Describe the bug
When trying to export the SwinUNETR model from MONAI, I get the error:

RuntimeError: Failed to export an ONNX attribute 'onnx::Gather', since it's not constant, please try to make things (e.g., kernel size) static if possible.

In a different issue, I read that this issue might get fixed by changing x_shape = x.size() to x_shape = [int(s) for s in x.size()] in the problematic code -- I found out that problem manifests at proj_out(). Doing this, though, results in a different error:

RuntimeError: Unsupported: ONNX export of instance_norm for unknown channel size.

Making this change in all places where I find x_shape = x.size() results in a floating point exception!

To Reproduce

Here is a minimal example demonstrating the issue:

from monai.networks.nets import SwinUNETR    
import torch                                 
                                             
if __name__ == '__main__':                   
                                             
    model = SwinUNETR(img_size=(96, 96, 96), 
                      in_channels=1,         
                      out_channels=5,        
                      feature_size=48,       
                      drop_rate=0.0,         
                      attn_drop_rate=0.0,    
                      dropout_path_rate=0.0, 
                      use_checkpoint=True,   
                      )                      
    inputs = [torch.randn([1,1,96,96,96])]
    input_names = ['input']                          
    output_names = ['output']                        
                                                     
    torch.onnx.export(                               
        model,                                       
        tuple(inputs), 'model.onnx',                 
        verbose=False,                               
        input_names=input_names,                     
        output_names=output_names,                   
        dynamic_axes=None,                           
        opset_version=11,                            
    )                                                

Environment

================================
MONAI version: 0.9.1
Numpy version: 1.23.2
Pytorch version: 1.12.1.post200
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 356d2d2
MONAI file: .../envs/temp_env/lib/python3.10/site-packages/monai/init.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: NOT INSTALLED or UNKNOWN VERSION.
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.4.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...

psutil required for print_system_info

================================
Printing GPU config...

Num GPUs: 1
Has CUDA: True
CUDA version: 11.2
cuDNN enabled: True
cuDNN version: 8401
Current device: 0
Library compiled for CUDA architectures: ['sm_35', 'sm_50', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'compute_86']
GPU 0 Name: Tesla T4
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 40
GPU 0 Total memory (GB): 14.6
GPU 0 CUDA capability (maj.min): 7.5

Additional context
I have also filed an issue with Pytorch as I'm not certain on which side should the bug be resolved.

@wyli
Copy link
Contributor

wyli commented Sep 13, 2022

Could you please elaborate the float exception? Indeed we have a recent patch #4913 not sure if it’s related.

@nicktasios
Copy link
Author

@wyli I tried the change in that patch and that resulted in the same problem. The problem exception happens in the following:

Thread 1 "python" received signal SIGFPE, Arithmetic exception.                                                                                                  
0x00002aaae8980949 in torch::jit::(anonymous namespace)::ComputeShapeFromReshape(torch::jit::Node*, c10::SymbolicShape const&, c10::SymbolicShape const&, int) ()
   from /home/014118_emtic_oncology/Pancreas/nick/envs/temp_env/lib/python3.10/site-packages/torch/lib/libtorch_python.so                                        

@wyli
Copy link
Contributor

wyli commented Oct 3, 2022

it seems by using x_shape = [int(s) for s in x.size()] as you mentioned in the pytorch thread partly addresses the issue.

it's possible to export with these parameter changes:

    model = SwinUNETR(img_size=(96, 96, 96), 
                      in_channels=1,         
                      out_channels=5,        
                      feature_size=48,
+                     norm_name=("instance", {"affine": True}),
                      drop_rate=0.0,         
                      attn_drop_rate=0.0,    
                      dropout_path_rate=0.0, 
-                     use_checkpoint=True,
+                     use_checkpoint=False,
                      )               

I don't think use_checkpoint=True gradient checkpointing could be easily supported here.

@wyli wyli closed this as completed Oct 3, 2022
@nicktasios
Copy link
Author

@wyli I tried the changes you suggested and indeed, the model was successfully exported to onnx. Unfortunately, during inference I get the following error:

[E:onnxruntime:ONNX_ENGINE, tensorrt_execution_provider.h:51 log] [2022-10-04 14:23:43   ERROR] 10: [optimizer.cpp::computeCosts::2011] Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[Reshape_1863 + Transpose_1864...Add_1962]}.)
Exception caught: TensorRT EP could not build engine for fused node: TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_18064542951413184515_8_1
terminate called after throwing an instance of 'Ort::Exception'
  what():  TensorRT EP could not build engine for fused node: TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_18064542951413184515_8_1

@tangy5
Copy link
Contributor

tangy5 commented Jan 6, 2023

Hi @wyli , if you happen to see this in this closed issue.
Now the Swin UNETR can safely convert to torchscript by using torch.onnx.export or torch.jit.trace.
But it cannot be converted by torch.jit.script since it contains lot of dynamic programming with python, including use_checkpoint=True gradient checkpointing which is not supported.

Do you have any idea on whether we need to make Swin UNETR model scriptable by "torch.jit.script" ? Since "torch.jit.trace" and "ONXX" has some limitations such as it do not support branching.

Thank you.

@wyli
Copy link
Contributor

wyli commented Jan 6, 2023

I looked into that some time ago, and torch.jit.script support is not easy for this model, perhaps we don't spend more time on this for now. If it's really needed we may have to write a less flexible version of the model and hard-code some hyperparameters.

@tangy5
Copy link
Contributor

tangy5 commented Jan 6, 2023

I looked into that some time ago, and torch.jit.script support is not easy for this model, perhaps we don't spend more time on this for now. If it's really needed we may have to write a less flexible version of the model and hard-code some hyperparameters.

Thanks, I agree, I was trying to rewrite this model these days to support torch.jit.script, but stuck with the use_checkpoint option. Maybe we remain it traceable torch.jit.trace for now. If we need to support specific TensorRT in the future, we can write a another version of this model.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jan 6, 2023

Hi @wyli @tangy5 ,

I think TorchScript and TensorRT support are "must-have" for the next versions release? Let's discuss it later.

Thanks.

@tangy5
Copy link
Contributor

tangy5 commented Jan 6, 2023

Hi @wyli @tangy5 ,

I think TorchScript and TensorRT support are "must-have" for the next versions release? Let's discuss it later.

Thanks.

Sure, the point is whether we need torch.jit.script for torchScript model, or torch.jit.trace and ONXX are good enough. If the torch.jit.script is a must have, I suggest we can write a light version of this model (e.g., swinunetr_lt), and remove dynamic programmed sections. But we might have other options...Let's discuss. Thank you.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jan 6, 2023

I think the TorchScript -> TensorRT is recommended way now, instead of the previous ONNX -> TensorRT.
And here is the example:
https://github.com/pytorch/TensorRT/blob/main/README.md?plain=1#L88
I think we may need to connect TensorRT team for more details later, CC @deepib .

Thanks.

@wyli
Copy link
Contributor

wyli commented Jan 6, 2023

also, if it's for inference only, the option of using gradient checkpointing is not needed

@csheaff
Copy link

csheaff commented Apr 6, 2023

Hello @tangy5 @wyli and others, I have tried and failed to use the above mentioned workarounds to convert SwinUNETR to TensorRT format. I'm using monai version 1.2.0rc3. I am attempting two routes to TensorRT: (1) through torch_tensorrt.ts.compile after torch.jit.trace and (2) using torch.onnx.export. Both result in errors shown below. It seems that torch.jit.trace does not successfully trace the graph, but perhaps I'm doing something wrong here. Any help would be much appreciated. Note I am working with a model saved using nn.parallel.DataParallel, hence the logic. Also, I am not using a model trained with norm_name=("instance", {"affine": True}), so it doesn't look like I can include that statement and load it.

model = SwinUNETR(
      img_size=(128, 128, 128),
      in_channels=1,
      out_channels=n_classes,
      feature_size=n_features,
      use_checkpoint=False,
  )
  load_w_data_parallel = False
  if load_w_data_parallel:
      model = torch.nn.parallel.DataParallel(model)
      model.load_state_dict(torch.load(model_path))
  else:
      # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686
      state_dict = torch.load(model_path)
      new_state_dict = OrderedDict()
      for k, v in state_dict.items():
          name = k.replace("module.", "")
          new_state_dict[name] = v
      model.load_state_dict(new_state_dict)

  traced_model = torch.jit.trace(model, torch.rand(1, 1, 128, 128, 128))

  trt_ts_model = torch_tensorrt.ts.compile(traced_model, inputs=torch.rand(1, 1, 128, 128, 128), enabled_precisions={torch.float, torch.half})

  input_names = ['input']
  output_names = ['output']
  torch.onnx.export(
      model,
      torch.rand(1, 1, 128, 128, 128),
      'model.onnx',
      verbose=False,
      input_names=input_names,
      output_names=output_names,
      dynamic_axes=None,
      opset_version=11,
  )

The torch_tensorrt.ts.compile commmand produces the following error:

*** RuntimeError: 0 INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/torch/csrc/jit/ir/alias_analysis.cpp":615, please report a bug to PyTorch. We don't have an op for aten::constant_pad_nd but it isn't a special case. Argument types: Tensor, int[], NoneType,

Candidates:
aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)

The torch.onnx.export command yeilds another error:

*** torch.onnx.errors.SymbolicValueError: Failed to export a node '%6407 : Long(device=cpu) = onnx::Gather[axis=0](%6404, %6406), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT # /opt/monai/monai/networks/nets/swin_unetr.py:1024:0
' (in list node %6444 : int[] = prim::ListConstruct(%6407), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT
) because it is not constant. Please try to make things (e.g. kernel sizes) static if possible. [Caused by the value '6444 defined in (%6444 : int[] = prim::ListConstruct(%6407), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT
)' (type 'List[int]') in the TorchScript graph. The containing node has kind 'prim::ListConstruct'.]

Inputs:
    #0: 6407 defined in (%6407 : Long(device=cpu) = onnx::Gather[axis=0](%6404, %6406), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT # /opt/monai/monai/networks/nets/swin_unetr.py:1024:0
)  (type 'Tensor')
Outputs:
    #0: 6444 defined in (%6444 : int[] = prim::ListConstruct(%6407), scope: monai.networks.nets.swin_unetr.SwinUNETR::/monai.networks.nets.swin_unetr.SwinTransformer::swinViT
)  (type 'List[int]'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants