diff --git a/mmagic/models/editors/stable_diffusion/stable_diffusion.py b/mmagic/models/editors/stable_diffusion/stable_diffusion.py index 71e311b729..308680339d 100644 --- a/mmagic/models/editors/stable_diffusion/stable_diffusion.py +++ b/mmagic/models/editors/stable_diffusion/stable_diffusion.py @@ -219,8 +219,8 @@ def infer(self, negative_prompt) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps) - timesteps = self.scheduler.timesteps + self.test_scheduler.set_timesteps(num_inference_steps) + timesteps = self.test_scheduler.timesteps # 5. Prepare latent variables if hasattr(self.unet, 'module'): @@ -249,7 +249,7 @@ def infer(self, # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input( + latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) latent_model_input = latent_model_input.to(latent_dtype) @@ -266,7 +266,7 @@ def infer(self, noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( + latents = self.test_scheduler.step( noise_pred, t, latents, **extra_step_kwargs)['prev_sample'] # 8. Post-processing