diff --git a/README.md b/README.md index 6e7bff4..936913d 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ pipeline = DiffusionPipeline( w16=True, shift=3.0, use_t5=False, - model_size="2b", + model_version="2b", low_memory_mode=False, a16=True, ) diff --git a/python/src/diffusionkit/mlx/__init__.py b/python/src/diffusionkit/mlx/__init__.py index 8e066bf..d14f02d 100644 --- a/python/src/diffusionkit/mlx/__init__.py +++ b/python/src/diffusionkit/mlx/__init__.py @@ -33,9 +33,9 @@ logger = get_logger(__name__) MMDIT_CKPT = { - "2b": "stabilityai/stable-diffusion-3-medium", - "8b": "models/sd3_8b_beta.safetensors", - "flux": "argmaxinc/mlx-FLUX.1-schnell", + "stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium", + "sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased + "FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell", } @@ -46,7 +46,7 @@ def __init__( w16: bool = False, shift: float = 1.0, use_t5: bool = True, - model_size: str = "2b", + model_version: str = "stable-diffusion-3-medium", low_memory_mode: bool = True, a16: bool = False, local_ckpt=None, @@ -57,12 +57,12 @@ def __init__( self.dtype = self.float16_dtype if w16 else mx.float32 self.activation_dtype = self.float16_dtype if a16 else mx.float32 self.use_t5 = use_t5 - mmdit_ckpt = MMDIT_CKPT[model_size] + mmdit_ckpt = MMDIT_CKPT[model_version] self.low_memory_mode = low_memory_mode self.mmdit = load_mmdit( float16=w16, key=mmdit_ckpt, - model_key=model_size, + model_key=model_version, low_memory_mode=low_memory_mode, ) self.sampler = ModelSamplingDiscreteFlow(shift=shift) @@ -120,9 +120,10 @@ def unload_t5(self): def ensure_models_are_loaded(self): mx.eval(self.mmdit.parameters()) mx.eval(self.clip_l.parameters()) - mx.eval(self.clip_g.parameters()) mx.eval(self.decoder.parameters()) - if self.use_t5: + if hasattr(self, "clip_g"): + mx.eval(self.clip_g.parameters()) + if hasattr(self, "t5_encoder") and self.use_t5: mx.eval(self.t5_encoder.parameters()) def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None): @@ -539,7 +540,7 @@ def __init__( w16: bool = False, shift: float = 1.0, use_t5: bool = True, - model_size: str = "2b", + model_version: str = "FLUX.1-schnell", low_memory_mode: bool = True, a16: bool = False, local_ckpt=None, @@ -549,8 +550,7 @@ def __init__( model_io._FLOAT16 = self.float16_dtype self.dtype = self.float16_dtype if w16 else mx.float32 self.activation_dtype = self.float16_dtype if a16 else mx.float32 - self.use_t5 = use_t5 - mmdit_ckpt = MMDIT_CKPT[model_size] + mmdit_ckpt = MMDIT_CKPT[model_version] self.low_memory_mode = low_memory_mode self.mmdit = load_flux(float16=w16, low_memory_mode=low_memory_mode) self.sampler = FluxSampler(shift=shift) @@ -559,8 +559,8 @@ def __init__( self.latent_format = FluxLatentFormat() if not use_t5: - logger.info("FLUX model is being used without T5. Setting use_t5 to True.") - self.use_t5 = True + logger.warning("FLUX can not be used without T5. Loading T5..") + self.use_t5 = True self.clip_l = load_text_encoder( model, @@ -664,8 +664,6 @@ def process_out(self, latent): class SD3LatentFormat(LatentFormat): - """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" - def __init__(self): super().__init__() self.scale_factor = 1.5305 diff --git a/python/src/diffusionkit/mlx/model_io.py b/python/src/diffusionkit/mlx/model_io.py index 9b02a5c..7e27d98 100644 --- a/python/src/diffusionkit/mlx/model_io.py +++ b/python/src/diffusionkit/mlx/model_io.py @@ -35,11 +35,11 @@ _DEFAULT_MMDIT = "stabilityai/stable-diffusion-3-medium" _MMDIT = { "stabilityai/stable-diffusion-3-medium": { - "2b": "sd3_medium.safetensors", + "stable-diffusion-3-medium": "sd3_medium.safetensors", "vae": "sd3_medium.safetensors", }, "argmaxinc/mlx-FLUX.1-schnell": { - "flux": "flux-schnell.safetensors", + "FLUX.1-schnell": "flux-schnell.safetensors", "vae": "ae.safetensors", }, } @@ -72,12 +72,12 @@ _FLOAT16 = mx.bfloat16 DEPTH = { - "2b": 24, - "8b": 38, + "stable-diffusion-3-medium": 24, + "sd3-8b-unreleased": 38, } MAX_LATENT_RESOLUTION = { - "2b": 96, - "8b": 192, + "stable-diffusion-3-medium": 96, + "sd3-8b-unreleased": 192, } LOCAl_SD3_CKPT = None @@ -675,7 +675,7 @@ def load_mmdit( def load_flux( key: str = "argmaxinc/mlx-FLUX.1-schnell", float16: bool = False, - model_key: str = "flux", + model_key: str = "FLUX.1-schnell", low_memory_mode: bool = True, ): """Load the MM-DiT Flux model from the checkpoint file.""" diff --git a/python/src/diffusionkit/mlx/scripts/generate_images.py b/python/src/diffusionkit/mlx/scripts/generate_images.py index 8a1f6f1..8d29070 100644 --- a/python/src/diffusionkit/mlx/scripts/generate_images.py +++ b/python/src/diffusionkit/mlx/scripts/generate_images.py @@ -7,21 +7,25 @@ import argparse from argmaxtools.utils import get_logger -from diffusionkit.mlx import DiffusionPipeline, FluxPipeline +from diffusionkit.mlx import DiffusionPipeline, FluxPipeline, MMDIT_CKPT logger = get_logger(__name__) +# Defaults HEIGHT = { - "2b": 512, - "8b": 1024, + "stable-diffusion-3-medium": 512, + "sd3-8b-unreleased": 1024, + "FLUX.1-schnell": 512, } WIDTH = { - "2b": 512, - "8b": 1024, + "stable-diffusion-3-medium": 512, + "sd3-8b-unreleased": 1024, + "FLUX.1-schnell": 512, } SHIFT = { - "2b": 3.0, - "8b": 3.0, + "stable-diffusion-3-medium": 3.0, + "sd3-8b-unreleased": 3.0, + "FLUX.1-schnell": 1.0, } @@ -34,10 +38,10 @@ def cli(): "--image-path", type=str, help="Path to the image prompt", default=None ) parser.add_argument( - "--model-size", - choices=("2b", "8b", "flux"), - default="2b", - help="Stable Diffusion 3 model size (2b or 8b).", + "--model-version", + choices=tuple(MMDIT_CKPT.keys()), + default="FLUX.1-schnell", + help="Diffusion model version, e.g. FLUX-1.schnell, stable-diffusion-3-medium", ) parser.add_argument( "--steps", type=int, default=50, help="Number of diffusion steps." @@ -81,12 +85,6 @@ def cli(): dest="low_memory_mode", help="Disable low memory mode: No models offloading", ) - parser.add_argument( - "--w16", action="store_true", help="Loads the models in float16." - ) - parser.add_argument( - "--a16", action="store_true", help="Use float16 for the model activations." - ) parser.add_argument( "--benchmark-mode", action="store_true", @@ -106,8 +104,11 @@ def cli(): ) args = parser.parse_args() - if args.model_size == "flux": - logger.warning("Disabling CFG for flux-schnell model.") + args.w16 = True + args.a16 = True + + if args.model_version == "FLUX.1-schnell" and args.cfg > 0.0: + logger.warning("Disabling CFG for FLUX.1-schnell model.") args.cfg = 0.0 if args.benchmark_mode: @@ -118,8 +119,8 @@ def cli(): if args.denoise < 0.0 or args.denoise > 1.0: raise ValueError("Denoising factor must be between 0.0 and 1.0") - shift = args.shift or SHIFT[args.model_size] - pipeline_class = FluxPipeline if args.model_size == "flux" else DiffusionPipeline + shift = args.shift or SHIFT[args.model_version] + pipeline_class = FluxPipeline if "FLUX" in args.model_version else DiffusionPipeline # Load the models sd = pipeline_class( @@ -127,7 +128,7 @@ def cli(): w16=args.w16, shift=shift, use_t5=args.t5, - model_size=args.model_size, + model_version=args.model_version, low_memory_mode=args.low_memory_mode, a16=args.a16, local_ckpt=args.local_ckpt, @@ -137,8 +138,8 @@ def cli(): if args.preload_models: sd.ensure_models_are_loaded() - height = args.height or HEIGHT[args.model_size] - width = args.width or WIDTH[args.model_size] + height = args.height or HEIGHT[args.model_version] + width = args.width or WIDTH[args.model_version] logger.info(f"Output image resolution will be {height}x{width}") if args.benchmark_mode: diff --git a/python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py b/python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py index 005d985..47a8714 100644 --- a/python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py +++ b/python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py @@ -7,18 +7,17 @@ import os import unittest -import mlx.core as mx -import numpy as np from argmaxtools.utils import get_logger -from diffusionkit.mlx import DiffusionPipeline +from diffusionkit.mlx import DiffusionPipeline, MMDIT_CKPT from diffusionkit.utils import image_psnr + from huggingface_hub import hf_hub_download from PIL import Image logger = get_logger(__name__) -W16 = False -A16 = False +W16 = True +A16 = True TEST_PSNR_THRESHOLD = 20 TEST_MIN_SPEEDUP = 0.95 SD3_TEST_IMAGES_REPO = "argmaxinc/sd-test-images" @@ -27,7 +26,7 @@ LOW_MEMORY_MODE = True SAVE_IMAGES = True -MODEL_SIZE = "2b" +MODEL_VERSION = "stable-diffusion-3-medium" USE_T5 = False SKIP_CORRECTNESS = False @@ -51,13 +50,13 @@ def test_sd3_pipeline_correctness(self): metadata = json.load(f) # Group metadata by model size - model_examples = {"2b": [], "8b": []} + model_examples = {"stable-diffusion-3-medium": []} for data in metadata: - model_examples[data["model_size"]].append(data) + model_examples[data["model_version"]].append(data) - for model_size, examples in model_examples.items(): + for model_version, examples in model_examples.items(): sd3 = DiffusionPipeline( - model_size=model_size, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16 + model_version=model_version, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16 ) if not LOW_MEMORY_MODE: sd3.ensure_models_are_loaded() @@ -93,7 +92,7 @@ def test_sd3_pipeline_correctness(self): if LOW_MEMORY_MODE: del sd3 sd3 = DiffusionPipeline( - model_size=model_size, + model_version=model_version, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16, @@ -105,21 +104,21 @@ def test_memory_usage(self): metadata = json.load(f) # Group metadata by model size - model_examples = {"2b": [], "8b": []} + model_examples = {"stable-diffusion-3-medium": []} for data in metadata: - model_examples[data["model_size"]].append(data) + model_examples[data["model_version"]].append(data) sd3 = DiffusionPipeline( - model_size=MODEL_SIZE, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16 + model_version=MODEL_VERSION, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16 ) if not LOW_MEMORY_MODE: sd3.ensure_models_are_loaded() log = None - for example in model_examples[MODEL_SIZE]: + for example in model_examples[MODEL_VERSION]: sd3.use_t5 = USE_T5 logger.info( - f"Testing memory usage... USE_T5 = {USE_T5} | MODEL_SIZE = {MODEL_SIZE}" + f"Testing memory usage... USE_T5 = {USE_T5} | MODEL_VERSION = {MODEL_VERSION}" ) _, log = sd3.generate_image( text=example["prompt"], @@ -132,7 +131,7 @@ def test_memory_usage(self): break out_folder = os.path.join(TEST_CACHE_DIR, CACHE_SUBFOLDER) - out_path = os.path.join(out_folder, f"{MODEL_SIZE}_log.json") + out_path = os.path.join(out_folder, f"{MODEL_VERSION}_log.json") if not os.path.exists(out_folder): os.makedirs(out_folder, exist_ok=True) with open(out_path, "w") as f: @@ -142,12 +141,12 @@ def test_memory_usage(self): def main(args): - global LOW_MEMORY_MODE, SAVE_IMAGES, SKIP_CORRECTNESS, MODEL_SIZE, W16, A16, CACHE_SUBFOLDER, USE_T5 + global LOW_MEMORY_MODE, SAVE_IMAGES, SKIP_CORRECTNESS, MODEL_VERSION, W16, A16, CACHE_SUBFOLDER, USE_T5 LOW_MEMORY_MODE = args.low_memory_mode SAVE_IMAGES = args.save_images SKIP_CORRECTNESS = args.skip_correctness - MODEL_SIZE = args.model_size + MODEL_VERSION = args.model_version W16 = args.w16 A16 = args.a16 CACHE_SUBFOLDER = args.subfolder @@ -181,7 +180,7 @@ def main(args): "--skip-correctness", action="store_true", help="Skip the correctness test." ) parser.add_argument( - "--model-size", type=str, default="2b", help="Model size to use for the test." + "--model-size", type=str, default="stable-diffusion-3-medium", choices=tuple(MMDIT_CKPT.keys()), help="model version to test" ) parser.add_argument( "--w16", action="store_true", help="Loads the models in float16." @@ -193,7 +192,7 @@ def main(args): "--subfolder", default="default", type=str, - help=f"If specified, this string will be appended to the cache directory name.", + help="If specified, this string will be appended to the cache directory name.", ) parser.add_argument( "--use-t5", action="store_true", help="Use T5 model for text generation."