diff --git a/docsrc/py_api/dynamo.rst b/docsrc/py_api/dynamo.rst index fce5372d0e..6b4a527663 100644 --- a/docsrc/py_api/dynamo.rst +++ b/docsrc/py_api/dynamo.rst @@ -22,6 +22,8 @@ Functions .. autofunction:: export +.. autofunction:: convert_module_to_trt_engine + Classes diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 1381971047..170e97ca68 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections.abc import logging from enum import Enum from typing import Any, Callable, List, Optional, Sequence, Set @@ -240,8 +241,6 @@ def compile( return compiled_fx_module elif target_ir == _IRType.dynamo: # Prepare torch and torchtrt inputs - import collections.abc - from torch_tensorrt.dynamo.utils import prepare_inputs if not isinstance(input_list, collections.abc.Sequence): @@ -345,10 +344,19 @@ def convert_method_to_trt_engine( "convert_method_to_trt_engine call is not supported for ir=fx" ) elif target_ir == _IRType.dynamo: + # Prepare torch and torchtrt inputs + from torch_tensorrt.dynamo.utils import prepare_inputs + + if not isinstance(inputs, collections.abc.Sequence): + inputs = [inputs] + + # Export the module + torchtrt_inputs = prepare_inputs(inputs) + exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs) + return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return] - module, + exp_program, inputs=inputs, - method_name=method_name, enabled_precisions=enabled_precisions_set, **kwargs, ) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index ed9a0bb7ae..33364a2897 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -416,8 +416,7 @@ def compile_module( def convert_module_to_trt_engine( - module: torch.fx.GraphModule, - method_name: str = "forward", + exported_program: ExportedProgram, inputs: Optional[Sequence[Input | torch.Tensor]] = None, enabled_precisions: ( Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] @@ -447,15 +446,15 @@ def convert_module_to_trt_engine( calibrator: object = None, allow_shape_tensors: bool = False, ) -> bytes: - """Convert a GraphModule module method to a serialized TensorRT engine + """Convert an ExportedProgram to a serialized TensorRT engine - Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings + Converts an ExportedProgram to a serialized TensorRT engine given a dictionary of conversion settings Arguments: - module (torch.fx.GraphModule): Source module + exported_program (torch.export.ExportedProgram): Source module Keyword Args: - inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + inputs (Optional[Sequence[torch_tensorrt.Input | torch.Tensor]]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. :: @@ -470,30 +469,11 @@ def convert_module_to_trt_engine( ), # Dynamic input shape for input #2 torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings ] - - method_name (str): Name of method to convert - input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using - torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** :: - - input_signature=([ - torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 - torch_tensorrt.Input( - min_shape=(1, 224, 224, 3), - opt_shape=(1, 512, 512, 3), - max_shape=(1, 1024, 1024, 3), - dtype=torch.int32 - format=torch.channel_last - ), # Dynamic input shape for input #2 - ], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3 - - device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: - - device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) - + enabled_precisions (Optional[Set[torch.dtype | _enums.dtype]]): The set of datatypes that TensorRT can use debug (bool): Whether to print out verbose debugging information workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) min_block_size (int): Minimum number of operators per TRT-Engine Block - torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage + torch_executed_ops (Set[str]): Set of operations to run in Torch, regardless of converter coverage pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine version_compatible (bool): Provide version forward-compatibility for engine plan files @@ -560,13 +540,25 @@ def convert_module_to_trt_engine( "dla_global_dram_size": dla_global_dram_size, } + # Decompose the exported program + exported_program = exported_program.run_decompositions( + get_decompositions(enable_experimental_decompositions) + ) + gm = exported_program.module() + logger.debug("Input graph: " + str(gm.graph)) + + # Apply lowering on the graph module + torch_inputs = get_torch_inputs(input_list, device) + gm = apply_lowering_passes(gm, torch_inputs) + logger.debug("Lowered Input graph: " + str(gm.graph)) + settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) try: - interpreter_result = interpret_module_to_result(module, input_list, settings) + interpreter_result = interpret_module_to_result(gm, input_list, settings) except UnsupportedOperatorException: logger.error( - f"Conversion of module {module} not currently fully supported or convertible!", + f"Conversion of module {gm} not currently fully supported or convertible!", exc_info=True, ) except Exception as e: diff --git a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py similarity index 81% rename from tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py rename to tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py index b10cae23fa..00b5dd8b31 100644 --- a/tests/py/dynamo/runtime/test_convert_method_to_trt_engine.py +++ b/tests/py/dynamo/runtime/test_convert_module_to_trt_engine.py @@ -7,7 +7,7 @@ from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity -class TestConvertMethodToTrtEngine(unittest.TestCase): +class TestConvertModuleToTrtEngine(unittest.TestCase): def test_convert_module(self): class Test(torch.nn.Module): def forward(self, a, b): @@ -18,11 +18,11 @@ def forward(self, a, b): # Create a model model = Test() - symbolic_traced_gm = torch.fx.symbolic_trace(model) + exp_program = torch.export.export(model, (input_data_0, input_data_1)) # Convert to TensorRT engine trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine( - symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1] + exp_program, inputs=(input_data_0, input_data_1) ) # Deserialize the TensorRT engine @@ -30,7 +30,9 @@ def forward(self, a, b): engine = runtime.deserialize_cuda_engine(trt_engine_str) # Inference on TRT Engine - py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"]) + py_trt_module = PythonTorchTensorRTModule( + engine, ["arg0_1", "arg1_1"], ["output0"] + ) trt_output = py_trt_module(input_data_0, input_data_1).cpu() # Inference on PyTorch model