From ff95381c5ff48ffd57423de621906d48267a42fe Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 25 Mar 2024 11:57:39 -0700 Subject: [PATCH] parent f39e89e3964bc3d6ea3a6989b1e4099e1bb3e6dd author Dheeraj Peri 1711393059 -0700 committer Dheeraj Peri 1711393072 -0700 chore: minor updates chore: Fix save failures chore: minor fixes chore: remove duplicate bert test case chore: remove comments chore: add load api chore: minor updates chore: minor updates --- .github/workflows/build-test.yml | 1 + core/runtime/TRTEngine.cpp | 2 +- docsrc/user_guide/saving_models.rst | 71 +++++---- py/torch_tensorrt/_compile.py | 101 ++++++++++++- py/torch_tensorrt/dynamo/_compiler.py | 9 +- py/torch_tensorrt/dynamo/_defaults.py | 1 - py/torch_tensorrt/dynamo/_exporter.py | 142 +++++++++++------- py/torch_tensorrt/dynamo/_settings.py | 3 - .../lowering/test_aten_lowering_passes.py | 4 +- tests/py/dynamo/models/test_export_serde.py | 64 ++++---- tests/py/dynamo/models/test_models_export.py | 63 +------- tests/py/dynamo/models/test_output_format.py | 62 -------- tests/py/dynamo/models/test_save_load.py | 50 ++++++ 13 files changed, 312 insertions(+), 261 deletions(-) delete mode 100644 tests/py/dynamo/models/test_output_format.py create mode 100644 tests/py/dynamo/models/test_save_load.py diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 5573bb8d28..839982e295 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -169,6 +169,7 @@ jobs: cd tests/py/dynamo ${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py + ${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/save_load_test_results.xml --ir dynamo models/test_save_load.py popd tests-py-torch-compile-be: diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 92e5d7a8ff..7a046f6d94 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const { exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str())) << std::endl; } - ss << " }" << std::endl; + ss << " ]" << std::endl; ss << " Device: " << device_info << std::endl; ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl; // clang-format on diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index 8379b44f0f..c081438b09 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -9,23 +9,22 @@ Saving models compiled with Torch-TensorRT :undoc-members: :show-inheritance: -Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation. +Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.save` API. Dynamo IR ------------- -The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default. -In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation. -The `output_format` can take the following options +The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default. +We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by +specifying the `output_format` flag. Here are the options `output_format` will accept -* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram -* `torchscript` (or) `ts` : This returns a TorchScript module -* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk. +* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module. +* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`. -a) Torchscript +a) ExportedProgram ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save +Here's an example usage .. code-block:: python @@ -34,19 +33,17 @@ If you set the `output_format="torchscript"`, this will return a `ScriptModule` model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] - # trt_ts is a torch.jit.ScriptModule object - trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript") - torch.jit.save(trt_ts, "trt_model.ts") + # trt_ep is a torch.fx.GraphModule object + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + torchtrt.save(trt_gm, "trt.ep", inputs=inputs) # Later, you can load it and run inference - model = torch.jit.load("trt_model.ts").cuda() + model = torch.export.load("trt.ep").module() model(*inputs) -b) ExportedProgram +b) Torchscript ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation. - .. code-block:: python import torch @@ -54,30 +51,14 @@ b) ExportedProgram model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] - # trt_ep is a torch.export.ExportedProgram object - trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs) - torch.export.save(trt_ep, "trt_model.ep") + # trt_gm is a torch.fx.GraphModule object + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + torchtrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs) # Later, you can load it and run inference - model = torch.export.load("trt_model.ep") + model = torch.jit.load("trt.ts").cuda() model(*inputs) -c) GraphModule -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`. -Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or -exported into `ExportedProgram` objects - -.. code-block:: python - - import torch - import torch_tensorrt - - model = MyModel().eval().cuda() - inputs = [torch.randn((1, 3, 224, 224)).cuda()] - # trt_gm is a torch.fx.GraphModule object - trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module") Torchscript IR ------------- @@ -99,3 +80,21 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well. model = torch.jit.load("trt_model.ts").cuda() model(*inputs) + +Loading the models +-------------------- + +We can load torchscript or exported_program models using `torch.jit.load` and `torch.export.load` APIs from PyTorch directly. +Alternatively, we provide a light wrapper `torch_tensorrt.load(file_path)` which can load either of the above model types. + +Here's an example usage + +.. code-block:: python + + import torch + import torch_tensorrt + + # file_path can be trt.ep or trt.es file obtained via saving the model (refer to the above section) + inputs = [torch.randn((1, 3, 224, 224)).cuda()] + model = torch_tensorrt.load().module() + model(*inputs) \ No newline at end of file diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9dd816e633..b5b43eb28d 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -6,6 +6,7 @@ import torch import torch.fx +import torch_tensorrt.dynamo import torch_tensorrt.ts from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input @@ -26,10 +27,7 @@ logger = logging.getLogger(__name__) -__all__ = [ - "compile", - "convert_method_to_trt_engine", -] +__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"] def _non_fx_input_interface( @@ -332,3 +330,98 @@ def convert_method_to_trt_engine( ) else: raise RuntimeError("Module is an unknown format or the ir requested is unknown") + + +def load(file_path: str = "") -> Any: + """ + Load either a Torchscript model or ExportedProgram. Autodetect the type using + try, except + """ + try: + ts_module = torch.jit.load(file_path) + return ts_module + except Exception: + pass + + try: + exp_program = torch.export.load(file_path) + return exp_program + except Exception: + raise ValueError( + "The file doesn't correspond to a Torchscript module or ExportedProgram. Please verify the file path." + ) + + +def save( + module: Any, + file_path: str = "", + *, + output_format: str = "exported_program", + inputs: Optional[Sequence[torch.Tensor]] = None, + retrace: bool = False, +) -> None: + """ + Save the model to disk in the specified output format. + Arguments: + module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule) + inputs (torch.Tensor): Torch input tensors + output_format: Format to save the model. Options include exported_program | torchscript. + retrace: When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. + This flag is experimental for now. + """ + module_type = _parse_module_type(module) + accepted_formats = {"exported_program", "torchscript"} + if inputs is not None and not all( + isinstance(input, torch.Tensor) for input in inputs + ): + raise ValueError( + "Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs" + ) + if output_format not in accepted_formats: + raise ValueError( + f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript" + ) + if not file_path: + raise ValueError("File path cannot be empty. Please provide a valid file path") + + if module_type == _ModuleType.nn: + raise ValueError( + "Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram." + ) + elif module_type == _ModuleType.ts: + if output_format == "exported_program": + raise ValueError( + "Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format" + ) + else: + torch.jit.save(module, file_path) + elif module_type == _ModuleType.ep: + if output_format == "torchscript": + raise ValueError( + "Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format" + ) + else: + torch.export.save(module, file_path) + elif module_type == _ModuleType.fx: + if inputs is None: + raise ValueError( + "Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model" + ) + # The module type is torch.fx.GraphModule + if output_format == "torchscript": + module_ts = torch.jit.trace(module, inputs) + torch.jit.save(module_ts, file_path) + else: + if not retrace: + from torch_tensorrt.dynamo._exporter import export + + exp_program = export(module, inputs) + torch.export.save(exp_program, file_path) + else: + from torch._higher_order_ops.torchbind import enable_torchbind_tracing + + with enable_torchbind_tracing(): + exp_program = torch.export.export( + module, tuple(inputs), strict=False + ) + torch.export.save(exp_program, file_path) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6312532f1c..b321eabcb2 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -30,7 +30,6 @@ MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, - OUTPUT_FORMAT, PASS_THROUGH_BUILD_FAILURES, PRECISION, REFIT, @@ -48,7 +47,6 @@ dryrun_stats_display, parse_non_trt_nodes, ) -from torch_tensorrt.dynamo._exporter import export from torch_tensorrt.dynamo.conversion import ( CompilationSettings, UnsupportedOperatorException, @@ -102,9 +100,8 @@ def compile( enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, dryrun: bool = DRYRUN, hardware_compatible: bool = HARDWARE_COMPATIBLE, - output_format: str = OUTPUT_FORMAT, **kwargs: Any, -) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]: +) -> torch.fx.GraphModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT Takes a existing TorchScript module and a set of settings to configure the compiler @@ -246,14 +243,12 @@ def compile( "dla_global_dram_size": dla_global_dram_size, "dryrun": dryrun, "hardware_compatible": hardware_compatible, - "output_format": output_format, } settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) trt_gm = compile_module(gm, inputs, settings) - trt_result = export(trt_gm, torch_inputs, output_format) - return trt_result + return trt_gm def compile_module( diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index ec038c0dba..3d48ab3def 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,6 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -OUTPUT_FORMAT = "exported_program" def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_exporter.py b/py/torch_tensorrt/dynamo/_exporter.py index c7e2f37795..e9d166a1cc 100644 --- a/py/torch_tensorrt/dynamo/_exporter.py +++ b/py/torch_tensorrt/dynamo/_exporter.py @@ -1,3 +1,4 @@ +import copy import operator from typing import Any, Dict, Sequence, Tuple, cast @@ -6,8 +7,11 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.export import ExportedProgram, ExportGraphSignature from torch.export.exported_program import ( + CustomObjArgument, InputKind, InputSpec, + ModuleCallEntry, + ModuleCallSignature, OutputKind, OutputSpec, TensorArgument, @@ -18,27 +22,16 @@ def export( gm: torch.fx.GraphModule, inputs: Sequence[torch.Tensor], - output_format: str, ) -> ExportedProgram: """Export the result of TensorRT compilation into the desired output format. Arguments: gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile`` inputs (torch.Tensor): Torch input tensors - output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program" """ - if output_format == "torchscript" or output_format == "ts": - return torch.jit.trace(gm, inputs) - elif output_format == "exported_program" or output_format == "ep": - patched_module = transform(gm, inputs) - exp_program = create_trt_exp_program(patched_module) - return exp_program - elif output_format == "graph_module" or output_format == "fx": - return gm - else: - raise ValueError( - f"Invalid output format {output_format} specified. Supported options include exported_program (or) ep | torchscript (or) ts | graph_module (or) fx" - ) + patched_module = transform(gm, inputs) + exp_program = create_trt_exp_program(patched_module) + return exp_program def transform( @@ -55,6 +48,10 @@ def transform( Returns an inlined torch.fx.GraphModule """ + # Make a copy the graph since this function transforms the input graph and changes it's attributes. + # This transformed graph is meant to be consumed by `create_trt_exp_program` + gm = copy.deepcopy(gm) + # Run shape analysis _, outputs_map = partitioning.run_shape_analysis(gm, inputs) @@ -72,7 +69,9 @@ def transform( return gm -def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule: +def lift( + gm: torch.fx.GraphModule, graph_signature: Any +) -> Tuple[torch.fx.GraphModule, ExportGraphSignature, Dict[str, Any], Dict[str, Any]]: """ Given an unlifted fx.GraphModule, lift all parameters, buffers into placeholders. Arguments: @@ -86,6 +85,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule # exp_program.state_dict contains parameters and buffers whereas a graph_module's state_dict # has all parameters registered as torch.tensors. state_dict = gm.state_dict() + constants = {} fake_mode = detect_fake_mode( tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder") @@ -100,52 +100,69 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule break # At first the user_inputs are only present in the graph_signature.input_specs and hence non_user_input_idx=0 - # The input_specs should be of the form [params, buffers, constant_tensors, user_inputs] + # The input_specs should be of the form [params, buffers, constant_tensors, custom_obj, user_inputs] non_user_input_idx = 0 for node in gm.graph.nodes: if node.op == "get_attr": - if node.target not in state_dict: - raise ValueError( - f"The get_attr node : {node.name} with target: {node.target} value could not be found in state_dict. Please check the input exported_program's graphmodule parameters." - ) - constant_tensor = state_dict[node.target] - input_kind = InputKind.CONSTANT_TENSOR + lift_val = None + input_kind = None - # state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively. - for name, _ in gm.named_parameters(): - if node.target == name: - input_kind = InputKind.PARAMETER - state_dict[name] = torch.nn.Parameter(state_dict[name]) - break - for name, _ in gm.named_buffers(): - if node.target == name: - input_kind = InputKind.BUFFER - break + if node.target not in state_dict: + constants[node.target] = getattr(gm, node.target) + input_kind = InputKind.CUSTOM_OBJ + lift_val = constants[node.target] + else: + lift_val = state_dict[node.target] + + input_kind = InputKind.CONSTANT_TENSOR + + # state_dict has these parameters/buffers as torch.Tensors. We override them as torch.nn.Parameter/torch.Tensors respectively. + for name, _ in gm.named_parameters(): + if node.target == name: + input_kind = InputKind.PARAMETER + state_dict[name] = torch.nn.Parameter(state_dict[name]) + break + for name, _ in gm.named_buffers(): + if node.target == name: + input_kind = InputKind.BUFFER + break + + assert lift_val is not None and input_kind is not None # Replace get_attr nodes with placeholder nodes and copy metadata. with gm.graph.inserting_before(first_user_input): - const_placeholder_node = gm.graph.placeholder(node.target) + # Ensure name doesn't contain period as it is used for submodules + const_placeholder_node = gm.graph.placeholder( + node.target.replace(".", "_") + ) # Copy the node meta into this new placeholder node const_placeholder_node.meta = node.meta - const_placeholder_node.meta["val"] = cast( - FakeTensor, - torch.empty_strided( - tuple(constant_tensor.shape), - tuple([1] * len(constant_tensor.shape)), - ), - ) + + if isinstance(lift_val, torch.Tensor): + const_placeholder_node.meta["val"] = cast( + FakeTensor, + torch.empty_strided( + tuple(lift_val.shape), + tuple([1] * len(lift_val.shape)), + ), + ) node.replace_all_uses_with(const_placeholder_node) gm.graph.erase_node(node) # Add these parameters/buffers/constants to the existing graph signature # before user inputs. These specs are looked up in the state_dict during ExportedProgram creation. + input_spec_arg = TensorArgument(name=const_placeholder_node.name) + if input_kind == InputKind.CUSTOM_OBJ: + input_spec_arg = CustomObjArgument( + name=const_placeholder_node.name, class_fqn="" + ) graph_signature.input_specs.insert( non_user_input_idx, InputSpec( kind=input_kind, - arg=TensorArgument(name=const_placeholder_node.name), + arg=input_spec_arg, target=node.target, ), ) @@ -154,7 +171,7 @@ def lift(gm: torch.fx.GraphModule, graph_signature: Any) -> torch.fx.GraphModule gm.graph.eliminate_dead_code() gm.graph.lint() - return gm, graph_signature, state_dict + return gm, graph_signature, state_dict, constants def get_duplicate_nodes( @@ -292,18 +309,30 @@ def create_trt_exp_program( input_specs=input_specs, output_specs=output_specs ) + module_call_graph = [ + ModuleCallEntry( + "", + ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=gm.graph._codegen.pytree_info.in_spec, + out_spec=gm.graph._codegen.pytree_info.out_spec, + ), + ) + ] + # Lift parameters/buffers/constants in the graph # torch.export serialization expects them to be lifted - gm, trt_graph_signature, state_dict = lift(gm, trt_graph_signature) + gm, trt_graph_signature, state_dict, constants = lift(gm, trt_graph_signature) trt_exp_program = ExportedProgram( - gm, - gm.graph, - trt_graph_signature, - state_dict, - {}, - [], - [], + root=gm, + graph=gm.graph, + graph_signature=trt_graph_signature, + state_dict=state_dict, + range_constraints={}, + module_call_graph=module_call_graph, + constants=constants, ) return trt_exp_program @@ -330,9 +359,13 @@ def inline_trt_modules( num_outputs = len(outputs_map[trt_module_node.name]) # Insert a call_function node to perform inference on TRT engine with gm.graph.inserting_before(trt_module_node): + engine_name = f"{name}_engine" + setattr(gm, engine_name, trt_module.engine) + engine_node = gm.graph.get_attr(engine_name) + trt_node = gm.graph.call_function( torch.ops.tensorrt.execute_engine.default, - (trt_module_node.args, trt_module.engine), + (trt_module_node.args, engine_node), ) trt_node.meta["val"] = [] assert num_outputs > 0 @@ -348,6 +381,13 @@ def inline_trt_modules( ) ) + # meta["val"] should be a lighter version of a tensor. For eg: it should be a FakeTensor (with output shape and dtype properties) + # Lighter version of a custom_obj is not defined clearly. meta["val"] does not have any type expectations but + # for custom object nodes, it should be CustomObjArgument + engine_node.meta["val"] = CustomObjArgument( + name=engine_node.name, class_fqn="" + ) + if num_outputs == 1: # Insert getitem nodes as outputs (for export serialization to work) with gm.graph.inserting_after(trt_node): diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c00b049f45..2420a227d8 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -19,7 +19,6 @@ MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, - OUTPUT_FORMAT, PASS_THROUGH_BUILD_FAILURES, PRECISION, REFIT, @@ -71,7 +70,6 @@ class CompilationSettings: TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the ouptut to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) - output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program" """ precision: torch.dtype = PRECISION @@ -99,4 +97,3 @@ class CompilationSettings: dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE - output_format: str = OUTPUT_FORMAT diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index bc75a8aa3d..b7c895ec11 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -1,7 +1,6 @@ import torch -from torch.testing._internal.common_utils import TestCase, run_tests - import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -444,6 +443,7 @@ def forward(self, input, weight, bias): max_diff = float( torch.max(torch.abs(optimized_model_results - torch_model_results)) ) + self.assertAlmostEqual( max_diff, 0, diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index efa593890e..1a54a28e60 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -42,18 +42,17 @@ def forward(self, x): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) - torch.export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch.export.load("/tmp/trt.ep") - + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input]) + deser_trt_module = torchtrt.load("/tmp/trt.ep").module() # Check Pyt and TRT exported program outputs - cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0]) + cos_sim = cosine_similarity(model(input), trt_module(input)[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) # Check Pyt and deserialized TRT exported program outputs - cos_sim = cosine_similarity(model(input), deser_trt_exp_program(input)[0]) + cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", @@ -93,12 +92,12 @@ def forward(self, x): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) - torch.export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch.export.load("/tmp/trt.ep") + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input]) + deser_trt_module = torchtrt.load("/tmp/trt.ep").module() # Check Pyt and TRT exported program outputs outputs_pyt = model(input) - outputs_trt = trt_exp_program(input) + outputs_trt = trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) assertions.assertTrue( @@ -107,7 +106,7 @@ def forward(self, x): ) # Check Pyt and deserialized TRT exported program outputs - outputs_trt_deser = deser_trt_exp_program(input) + outputs_trt_deser = deser_trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( @@ -145,16 +144,15 @@ def forward(self, x): ) ], "ir": ir, - "debug": True, } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) - torch.export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch.export.load("/tmp/trt.ep") + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input]) + deser_trt_module = torchtrt.load("/tmp/trt.ep").module() # Check Pyt and TRT exported program outputs outputs_pyt = model(input) - outputs_trt = trt_exp_program(input) + outputs_trt = trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) assertions.assertTrue( @@ -163,7 +161,7 @@ def forward(self, x): ) # Check Pyt and deserialized TRT exported program outputs - outputs_trt_deser = deser_trt_exp_program(input) + outputs_trt_deser = deser_trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( @@ -207,12 +205,11 @@ def forward(self, x): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) - torch.export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch.export.load("/tmp/trt.ep") - + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input]) + deser_trt_module = torchtrt.load("/tmp/trt.ep").module() outputs_pyt = model(input) - outputs_trt = trt_exp_program(input) + outputs_trt = trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) assertions.assertTrue( @@ -220,7 +217,7 @@ def forward(self, x): msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - outputs_trt_deser = deser_trt_exp_program(input) + outputs_trt_deser = deser_trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( @@ -248,19 +245,18 @@ def test_resnet18(ir): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) - torch.export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch.export.load("/tmp/trt.ep") - + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input]) + deser_trt_module = torchtrt.load("/tmp/trt.ep").module() outputs_pyt = model(input) - outputs_trt = trt_exp_program(input) + outputs_trt = trt_module(input) cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - outputs_trt_deser = deser_trt_exp_program(input) + outputs_trt_deser = deser_trt_module(input) cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) assertions.assertTrue( @@ -303,12 +299,12 @@ def forward(self, x): } exp_program = torchtrt.dynamo.trace(model, **compile_spec) - trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec) - torch.export.save(trt_exp_program, "/tmp/trt.ep") - deser_trt_exp_program = torch.export.load("/tmp/trt.ep") + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input]) + deser_trt_module = torchtrt.load("/tmp/trt.ep").module() outputs_pyt = model(input) - outputs_trt = trt_exp_program(input) + outputs_trt = trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx]) @@ -317,7 +313,7 @@ def forward(self, x): msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - outputs_trt_deser = deser_trt_exp_program(input) + outputs_trt_deser = deser_trt_module(input) for idx in range(len(outputs_pyt)): cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) assertions.assertTrue( diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index fd7b40592a..84f6bf7a36 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -110,11 +110,6 @@ def test_bert_base_uncased(ir): model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - model = ( - transformers_trace(model, input_names=["input_ids", "attention_mask"]) - .eval() - .cuda() - ) compile_spec = { "inputs": [ @@ -133,7 +128,7 @@ def test_bert_base_uncased(ir): "enabled_precisions": {torch.float}, "truncate_long_and_double": True, "ir": ir, - "min_block_size": 10, + "min_block_size": 15, } trt_mod = torchtrt.compile(model, **compile_spec) model_outputs = model(input, input2) @@ -142,58 +137,9 @@ def test_bert_base_uncased(ir): len(model_outputs) == len(trt_model_outputs), msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.", ) - for index, key in enumerate(model_outputs): - out, trt_out = model_outputs[key], trt_model_outputs[index] - cos_sim = cosine_similarity(out, trt_out) - assertions.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - - # Clean up model env - torch._dynamo.reset() - - -@pytest.mark.unit -def test_bert_base_uncased(ir): - model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() - input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - model = ( - transformers_trace(model, input_names=["input_ids", "attention_mask"]) - .eval() - .cuda() - ) - compile_spec = { - "inputs": [ - torchtrt.Input( - input.shape, - dtype=input.dtype, - format=torch.contiguous_format, - ), - torchtrt.Input( - input.shape, - dtype=input.dtype, - format=torch.contiguous_format, - ), - ], - "device": torchtrt.Device("cuda:0"), - "enabled_precisions": {torch.float}, - "truncate_long_and_double": True, - "ir": ir, - "min_block_size": 10, - "torch_executed_ops": {"torch.ops.aten.gelu.default"}, - } - trt_mod = torchtrt.compile(model, **compile_spec) - model_outputs = model(input, input2) - trt_model_outputs = trt_mod(input, input2) - assertions.assertTrue( - len(model_outputs) == len(trt_model_outputs), - msg=f"Number of outputs for BERT model compilation is different with Pytorch {len(model_outputs)} and TensorRT {len(trt_model_outputs)}. Please check the compilation.", - ) - for index, key in enumerate(model_outputs): - out, trt_out = model_outputs[key], trt_model_outputs[index] + for key, _ in model_outputs.items(): + out, trt_out = model_outputs[key], trt_model_outputs[key] cos_sim = cosine_similarity(out, trt_out) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, @@ -203,9 +149,6 @@ def test_bert_base_uncased(ir): # Clean up model env torch._dynamo.reset() - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_resnet18_half(ir): diff --git a/tests/py/dynamo/models/test_output_format.py b/tests/py/dynamo/models/test_output_format.py deleted file mode 100644 index 3d2e747ceb..0000000000 --- a/tests/py/dynamo/models/test_output_format.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -import pytest -import timm -import torch -import torch_tensorrt as torchtrt -import torchvision.models as models -from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity - -assertions = unittest.TestCase() - - -@pytest.mark.unit -def test_output_format(ir): - """ - This tests output_format type in the compilation setting - """ - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) - self.relu = torch.nn.ReLU() - - def forward(self, x): - conv = self.conv(x) - relu = self.relu(conv) - mul = relu * 0.5 - return mul - - model = MyModule().eval().cuda() - input = torch.randn((1, 3, 224, 224)).to("cuda") - - trt_ep = torchtrt.compile(model, ir="dynamo", inputs=[input], min_block_size=1) - assertions.assertTrue( - isinstance(trt_ep, torch.export.ExportedProgram), - msg=f"test_output_format output type does not match with torch.export.ExportedProgram", - ) - - trt_ts = torchtrt.compile( - model, - ir="dynamo", - inputs=[input], - min_block_size=1, - output_format="torchscript", - ) - assertions.assertTrue( - isinstance(trt_ts, torch.jit.ScriptModule), - msg=f"test_output_format output type does not match with torch.jit.ScriptModule", - ) - - trt_gm = torchtrt.compile( - model, - ir="dynamo", - inputs=[input], - min_block_size=1, - output_format="graph_module", - ) - assertions.assertTrue( - isinstance(trt_gm, torch.fx.GraphModule), - msg=f"test_output_format output type does not match with torch.fx.GraphModule", - ) diff --git a/tests/py/dynamo/models/test_save_load.py b/tests/py/dynamo/models/test_save_load.py new file mode 100644 index 0000000000..f1cb72684b --- /dev/null +++ b/tests/py/dynamo/models/test_save_load.py @@ -0,0 +1,50 @@ +import unittest + +import pytest +import timm +import torch +import torch_tensorrt as torchtrt +import torchvision.models as models +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +@pytest.mark.unit +def test_save_load_ts(ir): + """ + This tests save/load API on Torchscript format (model still compiled using dynamo workflow) + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + trt_gm = torchtrt.compile(model, ir=ir, inputs=[input], min_block_size=1) + assertions.assertTrue( + isinstance(trt_gm, torch.fx.GraphModule), + msg=f"test_save_load_ts output type does not match with torch.fx.GraphModule", + ) + outputs_trt = trt_gm(input) + # Save it as torchscript representation + torchtrt.save(trt_gm, "/tmp/trt.ts", output_format="torchscript", inputs=[input]) + + trt_ts_module = torchtrt.load("/tmp/trt.ts") + outputs_trt_deser = trt_ts_module(input) + + cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + )