From 16a93f1a25675eb9d1308f2fa1f7a3023240ff85 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 7 Aug 2024 10:35:18 +0530 Subject: [PATCH] [core] FreeNoise (#8948) * initial work draft for freenoise; needs massive cleanup * fix freeinit bug * add animatediff controlnet implementation * revert attention changes * add freenoise * remove old helper functions * add decode batch size param to all pipelines * make style * fix copied from comments * make fix-copies * make style * copy animatediff controlnet implementation from #8972 * add experimental support for num_frames not perfectly fitting context length, ocntext stride * make unet motion model lora work again based on #8995 * copy load video utils from #8972 * copied from AnimateDiff::prepare_latents * address the case where last batch of frames does not match length of indices in prepare latents * decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid * revert sparsectrl and sdxl freenoise changes * revert pia * add freenoise tests * make fix-copies * improve docstrings * add freenoise tests to animatediff controlnet * update tests * Update src/diffusers/models/unets/unet_motion_model.py * add freenoise to animatediff pag * address review comments * make style * update tests * make fix-copies * fix error message * remove copied from comment * fix imports in tests * update --------- Co-authored-by: Dhruv Nair --- src/diffusers/models/attention.py | 326 +++++++++++++++++- .../models/unets/unet_motion_model.py | 7 +- .../animatediff/pipeline_animatediff.py | 38 +- .../pipeline_animatediff_controlnet.py | 32 +- .../pipeline_animatediff_video2video.py | 38 +- src/diffusers/pipelines/free_noise_utils.py | 236 +++++++++++++ .../pag/pipeline_pag_sd_animatediff.py | 38 +- .../pipelines/animatediff/test_animatediff.py | 59 ++++ .../test_animatediff_controlnet.py | 59 ++++ .../test_animatediff_video2video.py | 69 +++- tests/pipelines/pag/test_pag_animatediff.py | 59 ++++ 11 files changed, 911 insertions(+), 50 deletions(-) create mode 100644 src/diffusers/pipelines/free_noise_utils.py diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 74a9d9efc6bd..e6858d842cbb 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -272,6 +272,17 @@ def __init__( attention_out_bias: bool = True, ): super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings self.only_cross_attention = only_cross_attention # We keep these boolean flags for backward-compatibility. @@ -782,6 +793,319 @@ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): return hidden_states +@maybe_allow_in_graph +class FreeNoiseTransformerBlock(nn.Module): + r""" + A FreeNoise Transformer block. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + cross_attention_dim (`int`, *optional*): + The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (`int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (`bool`, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, defaults to `False`): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, defaults to `False`): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, defaults to `False`): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + ff_inner_dim (`int`, *optional*): + Hidden dimension of feed-forward MLP. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in feed-forward MLP. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in attention output project layer. + context_length (`int`, defaults to `16`): + The maximum number of frames that the FreeNoise block processes at once. + context_stride (`int`, defaults to `4`): + The number of frames to be skipped before starting to process a new batch of `context_length` frames. + weighting_scheme (`str`, defaults to `"pyramid"`): + The weighting scheme to use for weighting averaging of processed latent frames. As described in the + Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting + used. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-5, + final_dropout: bool = False, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + context_length: int = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.dropout = dropout + self.cross_attention_dim = cross_attention_dim + self.activation_fn = activation_fn + self.attention_bias = attention_bias + self.double_self_attention = double_self_attention + self.norm_elementwise_affine = norm_elementwise_affine + self.positional_embeddings = positional_embeddings + self.num_positional_embeddings = num_positional_embeddings + self.only_cross_attention = only_cross_attention + + self.set_free_noise_properties(context_length, context_stride, weighting_scheme) + + # We keep these boolean flags for backward-compatibility. + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]: + frame_indices = [] + for i in range(0, num_frames - self.context_length + 1, self.context_stride): + window_start = i + window_end = min(num_frames, i + self.context_length) + frame_indices.append((window_start, window_end)) + return frame_indices + + def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]: + if weighting_scheme == "pyramid": + if num_frames % 2 == 0: + # num_frames = 4 => [1, 2, 2, 1] + weights = list(range(1, num_frames // 2 + 1)) + weights = weights + weights[::-1] + else: + # num_frames = 5 => [1, 2, 3, 2, 1] + weights = list(range(1, num_frames // 2 + 1)) + weights = weights + [num_frames // 2 + 1] + weights[::-1] + else: + raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}") + + return weights + + def set_free_noise_properties( + self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid" + ) -> None: + self.context_length = context_length + self.context_stride = context_stride + self.weighting_scheme = weighting_scheme + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None: + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + # hidden_states: [B x H x W, F, C] + device = hidden_states.device + dtype = hidden_states.dtype + + num_frames = hidden_states.size(1) + frame_indices = self._get_frame_indices(num_frames) + frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme) + frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1) + is_last_frame_batch_complete = frame_indices[-1][1] == num_frames + + # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length + # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges: + # [(0, 16), (4, 20), (8, 24), (10, 26)] + if not is_last_frame_batch_complete: + if num_frames < self.context_length: + raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}") + last_frame_batch_length = num_frames - frame_indices[-1][1] + frame_indices.append((num_frames - self.context_length, num_frames)) + + num_times_accumulated = torch.zeros((1, num_frames, 1), device=device) + accumulated_values = torch.zeros_like(hidden_states) + + for i, (frame_start, frame_end) in enumerate(frame_indices): + # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle + # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or + # essentially a non-multiple of `context_length`. + weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end]) + weights *= frame_weights + + hidden_states_chunk = hidden_states[:, frame_start:frame_end] + + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states_chunk) + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states_chunk = attn_output + hidden_states_chunk + if hidden_states_chunk.ndim == 4: + hidden_states_chunk = hidden_states_chunk.squeeze(1) + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states_chunk) + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states_chunk = attn_output + hidden_states_chunk + + if i == len(frame_indices) - 1 and not is_last_frame_batch_complete: + accumulated_values[:, -last_frame_batch_length:] += ( + hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:] + ) + num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length] + else: + accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights + num_times_accumulated[:, frame_start:frame_end] += weights + + hidden_states = torch.where( + num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values + ).to(dtype) + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + class FeedForward(nn.Module): r""" A feed-forward layer. diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index e96867bc3ed0..7a9b7f2d5afe 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -343,6 +343,7 @@ def custom_forward(*inputs): else: hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames) output_states = output_states + (hidden_states,) @@ -536,6 +537,7 @@ def custom_forward(*inputs): )[0] else: hidden_states = resnet(hidden_states, temb) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -761,6 +763,7 @@ def custom_forward(*inputs): )[0] else: hidden_states = resnet(hidden_states, temb) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -921,9 +924,9 @@ def custom_forward(*inputs): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) - else: hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames) if self.upsamplers is not None: @@ -1923,7 +1926,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: """ Sets the attention processor to use [feed forward @@ -1953,7 +1955,6 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int for module in self.children(): fn_recursive_feed_forward(module, chunk_size, dim) - # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking def disable_forward_chunking(self) -> None: def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): if hasattr(module, "set_chunk_feed_forward"): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 30d9eccf1c32..4c2ba7238cde 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -42,6 +42,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -72,6 +73,7 @@ class AnimateDiffPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation. @@ -394,15 +396,20 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -495,10 +502,21 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, num_channels_latents, @@ -506,11 +524,6 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -569,6 +582,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + vae_batch_size: int = 16, **kwargs, ): r""" @@ -637,6 +651,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + vae_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -808,7 +824,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 5574556a9b96..e30b29778129 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -30,6 +30,7 @@ from ...video_processor import VideoProcessor from ..controlnet.multicontrolnet import MultiControlNetModel from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -109,6 +110,7 @@ class AnimateDiffControlNetPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for text-to-video generation with ControlNet guidance. @@ -432,15 +434,16 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - def decode_latents(self, latents, decode_batch_size: int = 16): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) video = [] - for i in range(0, latents.shape[0], decode_batch_size): - batch_latents = latents[i : i + decode_batch_size] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] batch_latents = self.vae.decode(batch_latents).sample video.append(batch_latents) @@ -608,10 +611,22 @@ def check_inputs( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, num_channels_latents, @@ -619,11 +634,6 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -718,7 +728,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - decode_batch_size: int = 16, + vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -1054,7 +1064,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents, decode_batch_size) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 8129b88dc408..190170abd868 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -35,6 +35,7 @@ from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -176,6 +177,7 @@ class AnimateDiffVideoToVideoPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, ): r""" Pipeline for video-to-video generation. @@ -498,15 +500,29 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + def encode_video(self, video, generator, vae_batch_size: int = 16) -> torch.Tensor: + latents = [] + for i in range(0, len(video), vae_batch_size): + batch_video = video[i : i + vae_batch_size] + batch_video = retrieve_latents(self.vae.encode(batch_video), generator=generator) + latents.append(batch_video) + return torch.cat(latents) + + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -622,6 +638,7 @@ def prepare_latents( device, generator, latents=None, + vae_batch_size: int = 16, ): if latents is None: num_frames = video.shape[1] @@ -656,13 +673,10 @@ def prepare_latents( ) init_latents = [ - retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0) - for i in range(batch_size) + self.encode_video(video[i], generator[i], vae_batch_size).unsqueeze(0) for i in range(batch_size) ] else: - init_latents = [ - retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video - ] + init_latents = [self.encode_video(vid, generator, vae_batch_size).unsqueeze(0) for vid in video] init_latents = torch.cat(init_latents, dim=0) @@ -747,6 +761,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + vae_batch_size: int = 16, ): r""" The call function to the pipeline for generation. @@ -822,6 +837,8 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + vae_batch_size (`int`, defaults to `16`): + The number of frames to decode at a time when calling `decode_latents` method. Examples: @@ -923,6 +940,7 @@ def __call__( device=device, generator=generator, latents=latents, + vae_batch_size=vae_batch_size, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -990,7 +1008,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py new file mode 100644 index 000000000000..f8128abb9b58 --- /dev/null +++ b/src/diffusers/pipelines/free_noise_utils.py @@ -0,0 +1,236 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + +from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock +from ..models.unets.unet_motion_model import ( + CrossAttnDownBlockMotion, + DownBlockMotion, + UpBlockMotion, +) +from ..utils import logging +from ..utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class AnimateDiffFreeNoiseMixin: + r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169).""" + + def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): + r"""Helper function to enable FreeNoise in transformer blocks.""" + + for motion_module in block.motion_modules: + num_transformer_blocks = len(motion_module.transformer_blocks) + + for i in range(num_transformer_blocks): + if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): + motion_module.transformer_blocks[i].set_free_noise_properties( + self._free_noise_context_length, + self._free_noise_context_stride, + self._free_noise_weighting_scheme, + ) + else: + assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock) + basic_transfomer_block = motion_module.transformer_blocks[i] + + motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock( + dim=basic_transfomer_block.dim, + num_attention_heads=basic_transfomer_block.num_attention_heads, + attention_head_dim=basic_transfomer_block.attention_head_dim, + dropout=basic_transfomer_block.dropout, + cross_attention_dim=basic_transfomer_block.cross_attention_dim, + activation_fn=basic_transfomer_block.activation_fn, + attention_bias=basic_transfomer_block.attention_bias, + only_cross_attention=basic_transfomer_block.only_cross_attention, + double_self_attention=basic_transfomer_block.double_self_attention, + positional_embeddings=basic_transfomer_block.positional_embeddings, + num_positional_embeddings=basic_transfomer_block.num_positional_embeddings, + context_length=self._free_noise_context_length, + context_stride=self._free_noise_context_stride, + weighting_scheme=self._free_noise_weighting_scheme, + ).to(device=self.device, dtype=self.dtype) + + motion_module.transformer_blocks[i].load_state_dict( + basic_transfomer_block.state_dict(), strict=True + ) + + def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]): + r"""Helper function to disable FreeNoise in transformer blocks.""" + + for motion_module in block.motion_modules: + num_transformer_blocks = len(motion_module.transformer_blocks) + + for i in range(num_transformer_blocks): + if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock): + free_noise_transfomer_block = motion_module.transformer_blocks[i] + + motion_module.transformer_blocks[i] = BasicTransformerBlock( + dim=free_noise_transfomer_block.dim, + num_attention_heads=free_noise_transfomer_block.num_attention_heads, + attention_head_dim=free_noise_transfomer_block.attention_head_dim, + dropout=free_noise_transfomer_block.dropout, + cross_attention_dim=free_noise_transfomer_block.cross_attention_dim, + activation_fn=free_noise_transfomer_block.activation_fn, + attention_bias=free_noise_transfomer_block.attention_bias, + only_cross_attention=free_noise_transfomer_block.only_cross_attention, + double_self_attention=free_noise_transfomer_block.double_self_attention, + positional_embeddings=free_noise_transfomer_block.positional_embeddings, + num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings, + ).to(device=self.device, dtype=self.dtype) + + motion_module.transformer_blocks[i].load_state_dict( + free_noise_transfomer_block.state_dict(), strict=True + ) + + def _prepare_latents_free_noise( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + context_num_frames = ( + self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames + ) + + shape = ( + batch_size, + num_channels_latents, + context_num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if self._free_noise_noise_type == "random": + return latents + else: + if latents.size(2) == num_frames: + return latents + elif latents.size(2) != self._free_noise_context_length: + raise ValueError( + f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}" + ) + latents = latents.to(device) + + if self._free_noise_noise_type == "shuffle_context": + for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride): + # ensure window is within bounds + window_start = max(0, i - self._free_noise_context_length) + window_end = min(num_frames, window_start + self._free_noise_context_stride) + window_length = window_end - window_start + + if window_length == 0: + break + + indices = torch.LongTensor(list(range(window_start, window_end))) + shuffled_indices = indices[torch.randperm(window_length, generator=generator)] + + current_start = i + current_end = min(num_frames, current_start + window_length) + if current_end == current_start + window_length: + # batch of frames perfectly fits the window + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + else: + # handle the case where the last batch of frames does not fit perfectly with the window + prefix_length = current_end - current_start + shuffled_indices = shuffled_indices[:prefix_length] + latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices] + + elif self._free_noise_noise_type == "repeat_context": + num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length + latents = torch.cat([latents] * num_repeats, dim=2) + + latents = latents[:, :, :num_frames] + return latents + + def enable_free_noise( + self, + context_length: Optional[int] = 16, + context_stride: int = 4, + weighting_scheme: str = "pyramid", + noise_type: str = "shuffle_context", + ) -> None: + r""" + Enable long video generation using FreeNoise. + + Args: + context_length (`int`, defaults to `16`, *optional*): + The number of video frames to process at once. It's recommended to set this to the maximum frames the + Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion + adapter config is used. + context_stride (`int`, *optional*): + Long videos are generated by processing many frames. FreeNoise processes these frames in sliding + windows of size `context_length`. Context stride allows you to specify how many frames to skip between + each window. For example, a context length of 16 and context stride of 4 would process 24 frames as: + [0, 15], [4, 19], [8, 23] (0-based indexing) + weighting_scheme (`str`, defaults to `pyramid`): + Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting + schemes are supported currently: + - "pyramid" + Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1]. + noise_type (`str`, defaults to "shuffle_context"): + TODO + """ + + allowed_weighting_scheme = ["pyramid"] + allowed_noise_type = ["shuffle_context", "repeat_context", "random"] + + if context_length > self.motion_adapter.config.motion_max_seq_length: + logger.warning( + f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results." + ) + if weighting_scheme not in allowed_weighting_scheme: + raise ValueError( + f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}" + ) + if noise_type not in allowed_noise_type: + raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}") + + self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length + self._free_noise_context_stride = context_stride + self._free_noise_weighting_scheme = weighting_scheme + self._free_noise_noise_type = noise_type + + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + self._enable_free_noise_in_block(block) + + def disable_free_noise(self) -> None: + self._free_noise_context_length = None + + blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks] + for block in blocks: + self._disable_free_noise_in_block(block) + + @property + def free_noise_enabled(self): + return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index b3b103742061..dbafb32b3a28 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -35,6 +35,7 @@ from ...video_processor import VideoProcessor from ..animatediff.pipeline_output import AnimateDiffPipelineOutput from ..free_init_utils import FreeInitMixin +from ..free_noise_utils import AnimateDiffFreeNoiseMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pag_utils import PAGMixin @@ -83,6 +84,7 @@ class AnimateDiffPAGPipeline( IPAdapterMixin, StableDiffusionLoraLoaderMixin, FreeInitMixin, + AnimateDiffFreeNoiseMixin, PAGMixin, ): r""" @@ -404,15 +406,21 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents - def decode_latents(self, latents): + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.decode_latents + def decode_latents(self, latents, vae_batch_size: int = 16): latents = 1 / self.vae.config.scaling_factor * latents batch_size, channels, num_frames, height, width = latents.shape latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) - image = self.vae.decode(latents).sample - video = image[None, :].reshape((batch_size, num_frames, -1) + image.shape[2:]).permute(0, 2, 1, 3, 4) + video = [] + for i in range(0, latents.shape[0], vae_batch_size): + batch_latents = latents[i : i + vae_batch_size] + batch_latents = self.vae.decode(batch_latents).sample + video.append(batch_latents) + + video = torch.cat(video) + video = video[None, :].reshape((batch_size, num_frames, -1) + video.shape[2:]).permute(0, 2, 1, 3, 4) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 video = video.float() return video @@ -499,10 +507,22 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.AnimateDiffPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None ): + # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169) + if self.free_noise_enabled: + latents = self._prepare_latents_free_noise( + batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + shape = ( batch_size, num_channels_latents, @@ -510,11 +530,6 @@ def prepare_latents( height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -573,6 +588,7 @@ def __call__( clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + vae_batch_size: int = 16, pag_scale: float = 3.0, pag_adaptive_scale: float = 0.0, ): @@ -831,7 +847,7 @@ def __call__( if output_type == "latent": video = latents else: - video_tensor = self.decode_latents(latents) + video_tensor = self.decode_latents(latents, vae_batch_size) video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index dd636b7ce669..1354ac9ff1a8 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -17,6 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device @@ -401,6 +402,64 @@ def test_free_init_with_schedulers(self): "Enabling of FreeInit should lead to results different from the default pipeline results", ) + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device) + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + + inputs_disable_free_noise = self.get_dummy_inputs(torch_device) + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py index f46e514f9ac0..72315bd0c965 100644 --- a/tests/pipelines/animatediff/test_animatediff_controlnet.py +++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py @@ -18,6 +18,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import logging from diffusers.utils.testing_utils import torch_device @@ -409,6 +410,64 @@ def test_free_init_with_schedulers(self): "Enabling of FreeInit should lead to results different from the default pipeline results", ) + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + + inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) + def test_vae_slicing(self, video_count=2): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index ced042b4a702..cd33bf0891a5 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -17,6 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available, logging from diffusers.utils.testing_utils import torch_device @@ -114,7 +115,7 @@ def get_dummy_components(self): } return components - def get_dummy_inputs(self, device, seed=0): + def get_dummy_inputs(self, device, seed=0, num_frames: int = 2): if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -122,8 +123,7 @@ def get_dummy_inputs(self, device, seed=0): video_height = 32 video_width = 32 - video_num_frames = 2 - video = [Image.new("RGB", (video_width, video_height))] * video_num_frames + video = [Image.new("RGB", (video_width, video_height))] * num_frames inputs = { "video": video, @@ -428,3 +428,66 @@ def test_free_init_with_schedulers(self): 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results", ) + + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_normal["num_inference_steps"] = 2 + inputs_normal["strength"] = 0.5 + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_enable_free_noise["num_inference_steps"] = 2 + inputs_enable_free_noise["strength"] = 0.5 + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16) + inputs_disable_free_noise["num_inference_steps"] = 2 + inputs_disable_free_noise["strength"] = 0.5 + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py index 2cd80921e932..6854fb8b9a2e 100644 --- a/tests/pipelines/pag/test_pag_animatediff.py +++ b/tests/pipelines/pag/test_pag_animatediff.py @@ -17,6 +17,7 @@ UNet2DConditionModel, UNetMotionModel, ) +from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.utils import is_xformers_available from diffusers.utils.testing_utils import torch_device @@ -347,6 +348,64 @@ def test_free_init_with_schedulers(self): "Enabling of FreeInit should lead to results different from the default pipeline results", ) + def test_free_noise_blocks(self): + components = self.get_dummy_components() + pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + pipe.enable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertTrue( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.", + ) + + pipe.disable_free_noise() + for block in pipe.unet.down_blocks: + for motion_module in block.motion_modules: + for transformer_block in motion_module.transformer_blocks: + self.assertFalse( + isinstance(transformer_block, FreeNoiseTransformerBlock), + "Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.", + ) + + def test_free_noise(self): + components = self.get_dummy_components() + pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + inputs_normal = self.get_dummy_inputs(torch_device) + frames_normal = pipe(**inputs_normal).frames[0] + + for context_length in [8, 9]: + for context_stride in [4, 6]: + pipe.enable_free_noise(context_length, context_stride) + + inputs_enable_free_noise = self.get_dummy_inputs(torch_device) + frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0] + + pipe.disable_free_noise() + + inputs_disable_free_noise = self.get_dummy_inputs(torch_device) + frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0] + + sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum() + max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max() + self.assertGreater( + sum_enabled, + 1e1, + "Enabling of FreeNoise should lead to results different from the default pipeline results", + ) + self.assertLess( + max_diff_disabled, + 1e-4, + "Disabling of FreeNoise should lead to results similar to the default pipeline results", + ) + @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed",