Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename LTX blocks and docs title #10213

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
title: LTXVideo
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/models/autoencoderkl_ltx_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTXVideo

vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```

## AutoencoderKLLTXVideo
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/models/ltx_video_transformer3d.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import LTXVideoTransformer3DModel

transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```

## LTXVideoTransformer3DModel
Expand Down
75 changes: 38 additions & 37 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .vae import DecoderOutput, DiagonalGaussianDistribution


class LTXCausalConv3d(nn.Module):
class LTXVideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -79,9 +79,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class LTXResnetBlock3d(nn.Module):
class LTXVideoResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX model.
A 3D ResNet block used in the LTXVideo model.

Args:
in_channels (`int`):
Expand Down Expand Up @@ -117,21 +117,21 @@ def __init__(
self.nonlinearity = get_activation(non_linearity)

self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.conv1 = LTXCausalConv3d(
self.conv1 = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)

self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.dropout = nn.Dropout(dropout)
self.conv2 = LTXCausalConv3d(
self.conv2 = LTXVideoCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)

self.norm3 = None
self.conv_shortcut = None
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
self.conv_shortcut = LTXCausalConv3d(
self.conv_shortcut = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
)

Expand All @@ -157,7 +157,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return hidden_states


class LTXUpsampler3d(nn.Module):
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
Expand All @@ -170,7 +170,7 @@ def __init__(

out_channels = in_channels * stride[0] * stride[1] * stride[2]

self.conv = LTXCausalConv3d(
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
Expand All @@ -191,9 +191,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class LTXDownBlock3D(nn.Module):
class LTXVideoDownBlock3D(nn.Module):
r"""
Down block used in the LTX model.
Down block used in the LTXVideo model.

Args:
in_channels (`int`):
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
Expand All @@ -250,7 +250,7 @@ def __init__(
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList(
[
LTXCausalConv3d(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
Expand All @@ -262,7 +262,7 @@ def __init__(

self.conv_out = None
if in_channels != out_channels:
self.conv_out = LTXResnetBlock3d(
self.conv_out = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
Expand Down Expand Up @@ -300,9 +300,9 @@ def create_forward(*inputs):


# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXMidBlock3d(nn.Module):
class LTXVideoMidBlock3d(nn.Module):
r"""
A middle block used in the LTX model.
A middle block used in the LTXVideo model.

Args:
in_channels (`int`):
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
Expand Down Expand Up @@ -367,9 +367,9 @@ def create_forward(*inputs):
return hidden_states


class LTXUpBlock3d(nn.Module):
class LTXVideoUpBlock3d(nn.Module):
r"""
Up block used in the LTX model.
Up block used in the LTXVideo model.

Args:
in_channels (`int`):
Expand Down Expand Up @@ -410,7 +410,7 @@ def __init__(

self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTXResnetBlock3d(
self.conv_in = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
Expand All @@ -421,12 +421,12 @@ def __init__(

self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])

resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
Expand Down Expand Up @@ -463,9 +463,9 @@ def create_forward(*inputs):
return hidden_states


class LTXEncoder3d(nn.Module):
class LTXVideoEncoder3d(nn.Module):
r"""
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
representation.

Args:
Expand Down Expand Up @@ -509,7 +509,7 @@ def __init__(

output_channel = block_out_channels[0]

self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
Expand All @@ -524,7 +524,7 @@ def __init__(
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]

down_block = LTXDownBlock3D(
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
Expand All @@ -536,7 +536,7 @@ def __init__(
self.down_blocks.append(down_block)

# mid block
self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
Expand All @@ -546,14 +546,14 @@ def __init__(
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
)

self.gradient_checkpointing = False

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `LTXEncoder3D` class."""
r"""The forward method of the `LTXVideoEncoder3d` class."""

p = self.patch_size
p_t = self.patch_size_t
Expand Down Expand Up @@ -599,9 +599,10 @@ def create_forward(*inputs):
return hidden_states


class LTXDecoder3d(nn.Module):
class LTXVideoDecoder3d(nn.Module):
r"""
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
sample.

Args:
in_channels (`int`, defaults to 128):
Expand Down Expand Up @@ -647,11 +648,11 @@ def __init__(
layers_per_block = tuple(reversed(layers_per_block))
output_channel = block_out_channels[0]

self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
)

self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
)

Expand All @@ -662,7 +663,7 @@ def __init__(
input_channel = output_channel
output_channel = block_out_channels[i]

up_block = LTXUpBlock3d(
up_block = LTXVideoUpBlock3d(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i + 1],
Expand All @@ -676,7 +677,7 @@ def __init__(
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
)

Expand Down Expand Up @@ -777,7 +778,7 @@ def __init__(
) -> None:
super().__init__()

self.encoder = LTXEncoder3d(
self.encoder = LTXVideoEncoder3d(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
Expand All @@ -788,7 +789,7 @@ def __init__(
resnet_norm_eps=resnet_norm_eps,
is_causal=encoder_causal,
)
self.decoder = LTXDecoder3d(
self.decoder = LTXVideoDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
Expand Down Expand Up @@ -837,7 +838,7 @@ def __init__(
self.tile_sample_stride_width = 448

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
module.gradient_checkpointing = value

def enable_tiling(
Expand Down
16 changes: 8 additions & 8 deletions src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class LTXAttentionProcessor2_0:
class LTXVideoAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
Expand All @@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

def __call__(
Expand Down Expand Up @@ -92,7 +92,7 @@ def __call__(
return hidden_states


class LTXRotaryPosEmbed(nn.Module):
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
dim: int,
Expand Down Expand Up @@ -164,7 +164,7 @@ def forward(


@maybe_allow_in_graph
class LTXTransformerBlock(nn.Module):
class LTXVideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).

Expand Down Expand Up @@ -208,7 +208,7 @@ def __init__(
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)

self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
Expand All @@ -221,7 +221,7 @@ def __init__(
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)

self.ff = FeedForward(dim, activation_fn=activation_fn)
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(

self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)

self.rope = LTXRotaryPosEmbed(
self.rope = LTXVideoRotaryPosEmbed(
dim=inner_dim,
base_num_frames=20,
base_height=2048,
Expand All @@ -339,7 +339,7 @@ def __init__(

self.transformer_blocks = nn.ModuleList(
[
LTXTransformerBlock(
LTXVideoTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
Expand Down
Loading