Skip to content

Commit

Permalink
fix: Repair integer inputs in dynamic shape cases
Browse files Browse the repository at this point in the history
- Generative inference with HF text generation models such as gpt2 can
fail if graph segmentation causes a symbolic integer to be passed from
Torch to TRT, since the Torch output is an integer, while TRT expects a
tensor
- Added logic to the modules to address this case
- Added test cases to validate generation with both Python and C++
runtimes
  • Loading branch information
gs-olive committed Jun 7, 2024
1 parent 6d8dbfd commit 5906ab4
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 20 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ jobs:
cd tests/py/dynamo
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py
popd
tests-py-dynamo-serde:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ jobs:
cd tests/py/dynamo
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_fe_test_results.xml --ir dynamo models/test_models_export.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/test_dyn_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_generation_compile.xml models/test_hf_generate_dynamic.py
popd
tests-py-torch-compile-be:
Expand Down
19 changes: 14 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.context = self.engine.create_execution_context()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]
with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
Expand Down Expand Up @@ -174,7 +179,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self.input_names
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}."

contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
for i, input_name in enumerate(self.input_names):
if not contiguous_inputs[i].is_cuda:
logger.warning(
Expand All @@ -193,12 +197,17 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
contiguous_inputs[i].dtype == self.input_dtypes[i]
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."

# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
# as per TensorRT requirements
if self.engine.is_shape_inference_io(input_name):
# Shape tensor inputs are casted to int32 explicitly.
# Refer to https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int32)
# Shape tensor inputs are casted to int64 explicitly
# Currently Torch CPU pointers are not working; numpy pointers are used instead
# to refer to underlying memory
inputs_cpu = (
contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
)
self.context.set_tensor_address(
input_name, inputs_cpu.data_ptr()
input_name, inputs_cpu.ctypes.data
)
else:
self.context.set_input_shape(
Expand Down
26 changes: 11 additions & 15 deletions py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
"""Implementation of the forward pass for a TensorRT engine
Args:
*inputs (torch.Tensor): Inputs to the forward function, must all be ``torch.Tensor``
*inputs (Union[torch.Tensor, int]): Inputs to the forward function
Returns:
torch.Tensor or Tuple(torch.Tensor): Result of the engine computation
Expand All @@ -158,22 +158,18 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
self.input_binding_names
), f"Wrong number of inputs, expected {len(self.input_binding_names)} got {len(inputs)}."

types: List[bool] = [issubclass(type(i), torch.Tensor) for i in inputs]

try:
assert all(types)
except AssertionError:

def is_non_tensor(i: Tuple[Any, bool]) -> bool:
return not i[1]

non_tensors = [i[0] for i in filter(is_non_tensor, zip(inputs, types))]
raise RuntimeError(
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
)
# If the inputs are not Torch Tensors, which can occur in scenarios such as shape tensors
# which are outputs of a preceding Torch subgraph (where the Dynamic input may be an integer)
# directly cast the input to a Torch Tensor.
#
# This also avoids the need for type-checking inputs, since they are now explicitly casted to Torch tensors
input_tensors: List[torch.Tensor] = [
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
list(inputs), self.engine
list(input_tensors), self.engine
)

if len(outputs) == 1:
Expand Down
132 changes: 132 additions & 0 deletions tests/py/dynamo/models/test_hf_generate_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
from transformers.generation.stopping_criteria import (
EosTokenCriteria,
MaxLengthCriteria,
)


@pytest.mark.unit
def test_dynamic_generation_python_rt():
"""
Tests HuggingFace Generate Code with dynamic shapes
Code Credit: @peri044
"""
# Define tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2", pad_token_id=tokenizer.eos_token_id, use_cache=False
)
.eval()
.to("cuda")
)

# Input prompt
model_inputs = tokenizer(("Repeat " * 128)[:-1], return_tensors="pt").to("cuda")
input_ids = model_inputs["input_ids"]
max_tokens = 40

# Pyt model outputs
greedy_output = model.generate(**model_inputs, max_new_tokens=max_tokens)
print(
"Pytorch model generated text: ",
tokenizer.decode(greedy_output[0], skip_special_tokens=True),
)

# Compile Torch-TRT model
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float},
"torch_executed_ops": {"torch.ops.aten.slice.Tensor"},
"use_python_runtime": True,
"optimization_level": 0,
"min_block_size": 29,
},
)

# Auto-regressive generation loop for greedy search
stopping_criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=max_tokens),
EosTokenCriteria(eos_token_id=tokenizer.eos_token_id),
]
)
while True:
trt_outputs = model(input_ids)
logits = trt_outputs.logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if stopping_criteria(input_ids, logits).item():
break

# TODO: Add test for correctness


@pytest.mark.unit
def test_dynamic_generation_cpp_rt():
"""
Tests HuggingFace Generate Code with dynamic shapes
Code Credit: @peri044
"""
# Define tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2", pad_token_id=tokenizer.eos_token_id, use_cache=False
)
.eval()
.to("cuda")
)

# Input prompt
model_inputs = tokenizer(("Repeat " * 128)[:-1], return_tensors="pt").to("cuda")
input_ids = model_inputs["input_ids"]
max_tokens = 40

# Pyt model outputs
greedy_output = model.generate(**model_inputs, max_new_tokens=max_tokens)
print(
"Pytorch model generated text: ",
tokenizer.decode(greedy_output[0], skip_special_tokens=True),
)

# Compile Torch-TRT model
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float},
"torch_executed_ops": {"torch.ops.aten.slice.Tensor"},
"use_python_runtime": False,
"optimization_level": 0,
"min_block_size": 29,
},
)

# Auto-regressive generation loop for greedy search
stopping_criteria = StoppingCriteriaList(
[
MaxLengthCriteria(max_length=max_tokens),
EosTokenCriteria(eos_token_id=tokenizer.eos_token_id),
]
)
while True:
trt_outputs = model(input_ids)
logits = trt_outputs.logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if stopping_criteria(input_ids, logits).item():
break

# TODO: Add test for correctness

0 comments on commit 5906ab4

Please sign in to comment.