Skip to content

Commit

Permalink
StableDiffusion: Decode latents separately to run larger batches (ope…
Browse files Browse the repository at this point in the history
…n-mmlab#1150)

* StableDiffusion: Decode latents separately to run larger batches

* Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode

* Rename sliced_decode to slicing

* fix whitespace

* fix quality check and repository consistency

* VAE slicing tests and documentation

* API doc hooks for VAE slicing

* reformat vae slicing tests

* Skip VAE slicing for one-image batches

* Documentation tweaks for VAE slicing

Co-authored-by: Ilmari Heikkinen <[email protected]>
  • Loading branch information
kig and Ilmari Heikkinen authored Nov 29, 2022
1 parent bcb6cc1 commit c28d3c8
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/api/pipelines/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_vae_slicing
- disable_vae_slicing

## StableDiffusionImg2ImgPipeline
[[autodoc]] StableDiffusionImg2ImgPipeline
Expand Down
28 changes: 28 additions & 0 deletions docs/source/optimization/fp16.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,34 @@ image = pipe(prompt).images[0]

There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!


## Sliced VAE decode for larger batches

To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.

You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.

To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:

```Python
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
images = pipe([prompt] * 32).images
```

You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.


## Offloading to CPU with accelerate for memory savings

For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.
Expand Down
31 changes: 30 additions & 1 deletion src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def __init__(

self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False

def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
Expand All @@ -576,7 +577,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK

return AutoencoderKLOutput(latent_dist=posterior)

def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)

Expand All @@ -585,6 +586,34 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode

return DecoderOutput(sample=dec)

def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True

def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False

def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample

if not return_dict:
return (decoded,)

return DecoderOutput(sample=decoded)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,22 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
Expand Down
79 changes: 79 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,46 @@ def test_stable_diffusion_attention_chunk(self):

assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4

def test_stable_diffusion_vae_slicing(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"

image_count = 4

generator = torch.Generator(device=device).manual_seed(0)
output_1 = sd_pipe(
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
)

# make sure sliced vae decode yields the same result
sd_pipe.enable_vae_slicing()
generator = torch.Generator(device=device).manual_seed(0)
output_2 = sd_pipe(
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
)

# there is a small discrepancy at image borders vs. full batch decode
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3

def test_stable_diffusion_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
Expand Down Expand Up @@ -886,6 +926,45 @@ def test_stable_diffusion_memory_chunking(self):
assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3

def test_stable_diffusion_vae_slicing(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()

prompt = "a photograph of an astronaut riding a horse"

# enable vae slicing
pipe.enable_vae_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images

mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 4 GB is allocated
assert mem_bytes < 4e9

# disable vae slicing
pipe.disable_vae_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images

# make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9
# There is a small discrepancy at the image borders vs. a fully batched version.
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3

def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
Expand Down

0 comments on commit c28d3c8

Please sign in to comment.