Skip to content

Commit

Permalink
fixed nested if
Browse files Browse the repository at this point in the history
  • Loading branch information
arda-argmax committed Aug 27, 2024
1 parent 294ea0f commit d6d85bf
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions python/src/diffusionkit/mlx/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,12 +940,11 @@ def affine_transform(
norm_module: nn.Module = None,
) -> mx.array:
"""Affine transformation (Used for Adaptive LayerNorm Modulation)"""
if x.shape[0] == 1:
if norm_module is not None:
return mx.fast.layer_norm(
x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps
)
if norm_module is not None:
if x.shape[0] == 1 and norm_module is not None:
return mx.fast.layer_norm(
x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps
)
elif norm_module is not None:
return norm_module(x) * (1.0 + residual_scale) + shift
else:
return x * (1.0 + residual_scale) + shift
Expand Down

0 comments on commit d6d85bf

Please sign in to comment.