Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Cannot compile scripted model with torch-tensorrt #1536

Closed
Shahrullo opened this issue Dec 7, 2022 · 9 comments
Closed

🐛 [Bug] Cannot compile scripted model with torch-tensorrt #1536

Shahrullo opened this issue Dec 7, 2022 · 9 comments
Assignees
Labels

Comments

@Shahrullo
Copy link

Bug Description

I successfully scripted custom model but cannot compile with torch-tensorrt. The model takes two stereo images and outputs single estimated depth. Scripted model is working fine but cannot compile it. Samples from the Pytorch official are working and compiling fine but not mine. That's why there is no need to define environment.

trt_module = torch_tensorrt.compile(
            model,
            inputs = [
                torch_tensorrt.Input(shape=[1, 3, 192, 576], dtype=torch.float),
                torch_tensorrt.Input(shape=[1, 3, 192, 576], dtype=torch.float),
            ],
            min_block_size=1,
            enabled_precisions={torch.float32},
        )

It results the following error

WARNING: [Torch-TensorRT] - For input img0.1, found user specified input dtype as Float32. The compiler is going to use the user setting Float32
WARNING: [Torch-TensorRT] - For input img1.1, found user specified input dtype as Float32. The compiler is going to use the user setting Float32

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_1574378/1465320287.py in <module>
      7             ],
      8             min_block_size=1,
----> 9             enabled_precisions={torch.float32},
     10         )

~/anaconda3/envs/unimatch/lib/python3.7/site-packages/torch_tensorrt/_compile.py in compile(module, ir, inputs, enabled_precisions, **kwargs)
    124             ts_mod = torch.jit.script(module)
    125         return torch_tensorrt.ts.compile(
--> 126             ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
    127         )
    128     elif target_ir == _IRType.fx:

~/anaconda3/envs/unimatch/lib/python3.7/site-packages/torch_tensorrt/ts/_compiler.py in compile(module, inputs, input_signature, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)
    134     }
    135 
--> 136     compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
    137     compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
    138     return compiled_module

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:93] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 429 produced from %429 : Tensor[] = prim::ListConstruct() in lowering graph for mini graph input.

Additional context

For example, the following sample data is working with the scripted model:

ex1 = torch.rand([1, 3, 192, 576]).to('cuda')
ex2 = torch.rand([1, 3, 192, 576]).to('cuda')

model(ex1, ex2)
[tensor([[[35.2793, 36.9096, 32.8977,  ..., 28.7939, 27.1365, 29.9576],
          [28.3197, 32.1967, 33.1406,  ..., 30.8387, 26.2071, 26.8993],
          [33.6014, 34.1282, 27.7217,  ..., 27.7997, 28.2945, 33.2738],
          ...,
          [32.7583, 32.4200, 34.9346,  ..., 28.7729, 25.6432, 28.6702],
          [32.8773, 36.2377, 32.7644,  ..., 32.9409, 28.6909, 27.9045],
          [33.1214, 32.1883, 31.3315,  ..., 27.2826, 27.1451, 31.5295]]],
        device='cuda:0', grad_fn=<SqueezeBackward1>)]
@Shahrullo Shahrullo added the bug Something isn't working label Dec 7, 2022
@narendasan
Copy link
Collaborator

@Shahrullo Can you provide information on what your system config is and what version you are using?

@Shahrullo
Copy link
Author

@narendasan Here are the config info:

  • Torch-TensorRT Version: 1.2.0
  • PyTorch Version: 1.12.1
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): conda
  • Python version: 3.7.
  • CUDA version: 11.3

@bowang007
Copy link
Collaborator

@Shahrullo can you try latest version 1.3?
This error is same to this one: #922, which was fixed right after version 1.2.
Users get it resolved using 1.3: #922 (comment)

@Shahrullo
Copy link
Author

@bowang007 Thank you very much for the response. Indeed, I'll try it with 1.3

@Shahrullo
Copy link
Author

Dear @bowang007, I tried with the latest docker image 22.11-py3 link. Now it is giving me other errors:

WARNING: [Torch-TensorRT] - For input img0.2, found user specified input dtype as Float32. The compiler is going to use the user setting Float32
WARNING: [Torch-TensorRT] - For input img1.1, found user specified input dtype as Float32. The compiler is going to use the user setting Float32

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In [9], line 1
----> 1 torch_tensorrt.compile(model,
      2                        inputs = [
      3                 torch_tensorrt.Input(shape=[1, 3, 192, 576], dtype=torch.float),
      4                 torch_tensorrt.Input(shape=[1, 3, 192, 576], dtype=torch.float),
      5             ],)

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py:125, in compile(module, ir, inputs, enabled_precisions, **kwargs)
    120         logging.log(
    121             logging.Level.Info,
    122             "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
    123         )
    124         ts_mod = torch.jit.script(module)
--> 125     return torch_tensorrt.ts.compile(
    126         ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
    127     )
    128 elif target_ir == _IRType.fx:
    129     if (
    130         torch.float16 in enabled_precisions
    131         or torch_tensorrt.dtype.half in enabled_precisions
    132     ):

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py:136, in compile(module, inputs, input_signature, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules)
    110     raise ValueError(
    111         f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}"
    112     )
    114 spec = {
    115     "inputs": inputs,
    116     "input_signature": input_signature,
   (...)
    133     },
    134 }
--> 136 compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
    137 compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
    138 return compiled_module

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:156] Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled

Further debugging

When I run with

spec = {
    "forward": torch_tensorrt.ts.TensorRTCompileSpec(
            inputs= [torch_tensorrt.Input(shape=[1, 3, 192, 576]),
                       torch_tensorrt.Input(shape=[1, 3, 192, 576])],
            enabled_precisions= {torch.float, torch.half},
            refit= False,
            debug= False,
            device= {
                "device_type": torch_tensorrt.DeviceType.GPU,
                "gpu_id": 0,
                "dla_core": 0,
                "allow_gpu_fallback": True,
            },
            capability= torch_tensorrt.EngineCapability.default,
            num_avg_timing_iters= 1,
    )
}
trt_model = torch._C._jit_to_backend('tensorrt', model, spec)

It is giving unsupported operators error

ERROR: [Torch-TensorRT] - Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.
Unsupported operators listed below:
  - aten::update.str(Dict(str, t)(a!) self, Dict(str, t)(a!) to_add) -> ()
  - aten::len.Tensor(Tensor t) -> int
  - aten::div(Scalar a, Scalar b) -> float
  - aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor
  - %6181 : Dict(str, Tensor[]) = prim::DictConstruct(%48, %187)
  - %results_dict.1 : Dict(str, Tensor[]) = prim::DictConstruct()
  - aten::mul.left_t(t[] l, int n) -> t[]
  - prim::device(Tensor a) -> Device
  - aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor
  - aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
  - aten::_set_item.t(t[](a!) l, int idx, t(b -> *) el) -> t[](a!)
  - aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
  - aten::triu(Tensor self, int diagonal=0) -> Tensor
  - aten::im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
  - aten::Bool.Tensor(Tensor a) -> bool
  - aten::tensor.float(float t, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor
  - aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
  - aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
  - aten::meshgrid(Tensor[] tensors) -> Tensor[]
  - aten::cuda(Tensor(a) self) -> Tensor(b|a)
  - aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
You can either implement converters for these ops in your application or request implementation
.....
RuntimeError: [Error thrown at /opt/pytorch/torch_tensorrt/py/torch_tensorrt/csrc/tensorrt_backend.cpp:67] Expected core::CheckMethodOperatorSupport(mod, it->key().toStringRef()) to be true but got false
Method forwardcannot be compiled by Torch-TensorRT

Final tought:

I successfully compile with the Script model. Aren't the torch-tensorrt operators the same with torch.jit?

@bowang007
Copy link
Collaborator

Hey @Shahrullo , for the first bug,
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:156] Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled
As what is indicated in the error message, you should set truncate_long_and_double to true.

For the second bug, it shouldn't happen if you enable the operators fallback, please refer to this page:
https://github.com/pytorch/TensorRT/tree/master/core/partitioning

@Shahrullo
Copy link
Author

Shahrullo commented Dec 15, 2022

@bowang007, I set truncate_long_and_double to true, now it shows RuntimeError for interpolate method:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/TORCHSCRIPT1/unimatch.py", line 75, in forward
      feature1 = feature1_list[scale_idx]
      if torch.gt(scale_idx, 0):
        _19 = _4(flow0, None, 2., "bilinear", True, None, False, )
              ~~ <--- HERE
        flow2 = torch.mul(_19, 2)
        flow3 = torch.detach(flow2)
  File "code/__torch__/torch/nn/functional/___torch_mangle_26.py", line 284, in interpolate
                          _95 = _96
                        else:
                          _97 = torch.upsample_bilinear2d(input, output_size3, align_corners6, scale_factors3)
                                ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                          _95 = _97
                        _93 = _95

Traceback of TorchScript, original code (most recent call last):
  File "path/to/model/unimatch.py", line 104, in forward
    
            if scale_idx > 0:
                flow = F.interpolate(flow, scale_factor=2.0, mode='bilinear', align_corners=True) * 2
                       ~~~~~~~~~~~~~ <--- HERE
    
            # if flow is not None:
  File "/home/user/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 3938, in interpolate
        if antialias:
            return torch._C._nn._upsample_bilinear2d_aa(input, output_size, align_corners, scale_factors)
        return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors)
               ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    if input.dim() == 5 and mode == "trilinear":
        assert align_corners is not None
RuntimeError: Expected static_cast<int64_t>(scale_factors->size()) == spatial_dimensions to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

@bowang007
Copy link
Collaborator

bowang007 commented Dec 15, 2022

Hey @Shahrullo , can you print logs and post it here ? something like:
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Debug)
This could give us more detailed message about where it fails.
From the error message you provided, I'm not sure whether it comes from Torchscript or TorchTensorRT.
If possible, do you have a small reproducer? So I can help test your model locally.
Thanks.

@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants