diff --git a/python/src/diffusionkit/mlx/__init__.py b/python/src/diffusionkit/mlx/__init__.py index af20441..dceee64 100644 --- a/python/src/diffusionkit/mlx/__init__.py +++ b/python/src/diffusionkit/mlx/__init__.py @@ -467,7 +467,7 @@ def generate_image( logger.info( f"Pre decode active memory: {log['decoding']['pre']['active_memory']}GB" ) - latents = latents.astype(mx.float32) + latents = latents.astype(self.activation_dtype) decoded = self.decode_latents_to_image(latents) mx.eval(decoded) diff --git a/python/src/diffusionkit/mlx/mmdit.py b/python/src/diffusionkit/mlx/mmdit.py index daca926..df65b22 100644 --- a/python/src/diffusionkit/mlx/mmdit.py +++ b/python/src/diffusionkit/mlx/mmdit.py @@ -16,7 +16,7 @@ logger = get_logger(__name__) -SDPA_FLASH_ATTN_THRESHOLD = 1000 +SDPA_FLASH_ATTN_THRESHOLD = 1024 class MMDiT(nn.Module): @@ -218,8 +218,6 @@ def __call__( timestep, positional_encodings=positional_encodings, ) - mx.eval(latent_image_embeddings) - mx.eval(token_level_text_embeddings) # UnifiedTransformerBlock layers if self.config.depth_unified > 0: @@ -449,9 +447,10 @@ def pre_sdpa( # LayerNorm and modulate before SDPA try: modulated_pre_attention = affine_transform( - self.norm1(tensor), + tensor, shift=post_norm1_shift, residual_scale=post_norm1_residual_scale, + norm_module=self.norm1, ) except Exception as e: logger.error( @@ -531,9 +530,10 @@ def post_sdpa( # Apply separate modulation parameters and LayerNorm across attn and mlp mlp_out = self.mlp( affine_transform( - self.norm2(residual), + residual, shift=post_norm2_shift, residual_scale=post_norm2_residual_scale, + norm_module=self.norm2, ) ) return residual + post_mlp_scale * mlp_out @@ -749,8 +749,9 @@ def __init__(self, head_dim): self.k_norm = nn.RMSNorm(head_dim, eps=1e-6) def __call__(self, q: mx.array, k: mx.array) -> Tuple[mx.array, mx.array]: - q = self.q_norm(q.astype(mx.float32)) - k = self.k_norm(k.astype(mx.float32)) + # Note: mlx.nn.RMSNorm has high precision accumulation (does not require upcasting) + q = self.q_norm(q) + k = self.k_norm(k) return q, k @@ -778,9 +779,10 @@ def __call__( shift, residual_scale = mx.split(modulation_params, 2, axis=-1) latent_image_embeddings = affine_transform( - self.norm_final(latent_image_embeddings), + latent_image_embeddings, shift=shift, residual_scale=residual_scale, + norm_module=self.norm_final, ) return self.linear(latent_image_embeddings) @@ -932,9 +934,16 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array: def affine_transform( - x: mx.array, shift: mx.array, residual_scale: mx.array + x: mx.array, + shift: mx.array, + residual_scale: mx.array, + norm_module: nn.Module = None, ) -> mx.array: """Affine transformation (Used for Adaptive LayerNorm Modulation)""" + if norm_module is not None: + return mx.fast.layer_norm( + x, 1.0 + residual_scale.squeeze(), shift.squeeze(), norm_module.eps + ) return x * (1.0 + residual_scale) + shift diff --git a/python/src/diffusionkit/mlx/vae.py b/python/src/diffusionkit/mlx/vae.py index f77fa19..48fd0bb 100644 --- a/python/src/diffusionkit/mlx/vae.py +++ b/python/src/diffusionkit/mlx/vae.py @@ -84,17 +84,15 @@ def __init__( self.conv_shortcut = nn.Linear(in_channels, out_channels) def __call__(self, x, temb=None): - dtype = x.dtype - if temb is not None: temb = self.time_emb_proj(nn.silu(temb)) - y = self.norm1(x.astype(mx.float32)).astype(dtype) + y = self.norm1(x) y = nn.silu(y) y = self.conv1(y) if temb is not None: y = y + temb[:, None, None, :] - y = self.norm2(y.astype(mx.float32)).astype(dtype) + y = self.norm2(y) y = nn.silu(y) y = self.conv2(y) @@ -386,37 +384,19 @@ def __init__( self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) def __call__(self, x): - t = x.dtype x = self.conv_in(x) x = self.mid_blocks[0](x) - if mx.isnan(x).any(): - raise ValueError("NaN detected in VAE Decoder after mid_blocks[0]") - x = x.astype(mx.float32) x = self.mid_blocks[1](x) - if mx.isnan(x).any(): - raise ValueError("NaN detected in VAE Decoder after mid_blocks[1]") - x = x.astype(t) x = self.mid_blocks[2](x) - if mx.isnan(x).any(): - raise ValueError("NaN detected in VAE Decoder after mid_blocks[2]") for l in reversed(self.up_blocks): x = l(x) mx.eval(x) - if mx.isnan(x).any(): - raise ValueError("NaN detected in VAE Decoder after up_blocks") - - x = x.astype(mx.float32) x = self.conv_norm_out(x) - if mx.isnan(x).any(): - raise ValueError("NaN detected in VAE Decoder after conv_norm_out") - x = x.astype(t) x = nn.silu(x) x = self.conv_out(x) - if mx.isnan(x).any(): - raise ValueError("NaN detected in VAE Decoder after conv_out") return x diff --git a/setup.py b/setup.py index 246eac1..ce658ba 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import find_packages, setup from setuptools.command.install import install -VERSION = "0.3.0" +VERSION = "0.3.1" class VersionInstallCommand(install):