Skip to content

Commit

Permalink
Reproducible images by supplying latents to pipeline (huggingface#247)
Browse files Browse the repository at this point in the history
* Accept latents as input for StableDiffusionPipeline.

* Notebook to demonstrate reusable seeds (latents).

* More accurate type annotation

Co-authored-by: Suraj Patil <[email protected]>

* Review comments: move to device, raise instead of assert.

* Actually commit the test notebook.

I had mistakenly pushed an empty file instead.

* Adapt notebook to Colab.

* Update examples readme.

* Move notebook to personal repo.

Co-authored-by: Suraj Patil <[email protected]>
  • Loading branch information
pcuenca and patil-suraj authored Aug 25, 2022
1 parent 6d6e0be commit 17ed037
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __call__(
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
Expand Down Expand Up @@ -98,12 +99,18 @@ def __call__(
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# get the intial random noise
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=self.device,
)
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)

# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
Expand Down

0 comments on commit 17ed037

Please sign in to comment.