Skip to content

Commit

Permalink
Merge pull request #15 from argmaxinc/atila/flux_cache
Browse files Browse the repository at this point in the history
FLUX memory optimizations
  • Loading branch information
atiorh authored Aug 14, 2024
2 parents 730a876 + b332a04 commit 76405ec
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 115 deletions.
3 changes: 3 additions & 0 deletions python/src/diffusionkit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
92 changes: 59 additions & 33 deletions python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import gc
import math
import time
from pprint import pprint
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
import numpy as np
from argmaxtools.test_utils import AppleSiliconContextMixin, InferenceContextSpec
from argmaxtools.utils import get_logger
from diffusionkit.utils import bytes2gigabytes
from PIL import Image
Expand All @@ -39,6 +41,14 @@
}


class DiffusionKitInferenceContext(AppleSiliconContextMixin, InferenceContextSpec):
def code_spec(self):
return {}

def model_spec(self):
return {}


class DiffusionPipeline:
def __init__(
self,
Expand Down Expand Up @@ -292,6 +302,11 @@ def generate_image(
logger.info(
f"Pre text encoding active memory: {log['text_encoding']['pre']['active_memory']}GB"
)

# FIXME(arda): Need the same for CLIP models (low memory mode will not succeed a second time otherwise)
if not hasattr(self, "t5"):
self.set_up_t5()

conditioning, pooled_conditioning = self.encode_text(
text, cfg_weight, negative_text
)
Expand Down Expand Up @@ -442,8 +457,19 @@ def generate_image(
logger.info(
f"Post decode active memory: {log['decoding']['post']['active_memory']}GB"
)
logger.info(f"Decoding time: {log['decoding']['time']}s")
logger.info(f"Peak memory: {log['peak_memory']}GB")

logger.info("============= Summary =============")
logger.info(f"Text encoder: {log['text_encoding']['time']:.1f}s")
logger.info(f"Denoising: {log['denoising']['time']:.1f}s")
logger.info(f"Image decoder: {log['decoding']['time']:.1f}s")
logger.info(f"Peak memory: {log['peak_memory']:.1f}GB")

logger.info("============= Inference Context =============")
ic = DiffusionKitInferenceContext()
logger.info("Operating System:")
pprint(ic.os_spec())
logger.info("Device:")
pprint(ic.device_spec())

# unload VAE Decoder model after decoding in low memory mode
if self.low_memory_mode:
Expand All @@ -462,24 +488,14 @@ def generate_image(

return Image.fromarray(np.array(x)), log

def generate_ids(self, latent_size: Tuple[int]):
h, w = latent_size
img_ids = mx.zeros((h // 2, w // 2, 3))
img_ids[..., 1] = img_ids[..., 1] + mx.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + mx.arange(w // 2)[None, :]
img_ids = img_ids.reshape(1, -1, 3)

txt_ids = mx.zeros((1, 256, 3)) # Hardcoded to context length of T5
return img_ids, txt_ids

def read_image(self, image_path: str):
# Read the image
img = Image.open(image_path)

# Make sure image shape is divisible by 64
W, H = (dim - dim % 64 for dim in (img.width, img.height))
if W != img.width or H != img.height:
print(
logger.warning(
f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}"
)
img = img.resize((W, H), Image.LANCZOS) # use desired downsampling filter
Expand Down Expand Up @@ -557,9 +573,6 @@ def __init__(
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
self.latent_format = FluxLatentFormat()

if not use_t5:
logger.warning("FLUX can not be used without T5. Loading T5..")
self.use_t5 = True

self.clip_l = load_text_encoder(
Expand Down Expand Up @@ -615,8 +628,22 @@ def __init__(self, model: DiffusionPipeline):
super().__init__()
self.model = model

def cache_modulation_params(self, pooled_text_embeddings, sigmas):
self.model.mmdit.cache_modulation_params(
pooled_text_embeddings, sigmas.astype(self.model.activation_dtype)
)

def clear_cache(self):
self.model.mmdit.clear_modulation_params_cache()

def __call__(
self, x_t, t, conditioning, cfg_weight: float = 7.5, pooled_conditioning=None
self,
x_t,
timestep,
sigma,
conditioning,
cfg_weight: float = 7.5,
pooled_conditioning=None,
):
if cfg_weight <= 0:
logger.debug("CFG Weight disabled")
Expand All @@ -625,23 +652,14 @@ def __call__(
x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype(
self.model.activation_dtype
)
t_mmdit = mx.broadcast_to(t, [len(x_t_mmdit)])
timestep = self.model.sampler.timestep(t_mmdit).astype(
self.model.activation_dtype
)
mmdit_input = {
"latent_image_embeddings": x_t_mmdit,
"token_level_text_embeddings": mx.expand_dims(conditioning, 2),
"pooled_text_embeddings": mx.expand_dims(
mx.expand_dims(pooled_conditioning, 1), 1
),
"timestep": timestep,
"timestep": mx.broadcast_to(timestep, [len(x_t_mmdit)]),
}

mmdit_output = self.model.mmdit(**mmdit_input)
eps_pred = self.model.sampler.calculate_denoised(
t_mmdit, mmdit_output, x_t_mmdit
)
eps_pred = self.model.sampler.calculate_denoised(sigma, mmdit_output, x_t_mmdit)
if cfg_weight <= 0:
return eps_pred
else:
Expand Down Expand Up @@ -691,20 +709,28 @@ def to_d(x, sigma, denoised):
def sample_euler(model: CFGDenoiser, x, sigmas, extra_args=None):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = mx.ones([x.shape[0]])

from tqdm import trange

t = trange(len(sigmas) - 1)

timesteps = model.model.sampler.timestep(sigmas).astype(
model.model.activation_dtype
)
model.cache_modulation_params(extra_args.pop("pooled_conditioning"), timesteps)

iter_time = []
for i in t:
start_time = t.format_dict["elapsed"]
sigma_hat = sigmas[i]
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
dt = sigmas[i + 1] - sigma_hat
denoised = model(x, timesteps[i], sigmas[i], **extra_args)
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
# Euler method
x = x + d * dt
mx.eval(x)
end_time = t.format_dict["elapsed"]
iter_time.append(round((end_time - start_time), 3))

# model.clear_cache()

return x, iter_time
Loading

0 comments on commit 76405ec

Please sign in to comment.