From f05ac99eda15d5dee50b7911b2237676707f3239 Mon Sep 17 00:00:00 2001 From: arda-argmax Date: Tue, 13 Aug 2024 15:32:07 -0700 Subject: [PATCH] remove FIXMEs --- python/src/diffusionkit/mlx/__init__.py | 2 +- python/src/diffusionkit/mlx/mmdit.py | 5 +---- python/src/diffusionkit/mlx/tokenizer.py | 2 -- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/src/diffusionkit/mlx/__init__.py b/python/src/diffusionkit/mlx/__init__.py index 300576c..8e066bf 100644 --- a/python/src/diffusionkit/mlx/__init__.py +++ b/python/src/diffusionkit/mlx/__init__.py @@ -420,7 +420,7 @@ def generate_image( logger.info( f"Pre decode active memory: {log['decoding']['pre']['active_memory']}GB" ) - latents = latents.astype(mx.float32) # FIXME + latents = latents.astype(mx.float32) decoded = self.decode_latents_to_image(latents) mx.eval(decoded) diff --git a/python/src/diffusionkit/mlx/mmdit.py b/python/src/diffusionkit/mlx/mmdit.py index bc39098..c383b43 100644 --- a/python/src/diffusionkit/mlx/mmdit.py +++ b/python/src/diffusionkit/mlx/mmdit.py @@ -422,7 +422,6 @@ def rearrange_for_norm(t): if self.config.use_qk_norm: q, k = self.qk_norm(q, k) - # FIXME(arda): if not flux if self.config.depth_unified == 0: q = q.transpose(0, 2, 1, 3).reshape(batch, -1, 1, self.config.hidden_size) k = k.transpose(0, 2, 1, 3).reshape(batch, -1, 1, self.config.hidden_size) @@ -524,7 +523,6 @@ def rearrange_for_sdpa(t): batch, -1, self.config.num_heads, self.per_head_dim ).transpose(0, 2, 1, 3) - # FIXME(arda): if flux if self.config.depth_unified > 0: multimodal_sdpa_inputs = { "q": mx.concatenate( @@ -583,7 +581,6 @@ def rearrange_for_sdpa(t): img_seq_len = latent_image_embeddings.shape[1] txt_seq_len = token_level_text_embeddings.shape[1] - # FIXME(arda): if flux if self.config.depth_unified > 0: text_sdpa_output = sdpa_outputs[:, :txt_seq_len, :, :] image_sdpa_output = sdpa_outputs[:, txt_seq_len:, :, :] @@ -672,7 +669,7 @@ def rearrange_for_sdpa(t): .reshape(batch, -1, 1, self.config.hidden_size) ) - # FIXME(arda): update state dict later + # o_proj and mlp.fc2 uses the same bias, remove mlp.fc2 bias self.transformer_block.mlp.fc2.bias = self.transformer_block.mlp.fc2.bias * 0.0 # Post-SDPA layers diff --git a/python/src/diffusionkit/mlx/tokenizer.py b/python/src/diffusionkit/mlx/tokenizer.py index 3fb61de..c5fb2ba 100644 --- a/python/src/diffusionkit/mlx/tokenizer.py +++ b/python/src/diffusionkit/mlx/tokenizer.py @@ -19,7 +19,6 @@ def __init__(self, bpe_ranks, vocab, pad_with_eos=False): regex.IGNORECASE, ) - # FIXME(arda): Make these configurable self.pad_to_max_length = True self.max_length = 77 @@ -119,7 +118,6 @@ def __init__(self, config: T5Config): model_max_length=getattr(config, "n_positions", 512), ) - # FIXME(arda): Make these configurable self.pad_to_max_length = True self.max_length = 77 self.pad_with_eos = False