From 5f8a1850890477378b7567c7bd9755b6ce4ab49d Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco Date: Mon, 28 Oct 2024 14:39:09 +0100 Subject: [PATCH 1/2] add: sd-3.5 4-bit quantized --- python/src/diffusionkit/mlx/__init__.py | 2 + python/src/diffusionkit/mlx/model_io.py | 52 +++++++++++++++---- .../mlx/scripts/generate_images.py | 12 +++++ 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/python/src/diffusionkit/mlx/__init__.py b/python/src/diffusionkit/mlx/__init__.py index 3635b3e..6c6499a 100644 --- a/python/src/diffusionkit/mlx/__init__.py +++ b/python/src/diffusionkit/mlx/__init__.py @@ -37,6 +37,7 @@ MMDIT_CKPT = { "argmaxinc/mlx-stable-diffusion-3-medium": "argmaxinc/mlx-stable-diffusion-3-medium", "argmaxinc/mlx-stable-diffusion-3.5-large": "argmaxinc/mlx-stable-diffusion-3.5-large", + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized", "argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell", "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized", "argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev", @@ -45,6 +46,7 @@ T5_MAX_LENGTH = { "argmaxinc/mlx-stable-diffusion-3-medium": 512, "argmaxinc/mlx-stable-diffusion-3.5-large": 512, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 512, "argmaxinc/mlx-FLUX.1-schnell": 256, "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256, "argmaxinc/mlx-FLUX.1-dev": 512, diff --git a/python/src/diffusionkit/mlx/model_io.py b/python/src/diffusionkit/mlx/model_io.py index e541055..c1d59c8 100644 --- a/python/src/diffusionkit/mlx/model_io.py +++ b/python/src/diffusionkit/mlx/model_io.py @@ -55,6 +55,10 @@ "argmaxinc/mlx-stable-diffusion-3.5-large": "sd3.5_large.safetensors", "vae": "sd3.5_large.safetensors", }, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": { + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": "sd3.5_large_4bit_quantized.safetensors", + "vae": "sd3.5_large_4bit_quantized.safetensors", + }, } _DEFAULT_MODEL = "argmaxinc/stable-diffusion" _MODELS = { @@ -92,6 +96,10 @@ "vae_encoder": "first_stage_model.encoder.", "vae_decoder": "first_stage_model.decoder.", }, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": { + "vae_encoder": "first_stage_model.encoder.", + "vae_decoder": "first_stage_model.decoder.", + }, } _CONFIG = { @@ -100,6 +108,7 @@ "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": FLUX_SCHNELL, "argmaxinc/mlx-FLUX.1-dev": FLUX_SCHNELL, "argmaxinc/mlx-stable-diffusion-3.5-large": SD3_8b, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": SD3_8b, } _FLOAT16 = mx.bfloat16 @@ -107,10 +116,12 @@ DEPTH = { "argmaxinc/mlx-stable-diffusion-3-medium": 24, "argmaxinc/mlx-stable-diffusion-3.5-large": 38, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 38, } MAX_LATENT_RESOLUTION = { "argmaxinc/mlx-stable-diffusion-3-medium": 96, "argmaxinc/mlx-stable-diffusion-3.5-large": 192, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 192, } LOCAl_SD3_CKPT = None @@ -712,12 +723,23 @@ def load_mmdit( mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights) hf_hub_download(key, "config.json") weights = mx.load(mmdit_weights_ckpt) - weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.") - weights = {k: v.astype(dtype) for k, v in weights.items()} + prefix = "model.diffusion_model." + + if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": + weights = mmdit_state_dict_adjustments(weights, prefix=prefix) + else: + nn.quantize( + model, class_predicate=lambda _, module: isinstance(module, nn.Linear) + ) + weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k} + + weights = { + k: v.astype(dtype) if v.dtype != mx.uint32 else v for k, v in weights.items() + } if only_modulation_dict: weights = {k: v for k, v in weights.items() if "adaLN" in k} return tree_flatten(weights) - model.update(tree_unflatten(tree_flatten(weights))) + model.load_weights(list(weights.items())) return model @@ -852,11 +874,15 @@ def load_vae_decoder( vae_weights = _MMDIT[key][model_key] vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights) weights = mx.load(vae_weights_ckpt) - weights = vae_decoder_state_dict_adjustments( - weights, prefix=_PREFIX[key]["vae_decoder"] - ) + prefix = _PREFIX[key]["vae_decoder"] + + if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": + weights = vae_decoder_state_dict_adjustments(weights, prefix=prefix) + else: + weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k} + weights = {k: v.astype(dtype) for k, v in weights.items()} - model.update(tree_unflatten(tree_flatten(weights))) + model.load_weights(list(weights.items())) return model @@ -880,11 +906,15 @@ def load_vae_encoder( vae_weights = _MMDIT[key][model_key] vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights) weights = mx.load(vae_weights_ckpt) - weights = vae_encoder_state_dict_adjustments( - weights, prefix=_PREFIX[key]["vae_encoder"] - ) + prefix = _PREFIX[key]["vae_encoder"] + + if key != "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": + weights = vae_encoder_state_dict_adjustments(weights, prefix=prefix) + else: + weights = {k.replace(prefix, ""): v for k, v in weights.items() if prefix in k} + weights = {k: v.astype(dtype) for k, v in weights.items()} - model.update(tree_unflatten(tree_flatten(weights))) + model.load_weights(list(weights.items())) return model diff --git a/python/src/diffusionkit/mlx/scripts/generate_images.py b/python/src/diffusionkit/mlx/scripts/generate_images.py index d6c4bdb..c14ecf6 100644 --- a/python/src/diffusionkit/mlx/scripts/generate_images.py +++ b/python/src/diffusionkit/mlx/scripts/generate_images.py @@ -15,6 +15,7 @@ HEIGHT = { "argmaxinc/mlx-stable-diffusion-3-medium": 512, "argmaxinc/mlx-stable-diffusion-3.5-large": 1024, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 1024, "argmaxinc/mlx-FLUX.1-schnell": 512, "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512, "argmaxinc/mlx-FLUX.1-dev": 512, @@ -22,6 +23,7 @@ WIDTH = { "argmaxinc/mlx-stable-diffusion-3-medium": 512, "argmaxinc/mlx-stable-diffusion-3.5-large": 1024, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 1024, "argmaxinc/mlx-FLUX.1-schnell": 512, "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512, "argmaxinc/mlx-FLUX.1-dev": 512, @@ -29,6 +31,7 @@ SHIFT = { "argmaxinc/mlx-stable-diffusion-3-medium": 3.0, "argmaxinc/mlx-stable-diffusion-3.5-large": 3.0, + "argmaxinc/mlx-stable-diffusion-3.5-large-4bit-quantized": 3.0, "argmaxinc/mlx-FLUX.1-schnell": 1.0, "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0, "argmaxinc/mlx-FLUX.1-dev": 1.0, @@ -108,6 +111,11 @@ def cli(): type=str, help="Path to the local mmdit checkpoint.", ) + parser.add_argument( + "--quantized", + action="store_true", + help="Use 4-bit quantized model if available", + ) args = parser.parse_args() args.w16 = True @@ -125,6 +133,10 @@ def cli(): if args.denoise < 0.0 or args.denoise > 1.0: raise ValueError("Denoising factor must be between 0.0 and 1.0") + if args.quantized and "4bit-quantized" not in args.model_version: + args.model_version += "-4bit-quantized" + logger.info(f"Using 4-bit quantized model: {args.model_version}") + shift = args.shift or SHIFT[args.model_version] pipeline_class = FluxPipeline if "FLUX" in args.model_version else DiffusionPipeline From 1d05fcef21d42b553992ddd322b28e74d212fb78 Mon Sep 17 00:00:00 2001 From: Eduardo Pacheco Date: Mon, 28 Oct 2024 14:45:20 +0100 Subject: [PATCH 2/2] remove: unnecessary arg in generate_image --- python/src/diffusionkit/mlx/scripts/generate_images.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/python/src/diffusionkit/mlx/scripts/generate_images.py b/python/src/diffusionkit/mlx/scripts/generate_images.py index c14ecf6..22adfc6 100644 --- a/python/src/diffusionkit/mlx/scripts/generate_images.py +++ b/python/src/diffusionkit/mlx/scripts/generate_images.py @@ -111,11 +111,7 @@ def cli(): type=str, help="Path to the local mmdit checkpoint.", ) - parser.add_argument( - "--quantized", - action="store_true", - help="Use 4-bit quantized model if available", - ) + args = parser.parse_args() args.w16 = True @@ -133,10 +129,6 @@ def cli(): if args.denoise < 0.0 or args.denoise > 1.0: raise ValueError("Denoising factor must be between 0.0 and 1.0") - if args.quantized and "4bit-quantized" not in args.model_version: - args.model_version += "-4bit-quantized" - logger.info(f"Using 4-bit quantized model: {args.model_version}") - shift = args.shift or SHIFT[args.model_version] pipeline_class = FluxPipeline if "FLUX" in args.model_version else DiffusionPipeline