Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] KeyError after resolveNonTensorInputs #1018

Closed
mfeliz-cruise opened this issue May 3, 2022 · 4 comments · Fixed by #1032
Closed

🐛 [Bug] KeyError after resolveNonTensorInputs #1018

mfeliz-cruise opened this issue May 3, 2022 · 4 comments · Fixed by #1032
Assignees
Labels
bug Something isn't working

Comments

@mfeliz-cruise
Copy link
Contributor

Bug Description

Torch-TensorRT attempts to resolve all non-tensor inputs of a torch block if any of those inputs are generated by tensorrt blocks. This leads to a failed attempt to resolve a dictionary input to a torch block that is generated by another torch block. getDependencyNodes fails to identify the aten::_set_item as a dependency which results in a KeyError.

This is the original graph. This is a small artificial test case only intended to reproduce this issue.

graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %3 : str = prim::Constant[value="INS"]() 
  %4 : int = prim::Constant[value=-1]() 
  %5 : bool = prim::Constant[value=0]()
  %6 : str = prim::Constant[value="OUTS"]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
   = aten::_set_item(%out_dict.1, %3, %x.1) 
  %z.1 : Tensor = aten::__getitem__(%out_dict.1, %3)
  %9 : Tensor = aten::lt(%z.1, %y.1) 
  %13 : Tensor?[] = prim::ListConstruct(%9)
  %45 : int = prim::dtype(%z.1)
  %46 : Device = prim::device(%z.1)
  %49 : Tensor = aten::tensor(%4, %45, %46, %5)
  %14 : Tensor = aten::index_put_(%z.1, %13, %49, %5) 
   = aten::_set_item(%out_dict.1, %6, %z.1) 
  %15 : Tensor = aten::__getitem__(%out_dict.1, %3) 
  %16 : Tensor = aten::__getitem__(%out_dict.1, %6) 
  return (%15, %16)

It is segmented as follows. The Tensor?[] input to @2 from @1 will need to be resolved triggering resolution of all @2 inputs including %out_dict.1 which is a dictionary create in @0.

DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @0:
    Target: Torch

    Graph: graph(%x.1 : Tensor):
  %1 : str = prim::Constant[value="INS"]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
   = aten::_set_item(%out_dict.1, %1, %x.1)
  %z.1 : Tensor = aten::__getitem__(%out_dict.1, %1)
  return ()

DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @1:
    Target: TensorRT

    Graph: graph(%z.1 : Tensor,
      %y.1 : Tensor):
  %0 : Tensor = aten::lt(%z.1, %y.1)
  %3 : Tensor?[] = prim::ListConstruct(%0)
  %4 : int = prim::dtype(%z.1)
  return ()

DEBUG: [Torch-TensorRT - Debug Build] - Segment Block @2:
    Target: Torch

    Graph: graph(%z.1 : Tensor,
      %4 : int,
      %7 : Tensor?[],
      %out_dict.1 : Dict(str, Tensor)):
  %11 : str = prim::Constant[value="INS"]()
  %9 : str = prim::Constant[value="OUTS"]()
  %5 : bool = prim::Constant[value=0]()
  %3 : int = prim::Constant[value=-1]()
  %0 : Device = prim::device(%z.1)
  %2 : Tensor = aten::tensor(%3, %4, %0, %5)
  %6 : Tensor = aten::index_put_(%z.1, %7, %2, %5)
   = aten::_set_item(%out_dict.1, %9, %z.1)
  %10 : Tensor = aten::__getitem__(%out_dict.1, %11)
  %12 : Tensor = aten::__getitem__(%out_dict.1, %9)
  return ()

After resolveNonTensorInputs we can see that the prim::DictConstruct() node for %out_dict.1 is copied into @2 without the following aten::_set_item node.

Segment Block @0:
    Target: Torch

    Graph: graph(%x.1 : Tensor):
  %1 : str = prim::Constant[value="INS"]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
   = aten::_set_item(%out_dict.1, %1, %x.1)
  %z.1 : Tensor = aten::__getitem__(%out_dict.1, %1)
  return ()


Segment Block @1:
    Target: TensorRT

    Graph: graph(%z.1 : Tensor,
      %y.1 : Tensor):
  %0 : Tensor = aten::lt(%z.1, %y.1)
  %3 : Tensor?[] = prim::ListConstruct(%0)
  %4 : int = prim::dtype(%z.1)
  return ()


Segment Block @2:
    Target: Torch

    Graph: graph(%2 : Tensor,
      %z.1 : Tensor):
  %12 : str = prim::Constant[value="INS"]() 
  %10 : str = prim::Constant[value="OUTS"]()
  %8 : bool = prim::Constant[value=0]()
  %7 : int = prim::Constant[value=-1]()
  %out_dict.1 : Dict(str, Tensor) = prim::DictConstruct()
  %1 : Tensor?[] = prim::ListConstruct(%2)
  %3 : int = prim::dtype(%z.1)
  %5 : Device = prim::device(%z.1)
  %6 : Tensor = aten::tensor(%7, %3, %5, %8)
  %9 : Tensor = aten::index_put_(%z.1, %1, %6, %8)
   = aten::_set_item(%out_dict.1, %10, %z.1)
  %11 : Tensor = aten::__getitem__(%out_dict.1, %12)
  %13 : Tensor = aten::__getitem__(%out_dict.1, %10)
  return ()

To Reproduce

Steps to reproduce the behavior:

  1. Run the python below with the latest version of torch-tensorrt
# Third-party imports
import torch
import torch.nn as nn
import torch_tensorrt

class Reproducer(nn.Module):
    def __init__(self):
        super(Reproducer, self).__init__()
    def forward(self, x, y):
        out_dict = {}
        out_dict["INS"] = x
        z = out_dict["INS"]
        z[z < y] = -1
        out_dict["OUTS"] = z
        return out_dict["INS"], out_dict["OUTS"]

def reproduce_error():
    torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)
    model = Reproducer().eval().cuda()

    x = torch.randn(20, 16, 50, 32).cuda()
    y = torch.randn(20, 16, 50, 32).cuda()
    
    trt_model = torch_tensorrt.compile(model, inputs=[x, y],  **{
            "truncate_long_and_double": True,
        })
    #print(trt_model.forward(x, y))


reproduce_error()
  1. Note the error "RuntimeError: KeyError: INS".

Expected behavior

Torch-TensorRT should not attempt to resolve non-tensor inputs of torch blocks that are generated by torch blocks. If it does choose to resolve a dictionary input it should include aten::_set_item as a dependency.

@mfeliz-cruise mfeliz-cruise added the bug Something isn't working label May 3, 2022
@mfeliz-cruise
Copy link
Contributor Author

I have a fix for this that modifies resolveNonTensorInputs to only resolve non-tensor inputs to torch blocks that are generated by trt blocks. Happy to put up a PR if I can get access as a contributor.

@bowang007
Copy link
Collaborator

Hi @mfeliz-cruise , thank you for your help!
I was also looking into this issue last week, and I was wondering if it's possible to clone the dictionary node with inserted keys.
Hi @narendasan could you please help with him to raise a PR?

@narendasan
Copy link
Collaborator

narendasan commented May 3, 2022

You should be able to open a PR without any special privileges. Create a fork of this repo on your github account and push your patch to that. Github will guide you in opening a PR to upstream.

@mfeliz-cruise
Copy link
Contributor Author

mfeliz-cruise commented May 3, 2022

@bowang007 I was wondering if it's possible to clone the dictionary node with inserted keys
Would this clone resolve dictionary input to the second block? I guess we would need to create a Dict in the target block with all of the keys and values baked in. Either way it seems like we would need to identify the relevant sets which may be non-trivial with control flow. Seems like this might be as difficult as enhancing getDependency nodes to identify the sets as dependencies.

@narendasan thanks I'll try it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants