Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Compilation causes error: RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input. #922

Closed
chaoz-dev opened this issue Mar 9, 2022 · 48 comments · Fixed by #1024
Assignees
Labels
bug Something isn't working component: partitioning

Comments

@chaoz-dev
Copy link
Contributor

Bug Description

Compiling the graph throws the following error:

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.

Looking at the output torchscript graph, %47 is defined in a prior node, however, it does not appear to be visible in the current node.

To Reproduce

  import torch    
  import torch.nn as nn    
  import torch.nn.functional as F    
  import torch_tensorrt as torchtrt    
  import torch_tensorrt.logging as logging    
      
  logging.set_reportable_log_level(logging.Level.Graph)    
      
  torch.manual_seed(0)    
      
  DEVICE = torch.device("cuda:0")    
  SHAPE = (1, 1)    
      
      
  class Model(torch.nn.Module):    
      def __init__(self):    
          super().__init__()    
          self.lin = nn.Linear(1, 1)    
      
      def forward(self, a):    
          out = self.lin(a)    
      
          tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)    
          indices = torch.tril_indices(1, 1)    
          tril[:, indices[0], indices[1]] = out    
      
          return tril    
      
      
  if __name__ == "__main__":    
      tensor = torch.randn(SHAPE, dtype=torch.float32, device=DEVICE)    
      
      model = Model().eval().to(DEVICE)    
      out = model(tensor)    
      print(f"Model: {out}")    
      
      model_trt = torchtrt.compile(    
          model,    
          inputs=[    
              torchtrt.Input(shape=SHAPE),    
          ],    
          enabled_precisions={torch.float},    
          truncate_long_and_double=True    
      )    
      out_trt = model(tensor)    
      print(f"Model TRT: {out_trt}")    
      
      assert torch.max(torch.abs(out - out_trt)) < 1e-6    

Throws the following error:

Traceback (most recent call last):
  File "/scripts/tril.py", line 39, in <module>
    model_trt = torchtrt.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 97, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 119, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:66] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 47 produced from %47 : int = prim::dtype(%52) in lowering graph for mini graph input.

Expected behavior

Compilation should not fail, and should produce the following output when run:

Model: tensor([[[0.5434]]], device='cuda:0', grad_fn=<CopySlices>)

Environment

Ubuntu 18.04 x86-64, run with NGC 21.11-py3 and 22.02-py3.

Additional context

See output.txt for full torchscript graph output.

@chaoz-dev chaoz-dev added the bug Something isn't working label Mar 9, 2022
@chaoz-dev
Copy link
Contributor Author

chaoz-dev commented Mar 9, 2022

This looks similar to issue #756, with fix #757.

Looking at the sources, it looks like this fix may not have made it into /release/ngc/22.02 but should be present in /release/ngc/22.03 and afterwards. At the time of this writing, only 22.02-py3 is available. I'll close this pending testing and availability of 22.03-py3.

@chaoz-dev
Copy link
Contributor Author

Just tried fix #757 with master commit 11bcb98d3cd680c3c34e6cc4c4efdc7512c144cc built in NGC container nvcr.io/nvidia/tensorrt:22.02-py3 and PyTorch 1.10, and the error persists, so it's likely that #757 does not address this issue.

@chaoz-dev
Copy link
Contributor Author

Looks like the issue is in this line:

          tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)    

as changing this line to:

          tril = torch.zeros(1, 1, 1).cuda()

appears to bypass the issue.

@narendasan
Copy link
Collaborator

@peri044 Can you take a look at this, looks related to your past work on dtype

@khazamaa
Copy link

@chaoz-dev I had the same error

[Error thrown at core/partitioning/shape_analysis.cpp:67] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 12065 produced from %12065 : int = aten::size(%out.29, %11826), scope: __module.stylegan_decoder/__module.stylegan_decoder.style_conv1/__module.stylegan_decoder.style_conv1.modulated_conv

and your method of bypassing is not applicable in my case.
Could you suggest something else.

@chaoz-dev
Copy link
Contributor Author

You're probably facing a different but possibly related issue. Can you file a new bug report with the above information?

@bowang007
Copy link
Collaborator

bowang007 commented Apr 29, 2022

I took a look into this issue. This is caused by resolveNonTensorInput function.
What happens here is that when you are using:

tril = torch.zeros(1, 1, 1, device=a.device, dtype=out.dtype)    

It will introduce a NonTensorInput for a.device for the minigraph, this will induce ResolveNonTensorInput function to segment this subgraph again. This explains why it's fine when you change it to

tril = torch.zeros(1, 1, 1).cuda()

Let me see if we can refactor this function since it's doing a mess here.

@bowang007
Copy link
Collaborator

bowang007 commented May 4, 2022

@chaoz-dev I raised a PR for this bug just now here #1024.
The model you provided should be supported now.
Please take a look and ping me if there is any other issues.

@BrettRyland
Copy link

Sorry to re-raise this issue, but I'm still getting the same runtime error for deformable convolutions on the latest build of master (commit 91a92ca), which includes PR #1024.

Expanding upon the original reproduction code above trt_bug.py, I'm getting

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 62 produced from %62 : int = aten::__getitem__(%61, %54) # /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/deform_conv.py:71:28 in lowering graph for mini graph input.

ENV info:

Collecting environment information...
PyTorch version: 1.11.0a0+gitbc2c6ed
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04 LTS (x86_64)
GCC version: (Ubuntu 11.2.0-19ubuntu1) 11.2.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.4 (main, Apr  2 2022, 09:04:19) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-33-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.64
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Ti
Nvidia driver version: 515.43.04
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.21.5
[pip3] torch==1.11.0
[pip3] torch-tensorrt==1.2.0a0+91a92ca4
[pip3] torchvision==0.12.0
[conda] Could not collect

@Hodapp87
Copy link

Hodapp87 commented Jun 2, 2022

I am still getting this runtime error on another model; I don't believe I'm using deformable convolution here. I'm in the process of trying to clean up source code to show, but the error is something like:

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1442 produced from %x.2 : Tensor, %1442 : Float(32, 12, 3, 3, ...

with the ... replaced by about 25K characters of a dump of the graph.

Note also that this error did not occur with float16, only with int8.

This is with Torch-TensorRT v1.1.0, so PR #1024 is included.

@bowang007 bowang007 reopened this Jun 2, 2022
@bowang007
Copy link
Collaborator

bowang007 commented Jun 3, 2022

@BrettRyland Could you please try this PR: #1067
We refactored this part recently, and I tried your model with this PR, didn't find the issue you had.

@bowang007
Copy link
Collaborator

bowang007 commented Jun 3, 2022

I am still getting this runtime error on another model; I don't believe I'm using deformable convolution here. I'm in the process of trying to clean up source code to show, but the error is something like:

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1442 produced from %x.2 : Tensor, %1442 : Float(32, 12, 3, 3, ...

with the ... replaced by about 25K characters of a dump of the graph.

Note also that this error did not occur with float16, only with int8.

This is with Torch-TensorRT v1.1.0, so PR #1024 is included.

@Hodapp87 Could you please try this #1067 as well? Or could you please provide a reproducer if you still hit this issue?

@BrettRyland
Copy link

@BrettRyland Could you please try this PR: #1067 We refactored this part recently, and I tried your model with this PR, didn't find the issue you had.

I still get this error

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 62 produced from %62 : int = aten::__getitem__(%61, %54) # /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/deform_conv.py:71:28 in lowering graph for mini graph input.

(full log: trt_bug_log.txt) with PR #1067 merged in:

brett@br-workhorse:~/github/TensorRT/py$ git log --oneline --graph 
* 22d91f5e (HEAD -> master) fix: fix the bug that tag Constant node as fallback node
*   ccb826e7 Merge remote-tracking branch 'origin' into refactor_segmentation
|\  
| * 91a92ca4 (origin/master, origin/HEAD) docs: [Automated] Regenerating documenation for dcf3386
brett@br-workhorse:/tmp$ python3 -c 'import torch_tensorrt; print(torch_tensorrt.__version__)'
1.2.0a0+22d91f5e

Side note: the trt_bug.py script has a typo on line 93, it should be out_trt2 = model2(tensor2), not out_trt2 = model2(tensor), but I guess you saw that if you got it to run without issues.

Another side note: I don't think it's relevant to this issue, but I get the warning

WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1

despite not having cublas 11.8.0 on my system

brett@br-workhorse:/storage/github/TensorRT$ sudo updatedb && locate -i libcublas
/home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so
/home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64_Sstubs/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11.10.1.25
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11.10.1.25
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt_static.a
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas_static.a
/usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublasLt.so
/usr/share/doc/libcublas-11-7
/usr/share/doc/libcublas-dev-11-7
/usr/share/doc/libcublas-11-7/changelog.Debian.gz
/usr/share/doc/libcublas-11-7/copyright
/usr/share/doc/libcublas-dev-11-7/changelog.Debian.gz
/usr/share/doc/libcublas-dev-11-7/copyright
/var/lib/dpkg/info/libcublas-11-7.list
/var/lib/dpkg/info/libcublas-11-7.md5sums
/var/lib/dpkg/info/libcublas-dev-11-7.list
/var/lib/dpkg/info/libcublas-dev-11-7.md5sums
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so
lrwxrwxrwx 1 brett brett 103 May 25 16:55 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so -> /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so
lrwxrwxrwx 1 root root 15 Apr  6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so -> libcublas.so.11
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11
lrwxrwxrwx 1 root root 23 Apr  6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11 -> libcublas.so.11.10.1.25
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25 
-rw-r--r-- 1 root root 156720544 Apr  6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25

@Hodapp87
Copy link

Hodapp87 commented Jun 3, 2022

@Hodapp87 Could you please try this #1067 as well? Or could you please provide a reproducer if you still hit this issue?

It still occurs for me too. I am trying to provide code that can reproduce, but much of this is proprietary in nature and so it may take some time to disentangle it.

@bowang007
Copy link
Collaborator

bowang007 commented Jun 3, 2022

@BrettRyland Could you please try this PR: #1067 We refactored this part recently, and I tried your model with this PR, didn't find the issue you had.

I still get this error

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 62 produced from %62 : int = aten::__getitem__(%61, %54) # /home/brett/.local/lib/python3.10/site-packages/torchvision/ops/deform_conv.py:71:28 in lowering graph for mini graph input.

(full log: trt_bug_log.txt) with PR #1067 merged in:

brett@br-workhorse:~/github/TensorRT/py$ git log --oneline --graph 
* 22d91f5e (HEAD -> master) fix: fix the bug that tag Constant node as fallback node
*   ccb826e7 Merge remote-tracking branch 'origin' into refactor_segmentation
|\  
| * 91a92ca4 (origin/master, origin/HEAD) docs: [Automated] Regenerating documenation for dcf3386
brett@br-workhorse:/tmp$ python3 -c 'import torch_tensorrt; print(torch_tensorrt.__version__)'
1.2.0a0+22d91f5e

Side note: the trt_bug.py script has a typo on line 93, it should be out_trt2 = model2(tensor2), not out_trt2 = model2(tensor), but I guess you saw that if you got it to run without issues.

Another side note: I don't think it's relevant to this issue, but I get the warning

WARNING: [Torch-TensorRT TorchScript Conversion Context] - TensorRT was linked against cuBLAS/cuBLAS LT 11.8.0 but loaded cuBLAS/cuBLAS LT 111.0.1

despite not having cublas 11.8.0 on my system

brett@br-workhorse:/storage/github/TensorRT$ sudo updatedb && locate -i libcublas
/home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so
/home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64_Sstubs/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas.so.11.10.1.25
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt.so.11.10.1.25
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublasLt_static.a
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcublas_static.a
/usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublas.so
/usr/local/cuda-11.7/targets/x86_64-linux/lib/stubs/libcublasLt.so
/usr/share/doc/libcublas-11-7
/usr/share/doc/libcublas-dev-11-7
/usr/share/doc/libcublas-11-7/changelog.Debian.gz
/usr/share/doc/libcublas-11-7/copyright
/usr/share/doc/libcublas-dev-11-7/changelog.Debian.gz
/usr/share/doc/libcublas-dev-11-7/copyright
/var/lib/dpkg/info/libcublas-11-7.list
/var/lib/dpkg/info/libcublas-11-7.md5sums
/var/lib/dpkg/info/libcublas-dev-11-7.list
/var/lib/dpkg/info/libcublas-dev-11-7.md5sums
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so
lrwxrwxrwx 1 brett brett 103 May 25 16:55 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/execroot/Torch-TensorRT/bazel-out/k8-opt/bin/_solib_k8/_U@cuda_S_S_Ccublas___Ulib64/libcublas.so -> /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so
lrwxrwxrwx 1 root root 15 Apr  6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so -> libcublas.so.11
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11
lrwxrwxrwx 1 root root 23 Apr  6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11 -> libcublas.so.11.10.1.25
brett@br-workhorse:/storage/github/TensorRT$ ls -l /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25 
-rw-r--r-- 1 root root 156720544 Apr  6 04:07 /home/brett/.cache/bazel/_bazel_brett/f5a9c3674f5e1909bfdcbf1248e5df36/external/cuda/lib64/libcublas.so.11.10.1.25

@BrettRyland Did you clear your cache? I could get your model work after I used that PR.
You could do:
pip3 uninstall torch_tensorrt
to uninstall the previously installed torch_tensorrt. Do it multiple times to ensure that there isn't any library copies left.
Then:
python3 setup.py clean
python3 setup.py install
Could you please print some logs to make sure that the PR works? I also had some issues that I found in fact I wasn't using the merged code because I didn't clear the cache.
@Hodapp87 Can you try it as well?

@Hodapp87
Copy link

Hodapp87 commented Jun 6, 2022

For testing this, I used the Dockerfile straight out of the repository and then ran inside this container. Unless this caches something I'm unaware of, this should have been a clean build.

Here's the truncated output of my run, which shows versions as well (22d91f5e should be your PR's commit):

------------------------------------------------------------
torch version: 1.11.0+cu102
torch_tensorrt version: 1.2.0a0+22d91f5e
------------------------------------------------------------
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.2. Assuming it is Float32. If not, specify input type explicity
WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32
Traceback (most recent call last):
  File "./compile.py", line 119, in <module>
    compiled = torch_tensorrt.compile(mdl_ts, **compile_spec)
  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1935 produced from %x.2 : Tensor, %1935 : Float(32, 12, 3, 3, strides=[108, 9, 3, 1], ...

If I get a chance soon, I'll see if I can extract a simpler model out of this that I can send to try.

@BrettRyland
Copy link

@BrettRyland Did you clear your cache? I could get your model work after I used that PR. You could do: pip3 uninstall torch_tensorrt to uninstall the previously installed torch_tensorrt. Do it multiple times to ensure that there isn't any library copies left. Then: python3 setup.py clean python3 setup.py install Could you please print some logs to make sure that the PR works? I also had some issues that I found in fact I wasn't using the merged code because I didn't clear the cache. @Hodapp87 Can you try it as well?

Clearing the cache (I also ran bazel clean --expunge in the top-level directory and removed ~/.cache/bazel) allowed the test model to compile and run without problems, but my full model still gives the same RuntimeError, just in a different place.

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1162 produced from %x.6 : Tensor, %1162 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.

I'll need to try to isolate where in my full model it's happening now to see what's triggering it.

@BrettRyland
Copy link

OK, I've reduced my model to a smaller repro script trt_bug.py which still gives

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 61 produced from %x.1 : Tensor, %61 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.

It appears to be being caused by using a single-valued int64 index tensor in an aten::index_put_ operation:

scores[:, self.anchor_always_index, :] = self.false_anchor_score

where self.false_anchor_score is a registered buffer.
Replacing the index tensor with the value (using .item()) causes a different error:

RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:37] Expected isIValue() to be true but got false
Requested unwrapping of arg assuming it was an IValue, however arg type is nvinfer1::ITensor

which can also be avoided by using an explicit tensor instead of the self.false_anchor_score buffer.
Note that torch will happily script or trace the model with

scripted_model = torch.jit.script(model)

or

traced_model = torch.jit.trace(model, torch.rand((1, *model.input_size), device=DEVICE))

@bowang007
Copy link
Collaborator

bowang007 commented Jun 15, 2022

OK, I've reduced my model to a smaller repro script trt_bug.py which still gives

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 61 produced from %x.1 : Tensor, %61 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.

It appears to be being caused by using a single-valued int64 index tensor in an aten::index_put_ operation:

scores[:, self.anchor_always_index, :] = self.false_anchor_score

where self.false_anchor_score is a registered buffer. Replacing the index tensor with the value (using .item()) causes a different error:

RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:37] Expected isIValue() to be true but got false
Requested unwrapping of arg assuming it was an IValue, however arg type is nvinfer1::ITensor

which can also be avoided by using an explicit tensor instead of the self.false_anchor_score buffer. Note that torch will happily script or trace the model with

scripted_model = torch.jit.script(model)

or

traced_model = torch.jit.trace(model, torch.rand((1, *model.input_size), device=DEVICE))

Hi @BrettRyland I took a look into your model. This line

auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
introduces another input as %61 to the whole model's input, which might have something to do with the underlying function in PyTorch here https://github.com/pytorch/pytorch/blob/6114b0f921d5568c582d0168501f780df7a66d0d/torch/csrc/jit/passes/lower_graph.cpp#L149. This seems not an issue related to fallback, I'm now looking into it to figure out what happened.

@bowang007
Copy link
Collaborator

bowang007 commented Jun 28, 2022

For testing this, I used the Dockerfile straight out of the repository and then ran inside this container. Unless this caches something I'm unaware of, this should have been a clean build.

Here's the truncated output of my run, which shows versions as well (22d91f5e should be your PR's commit):

------------------------------------------------------------
torch version: 1.11.0+cu102
torch_tensorrt version: 1.2.0a0+22d91f5e
------------------------------------------------------------
WARNING: [Torch-TensorRT] - Cannot infer input type from calcuations in graph for input x.2. Assuming it is Float32. If not, specify input type explicity
WARNING: [Torch-TensorRT] - Input type for doing shape analysis could not be determined, defaulting to F32
Traceback (most recent call last):
  File "./compile.py", line 119, in <module>
    compiled = torch_tensorrt.compile(mdl_ts, **compile_spec)
  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 1935 produced from %x.2 : Tensor, %1935 : Float(32, 12, 3, 3, strides=[108, 9, 3, 1], ...

If I get a chance soon, I'll see if I can extract a simpler model out of this that I can send to try.

can I get more details? thanks @Hodapp87

@bowang007
Copy link
Collaborator

bowang007 commented Jun 30, 2022

OK, I've reduced my model to a smaller repro script trt_bug.py which still gives

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 61 produced from %x.1 : Tensor, %61 : Long(requires_grad=0, device=cpu) = prim::Param() in lowering graph for mini graph input.

It appears to be being caused by using a single-valued int64 index tensor in an aten::index_put_ operation:

scores[:, self.anchor_always_index, :] = self.false_anchor_score

where self.false_anchor_score is a registered buffer. Replacing the index tensor with the value (using .item()) causes a different error:

RuntimeError: [Error thrown at ./core/conversion/var/Var_inl.h:37] Expected isIValue() to be true but got false
Requested unwrapping of arg assuming it was an IValue, however arg type is nvinfer1::ITensor

which can also be avoided by using an explicit tensor instead of the self.false_anchor_score buffer. Note that torch will happily script or trace the model with

scripted_model = torch.jit.script(model)

or

traced_model = torch.jit.trace(model, torch.rand((1, *model.input_size), device=DEVICE))

hey @BrettRyland did you bypass the issue?
After detailed investigation, it seems that this error comes from pytorch, and we can bypass it by explicit tensors when building models., though we can also do some work like post-processing the graph produced by PyTorch to erase the introduced missing input.

@Belval
Copy link

Belval commented Jul 6, 2022

Still working on the repro, but I just build torch_tensorrt from source (using master) to see if it helped and it it did partially solve the issue above. Unfortunately, I get a different Expected ivalues_maps.count(input) to be true but got false error.

Here is the stacktrace for reference:

Traceback (most recent call last):
  File "scripts/evaluation/infer_tables_tensorrt.py", line 44, in <module>
    main()
  File "scripts/evaluation/infer_tables_tensorrt.py", line 22, in main
    torch_tensorrt.compile(model.backbone, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))])
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 111, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 17979 produced from %17979 : int[] = prim::ListConstruct(%batch.8, %17960) in lowering graph for mini graph input.

@bowang007
Copy link
Collaborator

bowang007 commented Jul 11, 2022

Still working on the repro, but I just build torch_tensorrt from source (using master) to see if it helped and it it did partially solve the issue above. Unfortunately, I get a different Expected ivalues_maps.count(input) to be true but got false error.

Here is the stacktrace for reference:

Traceback (most recent call last):
  File "scripts/evaluation/infer_tables_tensorrt.py", line 44, in <module>
    main()
  File "scripts/evaluation/infer_tables_tensorrt.py", line 22, in main
    torch_tensorrt.compile(model.backbone, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))])
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 111, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 17979 produced from %17979 : int[] = prim::ListConstruct(%batch.8, %17960) in lowering graph for mini graph input.

Hey @Belval could you please try either add this line:
"torch_executed_ops": ["prim::ListConstruct"]
or set:
"min_block_size": 1

Details about why this happens could be found here: #1173
Will raise a PR soon to cover these cases.

@Belval
Copy link

Belval commented Jul 12, 2022

torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))], torch_executed_ops=["prim::ListConstruct"]) returns a very similar stack trace as before:

RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got falseCould not find torch::jit::Value* 1968 produced from %x.1 : Tensor, %1939 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %1940 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %1941 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %1942 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1943 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1944 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1945 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1946 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1947 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1948 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1949 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1950 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1951 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1952 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1953 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1954 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1955 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1956 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1957 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1958 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1959 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1960 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1961 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1962 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1963 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1964 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %1965 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %1966 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %1967 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %1968 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0) = prim::Param() in lowering graph for mini graph input.

Interestingly, torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))], min_block_size=1) works to some extent, but instead I get not exception, the compilation exits with an error code (1) without printing anything:

  %15770 : int[] = prim::ListConstruct(%15715, %15769)
  %15771 : int = prim::min(%15770) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
   = prim::Loop(%15771, %self.bottom_up.stages_and_names.res2.2.conv2.use_bn.1) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
    block0(%1936 : int):
      %1938 : Tensor = aten::__getitem__(%1923, %1936) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
      %10295 : str = aten::__getitem__(%self._out_features.1, %1936) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
       = aten::_set_item(%1932, %10295, %1938) # /code/models/d2/detectron2/modeling/backbone/fpn.py:151:15
      -> (%self.bottom_up.stages_and_names.res2.2.conv2.use_bn.1)
  return (%1932)

:

Any ideas?

@bowang007
Copy link
Collaborator

@Belval I was trying to reproduce the error on ngc 22.06. However, I kept getting library loading errors if I want to reproduce your error by building torch_tensorrt from source. Did I miss anything?

@Belval
Copy link

Belval commented Jul 13, 2022

I am not sure that I understand your question. If you are referring to the repro package I sent you, then it could be the torch.ops.load_library in repro.py that does not have the correct path. libcustom_deform_conv.so contains the compiled TorchScript operators.

@bowang007
Copy link
Collaborator

@Belval may I ask how did you build torch-tensorrt from source on NGC 22.06? I tried to build the library from source but kept getting this error when i import torch_tensorrt:
ImportError: /workspace/TensorRT/py/torch_tensorrt/lib/libtorchtrt.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs
Did you reinstall some dependent libraries? I used the ngc-22.06 branch for build.
Thanks!

@Belval
Copy link

Belval commented Jul 14, 2022

From the error it could be the --use-cxx11-abi flag that is missing. I used this as reference: https://pytorch.org/TensorRT/tutorials/installation.html#installation see "Choosing the right ABI".

In a container based on nvcr.io/nvidia/pytorch:22.06-py3

  1. wget https://github.com/bazelbuild/bazelisk/releases/download/v1.12.0/bazelisk-linux-amd64
  2. chmod +x bazelisk-linux-amd64
  3. mv bazelisk-linux-amd64 bazel
  4. export PATH=$(pwd):$PATH
  5. cd /opt/tensorrt
  6. git clone https://github.com/pytorch/TensorRT.git
  7. vim WORKSPACE
  8. Comment out line 56 to 70 and 76 to 94 (The HTTP requests)
  9. Change the CUDA path at line 44 to just /usr/local/cuda since the simlinked version is the right one anyway.
  10. Uncomment lines 107 to 129 and update the paths.
  11. python3 setup.py bdist_wheel --use-cxx11-abi

I just tried it an I got an exception while building, unfortunately I built it in a running container that I stopped a few days ago so I can't sent you the exact config but I did something along the above. I don't remember installing additional dependencies.

@bowang007
Copy link
Collaborator

hey @Belval I was able to reproduce your error and the error msg is in fact shown before the graph.
However, one thing I'm curious about is that are there float input for your graph? because I see this graph right after lowering:

INFO: [Torch-TensorRT] - Lowered Graph: graph(%x.1 : Tensor,
      %2374 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2375 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2376 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2377 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2378 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2379 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2380 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2381 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2382 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2383 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2384 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2385 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2386 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2387 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2388 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2389 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2390 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2391 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2392 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2393 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2394 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2395 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2396 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2397 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2398 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2399 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2400 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2401 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2402 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0),
      %2403 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0)):

This explains why you are getting the missing value error when you use
torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))], torch_executed_ops=["prim::ListConstruct"])
These float values are referring to the input of this model.
I think there should be only 1 input for your model, if that's the case, then this kind of graph modification is introduced by this function from pytorch

auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
.

This looks like what happens here #922 (comment). In his model, this LowerGraph function also modified the graph and introduced other inputs, and it turns out that this happens because he used a Tensor as an Int for indexing.

Could you please check why this Float input are introduced in your model? I took a look and found all the introduced Float value as used here:

custom_deform_conv::modulated_deform_conv_forward(%2539, %2374, %2238, %2257, %offset.1, %mask0.1, %x0.6, %2258, %19915, %19917, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21, %self.bottom_up.stages_and_names.res2.0.radix.3, %self.bottom_up.stages_and_names.res3.0.conv2.stride.21)

so in fact all the input float values are used as the second input parameter for custom_deform_conv node. Could you please take a look if there is some kind of abnormal use around that part? I don't think this kind of error comes from torch_tensorrt, since this is an internal function from torch.

@bowang007
Copy link
Collaborator

hi @Belval any update?
I raised a Feature Request here: #1190
We will try to support this feature asap and will start from some simple cases like the one here: #922 (comment)
Since I don't have the detailed source code of your model, I'm not sure what happens to these custom_deform_conv nodes. But once this feature is supported on some simple models, I will run yours and try to support it asap.
Thanks!

@bowang007
Copy link
Collaborator

@Belval
I did some tests locally and this bug could be fixed by applying pytorch APIs correctly. As we can see, the Float values in the input has requires_grad=1, which means it's a training graph, that's the reason why LowerGraph() introduces many float type inputs. So could you please retrieve your model this way?

model = myModel().eval().to(DEVICE)
scripted_model = torch.jit.freeze(torch.jit.script(model))

then you can use torch_tensorrt.compile to compile the retrieved JIT graph.
I tested on his model #922 (comment) and it worked fine. Tried to do the same on yours but I don't have the original model source code so failed. Could you please take a try? thanks!

@ncomly-nvidia ncomly-nvidia added the release: v1.2 Tagged to be included in v1.2 label Jul 26, 2022
@Belval
Copy link

Belval commented Aug 16, 2022

Sorry for not getting back at you earlier.

I tried your suggestion but it does not seem to removes the requires_grad=1. Here's what I am doing:

backbone = model.backbone.eval().cuda()
backbone = torch.jit.freeze(torch.jit.script(backbone))

with torch_tensorrt.logging.debug():
    torch_tensorrt.compile(
        backbone,
        inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))],
        min_block_size=1
    )

I get the same error as described here. The model is in eval mode and backbone.training does report False.

@ncomly-nvidia
Copy link
Contributor

@bowang007 @narendasan ^

@bowang007
Copy link
Collaborator

bowang007 commented Aug 23, 2022

@Belval could you please try this fix? #1298
I tested locally and it works for the models that I have now.
I tried to test on your model but I only have the model file that's based on ngc 22.06, and I got some errors because of API changes when I want to build this new branch on ngc 22.06.

@BrettRyland
Copy link

BrettRyland commented Sep 1, 2022

FYI @bowang007, I had this bug show up again in my complicated proprietary model (not entirely sure why though). Applying #1298 on top of the current release/1.2 branch fixed it for me. I'm currently stuck on #1296 and to a lesser extent #1157, though.

@Belval
Copy link

Belval commented Sep 6, 2022

Still getting this error after checking out param_input

    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 2058 produced from %x.1 : Tensor, %2029 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2030 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2031 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2032 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2033 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2034 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2035 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2036 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2037 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2038 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2039 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2040 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2041 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2042 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2043 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2044 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2045 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2046 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2047 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2048 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2049 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2050 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2051 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2052 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2053 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2054 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2055 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %2056 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %2057 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0), %2058 : Float(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0) = prim::Param() in lowering graph for mini graph input.

Here is my Dockerfile:

FROM nvcr.io/nvidia/pytorch:22.08-py3

# Install system deps
RUN DEBIAN_FRONTEND=noninteractive apt-get update && DEBIAN_FRONTEND=noninteractive apt-get upgrade -y && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends tzdata && apt-get install sudo ffmpeg libsm6 libxext6  -y

# Install python deps, copy packages
## INSTALL DEPS ##

RUN apt-get remove ninja-build

# Rebuild TensorRT
RUN cd /opt/tensorrt && \
    wget https://github.com/bazelbuild/bazelisk/releases/download/v1.12.0/bazelisk-linux-amd64 && \
    chmod +x bazelisk-linux-amd64 && \
    mv bazelisk-linux-amd64 bazel && \
    git clone https://github.com/pytorch/TensorRT.git && \
    cd TensorRT && \
    git checkout param_input

ENV PATH=/opt/tensorrt:$PATH
# Modified WORKSPACE with correct paths
COPY ./WORKSPACE /opt/tensorrt/TensorRT
RUN cd /opt/tensorrt/TensorRT/py/ && python3 setup.py bdist_wheel --use-cxx11-abi

RUN cd /compiled_deps/extensions/modulated_deform_conv/ && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" .. && make
RUN cd /compiled_deps/extensions/ms_deform_attn/ && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')" .. && make

My script to reproduce the issue:

torch.ops.load_library("/compiled_deps/extensions/modulated_deform_conv/build/libcustom_deform_conv.so")
with torch.no_grad():
    with open("repro.ts", "rb") as f:
        backbone = torch.jit.load(f)

    with torch_tensorrt.logging.debug():
        torch_tensorrt.compile(
            backbone,
            inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))],
            torch_executed_ops=["prim::ListConstruct"],
            min_block_size=1
        )

I'll send the torchscripted backbone file so that you can try it on your side.

WORKSPACE file:

workspace(name = "Torch-TensorRT")

load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
    name = "rules_python",
    sha256 = "778197e26c5fbeb07ac2a2c5ae405b30f6cb7ad1f5510ea6fdac03bded96cc6f",
    url = "https://github.com/bazelbuild/rules_python/releases/download/0.2.0/rules_python-0.2.0.tar.gz",
)

load("@rules_python//python:pip.bzl", "pip_install")

http_archive(
    name = "rules_pkg",
    sha256 = "038f1caa773a7e35b3663865ffb003169c6a71dc995e39bf4815792f385d837d",
    urls = [
        "https://mirror.bazel.build/github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
        "https://github.com/bazelbuild/rules_pkg/releases/download/0.4.0/rules_pkg-0.4.0.tar.gz",
    ],
)

load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")

rules_pkg_dependencies()

git_repository(
    name = "googletest",
    commit = "703bd9caab50b139428cea1aaff9974ebee5742e",
    remote = "https://github.com/google/googletest",
    shallow_since = "1570114335 -0400",
)

# External dependency for torch_tensorrt if you already have precompiled binaries.
local_repository(
    name = "torch_tensorrt",
    path = "/opt/conda/lib/python3.8/site-packages/torch_tensorrt",
)

# CUDA should be installed on the system locally
new_local_repository(
    name = "cuda",
    build_file = "@//third_party/cuda:BUILD",
    path = "/usr/local/cuda/",
)

new_local_repository(
    name = "cublas",
    build_file = "@//third_party/cublas:BUILD",
    path = "/usr",
)
#############################################################################################################
# Tarballs and fetched dependencies (default - use in cases when building from precompiled bin and tarballs)
#############################################################################################################

#http_archive(
#    name = "libtorch",
#    build_file = "@//third_party/libtorch:BUILD",
#    sha256 = "80f089939de20e68e3fcad4dfa72a26c8bf91b5e77b11042f671f39ebac35865",
#    strip_prefix = "libtorch",
#    urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu113.zip"],
#)
#
#http_archive(
#    name = "libtorch_pre_cxx11_abi",
#    build_file = "@//third_party/libtorch:BUILD",
#    sha256 = "8e35371403f7052d9e9b43bcff383980dbde4df028986dc1dab539953481d55f",
#    strip_prefix = "libtorch",
#    urls = ["https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.12.0%2Bcu113.zip"],
#)

# Download these tarballs manually from the NVIDIA website
# Either place them in the distdir directory in third_party and use the --distdir flag
# or modify the urls to "file:///<PATH TO TARBALL>/<TARBALL NAME>.tar.gz

#http_archive(
#    name = "cudnn",
#    build_file = "@//third_party/cudnn/archive:BUILD",
#    sha256 = "7f3fbe6201708de409532a32d647af6b4bdb10d7f045d557270549e286487289",
#    strip_prefix = "cudnn-linux-x86_64-8.4.1.114_cuda11.4-archive",
#    urls = [
#        "https://developer.nvidia.com/compute/cudnn/secure/8.4.1/local_installers/11.6/cudnn-linux-x86_64-8.4.1.50_cuda11.6-archive.tar.xz",
#    ],
#)
#
#http_archive(
#    name = "tensorrt",
#    build_file = "@//third_party/tensorrt/archive:BUILD",
#    sha256 = "8107861af218694130f170e071f49814fa3e27f1386ce7cb6d807ac05a7fcf0e",
#    strip_prefix = "TensorRT-8.4.1.5",
#    urls = [
#        "https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.4.1/tars/tensorrt-8.4.1.5.linux.x86_64-gnu.cuda-11.6.cudnn8.4.tar.gz",
#    ],
#)

####################################################################################
# Locally installed dependencies (use in cases of custom dependencies or aarch64)
####################################################################################

# NOTE: In the case you are using just the pre-cxx11-abi path or just the cxx11 abi path
# with your local libtorch, just point deps at the same path to satisfy bazel.

# NOTE: NVIDIA's aarch64 PyTorch (python) wheel file uses the CXX11 ABI unlike PyTorch's standard
# x86_64 python distribution. If using NVIDIA's version just point to the root of the package
# for both versions here and do not use --config=pre-cxx11-abi

new_local_repository(
    name = "libtorch",
    path = "/opt/conda/lib/python3.8/site-packages/torch",
    build_file = "third_party/libtorch/BUILD"
)

new_local_repository(
    name = "libtorch_pre_cxx11_abi",
    path = "/opt/conda/lib/python3.8/site-packages/torch",
    build_file = "third_party/libtorch/BUILD"
)

new_local_repository(
    name = "cudnn",
    path = "/usr/",
    build_file = "@//third_party/cudnn/local:BUILD"
)

new_local_repository(
   name = "tensorrt",
   path = "/usr/",
   build_file = "@//third_party/tensorrt/local:BUILD"
)

#########################################################################
# Development Dependencies (optional - comment out on aarch64)
#########################################################################

pip_install(
    name = "devtools_deps",
    requirements = "//:requirements-dev.txt",
)

@bowang007
Copy link
Collaborator

bowang007 commented Sep 9, 2022

Hey @Belval I was able to reproduce the error finally.
Tested locally and the reason why we still have the error after using this PR #1298 is because the input Float is Tensor with gradient, then when we try to use this function https://github.com/pytorch/pytorch/blob/9b8e0a38a6674c78a8f2729ce15393859aa2bd3d/torch/csrc/jit/ir/constants.cpp#L62 to insert it as an constant it fails at this line https://github.com/pytorch/pytorch/blob/9b8e0a38a6674c78a8f2729ce15393859aa2bd3d/torch/csrc/jit/ir/constants.cpp#L15. This could be fixed by setting the gradient of these input Float to 0. This missing value error is fixed now. I will upload the fix later to that PR.
However, I'm now seeing other errors. I'm debugging other bugs to support this model. Thanks!

@bowang007
Copy link
Collaborator

Hey @Belval, I have your model supported using this branch:
https://github.com/pytorch/TensorRT/tree/amazon_model_support

Could you try it as well?
Btw, I'm getting Nan tensor output when I run your model using some random generated input through Torchscript or Torch-TensorRT converted model. Not sure if this is expected.

@Belval
Copy link

Belval commented Sep 13, 2022

Did you change anything else? Reusing the docker file I sent (but with amazon_model_support) I get what seems to be the same error.

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
[Error thrown at core/partitioning/shape_analysis.cpp:68] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 2058 produced from %x.1 : Tensor, %2029 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2030 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2031 : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0), %2032 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2033 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2034 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2035 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2036 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2037 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2038 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2039 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2040 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2041 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2042 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2043 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2044 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2045 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2046 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2047 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2048 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2049 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2050 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2051 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2052 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0), %2053 :

@bowang007
Copy link
Collaborator

No. Did you uninstall pre-installed torch_tensorrt?

pip3 uninstall torch_tensorrt
python3 setup.py clean
python3 setup.py install --use-cxx11-abi

@bowang007
Copy link
Collaborator

bowang007 commented Sep 14, 2022

@Belval I set up a clean container and tested again, here is the detailed steps:

# run a clean container
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:22.08-py3
# clone the repo
https://github.com/pytorch/TensorRT.git
# change the branch
git checkout amaon_model_support
# change the WORKSPACE file
Use the file you shared here: https://github.com/pytorch/TensorRT/issues/922#issuecomment-1238263094
# download bazel 
wget https://github.com/bazelbuild/bazelisk/releases/download/v1.12.0/bazelisk-linux-amd64
# set up bazel
chmod +x bazelisk-linux-amd64 
mv bazelisk-linux-amd64 /usr/bin/bazel
# start compiling:
pip3 uninstall torch_tensorrt
cd py
python3 setup.py install --use-cxx11-abi

Then run this script:

import torch
import torch_tensorrt
torch.ops.load_library("libcustom_deform_conv.so")
with torch.no_grad():
    with open("repro.ts", "rb") as f:
        backbone = torch.jit.load(f)

    with torch_tensorrt.logging.debug():
        torch_tensorrt.compile(
            backbone,
            inputs=[torch_tensorrt.Input((1, 3, 1440, 1440))],
            min_block_size=3
        )

@SM1991CODES
Copy link

I think doing model.eval() did the trick for me.

`
model.load_state_dict(torch.load(path_trained_pth, map_location="cuda:0"))
model.eval()
sample_input = torch.randn((2, 4, 384, 384)).cuda().float()
traced_model = torch.jit.trace(model, sample_input)

trt_ts_module = trt.compile(traced_model, inputs=[sample_input], enabled_precisions={torch.float32})
`

@ntakouris
Copy link

I get the same error trying to convert torchvision fcos with resnet50.

Relevant issue: pytorch/vision#6200

@samedii
Copy link

samedii commented Dec 4, 2022

Also seeing this on stable diffusion autoencoder.encode

Edit: Resolved for me by updating to 1.3.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: partitioning
Projects
None yet
Development

Successfully merging a pull request may close this issue.