From caf3a92ba3b2325dd05658d2eba4951442399461 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Mon, 24 Jun 2024 18:23:34 -0700 Subject: [PATCH] fix: Repair integer inputs in dynamic shape cases (#2876) --- .github/workflows/build-test-linux.yml | 2 +- .../runtime/_PythonTorchTensorRTModule.py | 19 +++++-- .../dynamo/runtime/_TorchTensorRTModule.py | 26 ++++----- tests/py/dynamo/models/test_dyn_models.py | 55 +++++++++++++++++++ tests/py/requirements.txt | 2 +- 5 files changed, 82 insertions(+), 22 deletions(-) diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 0bd570155e..22f03e6218 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -264,4 +264,4 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} - cancel-in-progress: true \ No newline at end of file + cancel-in-progress: true diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 471e3ef913..0daf75b091 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -128,6 +128,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self.context = self.engine.create_execution_context() def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + # Ensure inputs are available in all scopes and cast symbolic integers to Tensors + contiguous_inputs: List[torch.Tensor] = [ + (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) + for i in inputs + ] with ( torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") if self.profiling_enabled @@ -174,7 +179,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self.input_names ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." - contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs] for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: logger.warning( @@ -193,12 +197,17 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . contiguous_inputs[i].dtype == self.input_dtypes[i] ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." + # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers + # as per TensorRT requirements if self.engine.is_shape_inference_io(input_name): - # Shape tensor inputs are casted to int32 explicitly. - # Refer to https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 - inputs_cpu = contiguous_inputs[i].cpu().to(torch.int32) + # Shape tensor inputs are casted to int64 explicitly + # Currently Torch CPU pointers are not working; numpy pointers are used instead + # to refer to underlying memory + inputs_cpu = ( + contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() + ) self.context.set_tensor_address( - input_name, inputs_cpu.data_ptr() + input_name, inputs_cpu.ctypes.data ) else: self.context.set_input_shape( diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 709c10b36e..3efa04413f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -146,7 +146,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: """Implementation of the forward pass for a TensorRT engine Args: - *inputs (torch.Tensor): Inputs to the forward function, must all be ``torch.Tensor`` + *inputs (Union[torch.Tensor, int]): Inputs to the forward function Returns: torch.Tensor or Tuple(torch.Tensor): Result of the engine computation @@ -158,22 +158,18 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]: self.input_binding_names ), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}." - types: List[bool] = [issubclass(type(i), torch.Tensor) for i in inputs] - - try: - assert all(types) - except AssertionError: - - def is_non_tensor(i: Tuple[Any, bool]) -> bool: - return not i[1] - - non_tensors = [i[0] for i in filter(is_non_tensor, zip(inputs, types))] - raise RuntimeError( - f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}" - ) + # If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors + # which are outputs of a preceding Torch subgraph (where the Dynamic input may be an integer) + # directly cast the input to a Torch Tensor. + # + # This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors + input_tensors: List[torch.Tensor] = [ + (i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) + for i in inputs + ] outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine( - list(inputs), self.engine + list(input_tensors), self.engine ) if len(outputs) == 1: diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 3fd34de2ea..50fa9a2f50 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -310,3 +310,58 @@ def forward(self, x): cos_sim > COSINE_THRESHOLD, msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + + +@pytest.mark.unit +def test_dynamic_with_fallback_shape_tensor_pass_through(ir): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + x = x + 2 + x = x * 2 + out = torch.reshape(x, (-1, 224 * 224)) + out = self.relu(out) + return out + + model = MyModule().eval().cuda() + input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda") + + compile_spec = { + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "pass_through_build_failures": True, + "min_block_size": 1, + "torch_executed_ops": {"torch.ops.aten.add.Tensor"}, + } + + # Compile the model + if ir == "torch_compile": + torch._dynamo.mark_dynamic(input_bs4, 0, min=4, max=1024) + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(input_bs4) + elif ir == "dynamo": + compile_spec["inputs"] = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(1024, 3, 224, 224), + dtype=torch.float32, + name="x", + ) + ] + trt_model = torchtrt.compile(model, **compile_spec) + + trt_model(input_bs4) + + input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_dynamic_with_fallback_shape_tensor_pass_through model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index bdae578713..36dbda57bc 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -9,5 +9,5 @@ pytest-xdist>=3.6.1 pyyaml tensorrt==10.0.1 timm>=1.0.3 -transformers==4.39.3 +transformers==4.40.2 --extra-index-url https://pypi.nvidia.com