From 17ed0371bc847508248596e7f0ebe95e0706c074 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 25 Aug 2022 15:47:05 +0200 Subject: [PATCH] Reproducible images by supplying latents to pipeline (#247) * Accept latents as input for StableDiffusionPipeline. * Notebook to demonstrate reusable seeds (latents). * More accurate type annotation Co-authored-by: Suraj Patil * Review comments: move to device, raise instead of assert. * Actually commit the test notebook. I had mistakenly pushed an empty file instead. * Adapt notebook to Colab. * Update examples readme. * Move notebook to personal repo. Co-authored-by: Suraj Patil --- .../pipeline_stable_diffusion.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 550513b5c943..f0b353d931d4 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -46,6 +46,7 @@ def __call__( guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", **kwargs, ): @@ -98,12 +99,18 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - # get the intial random noise - latents = torch.randn( - (batch_size, self.unet.in_channels, height // 8, width // 8), - generator=generator, - device=self.device, - ) + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())