Skip to content

Commit

Permalink
parent f39e89e
Browse files Browse the repository at this point in the history
author Dheeraj Peri <[email protected]> 1711393059 -0700
committer Dheeraj Peri <[email protected]> 1711393072 -0700

chore: minor updates

chore: Fix save failures

chore: minor fixes

chore: remove duplicate bert test case

chore: remove comments

chore: add load api

chore: minor updates

chore: minor updates
  • Loading branch information
peri044 committed Mar 25, 2024
1 parent f39e89e commit ff95381
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 261 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ jobs:
cd tests/py/dynamo
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/save_load_test_results.xml --ir dynamo models/test_save_load.py
popd
tests-py-torch-compile-be:
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const {
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
<< std::endl;
}
ss << " }" << std::endl;
ss << " ]" << std::endl;
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
// clang-format on
Expand Down
71 changes: 35 additions & 36 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,22 @@ Saving models compiled with Torch-TensorRT
:undoc-members:
:show-inheritance:

Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.save` API.

Dynamo IR
-------------

The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
The `output_format` can take the following options
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
specifying the `output_format` flag. Here are the options `output_format` will accept

* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
* `torchscript` (or) `ts` : This returns a TorchScript module
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.

a) Torchscript
a) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
Here's an example usage

.. code-block:: python
Expand All @@ -34,50 +33,32 @@ If you set the `output_format="torchscript"`, this will return a `ScriptModule`
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ts is a torch.jit.ScriptModule object
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
torch.jit.save(trt_ts, "trt_model.ts")
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torchtrt.save(trt_gm, "trt.ep", inputs=inputs)
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model = torch.export.load("trt.ep").module()
model(*inputs)
b) ExportedProgram
b) Torchscript
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.

.. code-block:: python
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.export.ExportedProgram object
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
torch.export.save(trt_ep, "trt_model.ep")
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
torchtrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
# Later, you can load it and run inference
model = torch.export.load("trt_model.ep")
model = torch.jit.load("trt.ts").cuda()
model(*inputs)
c) GraphModule
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
exported into `ExportedProgram` objects

.. code-block:: python
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
Torchscript IR
-------------
Expand All @@ -99,3 +80,21 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)
Loading the models
--------------------

We can load torchscript or exported_program models using `torch.jit.load` and `torch.export.load` APIs from PyTorch directly.
Alternatively, we provide a light wrapper `torch_tensorrt.load(file_path)` which can load either of the above model types.

Here's an example usage

.. code-block:: python
import torch
import torch_tensorrt
# file_path can be trt.ep or trt.es file obtained via saving the model (refer to the above section)
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)
101 changes: 97 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.fx
import torch_tensorrt.dynamo
import torch_tensorrt.ts
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
Expand All @@ -26,10 +27,7 @@

logger = logging.getLogger(__name__)

__all__ = [
"compile",
"convert_method_to_trt_engine",
]
__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]


def _non_fx_input_interface(
Expand Down Expand Up @@ -332,3 +330,98 @@ def convert_method_to_trt_engine(
)
else:
raise RuntimeError("Module is an unknown format or the ir requested is unknown")


def load(file_path: str = "") -> Any:
"""
Load either a Torchscript model or ExportedProgram. Autodetect the type using
try, except
"""
try:
ts_module = torch.jit.load(file_path)
return ts_module
except Exception:
pass

try:
exp_program = torch.export.load(file_path)
return exp_program
except Exception:
raise ValueError(
"The file doesn't correspond to a Torchscript module or ExportedProgram. Please verify the file path."
)


def save(
module: Any,
file_path: str = "",
*,
output_format: str = "exported_program",
inputs: Optional[Sequence[torch.Tensor]] = None,
retrace: bool = False,
) -> None:
"""
Save the model to disk in the specified output format.
Arguments:
module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
inputs (torch.Tensor): Torch input tensors
output_format: Format to save the model. Options include exported_program | torchscript.
retrace: When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if inputs is not None and not all(
isinstance(input, torch.Tensor) for input in inputs
):
raise ValueError(
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
)
if output_format not in accepted_formats:
raise ValueError(
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
)
if not file_path:
raise ValueError("File path cannot be empty. Please provide a valid file path")

if module_type == _ModuleType.nn:
raise ValueError(
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
)
elif module_type == _ModuleType.ts:
if output_format == "exported_program":
raise ValueError(
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
)
else:
torch.jit.save(module, file_path)
elif module_type == _ModuleType.ep:
if output_format == "torchscript":
raise ValueError(
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
)
else:
torch.export.save(module, file_path)
elif module_type == _ModuleType.fx:
if inputs is None:
raise ValueError(
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
)
# The module type is torch.fx.GraphModule
if output_format == "torchscript":
module_ts = torch.jit.trace(module, inputs)
torch.jit.save(module_ts, file_path)
else:
if not retrace:
from torch_tensorrt.dynamo._exporter import export

exp_program = export(module, inputs)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing

with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(inputs), strict=False
)
torch.export.save(exp_program, file_path)
9 changes: 2 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
OUTPUT_FORMAT,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
Expand All @@ -48,7 +47,6 @@
dryrun_stats_display,
parse_non_trt_nodes,
)
from torch_tensorrt.dynamo._exporter import export
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
UnsupportedOperatorException,
Expand Down Expand Up @@ -102,9 +100,8 @@ def compile(
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = DRYRUN,
hardware_compatible: bool = HARDWARE_COMPATIBLE,
output_format: str = OUTPUT_FORMAT,
**kwargs: Any,
) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
Takes a existing TorchScript module and a set of settings to configure the compiler
Expand Down Expand Up @@ -246,14 +243,12 @@ def compile(
"dla_global_dram_size": dla_global_dram_size,
"dryrun": dryrun,
"hardware_compatible": hardware_compatible,
"output_format": output_format,
}

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)
trt_gm = compile_module(gm, inputs, settings)
trt_result = export(trt_gm, torch_inputs, output_format)
return trt_result
return trt_gm


def compile_module(
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
OUTPUT_FORMAT = "exported_program"


def default_device() -> Device:
Expand Down
Loading

0 comments on commit ff95381

Please sign in to comment.