Skip to content

Commit

Permalink
Release the mmdit model earlier to reduce memory usage. (#2581)
Browse files Browse the repository at this point in the history
* Stable diffusion 3.5 support.

* Clippy fixes.

* CFG fix.

* Remove some unnecessary clones.

* Avoid duplicating some of the code.

* Release the mmdit model earlier to reduce memory usage.
  • Loading branch information
LaurentMazare authored Oct 28, 2024
1 parent 0e2c8c1 commit 498bc2c
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions candle-examples/examples/stable-diffusion-3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,26 +183,27 @@ fn main() -> Result<()> {
let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?;

let mmdit = MMDiT::new(
&mmdit_config,
use_flash_attn,
vb.pp("model.diffusion_model"),
)?;

if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
)?;
let x = {
let mmdit = MMDiT::new(
&mmdit_config,
use_flash_attn,
vb.pp("model.diffusion_model"),
)?;
sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
)?
};
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
Expand Down

0 comments on commit 498bc2c

Please sign in to comment.