Skip to content

Commit

Permalink
Changed the user API to include inputs, arg_inputs, kwarg_inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Jul 18, 2024
1 parent 8ecef1c commit 2a17ebd
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 30 deletions.
89 changes: 70 additions & 19 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def compile(
module: Any,
ir: str = "default",
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
**kwargs: Any,
) -> (
Expand Down Expand Up @@ -180,14 +182,16 @@ def compile(
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
Returns:
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
"""

input_list = inputs if inputs is not None else []
enabled_precisions_set: Set[dtype | torch.dtype] = (
enabled_precisions
Expand Down Expand Up @@ -238,17 +242,33 @@ def compile(
return compiled_fx_module
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
if not arg_inputs and not inputs:
raise AssertionError("'arg_input' or 'input' should not be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_input' and 'input' should not be used at the same time."
)
arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
kwarg_inputs = {}

from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(input_list, collections.abc.Sequence):
input_list = [input_list]
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore

# Export the module
torchtrt_inputs = prepare_inputs(input_list)
exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs)
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)

exp_program = dynamo_trace(
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
)
trt_graph_module = dynamo_compile(
exp_program,
inputs=torchtrt_inputs,
arg_inputs=torchtrt_arg_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down Expand Up @@ -280,7 +300,9 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
def convert_method_to_trt_engine(
module: Any,
method_name: str = "forward",
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
ir: str = "default",
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
**kwargs: Any,
Expand Down Expand Up @@ -309,6 +331,8 @@ def convert_method_to_trt_engine(
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path)
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
Expand All @@ -330,7 +354,7 @@ def convert_method_to_trt_engine(
ts_mod = torch.jit.script(module)
serialized_engine: bytes = ts_convert_method_to_trt_engine(
ts_mod,
inputs=inputs,
inputs=arg_inputs,
method_name=method_name,
enabled_precisions=enabled_precisions_set,
**kwargs,
Expand All @@ -342,18 +366,35 @@ def convert_method_to_trt_engine(
)
elif target_ir == _IRType.dynamo:
# Prepare torch and torchtrt inputs
if not arg_inputs and not inputs:
raise AssertionError("'arg_input' or 'input' should not be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_input' and 'input' should not be used at the same time."
)
arg_inputs = arg_inputs or inputs

if kwarg_inputs is None:
kwarg_inputs = {}

from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore

# Export the module
torchtrt_inputs = prepare_inputs(inputs)
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)

exp_program = torch_tensorrt.dynamo.trace(
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs**kwargs
)

return dynamo_convert_module_to_trt_engine(
exp_program,
inputs=tuple(inputs),
arg_inputs=tuple(arg_inputs),
kwarg_inputs=torchtrt_kwarg_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down Expand Up @@ -408,6 +449,7 @@ def save(
*,
output_format: str = "exported_program",
inputs: Optional[Sequence[torch.Tensor]] = None,
arg_inputs: Optional[Sequence[torch.Tensor]] = None,
kwargs_inputs: Optional[dict[str, Any]] = None,
retrace: bool = False,
) -> None:
Expand All @@ -417,18 +459,27 @@ def save(
Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if inputs is not None and not all(
isinstance(input, torch.Tensor) for input in inputs
if arg_inputs is not None and not all(
isinstance(input, torch.Tensor) for input in arg_inputs
):
raise ValueError(
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
)
if arg_inputs and inputs:
raise AssertionError(
"'arg_input' and 'input' should not be used at the same time."
)

arg_inputs = inputs or arg_inputs

if kwargs_inputs is None:
kwargs_inputs = {}

Expand Down Expand Up @@ -460,27 +511,27 @@ def save(
else:
torch.export.save(module, file_path)
elif module_type == _ModuleType.fx:
if inputs is None:
if arg_inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
)
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(
module, inputs, example_kwarg_inputs=kwargs_inputs
module, arg_inputs, example_kwarg_inputs=kwargs_inputs
)
torch.jit.save(module_ts, file_path)
else:
if not retrace:
from torch_tensorrt.dynamo._exporter import export

exp_program = export(module, inputs, kwargs_inputs)
exp_program = export(module, arg_inputs, kwargs_inputs)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(inputs), kwargs=kwargs_inputs, strict=False
module, tuple(arg_inputs), kwargs=kwargs_inputs, strict=False
)
torch.export.save(exp_program, file_path)
33 changes: 27 additions & 6 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@

def compile(
exported_program: ExportedProgram,
inputs: Sequence[Any],
inputs: Optional[Sequence[Sequence[Any]]] = None,
*,
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
disable_tf32: bool = _defaults.DISABLE_TF32,
Expand Down Expand Up @@ -111,6 +112,8 @@ def compile(
]
Keyword Arguments:
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
Expand Down Expand Up @@ -183,13 +186,21 @@ def compile(
)

# Aliasing inputs to arg_inputs for better understanding
arg_inputs = inputs
if not arg_inputs and not inputs:
raise AssertionError("'arg_input' or 'input' should not be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_input' and 'input' should not be used at the same time."
)

arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
kwarg_inputs = {}

if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs]
arg_inputs = [arg_inputs] # type: ignore

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
Expand Down Expand Up @@ -481,9 +492,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

def convert_module_to_trt_engine(
exported_program: ExportedProgram,
inputs: Sequence[Any],
kwarg_inputs: Optional[dict[str, Any]] = None,
inputs: Optional[Sequence[Sequence[Any]]] = None,
*,
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
Expand Down Expand Up @@ -595,8 +607,17 @@ def convert_module_to_trt_engine(
DeprecationWarning,
stacklevel=2,
)
if not arg_inputs and not inputs:
raise AssertionError("'arg_input' or 'input' should not be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_input' and 'input' should not be used at the same time."
)

arg_inputs = inputs or arg_inputs

arg_input_list = list(inputs) if inputs is not None else []
arg_input_list = list(arg_inputs) if arg_inputs is not None else []
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
if kwarg_inputs is None:
kwarg_inputs = {}
Expand Down
33 changes: 28 additions & 5 deletions py/torch_tensorrt/dynamo/_tracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Any, Tuple
from typing import Any, Optional, Tuple

import torch
from torch.export import Dim, export
Expand All @@ -14,7 +14,10 @@

def trace(
mod: torch.nn.Module | torch.fx.GraphModule,
inputs: Tuple[Any, ...],
inputs: Optional[Tuple[Any, ...]] = None,
*,
arg_inputs: Optional[Tuple[Any, ...]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
**kwargs: Any,
) -> torch.export.ExportedProgram:
"""Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT
Expand All @@ -40,6 +43,8 @@ def trace(
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
Keyword Arguments:
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
device (Union(torch.device, dict)): Target device for TensorRT engines to run on ::
device=torch.device("cuda:0")
Expand All @@ -52,14 +57,27 @@ def trace(
"""

# Set log level at the top of compilation (torch_tensorrt.dynamo)
if not arg_inputs and not inputs:
raise AssertionError("'arg_input' or 'input' should not be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_input' and 'input' should not be used at the same time."
)
arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
kwarg_inputs = {}

debug = kwargs.get("debug", DEBUG)
if debug:
set_log_level(logger.parent, logging.DEBUG)

device = to_torch_device(kwargs.get("device", default_device()))
torch_inputs = get_torch_inputs(inputs, device)
torch_arg_inputs = get_torch_inputs(arg_inputs, device)
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
dynamic_shapes = []
for input in inputs:
for input in arg_inputs: # type: ignore
if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC:
min_shape = input.shape["min_shape"]
opt_shape = input.shape["opt_shape"]
Expand All @@ -78,6 +96,11 @@ def trace(

dynamic_shapes.append(dynamic_dims)

exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes))
exp_program = export(
mod,
tuple(torch_arg_inputs),
kwargs=torch_kwarg_inputs,
dynamic_shapes=tuple(dynamic_shapes),
)

return exp_program

0 comments on commit 2a17ebd

Please sign in to comment.