diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 550513b5c943..f0b353d931d4 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -46,6 +46,7 @@ def __call__( guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", **kwargs, ): @@ -98,12 +99,18 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - # get the intial random noise - latents = torch.randn( - (batch_size, self.unet.in_channels, height // 8, width // 8), - generator=generator, - device=self.device, - ) + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())