Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented basic pipeline for Refitting #2886

Merged
merged 71 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
dc98f23
Implemented basic pipeline for Refitting
cehongwang Jun 4, 2024
74d458e
Organized code for refitting
cehongwang Jun 4, 2024
c47bef3
Renamed function
cehongwang Jun 4, 2024
869aaad
Supported multi-engine
cehongwang Jun 5, 2024
e4cb669
Support both TRTModules and return a new copy
cehongwang Jun 5, 2024
388dadc
Enabled module saving with settings
cehongwang Jun 5, 2024
f822b28
Enabled three types of runtime. Build an interface for user to easy r…
cehongwang Jun 5, 2024
56eb549
Reorganized the code
cehongwang Jun 5, 2024
94483a8
Added weight type check and number check
cehongwang Jun 6, 2024
4ba84b7
Deleted the save in compilation
cehongwang Jun 6, 2024
ee6f123
deleted more compilation save
cehongwang Jun 6, 2024
bc23ddb
Supported different dtypes. Support all possible layers. Support deep…
cehongwang Jun 7, 2024
501e5d9
Delete the outdated file
cehongwang Jun 7, 2024
bd5fb55
Deleted setting loading
cehongwang Jun 7, 2024
578927c
Fixed bugs when handling multiple engines. Tested with custom module …
cehongwang Jun 7, 2024
82cd252
Fixed dtype bugs
cehongwang Jun 10, 2024
e3cf823
Made a note to add INormalization Layer
cehongwang Jun 10, 2024
c3b0862
Update the unsupported torch module weights
cehongwang Jun 11, 2024
400bcac
Cleaned up the code. Added refitting outcome check
cehongwang Jun 12, 2024
6f08664
Use enums to change dtype from np to trt
cehongwang Jun 12, 2024
2250239
Moved check output to util and added a gated flag
cehongwang Jun 12, 2024
c906d0e
fixed a bug in check_output function. Changed to only check once afte…
cehongwang Jun 13, 2024
51cba6f
reverse the main function
cehongwang Jun 13, 2024
e3576fa
Added support to inline module w/ or w/o graph break
cehongwang Jun 13, 2024
cde8fe9
Added an extra attribute to TRT engine in cpp
cehongwang Jun 13, 2024
b575105
Added an attribute in TorchTRTModule in python.
cehongwang Jun 13, 2024
9923125
Fixed a type
cehongwang Jun 13, 2024
e25941e
Fixed a bug for inline_module refit
cehongwang Jun 14, 2024
646da9e
Added refit example documentation
cehongwang Jun 14, 2024
924f4a8
Added backward compatibility
cehongwang Jun 14, 2024
3c25a3a
Rename the setting enum
cehongwang Jun 14, 2024
0c9637d
Cleaned up cpp constructors
cehongwang Jun 14, 2024
bb5fdba
Fixed a type of setting storage checking
cehongwang Jun 14, 2024
e47bcb2
Renamed settings to metadata
cehongwang Jun 14, 2024
e6e71ca
Added refit to __init__ of dynamo
cehongwang Jun 14, 2024
cf43a79
Added docstring. Added support for dynamic shape
cehongwang Jun 14, 2024
cfeb6bf
Chagned the check_output function to return a boolean
cehongwang Jun 14, 2024
a092229
Chagned get_settings to a static method in TorchTensorRTModule
cehongwang Jun 14, 2024
4819a6d
Simplified the code
cehongwang Jun 14, 2024
bd77f22
Added three testcases
cehongwang Jun 14, 2024
1b3a769
Supported torch ops in settings
cehongwang Jun 14, 2024
1456ad9
Updated the example
cehongwang Jun 15, 2024
1acfe31
Wrote 6 test cases for refitting feature to cover different scenarios
cehongwang Jun 15, 2024
d38e422
Fixed a bug in tests
cehongwang Jun 15, 2024
880afde
Delete settings check
cehongwang Jun 15, 2024
2dc5bfa
Fixed a bug of modifing settings inplace
cehongwang Jun 15, 2024
410689c
added it to //docsrc/py_api/dynamo.rst so that it gets rendered in th…
cehongwang Jun 17, 2024
381f14a
Added reference to doc
cehongwang Jun 18, 2024
eebe883
Changed the default outputcheck to false
cehongwang Jun 18, 2024
2a3d567
Chagned the assertion
cehongwang Jun 18, 2024
003380a
Renamed the imported name
cehongwang Jun 18, 2024
323db97
Renamed
cehongwang Jun 18, 2024
de0ab94
Fixed a bug of serialized info signature
cehongwang Jun 18, 2024
de3da26
Changed the refit condition check
cehongwang Jun 18, 2024
91c6036
Changed the file path in test file
cehongwang Jun 18, 2024
bd43882
Fixed minor format
cehongwang Jun 19, 2024
5ef9af7
Deleted setting repetitions
cehongwang Jun 19, 2024
8882425
Changed min_block_size to 1
cehongwang Jun 19, 2024
7f1f958
Added comments
cehongwang Jun 19, 2024
b8e023d
Merged two if statements
cehongwang Jun 24, 2024
df9cd39
Chagned the weight type
cehongwang Jun 26, 2024
b33fa0f
Fixed hardcoded index
cehongwang Jun 26, 2024
911984d
Fixed a type causing extra overhead
cehongwang Jun 27, 2024
0a1c8ca
Added comments and repaced the index to enum
cehongwang Jun 27, 2024
257db26
Fixed inline module check
cehongwang Jun 27, 2024
fef6766
Added deprecate warning. Renamed refit flag to make_refitable
cehongwang Jun 28, 2024
d6dbdd4
Merge branch 'main' into refitter-support
cehongwang Jul 1, 2024
7381221
Updated lowering process to conform with latest main branch
cehongwang Jul 1, 2024
e7768f7
Handled default setting usecases
cehongwang Jul 1, 2024
51a03c9
Fixed circular import bugs
cehongwang Jul 1, 2024
33bde0f
Changed deprecated behavior
cehongwang Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
bool hardware_compatible)
bool hardware_compatible,
const std::string& serialized_metadata)
: TRTEngine(
"deserialized_trt",
serialized_engine,
cuda_device,
_in_binding_names,
_out_binding_names,
hardware_compatible) {}
hardware_compatible,
serialized_metadata) {}

TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
: TRTEngine(
Expand All @@ -49,17 +51,19 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
RTDevice(serialized_info[DEVICE_IDX]),
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {}
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}

TRTEngine::TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
bool hardware_compatible) {
bool hardware_compatible,
const std::string& serialized_metadata) {
this->hardware_compatible = hardware_compatible;

this->serialized_metadata = serialized_metadata;
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
device_info = most_compatible_device.value();
Expand Down
8 changes: 6 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,26 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX

bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
// in compilation

~TRTEngine();
TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
bool hardware_compatible = false);
bool hardware_compatible = false,
const std::string& serialized_metadata = "");
TRTEngine(std::vector<std::string> serialized_info);
TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
bool hardware_compatible = false);
bool hardware_compatible = false,
const std::string& serialized_metadata = "");
TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";

serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));

return serialize_info;
Expand Down
1 change: 1 addition & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ typedef enum {
INPUT_BINDING_NAMES_IDX,
OUTPUT_BINDING_NAMES_IDX,
HW_COMPATIBLE_IDX,
SERIALIZED_METADATA_IDX,
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

Expand Down
2 changes: 1 addition & 1 deletion docsrc/py_api/dynamo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Functions

.. autofunction:: convert_module_to_trt_engine


.. autofunction:: refit_module_weights

Classes
--------
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
98 changes: 98 additions & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
.. _refit_engine_example:

Refit TenorRT Graph Module with Torch-TensorRT
===================================================================

We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.

In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.

In this tutorial, we are going to walk through
1. Compiling a PyTorch model to a TensorRT Graph Module
2. Save and load a graph module
3. Refit the graph module
"""

# %%
# Standard Workflow
# -----------------------------

# %%
# Imports and model definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch_tensorrt.dynamo import refit_module_weights

np.random.seed(0)
torch.manual_seed(0)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]


# %%
# Compile the module for the first time and save it.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

model = models.resnet18(pretrained=False).eval().to("cuda")
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
debug = False
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
refit=True,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
# This is only supported when use_python_runtime = False
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs)


# %%
# Refit the module with update model weights
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Create and compile the updated model
model2 = models.resnet18(pretrained=True).eval().to("cuda")
exp_program2 = torch.export.export(model2, tuple(inputs))


compiled_trt_ep = torch_trt.load("./compiled.ep")

# This returns a new module with updated weights
new_trt_gm = refit_module_weights(
compiled_module=compiled_trt_ep,
new_weight_module=exp_program2,
inputs=inputs,
)

# Check the output
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
expected_output, refitted_output, 1e-2, 1e-2
), "Refit Result is not correct. Refit failed"

print("Refit successfully!")

# %%
cehongwang marked this conversation as resolved.
Show resolved Hide resolved
# Alterative Workflow using Python Runtime
# -----------------------------

# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ def convert_method_to_trt_engine(
torchtrt_inputs = prepare_inputs(inputs)
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)

return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
return dynamo_convert_module_to_trt_engine(
exp_program,
inputs=inputs,
inputs=tuple(inputs),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this needs to be a tuple ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature defines it as an Tuple

inputs: Tuple[Any, ...],

If I do not change it I am not allowed to commit.

enabled_precisions=enabled_precisions_set,
**kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._compiler import compile, convert_module_to_trt_engine
from ._exporter import export
from ._refit import refit_module_weights
from ._settings import CompilationSettings
from ._SourceIR import SourceIR
from ._tracer import trace
Loading
Loading