-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate fused Mixtral MoE with Marlin kernels #7079
Integrate fused Mixtral MoE with Marlin kernels #7079
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
/unready |
Refactoring for maintainability
@dsikka I've added some |
expert_id: int, | ||
is_gptq: bool = False, | ||
): | ||
if is_gptq: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd want to use the weight loading functionality already present.
@@ -37,7 +37,7 @@ | |||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), | |||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"), | |||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), | |||
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), | |||
"QuantMixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we'd want mixtral_quant by default
] | ||
return ([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we leverage what already exists?
This PR's functionality has been implemented through PRs #8217 and #8032. I'm closing it.
Reimplement quantized Mixtral to combine Marlin kernels with fused MoE.
This PR rewrites the Mixtral model to run a modified Marlin kernel that takes advantage of
fused_moe
functionality.The C++ code takes in all expert data and
topk_ids
tensor. It runs a kernel to computesorted_ids
offsets related to each expert, and then feeds them to the Marlin kernels. The Marlin kernels are run multiple times per each expert, using current expert number to figure out the current position insidesorted_ids
and the number of tokens to process in each particular call. The values ofsorted_ids
are then used to indirectly access the rows of input/outputA
/C
tensors. If the the rows of inputA
are identical for each oftopk
experts that access them (first MMM of fused MoE), tensorA
consists ofM x K
elements, with each row being accessedtopk
times by the relevant experts. Otherwise (second MMM of fused MoE),A
consists ofM x topk x K
elements, with each row being accessed once.Unit testing:
End-to-end testing:
Run
offline_inference.py
withSonnet benchmark results (no act order, 4-bit):
Sonnet benchmark results (with act order, 8-bit):