Skip to content

Commit

Permalink
[ Misc ] Apply MoE Refactor to Deepseekv2 To Support Fp8 (vllm-projec…
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic authored and jimpang committed Jul 24, 2024
1 parent 2fa0a11 commit 381e2cb
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 137 deletions.
11 changes: 11 additions & 0 deletions .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2
model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.671
- name: "exact_match,flexible-extract"
value: 0.664
limit: 1000
num_fewshot: 5
1 change: 1 addition & 0 deletions .buildkite/lm-eval-harness/configs/models-large.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done

lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE
36 changes: 26 additions & 10 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,16 @@ def fused_topk(


# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
def grouped_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0):

assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")

scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
Expand Down Expand Up @@ -557,6 +559,9 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
Expand All @@ -579,6 +584,10 @@ def fused_moe(
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
Expand All @@ -592,8 +601,15 @@ def fused_moe(
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
else:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)

return fused_experts(hidden_states,
w1,
w2,
Expand Down
93 changes: 78 additions & 15 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Optional
from typing import List, Optional, Tuple

import torch

Expand Down Expand Up @@ -29,7 +29,10 @@ def apply(self,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
raise NotImplementedError


Expand Down Expand Up @@ -63,15 +66,21 @@ def apply(self,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:

return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True)
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)


class FusedMoE(torch.nn.Module):
Expand Down Expand Up @@ -104,6 +113,9 @@ def __init__(
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
):
Expand All @@ -119,6 +131,11 @@ def __init__(
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand All @@ -140,24 +157,30 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_id: int, expert_id: int):
param_data = param.data

# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral.
# Follow up PR to enable fp8 for other MoE models.
# Weight scales
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if shard_id == 0 or shard_id == 2:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == 0 else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else:
param_data[expert_id] = loaded_weight
# Weights
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
Expand Down Expand Up @@ -188,10 +211,50 @@ def forward(self, hidden_states: torch.Tensor,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize)
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group)

if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states

@classmethod
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, int]]:

gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
gate_down_up = [
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
]

return [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_scale"
if weight_name in gate_up else "experts.w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight"
if weight_name in gate_up else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.a13_scale"
if weight_name in gate_up else "experts.a2_scale",
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
]
10 changes: 8 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,10 @@ def apply(self,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True) -> torch.Tensor:
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:

return fused_moe(x,
layer.w13_weight,
Expand All @@ -390,7 +393,10 @@ def apply(self,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale)
a2_scale=layer.a2_scale,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)


class Fp8KVCacheMethod(QuantizeMethodBase):
Expand Down
Loading

0 comments on commit 381e2cb

Please sign in to comment.