Skip to content

Commit

Permalink
Rebased to main after refit acceleration is merged.
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Aug 15, 2024
1 parent bd685ae commit 83d7276
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
10 changes: 5 additions & 5 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torch_tensorrt as torch_trt
import torchvision.models as models

np.random.seed(0)
torch.manual_seed(0)
np.random.seed(5)
torch.manual_seed(5)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]

# %%
Expand Down Expand Up @@ -76,7 +76,7 @@
from diffusers import DiffusionPipeline

with torch.no_grad():
kwargs = {
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
Expand All @@ -86,7 +86,7 @@
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda:0"

prompt = "portrait of a woman standing, shuimobysim, wuchangshuo, best quality"
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, skin spots, acnes, skin blemishes, age spot, glans, (watermark:2),"

pipe = DiffusionPipeline.from_pretrained(
Expand All @@ -95,7 +95,7 @@
pipe.to(device)

# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **kwargs)
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)

image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./without_LoRA_mutable.jpg")
Expand Down
16 changes: 11 additions & 5 deletions py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class RefitFlag(Enum):


class RefitState:
_state: RefitFlag = RefitFlag.UNKNOWN
_state: RefitFlag = RefitFlag.NEEDS_RECOMPILE

def set_state(self, state: RefitFlag) -> None:
if isinstance(state, RefitFlag):
Expand Down Expand Up @@ -267,12 +267,14 @@ def refit_gm(self) -> None:
self.original_model.state_dict()
)
)
self.gm = refit_module_weights(self.gm, self.exp_program)
self.gm = refit_module_weights(
self.gm, self.exp_program, use_weight_map_cache=True, in_place=True
)

self.original_model.cpu()
torch.cuda.empty_cache()

def _compile(self) -> None:
def compile(self) -> None:
"""
(Re)compile the TRT graph module using the PyTorch module.
This function should be called whenever the weight structure get changed (shape, more layers...)
Expand Down Expand Up @@ -349,7 +351,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
# Step 3: Refit/recompile accordingly
if self.refit_state.get_state() == RefitFlag.NEEDS_RECOMPILE:
logger.info("(Re)Compiling the engine...")
self._compile()
self.compile()
self.store_state_dict_metadata()
self.refit_state.set_state(RefitFlag.LIVE)

Expand All @@ -360,7 +362,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
except Exception as e:
logger.error(e)
logger.error("Model refit failed. Recompiling the graph module.")
self._compile()
self.compile()
self.store_state_dict_metadata()
self.refit_state.set_state(RefitFlag.LIVE)

Expand All @@ -369,6 +371,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
self.run_info = (args, kwargs, result)
return result

def to(self, device: str):
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
self.orignial_model.to(device)

def __deepcopy__(self, memo: Any) -> Any:
cls = self.__class__
result = cls.__new__(cls)
Expand Down

0 comments on commit 83d7276

Please sign in to comment.