Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local ckpt support for python mlx #7

Merged
merged 3 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,16 @@ jobs:
- name: Checkout code
uses: actions/checkout@v3

- name: Cache Homebrew
uses: actions/cache@v3
with:
path: /home/linuxbrew/.linuxbrew
key: ${{ runner.os }}-homebrew-${{ hashFiles('**/Brewfile.lock.json') }}
restore-keys: |
${{ runner.os }}-homebrew-

- name: Set up Homebrew
id: set-up-homebrew
if: steps.cache-homebrew.outputs.cache-hit != 'true'
uses: Homebrew/actions/setup-homebrew@master

- name: Setup environment
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ Some notable optional arguments:
- For image-to-image, use `--image-path` (path to input image) and `--denoise` (value between 0. and 1.)
- T5 text embeddings, use `--t5`
- For different resolutions, use `--height` and `--width`
- For using a local checkpoint, use `--local-ckpt </path/to/ckpt.safetensors>` (e.g. `~/models/stable-diffusion-3-medium/sd3_medium.safetensors`).

Please refer to the help menu for all available arguments: `diffusionkit-cli -h`.

Expand Down
2 changes: 2 additions & 0 deletions python/src/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(
model_size: str = "2b",
low_memory_mode: bool = True,
a16: bool = False,
local_ckpt=None,
):
model_io.LOCAl_SD3_CKPT = local_ckpt
self.dtype = mx.float16 if w16 else mx.float32
self.activation_dtype = mx.float16 if a16 else mx.float32
self.use_t5 = use_t5
Expand Down
11 changes: 8 additions & 3 deletions python/src/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
"8b": 192,
}

LOCAl_SD3_CKPT = None


def mmdit_state_dict_adjustments(state_dict, prefix=""):
# Remove prefix
Expand Down Expand Up @@ -453,7 +455,8 @@ def load_mmdit(
model = MMDiT(config)

mmdit_weights = _MMDIT[key][model_key]
weights = mx.load(hf_hub_download(key, mmdit_weights))
mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights)
weights = mx.load(mmdit_weights_ckpt)
weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.")
weights = {k: v.astype(dtype) for k, v in weights.items()}
model.update(tree_unflatten(tree_flatten(weights)))
Expand Down Expand Up @@ -548,7 +551,8 @@ def load_vae_decoder(

dtype = mx.float16 if float16 else mx.float32
vae_weights = _MMDIT[key][model_key]
weights = mx.load(hf_hub_download(key, vae_weights))
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
weights = mx.load(vae_weights_ckpt)
weights = vae_decoder_state_dict_adjustments(
weights, prefix="first_stage_model.decoder."
)
Expand All @@ -575,7 +579,8 @@ def load_vae_encoder(

dtype = mx.float16 if float16 else mx.float32
vae_weights = _MMDIT[key][model_key]
weights = mx.load(hf_hub_download(key, vae_weights))
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
weights = mx.load(vae_weights_ckpt)
weights = vae_encoder_state_dict_adjustments(
weights, prefix="first_stage_model.encoder."
)
Expand Down
7 changes: 7 additions & 0 deletions python/src/mlx/scripts/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def cli():
default=0.0,
help="Denoising factor when an input image is provided. (between 0.0 and 1.0)",
)
parser.add_argument(
"--local-ckpt",
default=None,
type=str,
help="Path to the local mmdit checkpoint.",
)
args = parser.parse_args()

if args.benchmark_mode:
Expand All @@ -119,6 +125,7 @@ def cli():
model_size=args.model_size,
low_memory_mode=args.low_memory_mode,
a16=args.a16,
local_ckpt=args.local_ckpt,
)

# Ensure that models are read in memory if needed
Expand Down
Loading