Skip to content

Commit

Permalink
Rebased to main after refit acceleration is merged.
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Aug 14, 2024
1 parent bd685ae commit 86793b2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
11 changes: 6 additions & 5 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torch_tensorrt as torch_trt
import torchvision.models as models

np.random.seed(0)
torch.manual_seed(0)
np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]

# %%
Expand Down Expand Up @@ -76,7 +76,7 @@
from diffusers import DiffusionPipeline

with torch.no_grad():
kwargs = {
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
Expand All @@ -86,7 +86,7 @@
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda:0"

prompt = "portrait of a woman standing, shuimobysim, wuchangshuo, best quality"
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, skin spots, acnes, skin blemishes, age spot, glans, (watermark:2),"

pipe = DiffusionPipeline.from_pretrained(
Expand All @@ -95,7 +95,7 @@
pipe.to(device)

# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **kwargs)
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./without_LoRA_mutable.jpg")
Expand All @@ -107,5 +107,6 @@
pipe.unload_lora_weights()

# Refit triggered
pipe.to(device)
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./with_LoRA_mutable.jpg")
12 changes: 7 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class RefitFlag(Enum):


class RefitState:
_state: RefitFlag = RefitFlag.UNKNOWN
_state: RefitFlag = RefitFlag.NEEDS_RECOMPILE

def set_state(self, state: RefitFlag) -> None:
if isinstance(state, RefitFlag):
Expand Down Expand Up @@ -267,12 +267,14 @@ def refit_gm(self) -> None:
self.original_model.state_dict()
)
)
self.gm = refit_module_weights(self.gm, self.exp_program)
self.gm = refit_module_weights(
self.gm, self.exp_program, use_weight_map_cache=True, in_place=True
)

self.original_model.cpu()
torch.cuda.empty_cache()

def _compile(self) -> None:
def compile(self) -> None:
"""
(Re)compile the TRT graph module using the PyTorch module.
This function should be called whenever the weight structure get changed (shape, more layers...)
Expand Down Expand Up @@ -349,7 +351,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
# Step 3: Refit/recompile accordingly
if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE:
logger.info("(Re)Compiling the engine...")
self._compile()
self.compile()
self.store_state_dict_metadata()
self.refit_state.set_state(RefitFlag.LIVE)

Expand All @@ -360,7 +362,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
except Exception as e:
logger.error(e)
logger.error("Model refit failed. Recompiling the graph module.")
self._compile()
self.compile()
self.store_state_dict_metadata()
self.refit_state.set_state(RefitFlag.LIVE)

Expand Down

0 comments on commit 86793b2

Please sign in to comment.