From 37e0ab8c64eb8219e32cf546ac2aa570ed3d1f82 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 27 Oct 2024 10:01:04 +0100 Subject: [PATCH] Stable diffusion 3.5 support. (#2578) * Stable diffusion 3.5 support. * Clippy fixes. * CFG fix. * Remove some unnecessary clones. * Avoid duplicating some of the code. --- .../examples/stable-diffusion-3/clip.rs | 50 ++++- .../examples/stable-diffusion-3/main.rs | 198 +++++++++++------- .../examples/stable-diffusion-3/sampling.rs | 2 +- candle-transformers/src/models/mmdit/model.rs | 14 ++ .../src/models/mmdit/projections.rs | 30 ++- 5 files changed, 209 insertions(+), 85 deletions(-) diff --git a/candle-examples/examples/stable-diffusion-3/clip.rs b/candle-examples/examples/stable-diffusion-3/clip.rs index 77263d968c..d198366a83 100644 --- a/candle-examples/examples/stable-diffusion-3/clip.rs +++ b/candle-examples/examples/stable-diffusion-3/clip.rs @@ -1,6 +1,7 @@ use anyhow::{Error as E, Ok, Result}; use candle::{DType, IndexOp, Module, Tensor, D}; use candle_transformers::models::{stable_diffusion, t5}; +use std::path::PathBuf; use tokenizers::tokenizer::Tokenizer; struct ClipWithTokenizer { @@ -130,6 +131,53 @@ pub struct StableDiffusion3TripleClipWithTokenizer { } impl StableDiffusion3TripleClipWithTokenizer { + pub fn new_split( + clip_g_file: &PathBuf, + clip_l_file: &PathBuf, + t5xxl_file: &PathBuf, + device: &candle::Device, + ) -> Result { + let vb_clip_g = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)? + }; + let vb_clip_l = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)? + }; + let vb_t5 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)? + }; + let max_position_embeddings = 77usize; + let clip_l = ClipWithTokenizer::new( + vb_clip_l, + stable_diffusion::clip::Config::sdxl(), + "openai/clip-vit-large-patch14", + max_position_embeddings, + )?; + + let text_projection = + candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?; + + let clip_g = ClipWithTokenizer::new( + vb_clip_g, + stable_diffusion::clip::Config::sdxl2(), + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + max_position_embeddings, + )?; + + // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. + // This is a temporary workaround until the T5 implementation is updated to support fp16. + // Also see: + // https://github.com/huggingface/candle/issues/2480 + // https://github.com/huggingface/candle/pull/2481 + let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?; + Ok(Self { + clip_l, + clip_g, + clip_g_text_projection: text_projection, + t5, + }) + } + pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result { let max_position_embeddings = 77usize; let clip_l = ClipWithTokenizer::new( @@ -158,7 +206,6 @@ impl StableDiffusion3TripleClipWithTokenizer { // https://github.com/huggingface/candle/issues/2480 // https://github.com/huggingface/candle/pull/2481 let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?; - Ok(Self { clip_l, clip_g, @@ -195,7 +242,6 @@ impl StableDiffusion3TripleClipWithTokenizer { .encode_text_to_embedding(prompt, device)? .to_dtype(DType::F16)?; let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?; - Ok((context, y)) } } diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index ee467839e8..702d8eec16 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -11,6 +11,25 @@ use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename}; use anyhow::{Ok, Result}; use clap::Parser; +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "3-medium")] + V3Medium, + #[value(name = "3.5-large")] + V3_5Large, + #[value(name = "3.5-large-turbo")] + V3_5LargeTurbo, +} + +impl Which { + fn is_3_5(&self) -> bool { + match self { + Self::V3Medium => false, + Self::V3_5Large | Self::V3_5LargeTurbo => true, + } + } +} + #[derive(Parser)] #[command(author, version, about, long_about = None)] struct Args { @@ -30,10 +49,6 @@ struct Args { #[arg(long)] cpu: bool, - /// The GPU device ID to use. - #[arg(long, default_value_t = 0)] - gpu_device_id: usize, - /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -50,13 +65,17 @@ struct Args { #[arg(long, default_value_t = 1024)] width: usize, + /// The model to use. + #[arg(long, default_value = "3-medium")] + which: Which, + /// The seed to use when generating random samples. - #[arg(long, default_value_t = 28)] - num_inference_steps: usize, + #[arg(long)] + num_inference_steps: Option, // CFG scale. - #[arg(long, default_value_t = 4.0)] - cfg_scale: f64, + #[arg(long)] + cfg_scale: Option, // Time shift factor (alpha). #[arg(long, default_value_t = 3.0)] @@ -68,12 +87,6 @@ struct Args { } fn main() -> Result<()> { - let args = Args::parse(); - // Your main code here - run(args) -} - -fn run(args: Args) -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -81,7 +94,6 @@ fn run(args: Args) -> Result<()> { prompt, uncond_prompt, cpu, - gpu_device_id, tracing, use_flash_attn, height, @@ -90,7 +102,8 @@ fn run(args: Args) -> Result<()> { cfg_scale, time_shift, seed, - } = args; + which, + } = Args::parse(); let _guard = if tracing { let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); @@ -100,87 +113,110 @@ fn run(args: Args) -> Result<()> { None }; - let device = if cpu { - candle::Device::Cpu - } else if candle::utils::cuda_is_available() { - candle::Device::new_cuda(gpu_device_id)? - } else if candle::utils::metal_is_available() { - candle::Device::new_metal(gpu_device_id)? - } else { - candle::Device::Cpu + let device = candle_examples::device(cpu)?; + let default_inference_steps = match which { + Which::V3_5Large => 28, + Which::V3_5LargeTurbo => 4, + Which::V3Medium => 28, + }; + let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps); + let default_cfg_scale = match which { + Which::V3_5Large => 4.0, + Which::V3_5LargeTurbo => 1.0, + Which::V3Medium => 4.0, }; + let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale); let api = hf_hub::api::sync::Api::new()?; - let sai_repo = { - let name = "stabilityai/stable-diffusion-3-medium"; - api.repo(hf_hub::Repo::model(name.to_string())) - }; - let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; - let vb_fp16 = unsafe { - candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)? - }; + let (mmdit_config, mut triple, vb) = if which.is_3_5() { + let sai_repo = { + let name = match which { + Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large", + Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo", + Which::V3Medium => unreachable!(), + }; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?; + let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?; + let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?; + let model_file = { + let model_file = match which { + Which::V3_5Large => "sd3.5_large.safetensors", + Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors", + Which::V3Medium => unreachable!(), + }; + sai_repo.get(model_file)? + }; + let triple = StableDiffusion3TripleClipWithTokenizer::new_split( + &clip_g_file, + &clip_l_file, + &t5xxl_file, + &device, + )?; + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)? + }; + (MMDiTConfig::sd3_5_large(), triple, vb) + } else { + let sai_repo = { + let name = "stabilityai/stable-diffusion-3-medium"; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; + let vb_fp16 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)? + }; - let (context, y) = { let vb_fp32 = unsafe { - candle_nn::VarBuilder::from_mmaped_safetensors( - &[model_file.clone()], - DType::F32, - &device, - )? + candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; - let mut triple = StableDiffusion3TripleClipWithTokenizer::new( + let triple = StableDiffusion3TripleClipWithTokenizer::new( vb_fp16.pp("text_encoders"), vb_fp32.pp("text_encoders"), )?; - let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; - let (context_uncond, y_uncond) = - triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?; - ( - Tensor::cat(&[context, context_uncond], 0)?, - Tensor::cat(&[y, y_uncond], 0)?, - ) - }; - - let x = { - let mmdit = MMDiT::new( - &MMDiTConfig::sd3_medium(), - use_flash_attn, - vb_fp16.pp("model.diffusion_model"), - )?; - - if let Some(seed) = seed { - device.set_seed(seed)?; - } - let start_time = std::time::Instant::now(); - let x = sampling::euler_sample( - &mmdit, - &y, - &context, - num_inference_steps, - cfg_scale, - time_shift, - height, - width, - )?; - let dt = start_time.elapsed().as_secs_f32(); - println!( - "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s", - dt, - num_inference_steps as f32 / dt - ); - x + (MMDiTConfig::sd3_medium(), triple, vb_fp16) }; + let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; + let (context_uncond, y_uncond) = + triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?; + let context = Tensor::cat(&[context, context_uncond], 0)?; + let y = Tensor::cat(&[y, y_uncond], 0)?; + + let mmdit = MMDiT::new( + &mmdit_config, + use_flash_attn, + vb.pp("model.diffusion_model"), + )?; + + if let Some(seed) = seed { + device.set_seed(seed)?; + } + let start_time = std::time::Instant::now(); + let x = sampling::euler_sample( + &mmdit, + &y, + &context, + num_inference_steps, + cfg_scale, + time_shift, + height, + width, + )?; + let dt = start_time.elapsed().as_secs_f32(); + println!( + "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s", + dt, + num_inference_steps as f32 / dt + ); let img = { - let vb_vae = vb_fp16 - .clone() - .rename_f(sd3_vae_vb_rename) - .pp("first_stage_model"); + let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model"); let autoencoder = build_sd3_vae_autoencoder(vb_vae)?; // Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image. // https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723 - autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)? + autoencoder.decode(&((x / 1.5305)? + 0.0609)?)? }; let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; candle_examples::save_image(&img.i(0)?, "out.jpg")?; diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs index 0efd160eba..cd881b6a2f 100644 --- a/candle-examples/examples/stable-diffusion-3/sampling.rs +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -30,7 +30,7 @@ pub fn euler_sample( let timestep = (*s_curr) * 1000.0; let noise_pred = mmdit.forward( - &Tensor::cat(&[x.clone(), x.clone()], 0)?, + &Tensor::cat(&[&x, &x], 0)?, &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?, y, context, diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 864b662377..5b5c90b0c3 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -36,6 +36,20 @@ impl Config { frequency_embedding_size: 256, } } + + pub fn sd3_5_large() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 38, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 192, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } } pub struct MMDiT { diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs index dc1e8ec941..2775328596 100644 --- a/candle-transformers/src/models/mmdit/projections.rs +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -56,6 +56,8 @@ impl QkvOnlyAttnProjections { pub struct AttnProjections { head_dim: usize, qkv: nn::Linear, + ln_k: Option, + ln_q: Option, proj: nn::Linear, } @@ -64,16 +66,42 @@ impl AttnProjections { let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; let proj = nn::linear(dim, dim, vb.pp("proj"))?; + let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") { + let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?; + let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?; + (Some(ln_k), Some(ln_q)) + } else { + (None, None) + }; Ok(Self { head_dim, qkv, proj, + ln_k, + ln_q, }) } pub fn pre_attention(&self, x: &Tensor) -> Result { let qkv = self.qkv.forward(x)?; - split_qkv(&qkv, self.head_dim) + let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?; + let q = match self.ln_q.as_ref() { + None => q, + Some(l) => { + let (b, t, h) = q.dims3()?; + l.forward(&q.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + let k = match self.ln_k.as_ref() { + None => k, + Some(l) => { + let (b, t, h) = k.dims3()?; + l.forward(&k.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + Ok(Qkv { q, k, v }) } pub fn post_attention(&self, x: &Tensor) -> Result {