Skip to content

Commit

Permalink
add sdxl and lora-lcm optimization (#12444)
Browse files Browse the repository at this point in the history
* add sdxl and lora-lcm optimization

* fix openjourney speed drop
  • Loading branch information
JinheTang authored Nov 26, 2024
1 parent 0e23bd7 commit 66bd7ab
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

import torch
from diffusers import DiffusionPipeline, LCMScheduler
import ipex_llm
from ipex_llm import optimize_model
import argparse
import time


def main(args):
pipe = DiffusionPipeline.from_pretrained(
args.repo_id_or_model_path,
torch_dtype=torch.bfloat16,
).to("xpu")
torch_dtype=torch.float16,
)
pipe = optimize_model(pipe, low_bit=None)
pipe.to("xpu")

# set scheduler
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from diffusers import AutoPipelineForText2Image
import torch
import ipex_llm
from ipex_llm import optimize_model
import numpy as np
from PIL import Image
import argparse
Expand All @@ -27,9 +27,11 @@
def main(args):
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
args.repo_id_or_model_path,
torch_dtype=torch.bfloat16,
torch_dtype=torch.float16,
use_safetensors=True
).to("xpu")
)
pipeline_text2image = optimize_model(pipeline_text2image, low_bit=None)
pipeline_text2image.to("xpu")

with torch.inference_mode():
# warmup
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __call__(
# padding head_dim 40 to 64
query, key, value = padding_qkv_hd(query, key, value, 40, 64)

if use_sdp_non_causal(head_dim, query.device, query.dtype):
if use_sdp_non_causal(query.size(-1), query.device, query.dtype):
import xe_addons
hidden_states = xe_addons.sdp_non_causal(query, key.contiguous(),
value.contiguous(), attention_mask)
Expand Down

0 comments on commit 66bd7ab

Please sign in to comment.