diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 6ac66db73026..134a127d4320 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -429,7 +429,7 @@ - local: api/pipelines/ledits_pp title: LEDITS++ - local: api/pipelines/ltx_video - title: LTX + title: LTXVideo - local: api/pipelines/lumina title: Lumina-T2X - local: api/pipelines/marigold diff --git a/docs/source/en/api/models/autoencoderkl_ltx_video.md b/docs/source/en/api/models/autoencoderkl_ltx_video.md index 694b5ace6fdf..fbdb11e29cdd 100644 --- a/docs/source/en/api/models/autoencoderkl_ltx_video.md +++ b/docs/source/en/api/models/autoencoderkl_ltx_video.md @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import AutoencoderKLLTXVideo -vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda") +vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda") ``` ## AutoencoderKLLTXVideo diff --git a/docs/source/en/api/models/ltx_video_transformer3d.md b/docs/source/en/api/models/ltx_video_transformer3d.md index 8a60bc0432c6..fe2664cf685c 100644 --- a/docs/source/en/api/models/ltx_video_transformer3d.md +++ b/docs/source/en/api/models/ltx_video_transformer3d.md @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet. ```python from diffusers import LTXVideoTransformer3DModel -transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") ``` ## LTXVideoTransformer3DModel diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index ff202b980b95..a6cb943e09cc 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -28,7 +28,7 @@ from .vae import DecoderOutput, DiagonalGaussianDistribution -class LTXCausalConv3d(nn.Module): +class LTXVideoCausalConv3d(nn.Module): def __init__( self, in_channels: int, @@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class LTXResnetBlock3d(nn.Module): +class LTXVideoResnetBlock3d(nn.Module): r""" - A 3D ResNet block used in the LTX model. + A 3D ResNet block used in the LTXVideo model. Args: in_channels (`int`): @@ -117,13 +117,13 @@ def __init__( self.nonlinearity = get_activation(non_linearity) self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) - self.conv1 = LTXCausalConv3d( + self.conv1 = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal ) self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXCausalConv3d( + self.conv2 = LTXVideoCausalConv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal ) @@ -131,7 +131,7 @@ def __init__( self.conv_shortcut = None if in_channels != out_channels: self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) - self.conv_shortcut = LTXCausalConv3d( + self.conv_shortcut = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) @@ -157,7 +157,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return hidden_states -class LTXUpsampler3d(nn.Module): +class LTXVideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, @@ -170,7 +170,7 @@ def __init__( out_channels = in_channels * stride[0] * stride[1] * stride[2] - self.conv = LTXCausalConv3d( + self.conv = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, @@ -191,9 +191,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class LTXDownBlock3D(nn.Module): +class LTXVideoDownBlock3D(nn.Module): r""" - Down block used in the LTX model. + Down block used in the LTXVideo model. Args: in_channels (`int`): @@ -235,7 +235,7 @@ def __init__( resnets = [] for _ in range(num_layers): resnets.append( - LTXResnetBlock3d( + LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=in_channels, dropout=dropout, @@ -250,7 +250,7 @@ def __init__( if spatio_temporal_scale: self.downsamplers = nn.ModuleList( [ - LTXCausalConv3d( + LTXVideoCausalConv3d( in_channels=in_channels, out_channels=in_channels, kernel_size=3, @@ -262,7 +262,7 @@ def __init__( self.conv_out = None if in_channels != out_channels: - self.conv_out = LTXResnetBlock3d( + self.conv_out = LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=out_channels, dropout=dropout, @@ -300,9 +300,9 @@ def create_forward(*inputs): # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d -class LTXMidBlock3d(nn.Module): +class LTXVideoMidBlock3d(nn.Module): r""" - A middle block used in the LTX model. + A middle block used in the LTXVideo model. Args: in_channels (`int`): @@ -335,7 +335,7 @@ def __init__( resnets = [] for _ in range(num_layers): resnets.append( - LTXResnetBlock3d( + LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=in_channels, dropout=dropout, @@ -367,9 +367,9 @@ def create_forward(*inputs): return hidden_states -class LTXUpBlock3d(nn.Module): +class LTXVideoUpBlock3d(nn.Module): r""" - Up block used in the LTX model. + Up block used in the LTXVideo model. Args: in_channels (`int`): @@ -410,7 +410,7 @@ def __init__( self.conv_in = None if in_channels != out_channels: - self.conv_in = LTXResnetBlock3d( + self.conv_in = LTXVideoResnetBlock3d( in_channels=in_channels, out_channels=out_channels, dropout=dropout, @@ -421,12 +421,12 @@ def __init__( self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) + self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) resnets = [] for _ in range(num_layers): resnets.append( - LTXResnetBlock3d( + LTXVideoResnetBlock3d( in_channels=out_channels, out_channels=out_channels, dropout=dropout, @@ -463,9 +463,9 @@ def create_forward(*inputs): return hidden_states -class LTXEncoder3d(nn.Module): +class LTXVideoEncoder3d(nn.Module): r""" - The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent + The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent representation. Args: @@ -509,7 +509,7 @@ def __init__( output_channel = block_out_channels[0] - self.conv_in = LTXCausalConv3d( + self.conv_in = LTXVideoCausalConv3d( in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, @@ -524,7 +524,7 @@ def __init__( input_channel = output_channel output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] - down_block = LTXDownBlock3D( + down_block = LTXVideoDownBlock3D( in_channels=input_channel, out_channels=output_channel, num_layers=layers_per_block[i], @@ -536,7 +536,7 @@ def __init__( self.down_blocks.append(down_block) # mid block - self.mid_block = LTXMidBlock3d( + self.mid_block = LTXVideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, @@ -546,14 +546,14 @@ def __init__( # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXCausalConv3d( + self.conv_out = LTXVideoCausalConv3d( in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal ) self.gradient_checkpointing = False def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - r"""The forward method of the `LTXEncoder3D` class.""" + r"""The forward method of the `LTXVideoEncoder3d` class.""" p = self.patch_size p_t = self.patch_size_t @@ -599,9 +599,10 @@ def create_forward(*inputs): return hidden_states -class LTXDecoder3d(nn.Module): +class LTXVideoDecoder3d(nn.Module): r""" - The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. + The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output + sample. Args: in_channels (`int`, defaults to 128): @@ -647,11 +648,11 @@ def __init__( layers_per_block = tuple(reversed(layers_per_block)) output_channel = block_out_channels[0] - self.conv_in = LTXCausalConv3d( + self.conv_in = LTXVideoCausalConv3d( in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal ) - self.mid_block = LTXMidBlock3d( + self.mid_block = LTXVideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal ) @@ -662,7 +663,7 @@ def __init__( input_channel = output_channel output_channel = block_out_channels[i] - up_block = LTXUpBlock3d( + up_block = LTXVideoUpBlock3d( in_channels=input_channel, out_channels=output_channel, num_layers=layers_per_block[i + 1], @@ -676,7 +677,7 @@ def __init__( # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXCausalConv3d( + self.conv_out = LTXVideoCausalConv3d( in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal ) @@ -777,7 +778,7 @@ def __init__( ) -> None: super().__init__() - self.encoder = LTXEncoder3d( + self.encoder = LTXVideoEncoder3d( in_channels=in_channels, out_channels=latent_channels, block_out_channels=block_out_channels, @@ -788,7 +789,7 @@ def __init__( resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, ) - self.decoder = LTXDecoder3d( + self.decoder = LTXVideoDecoder3d( in_channels=latent_channels, out_channels=out_channels, block_out_channels=block_out_channels, @@ -837,7 +838,7 @@ def __init__( self.tile_sample_stride_width = 448 def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LTXEncoder3d, LTXDecoder3d)): + if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)): module.gradient_checkpointing = value def enable_tiling( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2ed8520a5d75..a895340bd124 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class LTXAttentionProcessor2_0: +class LTXVideoAttentionProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. @@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( - "LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) def __call__( @@ -92,7 +92,7 @@ def __call__( return hidden_states -class LTXRotaryPosEmbed(nn.Module): +class LTXVideoRotaryPosEmbed(nn.Module): def __init__( self, dim: int, @@ -164,7 +164,7 @@ def forward( @maybe_allow_in_graph -class LTXTransformerBlock(nn.Module): +class LTXVideoTransformerBlock(nn.Module): r""" Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). @@ -208,7 +208,7 @@ def __init__( cross_attention_dim=None, out_bias=attention_out_bias, qk_norm=qk_norm, - processor=LTXAttentionProcessor2_0(), + processor=LTXVideoAttentionProcessor2_0(), ) self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) @@ -221,7 +221,7 @@ def __init__( bias=attention_bias, out_bias=attention_out_bias, qk_norm=qk_norm, - processor=LTXAttentionProcessor2_0(), + processor=LTXVideoAttentionProcessor2_0(), ) self.ff = FeedForward(dim, activation_fn=activation_fn) @@ -327,7 +327,7 @@ def __init__( self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) - self.rope = LTXRotaryPosEmbed( + self.rope = LTXVideoRotaryPosEmbed( dim=inner_dim, base_num_frames=20, base_height=2048, @@ -339,7 +339,7 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - LTXTransformerBlock( + LTXVideoTransformerBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim,