Skip to content

Commit

Permalink
Merge pull request #23 from EduardoPach/add-quantize-arg
Browse files Browse the repository at this point in the history
Add quantize arg
  • Loading branch information
atiorh authored Aug 21, 2024
2 parents 3d1b7b7 + 4e778a0 commit 13c2a05
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
8 changes: 8 additions & 0 deletions python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@
"stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium",
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
"FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
}

T5_MAX_LENGTH = {
"stable-diffusion-3-medium": 512,
"FLUX.1-schnell": 256,
"FLUX.1-schnell-4bit-quantized": 256,
}


Expand Down Expand Up @@ -592,6 +594,7 @@ def __init__(
low_memory_mode: bool = True,
a16: bool = False,
local_ckpt=None,
quantize_mmdit: bool = False,
):
model_io.LOCAl_SD3_CKPT = local_ckpt
self.float16_dtype = mx.bfloat16
Expand All @@ -606,16 +609,21 @@ def __init__(
self.latent_format = FluxLatentFormat()
self.use_t5 = True
self.use_clip_g = False
self.quantize_mmdit = quantize_mmdit
self.check_and_load_models()

def load_mmdit(self, only_modulation_dict=False):
if only_modulation_dict:
return load_flux(
key=self.mmdit_ckpt,
model_key=self.model_version,
float16=True if self.dtype == self.float16_dtype else False,
low_memory_mode=self.low_memory_mode,
only_modulation_dict=only_modulation_dict,
)
self.mmdit = load_flux(
key=self.mmdit_ckpt,
model_key=self.model_version,
float16=True if self.dtype == self.float16_dtype else False,
low_memory_mode=self.low_memory_mode,
only_modulation_dict=only_modulation_dict,
Expand Down
27 changes: 23 additions & 4 deletions python/src/diffusionkit/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import mlx.core as mx
from huggingface_hub import hf_hub_download
from mlx import nn
from mlx.utils import tree_flatten, tree_unflatten
from transformers import T5Config

Expand Down Expand Up @@ -41,6 +42,10 @@
"FLUX.1-schnell": "flux-schnell.safetensors",
"vae": "ae.safetensors",
},
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": {
"FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
"vae": "ae.safetensors",
},
}
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
_MODELS = {
Expand All @@ -66,6 +71,10 @@
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": {
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
}

_FLOAT16 = mx.bfloat16
Expand Down Expand Up @@ -693,10 +702,20 @@ def load_flux(
flux_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, flux_weights)
hf_hub_download(key, "config.json")
weights = mx.load(flux_weights_ckpt)
weights = flux_state_dict_adjustments(
weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio
)
weights = {k: v.astype(dtype) for k, v in weights.items()}

if model_key == "FLUX.1-schnell":
weights = flux_state_dict_adjustments(
weights,
prefix="",
hidden_size=config.hidden_size,
mlp_ratio=config.mlp_ratio,
)
elif model_key == "FLUX.1-schnell-4bit-quantized": # 4-bit ckpt already adjusted
nn.quantize(model)

weights = {
k: v.astype(dtype) if v.dtype != mx.uint32 else v for k, v in weights.items()
}
if only_modulation_dict:
weights = {k: v for k, v in weights.items() if "adaLN" in k}
return tree_flatten(weights)
Expand Down
5 changes: 4 additions & 1 deletion python/src/diffusionkit/mlx/scripts/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
"stable-diffusion-3-medium": 512,
"sd3-8b-unreleased": 1024,
"FLUX.1-schnell": 512,
"FLUX.1-schnell-4bit-quantized": 512,
}
WIDTH = {
"stable-diffusion-3-medium": 512,
"sd3-8b-unreleased": 1024,
"FLUX.1-schnell": 512,
"FLUX.1-schnell-4bit-quantized": 512,
}
SHIFT = {
"stable-diffusion-3-medium": 3.0,
"sd3-8b-unreleased": 3.0,
"FLUX.1-schnell": 1.0,
"FLUX.1-schnell-4bit-quantized": 1.0,
}


Expand Down Expand Up @@ -107,7 +110,7 @@ def cli():
args.w16 = True
args.a16 = True

if args.model_version == "FLUX.1-schnell" and args.cfg > 0.0:
if "FLUX" in args.model_version and args.cfg > 0.0:
logger.warning("Disabling CFG for FLUX.1-schnell model.")
args.cfg = 0.0

Expand Down

0 comments on commit 13c2a05

Please sign in to comment.