Skip to content

Commit

Permalink
[core] FreeNoise (#8948)
Browse files Browse the repository at this point in the history
* initial work draft for freenoise; needs massive cleanup

* fix freeinit bug

* add animatediff controlnet implementation

* revert attention changes

* add freenoise

* remove old helper functions

* add decode batch size param to all pipelines

* make style

* fix copied from comments

* make fix-copies

* make style

* copy animatediff controlnet implementation from #8972

* add experimental support for num_frames not perfectly fitting context length, ocntext stride

* make unet motion model lora work again based on #8995

* copy load video utils from #8972

* copied from AnimateDiff::prepare_latents

* address the case where last batch of frames does not match length of indices in prepare latents

* decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid

* revert sparsectrl and sdxl freenoise changes

* revert pia

* add freenoise tests

* make fix-copies

* improve docstrings

* add freenoise tests to animatediff controlnet

* update tests

* Update src/diffusers/models/unets/unet_motion_model.py

* add freenoise to animatediff pag

* address review comments

* make style

* update tests

* make fix-copies

* fix error message

* remove copied from comment

* fix imports in tests

* update

---------

Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
a-r-r-o-w and DN6 authored Aug 7, 2024
1 parent 2d753b6 commit 16a93f1
Show file tree
Hide file tree
Showing 11 changed files with 911 additions and 50 deletions.
326 changes: 325 additions & 1 deletion src/diffusers/models/attention.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def custom_forward(*inputs):

else:
hidden_states = resnet(hidden_states, temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

output_states = output_states + (hidden_states,)
Expand Down Expand Up @@ -536,6 +537,7 @@ def custom_forward(*inputs):
)[0]
else:
hidden_states = resnet(hidden_states, temb)

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -761,6 +763,7 @@ def custom_forward(*inputs):
)[0]
else:
hidden_states = resnet(hidden_states, temb)

hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -921,9 +924,9 @@ def custom_forward(*inputs):
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb
)

else:
hidden_states = resnet(hidden_states, temb)

hidden_states = motion_module(hidden_states, num_frames=num_frames)

if self.upsamplers is not None:
Expand Down Expand Up @@ -1923,7 +1926,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
Expand Down Expand Up @@ -1953,7 +1955,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
Expand Down
38 changes: 27 additions & 11 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -72,6 +73,7 @@ class AnimateDiffPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for text-to-video generation.
Expand Down Expand Up @@ -394,15 +396,20 @@ def prepare_ip_adapter_image_embeds(

return ip_adapter_image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents

batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)

image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)

video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
Expand Down Expand Up @@ -495,22 +502,28 @@ def check_inputs(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Expand Down Expand Up @@ -569,6 +582,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
**kwargs,
):
r"""
Expand Down Expand Up @@ -637,6 +651,8 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
vae_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples:
Expand Down Expand Up @@ -808,7 +824,7 @@ def __call__(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# 10. Offload all models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...video_processor import VideoProcessor
from ..controlnet.multicontrolnet import MultiControlNetModel
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for text-to-video generation with ControlNet guidance.
Expand Down Expand Up @@ -432,15 +434,16 @@ def prepare_ip_adapter_image_embeds(

return ip_adapter_image_embeds

def decode_latents(self, latents, decode_batch_size: int = 16):
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents

batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)

video = []
for i in range(0, latents.shape[0], decode_batch_size):
batch_latents = latents[i : i + decode_batch_size]
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)

Expand Down Expand Up @@ -608,22 +611,29 @@ def check_inputs(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")

# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
# If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
)

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
Expand Down Expand Up @@ -718,7 +728,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
decode_batch_size: int = 16,
vae_batch_size: int = 16,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -1054,7 +1064,7 @@ def __call__(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents, decode_batch_size)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# 10. Offload all models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput

Expand Down Expand Up @@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
):
r"""
Pipeline for video-to-video generation.
Expand Down Expand Up @@ -498,15 +500,29 @@ def prepare_ip_adapter_image_embeds(

return ip_adapter_image_embeds

# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor:
latents = []
for i in range(0, len(video), vae_batch_size):
batch_video = video[i : i + vae_batch_size]
batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator)
latents.append(batch_video)
return torch.cat(latents)

# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents
def decode_latents(self, latents, vae_batch_size: int = 16):
latents = 1 / self.vae.config.scaling_factor * latents

batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)

image = self.vae.decode(latents).sample
video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4)
video = []
for i in range(0, latents.shape[0], vae_batch_size):
batch_latents = latents[i : i + vae_batch_size]
batch_latents = self.vae.decode(batch_latents).sample
video.append(batch_latents)

video = torch.cat(video)
video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
Expand Down Expand Up @@ -622,6 +638,7 @@ def prepare_latents(
device,
generator,
latents=None,
vae_batch_size: int = 16,
):
if latents is None:
num_frames = video.shape[1]
Expand Down Expand Up @@ -656,13 +673,10 @@ def prepare_latents(
)

init_latents = [
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
for i in range(batch_size)
self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size)
]
else:
init_latents = [
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
]
init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video]

init_latents = torch.cat(init_latents, dim=0)

Expand Down Expand Up @@ -747,6 +761,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_batch_size: int = 16,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -822,6 +837,8 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
vae_batch_size (`int`, defaults to `16`):
The number of frames to decode at a time when calling `decode_latents` method.
Examples:
Expand Down Expand Up @@ -923,6 +940,7 @@ def __call__(
device=device,
generator=generator,
latents=latents,
vae_batch_size=vae_batch_size,
)

# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
Expand Down Expand Up @@ -990,7 +1008,7 @@ def __call__(
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video_tensor = self.decode_latents(latents, vae_batch_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# 10. Offload all models
Expand Down
Loading

0 comments on commit 16a93f1

Please sign in to comment.