diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index 702d8eec16..01b09101a1 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -183,26 +183,27 @@ fn main() -> Result<()> { let context = Tensor::cat(&[context, context_uncond], 0)?; let y = Tensor::cat(&[y, y_uncond], 0)?; - let mmdit = MMDiT::new( - &mmdit_config, - use_flash_attn, - vb.pp("model.diffusion_model"), - )?; - if let Some(seed) = seed { device.set_seed(seed)?; } let start_time = std::time::Instant::now(); - let x = sampling::euler_sample( - &mmdit, - &y, - &context, - num_inference_steps, - cfg_scale, - time_shift, - height, - width, - )?; + let x = { + let mmdit = MMDiT::new( + &mmdit_config, + use_flash_attn, + vb.pp("model.diffusion_model"), + )?; + sampling::euler_sample( + &mmdit, + &y, + &context, + num_inference_steps, + cfg_scale, + time_shift, + height, + width, + )? + }; let dt = start_time.elapsed().as_secs_f32(); println!( "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",