Skip to content

Commit

Permalink
[Image generation] Fixed SDXL with LCM's Unet (openvinotoolkit#1210)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Nov 20, 2024
1 parent f02e272 commit 074736b
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 46 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lcm_dreamshaper_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ jobs:
run: |
source openvino_lcm_cpp/bin/activate
source ./ov/setupvars.sh
python ./samples/python/text2image/main.py ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
python ./samples/python/text2image/heterogeneous_stable_diffusion.py ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
env:
PYTHONPATH: ${{ env.build_dir }}

Expand Down Expand Up @@ -137,7 +137,7 @@ jobs:
. "./openvino_lcm_cpp/Scripts/Activate.ps1"
. "${{ env.OV_INSTALL_DIR }}/setupvars.ps1"
$env:Path += "${{ env.build_dir }}\openvino_genai"
python .\samples\python\text2image\main.py .\models\lcm_dreamshaper_v7\FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
python .\samples\python\text2image\heterogeneous_stable_diffusion.py .\models\lcm_dreamshaper_v7\FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting"
env:
PYTHONPATH: ${{ env.build_dir }}

Expand Down
4 changes: 2 additions & 2 deletions samples/python/text2image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ With adapter | Without adapter
![](./lora.bmp) | ![](./baseline.bmp)


# Fuse LoRA adapters into model weights
## Fuse LoRA adapters into model weights

To maximize inference performance using a LoRA adapter, refer to `lora_fuse.py`, which demonstrates fusing the adapter into the model weights. This approach achieves the same performance as the base model without a LoRA adapter but removes the flexibility to switch adapters between generate calls. This mode is ideal when performing multiple generations with the same LoRA adapters and blending alpha parameters, and when model recompilation on adapter changes is feasible. The example outputs the resulting image as `lora.bmp`.
To maximize inference performance using a LoRA adapter, refer to `lora_fuse.py`, which demonstrates fusing the adapter into the model weights. This approach achieves the same performance as the base model without a LoRA adapter but removes the flexibility to switch adapters between generate calls. This mode is ideal when performing multiple generations with the same LoRA adapters and blending alpha parameters, and when model recompilation on adapter changes is feasible. The example outputs the resulting image as `lora.bmp`.
21 changes: 21 additions & 0 deletions src/cpp/src/image_generation/diffusion_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ const std::string get_class_name(const std::filesystem::path& root_dir) {
return data["_class_name"].get<std::string>();
}

ov::Tensor get_guidance_scale_embedding(float guidance_scale, uint32_t embedding_dim) {
float w = guidance_scale * 1000;
uint32_t half_dim = embedding_dim / 2;
float emb = std::log(10000) / (half_dim - 1);

ov::Shape embedding_shape = {1, embedding_dim};
ov::Tensor w_embedding(ov::element::f32, embedding_shape);
float* w_embedding_data = w_embedding.data<float>();

for (size_t i = 0; i < half_dim; ++i) {
float temp = std::exp((i * (-emb))) * w;
w_embedding_data[i] = std::sin(temp);
w_embedding_data[i + half_dim] = std::cos(temp);
}

if (embedding_dim % 2 == 1)
w_embedding_data[embedding_dim - 1] = 0;

return w_embedding;
}

} // namespace


Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/image_generation/schedulers/lcm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void LCMScheduler::set_timesteps(size_t num_inference_steps, float strength) {
assert(skipping_step >= 1 && "The combination of `original_steps x strength` is smaller than `num_inference_steps`");

// LCM Inference Steps Schedule
std::reverse(lcm_origin_timesteps.begin(),lcm_origin_timesteps.end());
std::reverse(lcm_origin_timesteps.begin(), lcm_origin_timesteps.end());

using numpy_utils::linspace;
// v1. based on https://github.com/huggingface/diffusers/blame/2a7f43a73bda387385a47a15d7b6fe9be9c65eb2/src/diffusers/schedulers/scheduling_lcm.py#L387
Expand Down
39 changes: 5 additions & 34 deletions src/cpp/src/image_generation/stable_diffusion_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,6 @@
namespace ov {
namespace genai {

namespace {

ov::Tensor get_guidance_scale_embedding(float guidance_scale, uint32_t embedding_dim) {
float w = guidance_scale * 1000;
uint32_t half_dim = embedding_dim / 2;
float emb = std::log(10000) / (half_dim - 1);

ov::Shape embedding_shape = {1, embedding_dim};
ov::Tensor w_embedding(ov::element::f32, embedding_shape);
float* w_embedding_data = w_embedding.data<float>();

for (size_t i = 0; i < half_dim; ++i) {
float temp = std::exp((i * (-emb))) * w;
w_embedding_data[i] = std::sin(temp);
w_embedding_data[i + half_dim] = std::cos(temp);
}

if (embedding_dim % 2 == 1)
w_embedding_data[embedding_dim - 1] = 0;

return w_embedding;
}

} // namespace

class StableDiffusionPipeline : public DiffusionPipeline {
public:
StableDiffusionPipeline(PipelineType pipeline_type, const std::filesystem::path& root_dir) :
Expand Down Expand Up @@ -148,7 +123,7 @@ class StableDiffusionPipeline : public DiffusionPipeline {
void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override {
check_image_size(height, width);

const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
m_clip_text_encoder->reshape(batch_size_multiplier);
m_unet->reshape(num_images_per_prompt * batch_size_multiplier, height, width, m_clip_text_encoder->get_config().max_position_embeddings);
m_vae->reshape(num_images_per_prompt, height, width);
Expand Down Expand Up @@ -203,7 +178,7 @@ class StableDiffusionPipeline : public DiffusionPipeline {
// see https://huggingface.co/docs/diffusers/using-diffusers/write_own_pipeline#deconstruct-the-stable-diffusion-pipeline

const auto& unet_config = m_unet->get_config();
const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
const size_t vae_scale_factor = m_vae->get_vae_scale_factor();

if (generation_config.height < 0)
Expand Down Expand Up @@ -245,8 +220,8 @@ class StableDiffusionPipeline : public DiffusionPipeline {
}

if (unet_config.time_cond_proj_dim >= 0) { // LCM
ov::Tensor guidance_scale_embedding = get_guidance_scale_embedding(generation_config.guidance_scale, unet_config.time_cond_proj_dim);
m_unet->set_hidden_states("timestep_cond", guidance_scale_embedding);
ov::Tensor timestep_cond = get_guidance_scale_embedding(generation_config.guidance_scale - 1.0f, unet_config.time_cond_proj_dim);
m_unet->set_hidden_states("timestep_cond", timestep_cond);
}

m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength);
Expand Down Expand Up @@ -304,10 +279,6 @@ class StableDiffusionPipeline : public DiffusionPipeline {
}

private:
bool do_classifier_free_guidance(float guidance_scale) const {
return m_unet->do_classifier_free_guidance(guidance_scale);
}

void initialize_generation_config(const std::string& class_name) override {
assert(m_unet != nullptr);
assert(m_vae != nullptr);
Expand Down Expand Up @@ -341,7 +312,7 @@ class StableDiffusionPipeline : public DiffusionPipeline {
void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override {
check_image_size(generation_config.width, generation_config.height);

const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale);
const bool is_classifier_free_guidance = m_unet->do_classifier_free_guidance(generation_config.guidance_scale);
const bool is_lcm = m_unet->get_config().time_cond_proj_dim > 0;
const char * const pipeline_name = is_lcm ? "Latent Consistency Model" : "Stable Diffusion";

Expand Down
15 changes: 8 additions & 7 deletions src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline {
void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override {
check_image_size(height, width);

const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
m_clip_text_encoder->reshape(batch_size_multiplier);
m_clip_text_encoder_with_projection->reshape(batch_size_multiplier);
m_unet->reshape(num_images_per_prompt * batch_size_multiplier, height, width, m_clip_text_encoder->get_config().max_position_embeddings);
Expand Down Expand Up @@ -201,7 +201,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline {
// see https://huggingface.co/docs/diffusers/using-diffusers/write_own_pipeline#deconstruct-the-stable-diffusion-pipeline

const auto& unet_config = m_unet->get_config();
const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG
const size_t vae_scale_factor = m_vae->get_vae_scale_factor();

if (generation_config.height < 0)
Expand Down Expand Up @@ -376,6 +376,11 @@ class StableDiffusionXLPipeline : public DiffusionPipeline {
m_unet->set_hidden_states("time_ids", add_time_ids_repeated);
}

if (unet_config.time_cond_proj_dim >= 0) { // LCM
ov::Tensor timestep_cond = get_guidance_scale_embedding(generation_config.guidance_scale - 1.0f, unet_config.time_cond_proj_dim);
m_unet->set_hidden_states("timestep_cond", timestep_cond);
}

m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength);
std::vector<std::int64_t> timesteps = m_scheduler->get_timesteps();

Expand Down Expand Up @@ -430,10 +435,6 @@ class StableDiffusionXLPipeline : public DiffusionPipeline {
}

private:
bool do_classifier_free_guidance(float guidance_scale) const {
return guidance_scale > 1.0f && m_unet->get_config().time_cond_proj_dim < 0;
}

void initialize_generation_config(const std::string& class_name) override {
assert(m_unet != nullptr);
assert(m_vae != nullptr);
Expand Down Expand Up @@ -463,7 +464,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline {
void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override {
check_image_size(generation_config.width, generation_config.height);

const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale);
const bool is_classifier_free_guidance = m_unet->do_classifier_free_guidance(generation_config.guidance_scale);
const char * const pipeline_name = "Stable Diffusion XL";

OPENVINO_ASSERT(generation_config.prompt_3 == std::nullopt, "Prompt 3 is not used by ", pipeline_name);
Expand Down

0 comments on commit 074736b

Please sign in to comment.