Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLUX #11

Merged
merged 46 commits into from
Aug 14, 2024
Merged

FLUX #11

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e0b66c2
editing MMDiT for flux models
arda-argmax Aug 2, 2024
0d55e93
compatibility updates for flux
arda-argmax Aug 2, 2024
bac40c2
add pe_embedder, not wired yet
arda-argmax Aug 2, 2024
c624140
config list -> tuple
arda-argmax Aug 2, 2024
46a78cc
Use RMSNorm for QKNorm
atiorh Aug 3, 2024
ecbe2dd
Patchify via reshape
atiorh Aug 3, 2024
38e8ba6
Remove unused config and model attributes
atiorh Aug 3, 2024
8fdec91
WIP RoPE embeddings, not integrated with UniModalTransformerBlock yet
atiorh Aug 3, 2024
5c3a9f5
wIP: Text encoder scheme differences
atiorh Aug 3, 2024
ee2dc25
mmdit hidden_size is configurable
arda-argmax Aug 5, 2024
94b13e1
update hidden_size config for mmdit
arda-argmax Aug 5, 2024
d5c2165
Fix style and typing
atiorh Aug 5, 2024
dba0ed0
clean up layer idx counts
atiorh Aug 5, 2024
c771a0c
Parallel mlp and attn blocks
atiorh Aug 6, 2024
43c9b02
Fix style
atiorh Aug 6, 2024
5a0feb1
WIP: flux state dict adjustments
arda-argmax Aug 6, 2024
e0a58ad
added text conditioning for flux
arda-argmax Aug 6, 2024
0e98141
small fix for flux model init and black formatting
arda-argmax Aug 6, 2024
5490132
FIXME reminder for flux text encoding
arda-argmax Aug 6, 2024
f95fe3b
Fix forward pass
atiorh Aug 6, 2024
2d093fc
remove o_proj from UnifiedTransformerBlock
arda-argmax Aug 6, 2024
d3fbf2d
o_proj is required in UnifiedTransformerBlock
arda-argmax Aug 7, 2024
028c97d
no skip_text_post_sdpa for flux
arda-argmax Aug 7, 2024
27f818a
Fix modulation params
atiorh Aug 7, 2024
0818847
Fix final slicing
atiorh Aug 7, 2024
b79f930
added load_flux()
arda-argmax Aug 7, 2024
bd0f5da
img pos embed fix
arda-argmax Aug 7, 2024
062ac29
bug fix and unpack latent image for flux
arda-argmax Aug 8, 2024
fbf34dc
fix qk norm
arda-argmax Aug 9, 2024
3e341e9
k_proj bias is needed for qk norm
arda-argmax Aug 9, 2024
7c075c3
fix apply rope
arda-argmax Aug 9, 2024
b7c4fad
fix x_pos_embedder for sd3
arda-argmax Aug 9, 2024
628c0c1
fix flux layer sdpa
arda-argmax Aug 10, 2024
8de2028
SD3 works again
arda-argmax Aug 10, 2024
de5c542
flux working
arda-argmax Aug 10, 2024
0c3aaf8
flux repo name change and wire up flux pipeline to cli
arda-argmax Aug 10, 2024
411ca34
bfloat16 activation and weight loading support
arda-argmax Aug 10, 2024
221aeed
num steps and sampler fix
arda-argmax Aug 12, 2024
3e6efb9
float16 for sd3, bfloat16 for flux support
arda-argmax Aug 13, 2024
f97b61a
divided flux pipeline from diffusion pipeline
arda-argmax Aug 13, 2024
8d610b8
can count number of model downloads
arda-argmax Aug 13, 2024
925cc51
low memory mode reduces peak memory for sdpa
arda-argmax Aug 13, 2024
ca313d7
change sdpa flash attn threshold
arda-argmax Aug 13, 2024
f05ac99
remove FIXMEs
arda-argmax Aug 13, 2024
49fd435
update version
arda-argmax Aug 14, 2024
eb8ae9a
Clean up
atiorh Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pipeline = DiffusionPipeline(
w16=True,
shift=3.0,
use_t5=False,
model_size="2b",
model_version="2b",
low_memory_mode=False,
a16=True,
)
Expand Down
182 changes: 153 additions & 29 deletions python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .model_io import (
_DEFAULT_MODEL,
load_flux,
load_mmdit,
load_t5_encoder,
load_t5_tokenizer,
Expand All @@ -27,13 +28,14 @@
load_vae_decoder,
load_vae_encoder,
)
from .sampler import ModelSamplingDiscreteFlow
from .sampler import FluxSampler, ModelSamplingDiscreteFlow

logger = get_logger(__name__)

MMDIT_CKPT = {
"2b": "mmdit_2b",
"8b": "models/sd3_8b_beta.safetensors",
"stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium",
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
}


Expand All @@ -44,21 +46,29 @@ def __init__(
w16: bool = False,
shift: float = 1.0,
use_t5: bool = True,
model_size: str = "2b",
model_version: str = "stable-diffusion-3-medium",
low_memory_mode: bool = True,
a16: bool = False,
local_ckpt=None,
):
model_io.LOCAl_SD3_CKPT = local_ckpt
self.dtype = mx.float16 if w16 else mx.float32
self.activation_dtype = mx.float16 if a16 else mx.float32
self.float16_dtype = mx.float16
model_io._FLOAT16 = self.float16_dtype
self.dtype = self.float16_dtype if w16 else mx.float32
self.activation_dtype = self.float16_dtype if a16 else mx.float32
self.use_t5 = use_t5
mmdit_ckpt = MMDIT_CKPT[model_size]
mmdit_ckpt = MMDIT_CKPT[model_version]
self.low_memory_mode = low_memory_mode
self.mmdit = load_mmdit(float16=w16, model_key=mmdit_ckpt)
self.mmdit = load_mmdit(
float16=w16,
key=mmdit_ckpt,
model_key=model_version,
low_memory_mode=low_memory_mode,
)
self.sampler = ModelSamplingDiscreteFlow(shift=shift)
self.decoder = load_vae_decoder(float16=w16)
self.encoder = load_vae_encoder(float16=False)
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
self.latent_format = SD3LatentFormat()

self.clip_l = load_text_encoder(
model,
Expand Down Expand Up @@ -90,7 +100,7 @@ def __init__(
def set_up_t5(self):
if self.t5_encoder is None:
self.t5_encoder = load_t5_encoder(
float16=True if self.dtype == mx.float16 else False,
float16=True if self.dtype == self.float16_dtype else False,
low_memory_mode=self.low_memory_mode,
)
if self.t5_tokenizer is None:
Expand All @@ -110,9 +120,10 @@ def unload_t5(self):
def ensure_models_are_loaded(self):
mx.eval(self.mmdit.parameters())
mx.eval(self.clip_l.parameters())
mx.eval(self.clip_g.parameters())
mx.eval(self.decoder.parameters())
if self.use_t5:
if hasattr(self, "clip_g"):
mx.eval(self.clip_g.parameters())
if hasattr(self, "t5_encoder") and self.use_t5:
mx.eval(self.t5_encoder.parameters())

def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None):
Expand Down Expand Up @@ -213,7 +224,7 @@ def denoise_latents(
denoise = 1.0
else:
x_T = self.encode_image_to_latents(image_path, seed=seed)
x_T = SD3LatentFormat().process_in(x_T)
x_T = self.latent_format.process_in(x_T)
noise = self.get_noise(seed, x_T)
sigmas = self.get_sigmas(self.sampler, num_steps)
sigmas = sigmas[int(num_steps * (1 - denoise)) :]
Expand All @@ -228,7 +239,9 @@ def denoise_latents(
latent, iter_time = sample_euler(
CFGDenoiser(self), noise_scaled, sigmas, extra_args=extra_args
)
latent = SD3LatentFormat().process_out(latent)

latent = self.latent_format.process_out(latent)

return latent, iter_time

def generate_image(
Expand Down Expand Up @@ -305,9 +318,11 @@ def generate_image(

# unload T5 and CLIP models after obtaining conditioning in low memory mode
if self.low_memory_mode:
del self.clip_g
if hasattr(self, "t5_encoder"):
del self.t5_encoder
if hasattr(self, "clip_g"):
del self.clip_g
del self.clip_l
del self.t5_encoder
gc.collect()

logger.debug(f"Conditioning dtype before casting: {conditioning.dtype}")
Expand Down Expand Up @@ -406,7 +421,7 @@ def generate_image(
logger.info(
f"Pre decode active memory: {log['decoding']['pre']['active_memory']}GB"
)

latents = latents.astype(mx.float32)
decoded = self.decode_latents_to_image(latents)
mx.eval(decoded)

Expand Down Expand Up @@ -447,6 +462,16 @@ def generate_image(

return Image.fromarray(np.array(x)), log

def generate_ids(self, latent_size: Tuple[int]):
h, w = latent_size
img_ids = mx.zeros((h // 2, w // 2, 3))
img_ids[..., 1] = img_ids[..., 1] + mx.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + mx.arange(w // 2)[None, :]
img_ids = img_ids.reshape(1, -1, 3)

txt_ids = mx.zeros((1, 256, 3)) # Hardcoded to context length of T5
return img_ids, txt_ids

def read_image(self, image_path: str):
# Read the image
img = Image.open(image_path)
Expand All @@ -473,12 +498,15 @@ def get_noise(self, seed, x_T):
def get_sigmas(self, sampler, num_steps: int):
start = sampler.timestep(sampler.sigma_max).item()
end = sampler.timestep(sampler.sigma_min).item()
if isinstance(sampler, FluxSampler):
num_steps += 1
timesteps = mx.linspace(start, end, num_steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampler.sigma(ts))
sigs += [0.0]
if not isinstance(sampler, FluxSampler):
sigs += [0.0]
return mx.array(sigs)

def get_empty_latent(self, *shape):
Expand All @@ -505,6 +533,81 @@ def encode_image_to_latents(self, image_path: str, seed):
return mean + std * noise


class FluxPipeline(DiffusionPipeline):
def __init__(
self,
model: str = _DEFAULT_MODEL,
w16: bool = False,
shift: float = 1.0,
use_t5: bool = True,
model_version: str = "FLUX.1-schnell",
low_memory_mode: bool = True,
a16: bool = False,
local_ckpt=None,
):
model_io.LOCAl_SD3_CKPT = local_ckpt
self.float16_dtype = mx.bfloat16
model_io._FLOAT16 = self.float16_dtype
self.dtype = self.float16_dtype if w16 else mx.float32
self.activation_dtype = self.float16_dtype if a16 else mx.float32
mmdit_ckpt = MMDIT_CKPT[model_version]
self.low_memory_mode = low_memory_mode
self.mmdit = load_flux(float16=w16, low_memory_mode=low_memory_mode)
self.sampler = FluxSampler(shift=shift)
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
self.latent_format = FluxLatentFormat()

if not use_t5:
logger.warning("FLUX can not be used without T5. Loading T5..")
self.use_t5 = True

self.clip_l = load_text_encoder(
model,
w16,
model_key="clip_l",
)
self.tokenizer_l = load_tokenizer(
model,
merges_key="tokenizer_l_merges",
vocab_key="tokenizer_l_vocab",
pad_with_eos=True,
)
self.t5_encoder = None
self.t5_tokenizer = None
if self.use_t5:
self.set_up_t5()

def encode_text(
self,
text: str,
cfg_weight: float = 7.5,
negative_text: str = "",
):
tokens_l = self._tokenize(
self.tokenizer_l,
text,
(negative_text if cfg_weight > 1 else None),
)
conditioning_l = self.clip_l(tokens_l[[0], :]) # Ignore negative text
pooled_conditioning = conditioning_l.pooled_output

tokens_t5 = self._tokenize(
self.t5_tokenizer,
text,
(negative_text if cfg_weight > 1 else None),
)
padded_tokens_t5 = mx.zeros((1, 256)).astype(tokens_t5.dtype)
padded_tokens_t5[:, : tokens_t5.shape[1]] = tokens_t5[
[0], :
] # Ignore negative text
t5_conditioning = self.t5_encoder(padded_tokens_t5)
mx.eval(t5_conditioning)
conditioning = t5_conditioning

return conditioning, pooled_conditioning


class CFGDenoiser(nn.Module):
"""Helper for applying CFG Scaling to diffusion outputs"""

Expand All @@ -515,9 +618,13 @@ def __init__(self, model: DiffusionPipeline):
def __call__(
self, x_t, t, conditioning, cfg_weight: float = 7.5, pooled_conditioning=None
):
x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype(
self.model.activation_dtype
)
if cfg_weight <= 0:
logger.debug("CFG Weight disabled")
x_t_mmdit = x_t.astype(self.model.activation_dtype)
else:
x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype(
self.model.activation_dtype
)
t_mmdit = mx.broadcast_to(t, [len(x_t_mmdit)])
timestep = self.model.sampler.timestep(t_mmdit).astype(
self.model.activation_dtype
Expand All @@ -530,21 +637,24 @@ def __call__(
),
"timestep": timestep,
}

mmdit_output = self.model.mmdit(**mmdit_input)
eps_pred = self.model.sampler.calculate_denoised(
t_mmdit, mmdit_output, x_t_mmdit
)

eps_text, eps_neg = eps_pred.split(2)
return eps_neg + cfg_weight * (eps_text - eps_neg)
if cfg_weight <= 0:
return eps_pred
else:
eps_text, eps_neg = eps_pred.split(2)
return eps_neg + cfg_weight * (eps_text - eps_neg)


class SD3LatentFormat:
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
class LatentFormat:
"""Base class for latent format conversion"""

def __init__(self):
self.scale_factor = 1.5305
self.shift_factor = 0.0609
self.scale_factor = 1.0
self.shift_factor = 0.0

def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
Expand All @@ -553,6 +663,20 @@ def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor


class SD3LatentFormat(LatentFormat):
def __init__(self):
super().__init__()
self.scale_factor = 1.5305
self.shift_factor = 0.0609


class FluxLatentFormat(LatentFormat):
def __init__(self):
super().__init__()
self.scale_factor = 0.3611
self.shift_factor = 0.1159


def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
Expand Down
4 changes: 3 additions & 1 deletion python/src/diffusionkit/mlx/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def __init__(self, config: CLIPTextModelConfig):
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
mask = mask.astype(dtype) * (
-6e4 if (dtype == mx.bfloat16 or dtype == mx.float16) else -1e9
)
return mask

def __call__(self, x):
Expand Down
Loading
Loading