Skip to content

Commit

Permalink
Merge pull request #43 from EduardoPach/add-sd3.5-4bit
Browse files Browse the repository at this point in the history
Add SD3.5-large 4bit quantized
  • Loading branch information
arda-argmax authored Oct 29, 2024
2 parents 925ef5d + 1d05fce commit d737473
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 11 deletions.
2 changes: 2 additions & 0 deletions python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MMDIT_CKPT = {
"argmaxinc/mlx-stable-diffusion-3-medium": "argmaxinc/mlx-stable-diffusion-3-medium",
"argmaxinc/mlx-stable-diffusion-3.5-large": "argmaxinc/mlx-stable-diffusion-3.5-large",
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized",
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
Expand All @@ -45,6 +46,7 @@
T5_MAX_LENGTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"argmaxinc/mlx-stable-diffusion-3.5-large": 512,
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-schnell": 256,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
"argmaxinc/mlx-FLUX.1-dev": 512,
Expand Down
52 changes: 41 additions & 11 deletions python/src/diffusionkit/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@
"argmaxinc/mlx-stable-diffusion-3.5-large": "sd3.5_large.safetensors",
"vae": "sd3.5_large.safetensors",
},
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": {
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": "sd3.5_large_4bit_quantized.safetensors",
"vae": "sd3.5_large_4bit_quantized.safetensors",
},
}
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
_MODELS = {
Expand Down Expand Up @@ -92,6 +96,10 @@
"vae_encoder": "first_stage_model.encoder.",
"vae_decoder": "first_stage_model.decoder.",
},
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": {
"vae_encoder": "first_stage_model.encoder.",
"vae_decoder": "first_stage_model.decoder.",
},
}

_CONFIG = {
Expand All @@ -100,17 +108,20 @@
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": FLUX_SCHNELL,
"argmaxinc/mlx-FLUX.1-dev": FLUX_SCHNELL,
"argmaxinc/mlx-stable-diffusion-3.5-large": SD3_8b,
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": SD3_8b,
}

_FLOAT16 = mx.bfloat16

DEPTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 24,
"argmaxinc/mlx-stable-diffusion-3.5-large": 38,
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 38,
}
MAX_LATENT_RESOLUTION = {
"argmaxinc/mlx-stable-diffusion-3-medium": 96,
"argmaxinc/mlx-stable-diffusion-3.5-large": 192,
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 192,
}

LOCAl_SD3_CKPT = None
Expand Down Expand Up @@ -712,12 +723,23 @@ def load_mmdit(
mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights)
hf_hub_download(key, "config.json")
weights = mx.load(mmdit_weights_ckpt)
weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.")
weights = {k: v.astype(dtype) for k, v in weights.items()}
prefix = "model.diffusion_model."

if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized":
weights = mmdit_state_dict_adjustments(weights, prefix=prefix)
else:
nn.quantize(
model, class_predicate=lambda _, module: isinstance(module, nn.Linear)
)
weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k}

weights = {
k: v.astype(dtype) if v.dtype != mx.uint32 else v for k, v in weights.items()
}
if only_modulation_dict:
weights = {k: v for k, v in weights.items() if "adaLN" in k}
return tree_flatten(weights)
model.update(tree_unflatten(tree_flatten(weights)))
model.load_weights(list(weights.items()))

return model

Expand Down Expand Up @@ -852,11 +874,15 @@ def load_vae_decoder(
vae_weights = _MMDIT[key][model_key]
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
weights = mx.load(vae_weights_ckpt)
weights = vae_decoder_state_dict_adjustments(
weights, prefix=_PREFIX[key]["vae_decoder"]
)
prefix = _PREFIX[key]["vae_decoder"]

if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized":
weights = vae_decoder_state_dict_adjustments(weights, prefix=prefix)
else:
weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k}

weights = {k: v.astype(dtype) for k, v in weights.items()}
model.update(tree_unflatten(tree_flatten(weights)))
model.load_weights(list(weights.items()))

return model

Expand All @@ -880,11 +906,15 @@ def load_vae_encoder(
vae_weights = _MMDIT[key][model_key]
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
weights = mx.load(vae_weights_ckpt)
weights = vae_encoder_state_dict_adjustments(
weights, prefix=_PREFIX[key]["vae_encoder"]
)
prefix = _PREFIX[key]["vae_encoder"]

if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized":
weights = vae_encoder_state_dict_adjustments(weights, prefix=prefix)
else:
weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k}

weights = {k: v.astype(dtype) for k, v in weights.items()}
model.update(tree_unflatten(tree_flatten(weights)))
model.load_weights(list(weights.items()))

return model

Expand Down
4 changes: 4 additions & 0 deletions python/src/diffusionkit/mlx/scripts/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@
HEIGHT = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"argmaxinc/mlx-stable-diffusion-3.5-large": 1024,
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
WIDTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"argmaxinc/mlx-stable-diffusion-3.5-large": 1024,
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
SHIFT = {
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
"argmaxinc/mlx-stable-diffusion-3.5-large": 3.0,
"argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 3.0,
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
"argmaxinc/mlx-FLUX.1-dev": 1.0,
Expand Down Expand Up @@ -108,6 +111,7 @@ def cli():
type=str,
help="Path to the local mmdit checkpoint.",
)

args = parser.parse_args()

args.w16 = True
Expand Down

0 comments on commit d737473

Please sign in to comment.