Skip to content

Commit

Permalink
remove FIXMEs
Browse files Browse the repository at this point in the history
  • Loading branch information
arda-argmax committed Aug 13, 2024
1 parent ca313d7 commit f05ac99
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def generate_image(
logger.info(
f"Pre decode active memory: {log['decoding']['pre']['active_memory']}GB"
)
latents = latents.astype(mx.float32) # FIXME
latents = latents.astype(mx.float32)
decoded = self.decode_latents_to_image(latents)
mx.eval(decoded)

Expand Down
5 changes: 1 addition & 4 deletions python/src/diffusionkit/mlx/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def rearrange_for_norm(t):
if self.config.use_qk_norm:
q, k = self.qk_norm(q, k)

# FIXME(arda): if not flux
if self.config.depth_unified == 0:
q = q.transpose(0, 2, 1, 3).reshape(batch, -1, 1, self.config.hidden_size)
k = k.transpose(0, 2, 1, 3).reshape(batch, -1, 1, self.config.hidden_size)
Expand Down Expand Up @@ -524,7 +523,6 @@ def rearrange_for_sdpa(t):
batch, -1, self.config.num_heads, self.per_head_dim
).transpose(0, 2, 1, 3)

# FIXME(arda): if flux
if self.config.depth_unified > 0:
multimodal_sdpa_inputs = {
"q": mx.concatenate(
Expand Down Expand Up @@ -583,7 +581,6 @@ def rearrange_for_sdpa(t):
img_seq_len = latent_image_embeddings.shape[1]
txt_seq_len = token_level_text_embeddings.shape[1]

# FIXME(arda): if flux
if self.config.depth_unified > 0:
text_sdpa_output = sdpa_outputs[:, :txt_seq_len, :, :]
image_sdpa_output = sdpa_outputs[:, txt_seq_len:, :, :]
Expand Down Expand Up @@ -672,7 +669,7 @@ def rearrange_for_sdpa(t):
.reshape(batch, -1, 1, self.config.hidden_size)
)

# FIXME(arda): update state dict later
# o_proj and mlp.fc2 uses the same bias, remove mlp.fc2 bias
self.transformer_block.mlp.fc2.bias = self.transformer_block.mlp.fc2.bias * 0.0

# Post-SDPA layers
Expand Down
2 changes: 0 additions & 2 deletions python/src/diffusionkit/mlx/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(self, bpe_ranks, vocab, pad_with_eos=False):
regex.IGNORECASE,
)

# FIXME(arda): Make these configurable
self.pad_to_max_length = True
self.max_length = 77

Expand Down Expand Up @@ -119,7 +118,6 @@ def __init__(self, config: T5Config):
model_max_length=getattr(config, "n_positions", 512),
)

# FIXME(arda): Make these configurable
self.pad_to_max_length = True
self.max_length = 77
self.pad_with_eos = False
Expand Down

0 comments on commit f05ac99

Please sign in to comment.