Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert to coreml from local ckpt #2

Merged
merged 2 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@ huggingface-cli login --token YOUR_HF_HUB_TOKEN
**Step 2:** Prepare the denoise model (MMDiT) Core ML model files (`.mlpackage`)

```shell
python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> --model-version {2b} -o <output-mlpackages-directory> --latent-size {64, 128}
python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path <path-to-sd3-mmdit.safetensors or model-version-string-from-hub> --model-version {2b} -o <output-mlpackages-directory> --latent-size {64, 128}
```

**Step 3:** Prepare the VAE Decoder Core ML model files (`.mlpackage`)

```shell
python -m tests.torch2coreml.test_vae --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> -o <output-mlpackages-directory> --latent-size {64, 128}
python -m tests.torch2coreml.test_vae --sd3-ckpt-path <path-to-sd3-mmdit.safetensors or model-version-string-from-hub> -o <output-mlpackages-directory> --latent-size {64, 128}
```

Note:
- `--sd3-ckpt-path` can be a path to a local `.safetensors` file or a HuggingFace repo (e.g. `stabilityai/stable-diffusion-3-medium`).
</details>

## <a name="image-generation-with-python-mlx"></a> Image Generation with Python MLX
Expand Down
8 changes: 7 additions & 1 deletion tests/torch2coreml/test_mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TestSD3MMDiT(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase):

@classmethod
def setUpClass(cls):
global TEST_SD3_CKPT_PATH
cls.model_name = "MultiModalDiffusionTransformer"
cls.test_output_names = ["denoiser_output"]
cls.test_cache_dir = TEST_CACHE_DIR
Expand All @@ -60,7 +61,9 @@ def setUpClass(cls):
.eval()
)
logger.info("Initialized.")
TEST_SD3_CKPT_PATH = hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(
TEST_SD3_HF_REPO, "sd3_medium.safetensors"
)
if TEST_SD3_CKPT_PATH is not None:

logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}")
Expand Down Expand Up @@ -133,6 +136,9 @@ def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]:
parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int)
args = parser.parse_args()

TEST_SD3_CKPT_PATH = (
args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
)
TEST_SD3_HF_REPO = args.sd3_ckpt_path
TEST_LATENT_SIZE = args.latent_size
TEST_CKPT_FILE_NAME = args.ckpt_file_name
Expand Down
8 changes: 7 additions & 1 deletion tests/torch2coreml/test_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class TestSD3VAEDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCa

@classmethod
def setUpClass(cls):
global TEST_SD3_CKPT_PATH
cls.model_name = "VAEDecoder"
cls.test_output_names = ["image"]
cls.test_cache_dir = TEST_CACHE_DIR
Expand All @@ -56,7 +57,9 @@ def setUpClass(cls):
)
logger.info("Initialized.")

TEST_SD3_CKPT_PATH = hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(
TEST_SD3_HF_REPO, "sd3_medium.safetensors"
)
if TEST_SD3_CKPT_PATH is not None:
logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}")
_load_vae_decoder_weights(cls.test_torch_model, TEST_SD3_CKPT_PATH)
Expand Down Expand Up @@ -103,6 +106,9 @@ def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]:
parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int)
args = parser.parse_args()

TEST_SD3_CKPT_PATH = (
args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
)
TEST_SD3_HF_REPO = args.sd3_ckpt_path
TEST_LATENT_SIZE = args.latent_size

Expand Down