Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
atiorh committed Aug 14, 2024
1 parent 49fd435 commit 67f4043
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 68 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pipeline = DiffusionPipeline(
w16=True,
shift=3.0,
use_t5=False,
model_size="2b",
model_version="2b",
low_memory_mode=False,
a16=True,
)
Expand Down
28 changes: 13 additions & 15 deletions python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
logger = get_logger(__name__)

MMDIT_CKPT = {
"2b": "stabilityai/stable-diffusion-3-medium",
"8b": "models/sd3_8b_beta.safetensors",
"flux": "argmaxinc/mlx-FLUX.1-schnell",
"stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium",
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
}


Expand All @@ -46,7 +46,7 @@ def __init__(
w16: bool = False,
shift: float = 1.0,
use_t5: bool = True,
model_size: str = "2b",
model_version: str = "stable-diffusion-3-medium",
low_memory_mode: bool = True,
a16: bool = False,
local_ckpt=None,
Expand All @@ -57,12 +57,12 @@ def __init__(
self.dtype = self.float16_dtype if w16 else mx.float32
self.activation_dtype = self.float16_dtype if a16 else mx.float32
self.use_t5 = use_t5
mmdit_ckpt = MMDIT_CKPT[model_size]
mmdit_ckpt = MMDIT_CKPT[model_version]
self.low_memory_mode = low_memory_mode
self.mmdit = load_mmdit(
float16=w16,
key=mmdit_ckpt,
model_key=model_size,
model_key=model_version,
low_memory_mode=low_memory_mode,
)
self.sampler = ModelSamplingDiscreteFlow(shift=shift)
Expand Down Expand Up @@ -120,9 +120,10 @@ def unload_t5(self):
def ensure_models_are_loaded(self):
mx.eval(self.mmdit.parameters())
mx.eval(self.clip_l.parameters())
mx.eval(self.clip_g.parameters())
mx.eval(self.decoder.parameters())
if self.use_t5:
if hasattr(self, "clip_g"):
mx.eval(self.clip_g.parameters())
if hasattr(self, "t5_encoder") and self.use_t5:
mx.eval(self.t5_encoder.parameters())

def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None):
Expand Down Expand Up @@ -539,7 +540,7 @@ def __init__(
w16: bool = False,
shift: float = 1.0,
use_t5: bool = True,
model_size: str = "2b",
model_version: str = "FLUX.1-schnell",
low_memory_mode: bool = True,
a16: bool = False,
local_ckpt=None,
Expand All @@ -549,8 +550,7 @@ def __init__(
model_io._FLOAT16 = self.float16_dtype
self.dtype = self.float16_dtype if w16 else mx.float32
self.activation_dtype = self.float16_dtype if a16 else mx.float32
self.use_t5 = use_t5
mmdit_ckpt = MMDIT_CKPT[model_size]
mmdit_ckpt = MMDIT_CKPT[model_version]
self.low_memory_mode = low_memory_mode
self.mmdit = load_flux(float16=w16, low_memory_mode=low_memory_mode)
self.sampler = FluxSampler(shift=shift)
Expand All @@ -559,8 +559,8 @@ def __init__(
self.latent_format = FluxLatentFormat()

if not use_t5:
logger.info("FLUX model is being used without T5. Setting use_t5 to True.")
self.use_t5 = True
logger.warning("FLUX can not be used without T5. Loading T5..")
self.use_t5 = True

self.clip_l = load_text_encoder(
model,
Expand Down Expand Up @@ -664,8 +664,6 @@ def process_out(self, latent):


class SD3LatentFormat(LatentFormat):
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""

def __init__(self):
super().__init__()
self.scale_factor = 1.5305
Expand Down
14 changes: 7 additions & 7 deletions python/src/diffusionkit/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
_DEFAULT_MMDIT = "stabilityai/stable-diffusion-3-medium"
_MMDIT = {
"stabilityai/stable-diffusion-3-medium": {
"2b": "sd3_medium.safetensors",
"stable-diffusion-3-medium": "sd3_medium.safetensors",
"vae": "sd3_medium.safetensors",
},
"argmaxinc/mlx-FLUX.1-schnell": {
"flux": "flux-schnell.safetensors",
"FLUX.1-schnell": "flux-schnell.safetensors",
"vae": "ae.safetensors",
},
}
Expand Down Expand Up @@ -72,12 +72,12 @@
_FLOAT16 = mx.bfloat16

DEPTH = {
"2b": 24,
"8b": 38,
"stable-diffusion-3-medium": 24,
"sd3-8b-unreleased": 38,
}
MAX_LATENT_RESOLUTION = {
"2b": 96,
"8b": 192,
"stable-diffusion-3-medium": 96,
"sd3-8b-unreleased": 192,
}

LOCAl_SD3_CKPT = None
Expand Down Expand Up @@ -675,7 +675,7 @@ def load_mmdit(
def load_flux(
key: str = "argmaxinc/mlx-FLUX.1-schnell",
float16: bool = False,
model_key: str = "flux",
model_key: str = "FLUX.1-schnell",
low_memory_mode: bool = True,
):
"""Load the MM-DiT Flux model from the checkpoint file."""
Expand Down
49 changes: 25 additions & 24 deletions python/src/diffusionkit/mlx/scripts/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@
import argparse

from argmaxtools.utils import get_logger
from diffusionkit.mlx import DiffusionPipeline, FluxPipeline
from diffusionkit.mlx import DiffusionPipeline, FluxPipeline, MMDIT_CKPT

logger = get_logger(__name__)

# Defaults
HEIGHT = {
"2b": 512,
"8b": 1024,
"stable-diffusion-3-medium": 512,
"sd3-8b-unreleased": 1024,
"FLUX.1-schnell": 512,
}
WIDTH = {
"2b": 512,
"8b": 1024,
"stable-diffusion-3-medium": 512,
"sd3-8b-unreleased": 1024,
"FLUX.1-schnell": 512,
}
SHIFT = {
"2b": 3.0,
"8b": 3.0,
"stable-diffusion-3-medium": 3.0,
"sd3-8b-unreleased": 3.0,
"FLUX.1-schnell": 1.0,
}


Expand All @@ -34,10 +38,10 @@ def cli():
"--image-path", type=str, help="Path to the image prompt", default=None
)
parser.add_argument(
"--model-size",
choices=("2b", "8b", "flux"),
default="2b",
help="Stable Diffusion 3 model size (2b or 8b).",
"--model-version",
choices=tuple(MMDIT_CKPT.keys()),
default="FLUX.1-schnell",
help="Diffusion model version, e.g. FLUX-1.schnell, stable-diffusion-3-medium",
)
parser.add_argument(
"--steps", type=int, default=50, help="Number of diffusion steps."
Expand Down Expand Up @@ -81,12 +85,6 @@ def cli():
dest="low_memory_mode",
help="Disable low memory mode: No models offloading",
)
parser.add_argument(
"--w16", action="store_true", help="Loads the models in float16."
)
parser.add_argument(
"--a16", action="store_true", help="Use float16 for the model activations."
)
parser.add_argument(
"--benchmark-mode",
action="store_true",
Expand All @@ -106,8 +104,11 @@ def cli():
)
args = parser.parse_args()

if args.model_size == "flux":
logger.warning("Disabling CFG for flux-schnell model.")
args.w16 = True
args.a16 = True

if args.model_version == "FLUX.1-schnell" and args.cfg > 0.0:
logger.warning("Disabling CFG for FLUX.1-schnell model.")
args.cfg = 0.0

if args.benchmark_mode:
Expand All @@ -118,16 +119,16 @@ def cli():
if args.denoise < 0.0 or args.denoise > 1.0:
raise ValueError("Denoising factor must be between 0.0 and 1.0")

shift = args.shift or SHIFT[args.model_size]
pipeline_class = FluxPipeline if args.model_size == "flux" else DiffusionPipeline
shift = args.shift or SHIFT[args.model_version]
pipeline_class = FluxPipeline if "FLUX" in args.model_version else DiffusionPipeline

# Load the models
sd = pipeline_class(
model="argmaxinc/stable-diffusion",
w16=args.w16,
shift=shift,
use_t5=args.t5,
model_size=args.model_size,
model_version=args.model_version,
low_memory_mode=args.low_memory_mode,
a16=args.a16,
local_ckpt=args.local_ckpt,
Expand All @@ -137,8 +138,8 @@ def cli():
if args.preload_models:
sd.ensure_models_are_loaded()

height = args.height or HEIGHT[args.model_size]
width = args.width or WIDTH[args.model_size]
height = args.height or HEIGHT[args.model_version]
width = args.width or WIDTH[args.model_version]
logger.info(f"Output image resolution will be {height}x{width}")

if args.benchmark_mode:
Expand Down
41 changes: 20 additions & 21 deletions python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
import os
import unittest

import mlx.core as mx
import numpy as np
from argmaxtools.utils import get_logger
from diffusionkit.mlx import DiffusionPipeline
from diffusionkit.mlx import DiffusionPipeline, MMDIT_CKPT
from diffusionkit.utils import image_psnr

from huggingface_hub import hf_hub_download
from PIL import Image

logger = get_logger(__name__)

W16 = False
A16 = False
W16 = True
A16 = True
TEST_PSNR_THRESHOLD = 20
TEST_MIN_SPEEDUP = 0.95
SD3_TEST_IMAGES_REPO = "argmaxinc/sd-test-images"
Expand All @@ -27,7 +26,7 @@

LOW_MEMORY_MODE = True
SAVE_IMAGES = True
MODEL_SIZE = "2b"
MODEL_VERSION = "stable-diffusion-3-medium"
USE_T5 = False
SKIP_CORRECTNESS = False

Expand All @@ -51,13 +50,13 @@ def test_sd3_pipeline_correctness(self):
metadata = json.load(f)

# Group metadata by model size
model_examples = {"2b": [], "8b": []}
model_examples = {"stable-diffusion-3-medium": []}
for data in metadata:
model_examples[data["model_size"]].append(data)
model_examples[data["model_version"]].append(data)

for model_size, examples in model_examples.items():
for model_version, examples in model_examples.items():
sd3 = DiffusionPipeline(
model_size=model_size, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16
model_version=model_version, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16
)
if not LOW_MEMORY_MODE:
sd3.ensure_models_are_loaded()
Expand Down Expand Up @@ -93,7 +92,7 @@ def test_sd3_pipeline_correctness(self):
if LOW_MEMORY_MODE:
del sd3
sd3 = DiffusionPipeline(
model_size=model_size,
model_version=model_version,
w16=W16,
low_memory_mode=LOW_MEMORY_MODE,
a16=A16,
Expand All @@ -105,21 +104,21 @@ def test_memory_usage(self):
metadata = json.load(f)

# Group metadata by model size
model_examples = {"2b": [], "8b": []}
model_examples = {"stable-diffusion-3-medium": []}
for data in metadata:
model_examples[data["model_size"]].append(data)
model_examples[data["model_version"]].append(data)

sd3 = DiffusionPipeline(
model_size=MODEL_SIZE, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16
model_version=MODEL_VERSION, w16=W16, low_memory_mode=LOW_MEMORY_MODE, a16=A16
)
if not LOW_MEMORY_MODE:
sd3.ensure_models_are_loaded()

log = None
for example in model_examples[MODEL_SIZE]:
for example in model_examples[MODEL_VERSION]:
sd3.use_t5 = USE_T5
logger.info(
f"Testing memory usage... USE_T5 = {USE_T5} | MODEL_SIZE = {MODEL_SIZE}"
f"Testing memory usage... USE_T5 = {USE_T5} | MODEL_VERSION = {MODEL_VERSION}"
)
_, log = sd3.generate_image(
text=example["prompt"],
Expand All @@ -132,7 +131,7 @@ def test_memory_usage(self):
break

out_folder = os.path.join(TEST_CACHE_DIR, CACHE_SUBFOLDER)
out_path = os.path.join(out_folder, f"{MODEL_SIZE}_log.json")
out_path = os.path.join(out_folder, f"{MODEL_VERSION}_log.json")
if not os.path.exists(out_folder):
os.makedirs(out_folder, exist_ok=True)
with open(out_path, "w") as f:
Expand All @@ -142,12 +141,12 @@ def test_memory_usage(self):


def main(args):
global LOW_MEMORY_MODE, SAVE_IMAGES, SKIP_CORRECTNESS, MODEL_SIZE, W16, A16, CACHE_SUBFOLDER, USE_T5
global LOW_MEMORY_MODE, SAVE_IMAGES, SKIP_CORRECTNESS, MODEL_VERSION, W16, A16, CACHE_SUBFOLDER, USE_T5

LOW_MEMORY_MODE = args.low_memory_mode
SAVE_IMAGES = args.save_images
SKIP_CORRECTNESS = args.skip_correctness
MODEL_SIZE = args.model_size
MODEL_VERSION = args.model_version
W16 = args.w16
A16 = args.a16
CACHE_SUBFOLDER = args.subfolder
Expand Down Expand Up @@ -181,7 +180,7 @@ def main(args):
"--skip-correctness", action="store_true", help="Skip the correctness test."
)
parser.add_argument(
"--model-size", type=str, default="2b", help="Model size to use for the test."
"--model-size", type=str, default="stable-diffusion-3-medium", choices=tuple(MMDIT_CKPT.keys()), help="model version to test"
)
parser.add_argument(
"--w16", action="store_true", help="Loads the models in float16."
Expand All @@ -193,7 +192,7 @@ def main(args):
"--subfolder",
default="default",
type=str,
help=f"If specified, this string will be appended to the cache directory name.",
help="If specified, this string will be appended to the cache directory name.",
)
parser.add_argument(
"--use-t5", action="store_true", help="Use T5 model for text generation."
Expand Down

0 comments on commit 67f4043

Please sign in to comment.