Skip to content

Commit

Permalink
refactor: Use HF to download the LoRA (#3089)
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan authored and cehongwang committed Aug 15, 2024
1 parent 83d7276 commit 86891e8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
device = "cuda:0"

prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, skin spots, acnes, skin blemishes, age spot, glans, (watermark:2),"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"

pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
Expand All @@ -101,7 +101,11 @@
image.save("./without_LoRA_mutable.jpg")

# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights("./moxin.safetensors", adapter_name="lora1")
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="moxin.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
self.run_info = (args, kwargs, result)
return result

def to(self, device: str):
def to(self, device: str) -> None:
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
self.orignial_model.to(device)

def __deepcopy__(self, memo: Any) -> Any:
cls = self.__class__
result = cls.__new__(cls)
Expand Down

0 comments on commit 86891e8

Please sign in to comment.