From f5abc62838c4ce1a3ac5b4f6ae23504067c988c0 Mon Sep 17 00:00:00 2001 From: Arda Okan Date: Fri, 14 Jun 2024 00:30:30 -0700 Subject: [PATCH 1/2] convert to coreml from local ckpt --- README.md | 7 +++++-- tests/torch2coreml/test_mmdit.py | 4 +++- tests/torch2coreml/test_vae.py | 4 +++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 23c6e52..fe007e3 100644 --- a/README.md +++ b/README.md @@ -47,14 +47,17 @@ huggingface-cli login --token YOUR_HF_HUB_TOKEN **Step 2:** Prepare the denoise model (MMDiT) Core ML model files (`.mlpackage`) ```shell -python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path --model-version {2b} -o --latent-size {64, 128} +python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path --model-version {2b} -o --latent-size {64, 128} ``` **Step 3:** Prepare the VAE Decoder Core ML model files (`.mlpackage`) ```shell -python -m tests.torch2coreml.test_vae --sd3-ckpt-path -o --latent-size {64, 128} +python -m tests.torch2coreml.test_vae --sd3-ckpt-path -o --latent-size {64, 128} ``` + +Note: +- `--sd3-ckpt-path` can be a path to a local `.safetensors` file or a HuggingFace repo (e.g. `stabilityai/stable-diffusion-3-medium`). ## Image Generation with Python MLX diff --git a/tests/torch2coreml/test_mmdit.py b/tests/torch2coreml/test_mmdit.py index 81a18b8..0cd94f8 100644 --- a/tests/torch2coreml/test_mmdit.py +++ b/tests/torch2coreml/test_mmdit.py @@ -47,6 +47,7 @@ class TestSD3MMDiT(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase): @classmethod def setUpClass(cls): + global TEST_SD3_CKPT_PATH cls.model_name = "MultiModalDiffusionTransformer" cls.test_output_names = ["denoiser_output"] cls.test_cache_dir = TEST_CACHE_DIR @@ -60,7 +61,7 @@ def setUpClass(cls): .eval() ) logger.info("Initialized.") - TEST_SD3_CKPT_PATH = hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors") + TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors") if TEST_SD3_CKPT_PATH is not None: logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}") @@ -133,6 +134,7 @@ def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]: parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int) args = parser.parse_args() + TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None TEST_SD3_HF_REPO = args.sd3_ckpt_path TEST_LATENT_SIZE = args.latent_size TEST_CKPT_FILE_NAME = args.ckpt_file_name diff --git a/tests/torch2coreml/test_vae.py b/tests/torch2coreml/test_vae.py index de43608..1b4a466 100644 --- a/tests/torch2coreml/test_vae.py +++ b/tests/torch2coreml/test_vae.py @@ -45,6 +45,7 @@ class TestSD3VAEDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCa @classmethod def setUpClass(cls): + global TEST_SD3_CKPT_PATH cls.model_name = "VAEDecoder" cls.test_output_names = ["image"] cls.test_cache_dir = TEST_CACHE_DIR @@ -56,7 +57,7 @@ def setUpClass(cls): ) logger.info("Initialized.") - TEST_SD3_CKPT_PATH = hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors") + TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors") if TEST_SD3_CKPT_PATH is not None: logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}") _load_vae_decoder_weights(cls.test_torch_model, TEST_SD3_CKPT_PATH) @@ -103,6 +104,7 @@ def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]: parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int) args = parser.parse_args() + TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None TEST_SD3_HF_REPO = args.sd3_ckpt_path TEST_LATENT_SIZE = args.latent_size From 416fb7646cd774b71db72125d1e0ce4d78585357 Mon Sep 17 00:00:00 2001 From: Arda Okan Date: Fri, 14 Jun 2024 00:48:05 -0700 Subject: [PATCH 2/2] black style change --- tests/torch2coreml/test_mmdit.py | 8 ++++++-- tests/torch2coreml/test_vae.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/torch2coreml/test_mmdit.py b/tests/torch2coreml/test_mmdit.py index 0cd94f8..118b021 100644 --- a/tests/torch2coreml/test_mmdit.py +++ b/tests/torch2coreml/test_mmdit.py @@ -61,7 +61,9 @@ def setUpClass(cls): .eval() ) logger.info("Initialized.") - TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors") + TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download( + TEST_SD3_HF_REPO, "sd3_medium.safetensors" + ) if TEST_SD3_CKPT_PATH is not None: logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}") @@ -134,7 +136,9 @@ def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]: parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int) args = parser.parse_args() - TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None + TEST_SD3_CKPT_PATH = ( + args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None + ) TEST_SD3_HF_REPO = args.sd3_ckpt_path TEST_LATENT_SIZE = args.latent_size TEST_CKPT_FILE_NAME = args.ckpt_file_name diff --git a/tests/torch2coreml/test_vae.py b/tests/torch2coreml/test_vae.py index 1b4a466..512cd16 100644 --- a/tests/torch2coreml/test_vae.py +++ b/tests/torch2coreml/test_vae.py @@ -57,7 +57,9 @@ def setUpClass(cls): ) logger.info("Initialized.") - TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors") + TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download( + TEST_SD3_HF_REPO, "sd3_medium.safetensors" + ) if TEST_SD3_CKPT_PATH is not None: logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}") _load_vae_decoder_weights(cls.test_torch_model, TEST_SD3_CKPT_PATH) @@ -104,7 +106,9 @@ def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]: parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int) args = parser.parse_args() - TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None + TEST_SD3_CKPT_PATH = ( + args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None + ) TEST_SD3_HF_REPO = args.sd3_ckpt_path TEST_LATENT_SIZE = args.latent_size