diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml new file mode 100644 index 0000000000000..15268395ec68b --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2 +model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.671 + - name: "exact_match,flexible-extract" + value: 0.664 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 2007dd2e1cfa1..94b15a87235b9 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -1,3 +1,4 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml +DeepSeek-V2-Lite-Chat.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 1bddbd89e4ab1..dbb21be4f86e4 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true \ --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3c62008fbfcc1..413c0b6d0924e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -394,14 +394,16 @@ def fused_topk( # This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -): +def grouped_topk(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0): + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] group_scores = scores.view(num_token, num_expert_group, @@ -557,6 +559,9 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -579,6 +584,10 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -592,8 +601,15 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) + else: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, w1, w2, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 73cfcd7fc85f2..3904f3e3d0e76 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional +from typing import List, Optional, Tuple import torch @@ -29,7 +29,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: raise NotImplementedError @@ -63,7 +66,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: return fused_moe(x, layer.w13_weight, @@ -71,7 +77,10 @@ def apply(self, router_logits, top_k, renormalize=renormalize, - inplace=True) + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) class FusedMoE(torch.nn.Module): @@ -104,6 +113,9 @@ def __init__( params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, ): @@ -119,6 +131,11 @@ def __init__( self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -140,9 +157,8 @@ def weight_loader(self, param: torch.nn.Parameter, shard_id: int, expert_id: int): param_data = param.data - # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. - # Follow up PR to enable fp8 for other MoE models. - if "input_scale" in weight_name or "w2.weight_scale" in weight_name: + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( @@ -150,14 +166,21 @@ def weight_loader(self, param: torch.nn.Parameter, f"must be equal. But got {param_data[expert_id]} " f"vs. {loaded_weight}") param_data[expert_id] = loaded_weight - # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. - # Follow up PR to enable fp8 for other MoE models. + # Weight scales elif "weight_scale" in weight_name: - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - assert "w1" in weight_name or "w3" in weight_name - shard_id = 0 if "w1" in weight_name else 1 - param_data[expert_id][shard_id] = loaded_weight + # If we are in merged column case (gate_up_proj) + # shard_id 0 == gate_proj / w1 + # shard_id 2 == up_proj / w3 + if shard_id == 0 or shard_id == 2: + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == 0 else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + # shard_id 1 == down_proj / w2 + else: + param_data[expert_id] = loaded_weight + # Weights else: tp_rank = get_tensor_model_parallel_rank() shard_size = self.intermediate_size_per_partition @@ -188,10 +211,50 @@ def forward(self, hidden_states: torch.Tensor, x=hidden_states, router_logits=router_logits, top_k=self.top_k, - renormalize=self.renormalize) + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, int]]: + + gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] + gate_down_up = [ + ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name + ] + + return [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_scale" + if weight_name in gate_up else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, + shard_id) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_weight" + if weight_name in gate_up else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.a13_scale" + if weight_name in gate_up else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", expert_id, + shard_id) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0c2d2bd3fabe5..5c916c9b4d7e4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -377,7 +377,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: return fused_moe(x, layer.w13_weight, @@ -390,7 +393,10 @@ def apply(self, w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale) + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) class Fp8KVCacheMethod(QuantizeMethodBase): diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fb4097fd1e9b3..2d12ceb7f3dbf 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -29,11 +29,10 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, +from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_experts, grouped_topk +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -91,32 +90,34 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.n_routed_experts = config.n_routed_experts - self.top_k = config.num_experts_per_tok self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > self.n_routed_experts: + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") - - self.experts = nn.ModuleList([ - DeepseekV2MLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) - self.pack_params() + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.experts = FusedMoE(num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group) self.gate = ReplicatedLinear(config.hidden_size, - self.n_routed_experts, + config.n_routed_experts, bias=False, quant_config=None) - if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -128,50 +129,21 @@ def __init__( reduce_results=False, ) - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.config.n_shared_experts is not None: + if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - topk_weights, topk_ids = grouped_topk( - hidden_states, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - num_expert_group=self.config.n_group, - topk_group=self.config.topk_group) - final_hidden_states = fused_experts( - hidden_states, - self.w1, - self.w2, - topk_weights, - topk_ids, - inplace=True) * self.routed_scaling_factor - if self.config.n_shared_experts is not None: + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -504,34 +476,58 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e5bd58a9e97b0..0c456ada61230 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -372,31 +372,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in ["w1", "w3"] else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ("experts.w13_weight" - if weight_name in ["w1", "w3"] else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ("experts.a13_scale" - if weight_name in ["w1", "w3"] else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id) for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 7b18b5e04b275..2cc2f1440d147 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -50,6 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import print_warning_once class Qwen2MoeMLP(nn.Module): @@ -406,15 +407,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = [ - # These are the weights for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] - else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_experts) for shard_id, - weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -461,8 +460,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name not in params_dict: - continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader",