Skip to content

Commit

Permalink
fix: Repair integer inputs in dynamic shape cases (#2876)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Jun 25, 2024
1 parent fbc72d5 commit caf3a92
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,4 @@ jobs:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
cancel-in-progress: true
cancel-in-progress: true
19 changes: 14 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.context = self.engine.create_execution_context()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]
with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down Expand Up @@ -174,7 +179,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
for i, input_name in enumerate(self.input_names):
if not contiguous_inputs[i].is_cuda:
logger.warning(
Expand All @@ -193,12 +197,17 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
# Shape tensor inputs are casted to int32 explicitly.
# Refer to https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int32)
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
inputs_cpu = (
contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
)
self.context.set_tensor_address(
input_name, inputs_cpu.data_ptr()
input_name, inputs_cpu.ctypes.data
)
else:
self.context.set_input_shape(
Expand Down
26 changes: 11 additions & 15 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
"""Implementation of the forward pass for a TensorRT engine
Args:
*inputs (torch.Tensor): Inputs to the forward function, must all be ``torch.Tensor``
*inputs (Union[torch.Tensor, int]): Inputs to the forward function
Returns:
torch.Tensor or Tuple(torch.Tensor): Result of the engine computation
Expand All @@ -158,22 +158,18 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
self.input_binding_names
), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}."

types: List[bool] = [issubclass(type(i), torch.Tensor) for i in inputs]

try:
assert all(types)
except AssertionError:

def is_non_tensor(i: Tuple[Any, bool]) -> bool:
return not i[1]

non_tensors = [i[0] for i in filter(is_non_tensor, zip(inputs, types))]
raise RuntimeError(
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
)
# If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors
# which are outputs of a preceding Torch subgraph (where the Dynamic input may be an integer)
# directly cast the input to a Torch Tensor.
#
# This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors
input_tensors: List[torch.Tensor] = [
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
list(inputs), self.engine
list(input_tensors), self.engine
)

if len(outputs) == 1:
Expand Down
55 changes: 55 additions & 0 deletions tests/py/dynamo/models/test_dyn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,58 @@ def forward(self, x):
cos_sim > COSINE_THRESHOLD,
msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_dynamic_with_fallback_shape_tensor_pass_through(ir):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()

def forward(self, x):
out = self.conv(x)
x = x + 2
x = x * 2
out = torch.reshape(x, (-1, 224 * 224))
out = self.relu(out)
return out

model = MyModule().eval().cuda()
input_bs4 = torch.randn((4, 3, 224, 224)).to("cuda")

compile_spec = {
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"ir": ir,
"pass_through_build_failures": True,
"min_block_size": 1,
"torch_executed_ops": {"torch.ops.aten.add.Tensor"},
}

# Compile the model
if ir == "torch_compile":
torch._dynamo.mark_dynamic(input_bs4, 0, min=4, max=1024)
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
trt_model(input_bs4)
elif ir == "dynamo":
compile_spec["inputs"] = [
torchtrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(1024, 3, 224, 224),
dtype=torch.float32,
name="x",
)
]
trt_model = torchtrt.compile(model, **compile_spec)

trt_model(input_bs4)

input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda")
cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6))
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_dynamic_with_fallback_shape_tensor_pass_through model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
2 changes: 1 addition & 1 deletion tests/py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ pytest-xdist>=3.6.1
pyyaml
tensorrt==10.0.1
timm>=1.0.3
transformers==4.39.3
transformers==4.40.2
--extra-index-url https://pypi.nvidia.com

0 comments on commit caf3a92

Please sign in to comment.