From 498bc2cdc962482bd0324074050ae706d9ed9a5f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 28 Oct 2024 16:06:53 +0100 Subject: [PATCH] Release the mmdit model earlier to reduce memory usage. (#2581) * Stable diffusion 3.5 support. * Clippy fixes. * CFG fix. * Remove some unnecessary clones. * Avoid duplicating some of the code. * Release the mmdit model earlier to reduce memory usage. --- .../examples/stable-diffusion-3/main.rs | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index 702d8eec16..01b09101a1 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -183,26 +183,27 @@ fn main() -> Result<()> { let context = Tensor::cat(&[context, context_uncond], 0)?; let y = Tensor::cat(&[y, y_uncond], 0)?; - let mmdit = MMDiT::new( - &mmdit_config, - use_flash_attn, - vb.pp("model.diffusion_model"), - )?; - if let Some(seed) = seed { device.set_seed(seed)?; } let start_time = std::time::Instant::now(); - let x = sampling::euler_sample( - &mmdit, - &y, - &context, - num_inference_steps, - cfg_scale, - time_shift, - height, - width, - )?; + let x = { + let mmdit = MMDiT::new( + &mmdit_config, + use_flash_attn, + vb.pp("model.diffusion_model"), + )?; + sampling::euler_sample( + &mmdit, + &y, + &context, + num_inference_steps, + cfg_scale, + time_shift, + height, + width, + )? + }; let dt = start_time.elapsed().as_secs_f32(); println!( "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",