From c28d3c82ce6f56c4b373a8260c56357d13db900a Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Tue, 29 Nov 2022 20:28:14 +0800 Subject: [PATCH] StableDiffusion: Decode latents separately to run larger batches (#1150) * StableDiffusion: Decode latents separately to run larger batches * Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode * Rename sliced_decode to slicing * fix whitespace * fix quality check and repository consistency * VAE slicing tests and documentation * API doc hooks for VAE slicing * reformat vae slicing tests * Skip VAE slicing for one-image batches * Documentation tweaks for VAE slicing Co-authored-by: Ilmari Heikkinen --- .../source/api/pipelines/stable_diffusion.mdx | 2 + docs/source/optimization/fp16.mdx | 28 +++++++ src/diffusers/models/vae.py | 31 +++++++- .../alt_diffusion/pipeline_alt_diffusion.py | 16 ++++ .../pipeline_stable_diffusion.py | 16 ++++ .../stable_diffusion/test_stable_diffusion.py | 79 +++++++++++++++++++ 6 files changed, 171 insertions(+), 1 deletion(-) diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index afa72775f0..70c4abaaf6 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -76,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca - __call__ - enable_attention_slicing - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing ## StableDiffusionImg2ImgPipeline [[autodoc]] StableDiffusionImg2ImgPipeline diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index 4371daacc9..49fe3876bd 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -117,6 +117,34 @@ image = pipe(prompt).images[0] There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM! + +## Sliced VAE decode for larger batches + +To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time. + +You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use. + +To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example: + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +pipe.enable_vae_slicing() +images = pipe([prompt] * 32).images +``` + +You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches. + + ## Offloading to CPU with accelerate for memory savings For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass. diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 30de343d08..e29f4e8afa 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -565,6 +565,7 @@ def __init__( self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.use_slicing = False def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) @@ -576,7 +577,7 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderK return AutoencoderKLOutput(latent_dist=posterior) - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: z = self.post_quant_conv(z) dec = self.decoder(z) @@ -585,6 +586,34 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=dec) + def enable_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 4a80f2e689..9146d45bd3 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -216,6 +216,22 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c6ff904d24..afaef6f481 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -215,6 +215,22 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 229e18c69f..8dce61c3a4 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -557,6 +557,46 @@ def test_stable_diffusion_attention_chunk(self): assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + def test_stable_diffusion_vae_slicing(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + image_count = 4 + + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe( + [prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np" + ) + + # make sure sliced vae decode yields the same result + sd_pipe.enable_vae_slicing() + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe( + [prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np" + ) + + # there is a small discrepancy at image borders vs. full batch decode + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3 + def test_stable_diffusion_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -886,6 +926,45 @@ def test_stable_diffusion_memory_chunking(self): assert mem_bytes > 3.75 * 10**9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + def test_stable_diffusion_vae_slicing(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a photograph of an astronaut riding a horse" + + # enable vae slicing + pipe.enable_vae_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 4 GB is allocated + assert mem_bytes < 4e9 + + # disable vae slicing + pipe.disable_vae_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 4 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 4e9 + # There is a small discrepancy at the image borders vs. a fully batched version. + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3 + def test_stable_diffusion_text2img_pipeline_fp16(self): torch.cuda.reset_peak_memory_stats() model_id = "CompVis/stable-diffusion-v1-4"