Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] Apply MoE Refactor to Qwen2 + Deepseekv2 To Support Fp8 #6417

Merged
merged 23 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9824dca
add qwen moe fp8
robertgshaw2-redhat Jul 3, 2024
ea21bac
Update fp8.py
robertgshaw2-redhat Jul 3, 2024
79f59fe
Update fp8.py
robertgshaw2-redhat Jul 3, 2024
65eca7d
Merge branch 'main' into qwen-fp8
robertgshaw2-redhat Jul 13, 2024
c3bee0d
added fp8 to qwen
robertgshaw2-redhat Jul 13, 2024
0ef1255
added test coverage for fp8 moes
robertgshaw2-redhat Jul 13, 2024
d5444cc
updated qwen
robertgshaw2-redhat Jul 13, 2024
80d3ecd
Update vllm/model_executor/layers/fused_moe/layer.py
robertgshaw2-redhat Jul 13, 2024
2ca7385
Update vllm/model_executor/layers/fused_moe/layer.py
robertgshaw2-redhat Jul 13, 2024
ce10c8c
stash
robertgshaw2-redhat Jul 13, 2024
f7c6d24
formatted
robertgshaw2-redhat Jul 13, 2024
a6dd8c3
its working!
robertgshaw2-redhat Jul 13, 2024
480b8a1
added
robertgshaw2-redhat Jul 13, 2024
d9e4477
formatting
robertgshaw2-redhat Jul 13, 2024
c45ac7c
factor out expert_params_mapping
robertgshaw2-redhat Jul 13, 2024
0d55344
Delete .buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.…
robertgshaw2-redhat Jul 13, 2024
a954c5a
Delete .buildkite/lm-eval-harness/configs/models-large-fp8.txt
robertgshaw2-redhat Jul 13, 2024
8817127
Update test-pipeline.yaml
robertgshaw2-redhat Jul 13, 2024
1cad213
fixes
robertgshaw2-redhat Jul 13, 2024
2b1e2c0
Merge branch 'deepseek-fp8' of https://github.com/neuralmagic/nm-vllm…
robertgshaw2-redhat Jul 13, 2024
6c85445
added routing scaling factor
robertgshaw2-redhat Jul 13, 2024
da4bf83
format
robertgshaw2-redhat Jul 13, 2024
2ff2b35
format
robertgshaw2-redhat Jul 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2
model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.671
- name: "exact_match,flexible-extract"
value: 0.664
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done

lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE
36 changes: 26 additions & 10 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,16 @@ def fused_topk(


# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0):

assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")

scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
Expand Down Expand Up @@ -557,6 +559,9 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
Expand All @@ -579,6 +584,10 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
Expand All @@ -592,8 +601,15 @@ def fused_moe(
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
else:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)

return fused_experts(hidden_states,
w1,
w2,
Expand Down
93 changes: 78 additions & 15 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional
from typing import List, Optional, Tuple

import torch

Expand Down Expand Up @@ -29,7 +29,10 @@ def apply(self,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
raise NotImplementedError


Expand Down Expand Up @@ -63,15 +66,21 @@ def apply(self,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:

return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True)
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)


class FusedMoE(torch.nn.Module):
Expand Down Expand Up @@ -104,6 +113,9 @@ def __init__(
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
):
Expand All @@ -119,6 +131,11 @@ def __init__(
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand All @@ -140,24 +157,30 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_id: int, expert_id: int):
param_data = param.data

# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
# Weight scales
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if shard_id == 0 or shard_id == 2:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == 0 else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else:
param_data[expert_id] = loaded_weight
# Weights
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
Expand Down Expand Up @@ -188,10 +211,50 @@ def forward(self, hidden_states: torch.Tensor,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize)
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group)

if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states

@classmethod
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, int]]:

gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
gate_down_up = [
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
]

return [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_scale"
if weight_name in gate_up else "experts.w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight"
if weight_name in gate_up else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.a13_scale"
if weight_name in gate_up else "experts.a2_scale",
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
]
10 changes: 8 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,10 @@ def apply(self,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:

return fused_moe(x,
layer.w13_weight,
Expand All @@ -390,7 +393,10 @@ def apply(self,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale)
a2_scale=layer.a2_scale,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)


class Fp8KVCacheMethod(QuantizeMethodBase):
Expand Down
Loading
Loading