diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index 827a63c07..00046ca4b 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -10,6 +10,7 @@ "JetMoeMoE", ], "mixtral": "MixtralSparseMoeBlock", + "phimoe": "PhiMoESparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock", "deepseek_v2": "DeepseekV2MoE", } diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 44fc4cb47..bbdeb564d 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -18,6 +18,7 @@ "falcon", "phi", "phi3", + "phimoe", "gemma", "gemma2", "gemmoe", @@ -31,6 +32,8 @@ def patch_for_multipack(model_type, model_name=None, is_remote_code=False): patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe") elif model_type == "deepseek_v2": patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek") + elif model_type == "phimoe": + patch_remote(model_name, ".configuration_phimoe", ".modeling_phimoe") elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code: transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data