diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 88ce63c..7bbc28c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -14,8 +14,16 @@ jobs: - name: Checkout code uses: actions/checkout@v3 + - name: Cache Homebrew + uses: actions/cache@v3 + with: + path: /home/linuxbrew/.linuxbrew + key: ${{ runner.os }}-homebrew-${{ hashFiles('**/Brewfile.lock.json') }} + restore-keys: | + ${{ runner.os }}-homebrew- + - name: Set up Homebrew - id: set-up-homebrew + if: steps.cache-homebrew.outputs.cache-hit != 'true' uses: Homebrew/actions/setup-homebrew@master - name: Setup environment diff --git a/README.md b/README.md index 5b578ff..b9d837c 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ Some notable optional arguments: - For image-to-image, use `--image-path` (path to input image) and `--denoise` (value between 0. and 1.) - T5 text embeddings, use `--t5` - For different resolutions, use `--height` and `--width` +- For using a local checkpoint, use `--local-ckpt ` (e.g. `~/models/stable-diffusion-3-medium/sd3_medium.safetensors`). Please refer to the help menu for all available arguments: `diffusionkit-cli -h`. diff --git a/python/src/mlx/__init__.py b/python/src/mlx/__init__.py index 917f305..904a868 100644 --- a/python/src/mlx/__init__.py +++ b/python/src/mlx/__init__.py @@ -48,7 +48,9 @@ def __init__( model_size: str = "2b", low_memory_mode: bool = True, a16: bool = False, + local_ckpt=None, ): + model_io.LOCAl_SD3_CKPT = local_ckpt self.dtype = mx.float16 if w16 else mx.float32 self.activation_dtype = mx.float16 if a16 else mx.float32 self.use_t5 = use_t5 diff --git a/python/src/mlx/model_io.py b/python/src/mlx/model_io.py index 684852f..1036dec 100644 --- a/python/src/mlx/model_io.py +++ b/python/src/mlx/model_io.py @@ -62,6 +62,8 @@ "8b": 192, } +LOCAl_SD3_CKPT = None + def mmdit_state_dict_adjustments(state_dict, prefix=""): # Remove prefix @@ -453,7 +455,8 @@ def load_mmdit( model = MMDiT(config) mmdit_weights = _MMDIT[key][model_key] - weights = mx.load(hf_hub_download(key, mmdit_weights)) + mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights) + weights = mx.load(mmdit_weights_ckpt) weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.") weights = {k: v.astype(dtype) for k, v in weights.items()} model.update(tree_unflatten(tree_flatten(weights))) @@ -548,7 +551,8 @@ def load_vae_decoder( dtype = mx.float16 if float16 else mx.float32 vae_weights = _MMDIT[key][model_key] - weights = mx.load(hf_hub_download(key, vae_weights)) + vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights) + weights = mx.load(vae_weights_ckpt) weights = vae_decoder_state_dict_adjustments( weights, prefix="first_stage_model.decoder." ) @@ -575,7 +579,8 @@ def load_vae_encoder( dtype = mx.float16 if float16 else mx.float32 vae_weights = _MMDIT[key][model_key] - weights = mx.load(hf_hub_download(key, vae_weights)) + vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights) + weights = mx.load(vae_weights_ckpt) weights = vae_encoder_state_dict_adjustments( weights, prefix="first_stage_model.encoder." ) diff --git a/python/src/mlx/scripts/generate_images.py b/python/src/mlx/scripts/generate_images.py index 99ffde5..92f2785 100644 --- a/python/src/mlx/scripts/generate_images.py +++ b/python/src/mlx/scripts/generate_images.py @@ -99,6 +99,12 @@ def cli(): default=0.0, help="Denoising factor when an input image is provided. (between 0.0 and 1.0)", ) + parser.add_argument( + "--local-ckpt", + default=None, + type=str, + help="Path to the local mmdit checkpoint.", + ) args = parser.parse_args() if args.benchmark_mode: @@ -119,6 +125,7 @@ def cli(): model_size=args.model_size, low_memory_mode=args.low_memory_mode, a16=args.a16, + local_ckpt=args.local_ckpt, ) # Ensure that models are read in memory if needed