Skip to content

Commit

Permalink
Rename LTX blocks and docs title (#10213)
Browse files Browse the repository at this point in the history
* rename blocks and docs

* fix docs

---------

Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
a-r-r-o-w and DN6 authored Dec 23, 2024
1 parent 055d955 commit 9d27df8
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 48 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
title: LTXVideo
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/models/autoencoderkl_ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTXVideo

vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```

## AutoencoderKLLTXVideo
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/models/ltx_video_transformer3d.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import LTXVideoTransformer3DModel

transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```

## LTXVideoTransformer3DModel
Expand Down
75 changes: 38 additions & 37 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .vae import DecoderOutput, DiagonalGaussianDistribution


class LTXCausalConv3d(nn.Module):
class LTXVideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class LTXResnetBlock3d(nn.Module):
class LTXVideoResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX model.
A 3D ResNet block used in the LTXVideo model.
Args:
in_channels (`int`):
Expand Down Expand Up @@ -117,21 +117,21 @@ def __init__(
self.nonlinearity = get_activation(non_linearity)

self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.conv1 = LTXCausalConv3d(
self.conv1 = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)

self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.dropout = nn.Dropout(dropout)
self.conv2 = LTXCausalConv3d(
self.conv2 = LTXVideoCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)

self.norm3 = None
self.conv_shortcut = None
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
self.conv_shortcut = LTXCausalConv3d(
self.conv_shortcut = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
)

Expand All @@ -157,7 +157,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return hidden_states


class LTXUpsampler3d(nn.Module):
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
Expand All @@ -170,7 +170,7 @@ def __init__(

out_channels = in_channels * stride[0] * stride[1] * stride[2]

self.conv = LTXCausalConv3d(
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
Expand All @@ -191,9 +191,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class LTXDownBlock3D(nn.Module):
class LTXVideoDownBlock3D(nn.Module):
r"""
Down block used in the LTX model.
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
Expand All @@ -250,7 +250,7 @@ def __init__(
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList(
[
LTXCausalConv3d(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
Expand All @@ -262,7 +262,7 @@ def __init__(

self.conv_out = None
if in_channels != out_channels:
self.conv_out = LTXResnetBlock3d(
self.conv_out = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
Expand Down Expand Up @@ -300,9 +300,9 @@ def create_forward(*inputs):


# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXMidBlock3d(nn.Module):
class LTXVideoMidBlock3d(nn.Module):
r"""
A middle block used in the LTX model.
A middle block used in the LTXVideo model.
Args:
in_channels (`int`):
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
Expand Down Expand Up @@ -367,9 +367,9 @@ def create_forward(*inputs):
return hidden_states


class LTXUpBlock3d(nn.Module):
class LTXVideoUpBlock3d(nn.Module):
r"""
Up block used in the LTX model.
Up block used in the LTXVideo model.
Args:
in_channels (`int`):
Expand Down Expand Up @@ -410,7 +410,7 @@ def __init__(

self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTXResnetBlock3d(
self.conv_in = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
Expand All @@ -421,12 +421,12 @@ def __init__(

self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])

resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
Expand Down Expand Up @@ -463,9 +463,9 @@ def create_forward(*inputs):
return hidden_states


class LTXEncoder3d(nn.Module):
class LTXVideoEncoder3d(nn.Module):
r"""
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
representation.
Args:
Expand Down Expand Up @@ -509,7 +509,7 @@ def __init__(

output_channel = block_out_channels[0]

self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
Expand All @@ -524,7 +524,7 @@ def __init__(
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]

down_block = LTXDownBlock3D(
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
Expand All @@ -536,7 +536,7 @@ def __init__(
self.down_blocks.append(down_block)

# mid block
self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
Expand All @@ -546,14 +546,14 @@ def __init__(
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
)

self.gradient_checkpointing = False

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `LTXEncoder3D` class."""
r"""The forward method of the `LTXVideoEncoder3d` class."""

p = self.patch_size
p_t = self.patch_size_t
Expand Down Expand Up @@ -599,9 +599,10 @@ def create_forward(*inputs):
return hidden_states


class LTXDecoder3d(nn.Module):
class LTXVideoDecoder3d(nn.Module):
r"""
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, defaults to 128):
Expand Down Expand Up @@ -647,11 +648,11 @@ def __init__(
layers_per_block = tuple(reversed(layers_per_block))
output_channel = block_out_channels[0]

self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
)

self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
)

Expand All @@ -662,7 +663,7 @@ def __init__(
input_channel = output_channel
output_channel = block_out_channels[i]

up_block = LTXUpBlock3d(
up_block = LTXVideoUpBlock3d(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i + 1],
Expand All @@ -676,7 +677,7 @@ def __init__(
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
)

Expand Down Expand Up @@ -777,7 +778,7 @@ def __init__(
) -> None:
super().__init__()

self.encoder = LTXEncoder3d(
self.encoder = LTXVideoEncoder3d(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
Expand All @@ -788,7 +789,7 @@ def __init__(
resnet_norm_eps=resnet_norm_eps,
is_causal=encoder_causal,
)
self.decoder = LTXDecoder3d(
self.decoder = LTXVideoDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
Expand Down Expand Up @@ -837,7 +838,7 @@ def __init__(
self.tile_sample_stride_width = 448

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
module.gradient_checkpointing = value

def enable_tiling(
Expand Down
16 changes: 8 additions & 8 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class LTXAttentionProcessor2_0:
class LTXVideoAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
Expand All @@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
Expand Down Expand Up @@ -92,7 +92,7 @@ def __call__(
return hidden_states


class LTXRotaryPosEmbed(nn.Module):
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
dim: int,
Expand Down Expand Up @@ -164,7 +164,7 @@ def forward(


@maybe_allow_in_graph
class LTXTransformerBlock(nn.Module):
class LTXVideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
Expand Down Expand Up @@ -208,7 +208,7 @@ def __init__(
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)

self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
Expand All @@ -221,7 +221,7 @@ def __init__(
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)

self.ff = FeedForward(dim, activation_fn=activation_fn)
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(

self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)

self.rope = LTXRotaryPosEmbed(
self.rope = LTXVideoRotaryPosEmbed(
dim=inner_dim,
base_num_frames=20,
base_height=2048,
Expand All @@ -339,7 +339,7 @@ def __init__(

self.transformer_blocks = nn.ModuleList(
[
LTXTransformerBlock(
LTXVideoTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
Expand Down

0 comments on commit 9d27df8

Please sign in to comment.