Skip to content

Commit

Permalink
[Fix] Fix set timestep in stable diffusion's infer (#1869)
Browse files Browse the repository at this point in the history
fix set timestep in stable diffusion's infer
  • Loading branch information
LeoXing1996 authored May 29, 2023
1 parent ba4732c commit 47a54de
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def infer(self,
negative_prompt)

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
self.test_scheduler.set_timesteps(num_inference_steps)
timesteps = self.test_scheduler.timesteps

# 5. Prepare latent variables
if hasattr(self.unet, 'module'):
Expand Down Expand Up @@ -249,7 +249,7 @@ def infer(self,
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input = self.test_scheduler.scale_model_input(
latent_model_input, t)

latent_model_input = latent_model_input.to(latent_dtype)
Expand All @@ -266,7 +266,7 @@ def infer(self,
noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
latents = self.test_scheduler.step(
noise_pred, t, latents, **extra_step_kwargs)['prev_sample']

# 8. Post-processing
Expand Down

0 comments on commit 47a54de

Please sign in to comment.