Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] An wrong function call for torch._C._jit_to_tensorrt on spec in tutorial code. #286

Closed
lazykyama opened this issue Jan 15, 2021 · 3 comments · Fixed by #288
Closed
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working

Comments

@lazykyama
Copy link

Bug Description

This is a bug in a tutorial below.

To compile a model, torch._C._jit_to_tensorrt() is called with two parameters, script_model._c and spec.
But, a following error happened.

Traceback (most recent call last):
  File "simple.py", line 34, in <module>
    main()
  File "simple.py", line 27, in main
    trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<string>", line 6, in __setstate__
                self.__processed_module = state[1]
                self.__create_backend()
                self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
                                 ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

RuntimeError: isGenericDict() INTERNAL ASSERT FAILED at "/opt/python/cp36-cp36m/lib/python3.6/site-packages/torch/include/ATen/core/ivalue_inl.h":830, please report a bug to PyTorch. Expected GenericDict but got Object

This error looks like caused by wrong parameter. In test code (https://github.com/NVIDIA/TRTorch/blob/b93627ecc7bb7123d2e32aff9d762dd0e6bc3166/tests/py/test_to_backend_api.py), this API is called with a different parameter like below. It looks like making additional dict with forward key is necessary.

trt_model = torch._C._jit_to_tensorrt(script_model._c, {'forward': spec})

To Reproduce

Steps to reproduce the behavior:

  1. Launch NGC TRT container, nvcr.io/nvidia/tensorrt:20.03-py3.
  2. Install required libraries
    • pip install https://github.com/NVIDIA/TRTorch/releases/download/v0.1.0/trtorch-0.1.0-cp36-cp36m-linux_x86_64.whl torchvision==0.7.0
  3. Run each step described in the tutorial.

My repro code is below. Note that I modified the tutorial code to use FP32 for test.

import torch
import trtorch
import torchvision.models as models

def main():
    model = models.mobilenet_v2(pretrained=True)
    script_model = torch.jit.script(model)

    spec = {
        "forward": trtorch.TensorRTCompileSpec({
            "input_shapes": [[1, 3, 300, 300]],
            "op_precision": torch.float32,
            "refit": False,
            "debug": False,
            "strict_types": False,
            "allow_gpu_fallback": True,
            "device_type": "gpu",
            "capability": trtorch.EngineCapability.default,
            "num_min_timing_iters": 2,
            "num_avg_timing_iters": 1,
            "max_batch_size": 0,
        })
    }
    # trt_model = torch._C._jit_to_tensorrt(script_model._c, {'forward': spec})
    trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)

    x = torch.randn((1, 3, 300, 300)).to('cuda').to(torch.float32)
    print(trt_model.forward(x))


if __name__ == '__main__':
    main()

Expected behavior

No exception.

Environment

Build information about the TRTorch compiler can be found by turning on debug messages

  • PyTorch Version (e.g., 1.0): 1.6.0
  • CPU Architecture: Intel(R) Core(TM) i7-6850K
  • OS (e.g., Linux): Ubuntu 18.04.5 LTS (host) / Ubuntu 18.04.4 LTS (container)
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): N/A
  • Are you using local sources or building from archives: N/A
  • Python version: 3.6.9
  • CUDA version: 10.2.89
  • GPU models and configuration: TitanX (Pascal)
  • Any other relevant information: N/A

Additional context

N/A

@lazykyama lazykyama added the bug Something isn't working label Jan 15, 2021
@narendasan
Copy link
Collaborator

It seems like to me there might be some confusing variable names. spec in the test is of type trtorch.TensorRTCompileSpec, torch._C._jit_to_tensorrt takes a torch._C.ScriptModule and a Dictionary { str : torch.TensorRTCompileSpec }. So in the test we need to make a dictionary {'forward': spec}. Hence

trt_model = torch._C._jit_to_tensorrt(script_model._c, {'forward': spec})

spec in the docs unless I am misreading it is itself a Dictionary { str : trtorch.TensorRTCompileSpec } so we should be able to call it with

trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)

as shown in the documentation (https://nvidia.github.io/TRTorch/tutorials/use_from_pytorch.html).

It would probably be a good idea to change the name of one of these to avoid confusion.

@lazykyama
Copy link
Author

Let me add one comment about spec in the test code.

It is defined like below therefore current torch._C._jit_to_tensorrt seems to take a Dictionary { str : { str : torch.TensorRTCompileSpec }} instead of simple { str : torch.TensorRTCompileSpec }.
https://github.com/NVIDIA/TRTorch/blob/b93627ecc7bb7123d2e32aff9d762dd0e6bc3166/tests/py/test_to_backend_api.py#L14-L29

Is that expected?

narendasan added a commit that referenced this issue Jan 22, 2021
to_backend

Also fixes nested dictionary bug reported in #286

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
@narendasan narendasan added the bug: triaged [verified] We can replicate the bug label Jan 22, 2021
@narendasan
Copy link
Collaborator

Yeah there was an issue with an expected nested dict. It should be fixed now in master. Note the API in PyTorch has changed in 1.7 to trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working
Projects
None yet
2 participants