From 074736ba445e6cf84316a7f8016b8681f27d394a Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Wed, 13 Nov 2024 22:54:33 +0400 Subject: [PATCH] [Image generation] Fixed SDXL with LCM's Unet (#1210) --- .github/workflows/lcm_dreamshaper_cpp.yml | 4 +- samples/python/text2image/README.md | 4 +- .../image_generation/diffusion_pipeline.hpp | 21 ++++++++++ .../src/image_generation/schedulers/lcm.cpp | 2 +- .../stable_diffusion_pipeline.hpp | 39 +++---------------- .../stable_diffusion_xl_pipeline.hpp | 15 +++---- 6 files changed, 39 insertions(+), 46 deletions(-) diff --git a/.github/workflows/lcm_dreamshaper_cpp.yml b/.github/workflows/lcm_dreamshaper_cpp.yml index 9da36e49c1..f67a4e506e 100644 --- a/.github/workflows/lcm_dreamshaper_cpp.yml +++ b/.github/workflows/lcm_dreamshaper_cpp.yml @@ -76,7 +76,7 @@ jobs: run: | source openvino_lcm_cpp/bin/activate source ./ov/setupvars.sh - python ./samples/python/text2image/main.py ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting" + python ./samples/python/text2image/heterogeneous_stable_diffusion.py ./models/lcm_dreamshaper_v7/FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting" env: PYTHONPATH: ${{ env.build_dir }} @@ -137,7 +137,7 @@ jobs: . "./openvino_lcm_cpp/Scripts/Activate.ps1" . "${{ env.OV_INSTALL_DIR }}/setupvars.ps1" $env:Path += "${{ env.build_dir }}\openvino_genai" - python .\samples\python\text2image\main.py .\models\lcm_dreamshaper_v7\FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting" + python .\samples\python\text2image\heterogeneous_stable_diffusion.py .\models\lcm_dreamshaper_v7\FP16 "cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting" env: PYTHONPATH: ${{ env.build_dir }} diff --git a/samples/python/text2image/README.md b/samples/python/text2image/README.md index d8dc23d0fa..675d39d9a5 100644 --- a/samples/python/text2image/README.md +++ b/samples/python/text2image/README.md @@ -63,6 +63,6 @@ With adapter | Without adapter ![](./lora.bmp) | ![](./baseline.bmp) -# Fuse LoRA adapters into model weights +## Fuse LoRA adapters into model weights -To maximize inference performance using a LoRA adapter, refer to `lora_fuse.py`, which demonstrates fusing the adapter into the model weights. This approach achieves the same performance as the base model without a LoRA adapter but removes the flexibility to switch adapters between generate calls. This mode is ideal when performing multiple generations with the same LoRA adapters and blending alpha parameters, and when model recompilation on adapter changes is feasible. The example outputs the resulting image as `lora.bmp`. \ No newline at end of file +To maximize inference performance using a LoRA adapter, refer to `lora_fuse.py`, which demonstrates fusing the adapter into the model weights. This approach achieves the same performance as the base model without a LoRA adapter but removes the flexibility to switch adapters between generate calls. This mode is ideal when performing multiple generations with the same LoRA adapters and blending alpha parameters, and when model recompilation on adapter changes is feasible. The example outputs the resulting image as `lora.bmp`. diff --git a/src/cpp/src/image_generation/diffusion_pipeline.hpp b/src/cpp/src/image_generation/diffusion_pipeline.hpp index 2213f3711e..af7459a2da 100644 --- a/src/cpp/src/image_generation/diffusion_pipeline.hpp +++ b/src/cpp/src/image_generation/diffusion_pipeline.hpp @@ -36,6 +36,27 @@ const std::string get_class_name(const std::filesystem::path& root_dir) { return data["_class_name"].get(); } +ov::Tensor get_guidance_scale_embedding(float guidance_scale, uint32_t embedding_dim) { + float w = guidance_scale * 1000; + uint32_t half_dim = embedding_dim / 2; + float emb = std::log(10000) / (half_dim - 1); + + ov::Shape embedding_shape = {1, embedding_dim}; + ov::Tensor w_embedding(ov::element::f32, embedding_shape); + float* w_embedding_data = w_embedding.data(); + + for (size_t i = 0; i < half_dim; ++i) { + float temp = std::exp((i * (-emb))) * w; + w_embedding_data[i] = std::sin(temp); + w_embedding_data[i + half_dim] = std::cos(temp); + } + + if (embedding_dim % 2 == 1) + w_embedding_data[embedding_dim - 1] = 0; + + return w_embedding; +} + } // namespace diff --git a/src/cpp/src/image_generation/schedulers/lcm.cpp b/src/cpp/src/image_generation/schedulers/lcm.cpp index 7d22520639..d3afcd6300 100644 --- a/src/cpp/src/image_generation/schedulers/lcm.cpp +++ b/src/cpp/src/image_generation/schedulers/lcm.cpp @@ -99,7 +99,7 @@ void LCMScheduler::set_timesteps(size_t num_inference_steps, float strength) { assert(skipping_step >= 1 && "The combination of `original_steps x strength` is smaller than `num_inference_steps`"); // LCM Inference Steps Schedule - std::reverse(lcm_origin_timesteps.begin(),lcm_origin_timesteps.end()); + std::reverse(lcm_origin_timesteps.begin(), lcm_origin_timesteps.end()); using numpy_utils::linspace; // v1. based on https://github.com/huggingface/diffusers/blame/2a7f43a73bda387385a47a15d7b6fe9be9c65eb2/src/diffusers/schedulers/scheduling_lcm.py#L387 diff --git a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp index 92367d82a2..b8517a1476 100644 --- a/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_pipeline.hpp @@ -18,31 +18,6 @@ namespace ov { namespace genai { -namespace { - -ov::Tensor get_guidance_scale_embedding(float guidance_scale, uint32_t embedding_dim) { - float w = guidance_scale * 1000; - uint32_t half_dim = embedding_dim / 2; - float emb = std::log(10000) / (half_dim - 1); - - ov::Shape embedding_shape = {1, embedding_dim}; - ov::Tensor w_embedding(ov::element::f32, embedding_shape); - float* w_embedding_data = w_embedding.data(); - - for (size_t i = 0; i < half_dim; ++i) { - float temp = std::exp((i * (-emb))) * w; - w_embedding_data[i] = std::sin(temp); - w_embedding_data[i + half_dim] = std::cos(temp); - } - - if (embedding_dim % 2 == 1) - w_embedding_data[embedding_dim - 1] = 0; - - return w_embedding; -} - -} // namespace - class StableDiffusionPipeline : public DiffusionPipeline { public: StableDiffusionPipeline(PipelineType pipeline_type, const std::filesystem::path& root_dir) : @@ -148,7 +123,7 @@ class StableDiffusionPipeline : public DiffusionPipeline { void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override { check_image_size(height, width); - const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG m_clip_text_encoder->reshape(batch_size_multiplier); m_unet->reshape(num_images_per_prompt * batch_size_multiplier, height, width, m_clip_text_encoder->get_config().max_position_embeddings); m_vae->reshape(num_images_per_prompt, height, width); @@ -203,7 +178,7 @@ class StableDiffusionPipeline : public DiffusionPipeline { // see https://huggingface.co/docs/diffusers/using-diffusers/write_own_pipeline#deconstruct-the-stable-diffusion-pipeline const auto& unet_config = m_unet->get_config(); - const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG const size_t vae_scale_factor = m_vae->get_vae_scale_factor(); if (generation_config.height < 0) @@ -245,8 +220,8 @@ class StableDiffusionPipeline : public DiffusionPipeline { } if (unet_config.time_cond_proj_dim >= 0) { // LCM - ov::Tensor guidance_scale_embedding = get_guidance_scale_embedding(generation_config.guidance_scale, unet_config.time_cond_proj_dim); - m_unet->set_hidden_states("timestep_cond", guidance_scale_embedding); + ov::Tensor timestep_cond = get_guidance_scale_embedding(generation_config.guidance_scale - 1.0f, unet_config.time_cond_proj_dim); + m_unet->set_hidden_states("timestep_cond", timestep_cond); } m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength); @@ -304,10 +279,6 @@ class StableDiffusionPipeline : public DiffusionPipeline { } private: - bool do_classifier_free_guidance(float guidance_scale) const { - return m_unet->do_classifier_free_guidance(guidance_scale); - } - void initialize_generation_config(const std::string& class_name) override { assert(m_unet != nullptr); assert(m_vae != nullptr); @@ -341,7 +312,7 @@ class StableDiffusionPipeline : public DiffusionPipeline { void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override { check_image_size(generation_config.width, generation_config.height); - const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale); + const bool is_classifier_free_guidance = m_unet->do_classifier_free_guidance(generation_config.guidance_scale); const bool is_lcm = m_unet->get_config().time_cond_proj_dim > 0; const char * const pipeline_name = is_lcm ? "Latent Consistency Model" : "Stable Diffusion"; diff --git a/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp b/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp index 42ee49a19d..b709c58f47 100644 --- a/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp +++ b/src/cpp/src/image_generation/stable_diffusion_xl_pipeline.hpp @@ -152,7 +152,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { void reshape(const int num_images_per_prompt, const int height, const int width, const float guidance_scale) override { check_image_size(height, width); - const size_t batch_size_multiplier = do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG m_clip_text_encoder->reshape(batch_size_multiplier); m_clip_text_encoder_with_projection->reshape(batch_size_multiplier); m_unet->reshape(num_images_per_prompt * batch_size_multiplier, height, width, m_clip_text_encoder->get_config().max_position_embeddings); @@ -201,7 +201,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { // see https://huggingface.co/docs/diffusers/using-diffusers/write_own_pipeline#deconstruct-the-stable-diffusion-pipeline const auto& unet_config = m_unet->get_config(); - const size_t batch_size_multiplier = do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG + const size_t batch_size_multiplier = m_unet->do_classifier_free_guidance(generation_config.guidance_scale) ? 2 : 1; // Unet accepts 2x batch in case of CFG const size_t vae_scale_factor = m_vae->get_vae_scale_factor(); if (generation_config.height < 0) @@ -376,6 +376,11 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { m_unet->set_hidden_states("time_ids", add_time_ids_repeated); } + if (unet_config.time_cond_proj_dim >= 0) { // LCM + ov::Tensor timestep_cond = get_guidance_scale_embedding(generation_config.guidance_scale - 1.0f, unet_config.time_cond_proj_dim); + m_unet->set_hidden_states("timestep_cond", timestep_cond); + } + m_scheduler->set_timesteps(generation_config.num_inference_steps, generation_config.strength); std::vector timesteps = m_scheduler->get_timesteps(); @@ -430,10 +435,6 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { } private: - bool do_classifier_free_guidance(float guidance_scale) const { - return guidance_scale > 1.0f && m_unet->get_config().time_cond_proj_dim < 0; - } - void initialize_generation_config(const std::string& class_name) override { assert(m_unet != nullptr); assert(m_vae != nullptr); @@ -463,7 +464,7 @@ class StableDiffusionXLPipeline : public DiffusionPipeline { void check_inputs(const ImageGenerationConfig& generation_config, ov::Tensor initial_image) const override { check_image_size(generation_config.width, generation_config.height); - const bool is_classifier_free_guidance = do_classifier_free_guidance(generation_config.guidance_scale); + const bool is_classifier_free_guidance = m_unet->do_classifier_free_guidance(generation_config.guidance_scale); const char * const pipeline_name = "Stable Diffusion XL"; OPENVINO_ASSERT(generation_config.prompt_3 == std::nullopt, "Prompt 3 is not used by ", pipeline_name);