From 2a17ebd0af15f4e94f55ff10a2948b1463cae42c Mon Sep 17 00:00:00 2001 From: cehongwang Date: Wed, 17 Jul 2024 17:17:57 -0700 Subject: [PATCH] Changed the user API to include inputs, arg_inputs, kwarg_inputs --- py/torch_tensorrt/_compile.py | 89 +++++++++++++++++++++------ py/torch_tensorrt/dynamo/_compiler.py | 33 ++++++++-- py/torch_tensorrt/dynamo/_tracer.py | 33 ++++++++-- 3 files changed, 125 insertions(+), 30 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index ff7dfebb88..c62211db22 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -148,6 +148,8 @@ def compile( module: Any, ir: str = "default", inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None, + arg_inputs: Optional[Sequence[Sequence[Any]]] = None, + kwarg_inputs: Optional[dict[Any, Any]] = None, enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, **kwargs: Any, ) -> ( @@ -180,7 +182,8 @@ def compile( ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] - + arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path) **kwargs: Additional settings for the specific requested strategy (See submodules for more info) @@ -188,6 +191,7 @@ def compile( Returns: torch.nn.Module: Compiled Module, when run it will execute via TensorRT """ + input_list = inputs if inputs is not None else [] enabled_precisions_set: Set[dtype | torch.dtype] = ( enabled_precisions @@ -238,17 +242,33 @@ def compile( return compiled_fx_module elif target_ir == _IRType.dynamo: # Prepare torch and torchtrt inputs + if not arg_inputs and not inputs: + raise AssertionError("'arg_input' or 'input' should not be None.") + + elif arg_inputs and inputs: + raise AssertionError( + "'arg_input' and 'input' should not be used at the same time." + ) + arg_inputs = inputs or arg_inputs + + if kwarg_inputs is None: + kwarg_inputs = {} + from torch_tensorrt.dynamo.utils import prepare_inputs - if not isinstance(input_list, collections.abc.Sequence): - input_list = [input_list] + if not isinstance(arg_inputs, collections.abc.Sequence): + arg_inputs = [arg_inputs] # type: ignore # Export the module - torchtrt_inputs = prepare_inputs(input_list) - exp_program = dynamo_trace(module, torchtrt_inputs, **kwargs) + torchtrt_arg_inputs = prepare_inputs(arg_inputs) + torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) + + exp_program = dynamo_trace( + module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs + ) trt_graph_module = dynamo_compile( exp_program, - inputs=torchtrt_inputs, + arg_inputs=torchtrt_arg_inputs, enabled_precisions=enabled_precisions_set, **kwargs, ) @@ -280,7 +300,9 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any: def convert_method_to_trt_engine( module: Any, method_name: str = "forward", - inputs: Optional[Sequence[Input | torch.Tensor]] = None, + inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None, + arg_inputs: Optional[Sequence[Sequence[Any]]] = None, + kwarg_inputs: Optional[dict[Any, Any]] = None, ir: str = "default", enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, **kwargs: Any, @@ -309,6 +331,8 @@ def convert_method_to_trt_engine( torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] + arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels ir (str): The requested strategy to compile. (Options: default - Let Torch-TensorRT decide, ts - TorchScript with scripting path) **kwargs: Additional settings for the specific requested strategy (See submodules for more info) @@ -330,7 +354,7 @@ def convert_method_to_trt_engine( ts_mod = torch.jit.script(module) serialized_engine: bytes = ts_convert_method_to_trt_engine( ts_mod, - inputs=inputs, + inputs=arg_inputs, method_name=method_name, enabled_precisions=enabled_precisions_set, **kwargs, @@ -342,18 +366,35 @@ def convert_method_to_trt_engine( ) elif target_ir == _IRType.dynamo: # Prepare torch and torchtrt inputs + if not arg_inputs and not inputs: + raise AssertionError("'arg_input' or 'input' should not be None.") + + elif arg_inputs and inputs: + raise AssertionError( + "'arg_input' and 'input' should not be used at the same time." + ) + arg_inputs = arg_inputs or inputs + + if kwarg_inputs is None: + kwarg_inputs = {} + from torch_tensorrt.dynamo.utils import prepare_inputs - if not isinstance(inputs, collections.abc.Sequence): - inputs = [inputs] + if not isinstance(arg_inputs, collections.abc.Sequence): + arg_inputs = [arg_inputs] # type: ignore # Export the module - torchtrt_inputs = prepare_inputs(inputs) - exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) + torchtrt_arg_inputs = prepare_inputs(arg_inputs) + torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) + + exp_program = torch_tensorrt.dynamo.trace( + module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs**kwargs + ) return dynamo_convert_module_to_trt_engine( exp_program, - inputs=tuple(inputs), + arg_inputs=tuple(arg_inputs), + kwarg_inputs=torchtrt_kwarg_inputs, enabled_precisions=enabled_precisions_set, **kwargs, ) @@ -408,6 +449,7 @@ def save( *, output_format: str = "exported_program", inputs: Optional[Sequence[torch.Tensor]] = None, + arg_inputs: Optional[Sequence[torch.Tensor]] = None, kwargs_inputs: Optional[dict[str, Any]] = None, retrace: bool = False, ) -> None: @@ -417,18 +459,27 @@ def save( Arguments: module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module inputs (torch.Tensor): Torch input tensors + arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. output_format (str): Format to save the model. Options include exported_program | torchscript. retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. This flag is experimental for now. """ module_type = _parse_module_type(module) accepted_formats = {"exported_program", "torchscript"} - if inputs is not None and not all( - isinstance(input, torch.Tensor) for input in inputs + if arg_inputs is not None and not all( + isinstance(input, torch.Tensor) for input in arg_inputs ): raise ValueError( "Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs" ) + if arg_inputs and inputs: + raise AssertionError( + "'arg_input' and 'input' should not be used at the same time." + ) + + arg_inputs = inputs or arg_inputs + if kwargs_inputs is None: kwargs_inputs = {} @@ -460,27 +511,27 @@ def save( else: torch.export.save(module, file_path) elif module_type == _ModuleType.fx: - if inputs is None: + if arg_inputs is None: raise ValueError( "Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model" ) # The module type is torch.fx.GraphModule if output_format == "torchscript": module_ts = torch.jit.trace( - module, inputs, example_kwarg_inputs=kwargs_inputs + module, arg_inputs, example_kwarg_inputs=kwargs_inputs ) torch.jit.save(module_ts, file_path) else: if not retrace: from torch_tensorrt.dynamo._exporter import export - exp_program = export(module, inputs, kwargs_inputs) + exp_program = export(module, arg_inputs, kwargs_inputs) torch.export.save(exp_program, file_path) else: from torch._higher_order_ops.torchbind import enable_torchbind_tracing with enable_torchbind_tracing(): exp_program = torch.export.export( - module, tuple(inputs), kwargs=kwargs_inputs, strict=False + module, tuple(arg_inputs), kwargs=kwargs_inputs, strict=False ) torch.export.save(exp_program, file_path) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 66ecbcc6f7..a583b57f7a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -48,8 +48,9 @@ def compile( exported_program: ExportedProgram, - inputs: Sequence[Any], + inputs: Optional[Sequence[Sequence[Any]]] = None, *, + arg_inputs: Optional[Sequence[Sequence[Any]]] = None, kwarg_inputs: Optional[dict[Any, Any]] = None, device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE, disable_tf32: bool = _defaults.DISABLE_TF32, @@ -111,6 +112,8 @@ def compile( ] Keyword Arguments: + arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) @@ -183,13 +186,21 @@ def compile( ) # Aliasing inputs to arg_inputs for better understanding - arg_inputs = inputs + if not arg_inputs and not inputs: + raise AssertionError("'arg_input' or 'input' should not be None.") + + elif arg_inputs and inputs: + raise AssertionError( + "'arg_input' and 'input' should not be used at the same time." + ) + + arg_inputs = inputs or arg_inputs if kwarg_inputs is None: kwarg_inputs = {} if not isinstance(arg_inputs, collections.abc.Sequence): - arg_inputs = [arg_inputs] + arg_inputs = [arg_inputs] # type: ignore # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) @@ -481,9 +492,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: def convert_module_to_trt_engine( exported_program: ExportedProgram, - inputs: Sequence[Any], - kwarg_inputs: Optional[dict[str, Any]] = None, + inputs: Optional[Sequence[Sequence[Any]]] = None, *, + arg_inputs: Optional[Sequence[Sequence[Any]]] = None, + kwarg_inputs: Optional[dict[Any, Any]] = None, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] ) = _defaults.ENABLED_PRECISIONS, @@ -595,8 +607,17 @@ def convert_module_to_trt_engine( DeprecationWarning, stacklevel=2, ) + if not arg_inputs and not inputs: + raise AssertionError("'arg_input' or 'input' should not be None.") + + elif arg_inputs and inputs: + raise AssertionError( + "'arg_input' and 'input' should not be used at the same time." + ) + + arg_inputs = inputs or arg_inputs - arg_input_list = list(inputs) if inputs is not None else [] + arg_input_list = list(arg_inputs) if arg_inputs is not None else [] torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set() if kwarg_inputs is None: kwarg_inputs = {} diff --git a/py/torch_tensorrt/dynamo/_tracer.py b/py/torch_tensorrt/dynamo/_tracer.py index e1b89886ca..70a0f0752f 100644 --- a/py/torch_tensorrt/dynamo/_tracer.py +++ b/py/torch_tensorrt/dynamo/_tracer.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch from torch.export import Dim, export @@ -14,7 +14,10 @@ def trace( mod: torch.nn.Module | torch.fx.GraphModule, - inputs: Tuple[Any, ...], + inputs: Optional[Tuple[Any, ...]] = None, + *, + arg_inputs: Optional[Tuple[Any, ...]] = None, + kwarg_inputs: Optional[dict[Any, Any]] = None, **kwargs: Any, ) -> torch.export.ExportedProgram: """Exports a ``torch.export.ExportedProgram`` from a ``torch.nn.Module`` or ``torch.fx.GraphModule`` specifically targeting being compiled with Torch-TensorRT @@ -40,6 +43,8 @@ def trace( torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] Keyword Arguments: + arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. + kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. device (Union(torch.device, dict)): Target device for TensorRT engines to run on :: device=torch.device("cuda:0") @@ -52,14 +57,27 @@ def trace( """ # Set log level at the top of compilation (torch_tensorrt.dynamo) + if not arg_inputs and not inputs: + raise AssertionError("'arg_input' or 'input' should not be None.") + + elif arg_inputs and inputs: + raise AssertionError( + "'arg_input' and 'input' should not be used at the same time." + ) + arg_inputs = inputs or arg_inputs + + if kwarg_inputs is None: + kwarg_inputs = {} + debug = kwargs.get("debug", DEBUG) if debug: set_log_level(logger.parent, logging.DEBUG) device = to_torch_device(kwargs.get("device", default_device())) - torch_inputs = get_torch_inputs(inputs, device) + torch_arg_inputs = get_torch_inputs(arg_inputs, device) + torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) dynamic_shapes = [] - for input in inputs: + for input in arg_inputs: # type: ignore if isinstance(input, Input) and input.shape_mode == Input._ShapeMode.DYNAMIC: min_shape = input.shape["min_shape"] opt_shape = input.shape["opt_shape"] @@ -78,6 +96,11 @@ def trace( dynamic_shapes.append(dynamic_dims) - exp_program = export(mod, tuple(torch_inputs), dynamic_shapes=tuple(dynamic_shapes)) + exp_program = export( + mod, + tuple(torch_arg_inputs), + kwargs=torch_kwarg_inputs, + dynamic_shapes=tuple(dynamic_shapes), + ) return exp_program