-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fused MOE for Mixtral #2542
Fused MOE for Mixtral #2542
Conversation
MMLU evaluation on this PR looks good as well:
|
Latency numbers:
This PR:
Master:
So very nice improvements on both throughput and latency (except for some medium batch sizes, but maybe that can be further optimized by tuning the block sizes better). |
With the latest version of the fused MOE kernel, the fused kernel is now strictly dominating the current master (same settings as above):
|
@WoosukKwon It probably makes sense to review/merge #2453 first since the fused_moe kernel is from there :) |
@pcmoritz I tried importing your code from here and found that there is a absolute maximum difference of It seems this is a large difference. Could you add a test between the normal |
Thanks @casper-hansen, let me look into this some more / compare the numerics with HuggingFace. Here is what I've figured out so far: import torch
from transformers import AutoModelForCausalLM
from vllm.model_executor.layers.moe import MoE
from vllm.model_executor.models.mixtral import MixtralModel
from vllm.model_executor.models.mixtral import MixtralForCausalLM
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
config = model.config
mixtral_moe = model.model.layers[0].block_sparse_moe
hidden_states = torch.randn((1, 1, 4096))
output = mixtral_moe.forward(hidden_states) vLLM: First initialize model parallelism (this is needed b/c the model is trying to get the tensor parallelism which needs this to be initialized -- maybe going forward we can make the models run independent of that, it might be useful e.g. for unit tests): from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
torch.distributed.init_process_group(
backend="nccl",
world_size=1,
rank=0,
init_method=f"file:///tmp/test",
)
initialize_model_parallel() vllm_moe = MoE(
config.num_local_experts,
config.num_experts_per_tok,
config.hidden_size,
config.intermediate_size,
params_dtype=torch.bfloat16
)
# Load weights:
from vllm.model_executor.weight_utils import hf_model_weights_iterator
expert_params_mapping = [
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(vllm_moe.named_parameters())
for name, loaded_weight in hf_model_weights_iterator("mistralai/Mixtral-8x7B-v0.1"):
if name == "model.layers.0.block_sparse_moe.gate.weight":
params_dict["gate.weight"][:,:] = loaded_weight
if name.startswith("model.layers.0.block_sparse_moe.experts"):
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name in name:
param = params_dict[param_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name, expert_id=expert_id)
vllm_moe.forward(hidden_states.bfloat16().to("cuda")) |
See the test I created below for reference. I'm not sure what causes the difference, but seems it's a large difference. https://github.com/casper-hansen/AutoAWQ/blob/mixtral_fused/tests/test_fused_moe.py |
vllm/model_executor/layers/moe.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcmoritz Should we move the MoE
class back to the Mixtral model file? It seems like this MoE layer is not shared between Mixtral and DeepSeek.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me! Feel free to make any edits to the PR you'd like to make or let me know if I should make them :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd appreciate it if you can do it!
@casper-hansen I don't know if you followed the discussion -- we looked into the numerical differences (#2453 (comment)) and they are due to the TensorFloat tensor cores being used, so it is expected :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcmoritz LGTM! Thanks for the great work!
Just to be sure, I re-ran MMLU on the latest version of this PR and the result looks good:
|
Oh BTW, this PR will break the quantization support for Mixtral. 🤦 |
Co-authored-by: chen shen <[email protected]>
Co-authored-by: chen shen <[email protected]>
Co-authored-by: chen shen <[email protected]>
@pcmoritz are you using any specific implememtation to run MMLU benchmark (and others) on LLMs served through vllm, it will be great if you can share the details. |
To back off vllm-project#2542 manually.
This builds on #2453 and #2293 to fuse the MOE kernel for the Mixtral model.
It seems to give a significant performance improvement though (in my setup from 28600 to 33600 tok / s with 1000 input tokens and 50 output tokens on H100).
Latency with
python benchmarks/benchmark_latency.py --model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len 1000 --output-len 50 -tp 8 --num-iters 100 --batch-size <bs>
:This PR:
Compare to master: