From 9824dcae954ab32c3a68c372f4d9ae235af0df22 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 3 Jul 2024 02:52:22 +0000 Subject: [PATCH 01/21] add qwen moe fp8 --- vllm/model_executor/layers/fused_moe/layer.py | 26 ++++++++----- .../model_executor/layers/quantization/fp8.py | 3 +- vllm/model_executor/models/qwen2_moe.py | 37 +++++++++++++++++-- 3 files changed, 50 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 73cfcd7fc85f2..dde5ba29bbbc2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -140,9 +140,8 @@ def weight_loader(self, param: torch.nn.Parameter, shard_id: int, expert_id: int): param_data = param.data - # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. - # Follow up PR to enable fp8 for other MoE models. - if "input_scale" in weight_name or "w2.weight_scale" in weight_name: + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( @@ -150,14 +149,21 @@ def weight_loader(self, param: torch.nn.Parameter, f"must be equal. But got {param_data[expert_id]} " f"vs. {loaded_weight}") param_data[expert_id] = loaded_weight - # FIXME(robertgshaw2-neuralmagic): Overfit to Mixtral. - # Follow up PR to enable fp8 for other MoE models. + # Weight scales elif "weight_scale" in weight_name: - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - assert "w1" in weight_name or "w3" in weight_name - shard_id = 0 if "w1" in weight_name else 1 - param_data[expert_id][shard_id] = loaded_weight + # If we are in merged column case (gate_up_proj) + # * shard_id 0 == gate_proj / w1 + # * shard_id 2 == up_proj / w3 + if shard_id == 0 or shard_id == 2: + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == 0 else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + # * shard_id 1 == down_proj / w2 + else: + param_data[expert_id] = loaded_weight + # Weights else: tp_rank = get_tensor_model_parallel_rank() shard_size = self.intermediate_size_per_partition diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index dc2ca35c6d2c0..81eda669e86af 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -240,7 +240,6 @@ def apply(self, # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. - if bias is None and self.cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) @@ -407,7 +406,7 @@ def process_weights_after_loading(self, layer: Module) -> None: print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") + "for each layer.") layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), requires_grad=False) layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index ccaa6f20893e0..23ed4c1d2c16c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -50,6 +50,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import print_warning_once class Qwen2MoeMLP(nn.Module): @@ -405,13 +406,29 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_scale" + if weight_name in ["gate_proj", "up_proj"] else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, + shard_id) for expert_id in range(self.config.num_experts) + for shard_id, weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] else "experts.w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_experts) for shard_id, - weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + for expert_id in range(self.config.num_experts) + for shard_id, weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + ] + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.a13_scale" + if weight_name in ["gate_proj", "up_proj"] else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", expert_id, + shard_id) for expert_id in range(self.config.num_experts) + for shard_id, weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) ] params_dict = dict(self.named_parameters()) @@ -459,8 +476,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name not in params_dict: - continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr(param, "weight_loader", From ea21bacaae79d7570d6b9cffabc4f4b4fc14b4c4 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 2 Jul 2024 22:53:53 -0400 Subject: [PATCH 02/21] Update fp8.py --- vllm/model_executor/layers/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 81eda669e86af..37ccdabcbc46b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -406,7 +406,7 @@ def process_weights_after_loading(self, layer: Module) -> None: print_warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " - "for each layer.") + "for each layer. ") layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), requires_grad=False) layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), From 79f59fec5cb8269431b62df4c38d1c66330ea9fb Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Tue, 2 Jul 2024 22:54:29 -0400 Subject: [PATCH 03/21] Update fp8.py --- vllm/model_executor/layers/quantization/fp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 37ccdabcbc46b..dc2ca35c6d2c0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -240,6 +240,7 @@ def apply(self, # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. # If static, layer.input_scale is scalar and x_scale is input_scale. + if bias is None and self.cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) From c3bee0d6a8b6c5c95178af89c54b91874656788e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 11:41:44 +0000 Subject: [PATCH 04/21] added fp8 to qwen --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index dde5ba29bbbc2..a600bd2847084 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -149,7 +149,7 @@ def weight_loader(self, param: torch.nn.Parameter, f"must be equal. But got {param_data[expert_id]} " f"vs. {loaded_weight}") param_data[expert_id] = loaded_weight - # Weight scales + # Weight scales elif "weight_scale" in weight_name: # If we are in merged column case (gate_up_proj) # * shard_id 0 == gate_proj / w1 diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 177e33ba6362c..356bd0a0901b2 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -410,27 +410,29 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in ["gate_proj", "up_proj"] else "experts.w2_scale", + ("experts.w13_scale" if weight_name in ["gate_proj", "up_proj" + ] else "experts.w2_scale", f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, shard_id) for expert_id in range(self.config.num_experts) - for shard_id, weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + for shard_id, weight_name in enumerate( + ["gate_proj", "down_proj", "up_proj"]) ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] else "experts.w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_experts) - for shard_id, weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + for expert_id in range(self.config.num_experts) for shard_id, + weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) ] + [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.a13_scale" - if weight_name in ["gate_proj", "up_proj"] else "experts.a2_scale", + ("experts.a13_scale" if weight_name in ["gate_proj", "up_proj" + ] else "experts.a2_scale", f"experts.{expert_id}.{weight_name}.input_scale", expert_id, shard_id) for expert_id in range(self.config.num_experts) - for shard_id, weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + for shard_id, weight_name in enumerate( + ["gate_proj", "down_proj", "up_proj"]) ] params_dict = dict(self.named_parameters()) From 0ef1255b3d52bb65a22dc5baeeb97eaf2e10d9e0 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 11:52:13 +0000 Subject: [PATCH 05/21] added test coverage for fp8 moes --- .../configs/Qwen2-57B-A14-Instruct-FP8.yaml | 11 +++++++++++ .../configs/Qwen2-57B-A14-Instruct.yaml | 10 +++++----- .../lm-eval-harness/configs/models-large-fp8.txt | 3 +++ .buildkite/test-pipeline.yaml | 10 +++++++++- 4 files changed, 28 insertions(+), 6 deletions(-) create mode 100644 .buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml create mode 100644 .buildkite/lm-eval-harness/configs/models-large-fp8.txt diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml new file mode 100644 index 0000000000000..45d5efc8860f5 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 +model_name: "Qwen/Qwen2-57B-A14B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.792 + - name: "exact_match,flexible-extract" + value: 0.824 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml index 45d5efc8860f5..2e6eb0b98dae5 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml @@ -1,11 +1,11 @@ -# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 -model_name: "Qwen/Qwen2-57B-A14B-Instruct" +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m 0.777 -b "auto" -l 250 -f 5 -t 4 +model_name: "nm-testing/Qwen2-57B-A14B-Instruct-FP8-KV" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.792 + value: 0.823 - name: "exact_match,flexible-extract" - value: 0.824 -limit: 250 + value: 0.777 +limit: 1000 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large-fp8.txt b/.buildkite/lm-eval-harness/configs/models-large-fp8.txt new file mode 100644 index 0000000000000..aa2963322da2e --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/models-large-fp8.txt @@ -0,0 +1,3 @@ +Mixtral-8x7B-Instruct-v0.1-FP8.yaml +Qwen2-57B-A14-Instruct-FP8.yaml + diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c8f53224b1dcf..a1542b606e06b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -260,7 +260,15 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 -- label: LM Eval Large Models +- label: LM Eval Large Models - L4 Fp8 + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + commands: + - pip install lm-eval + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-large-fp8.txt -t 4 + +- label: LM Eval Large Models - A100 gpu: a100 num_gpus: 4 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" From d5444cc7f7b7f1ce2360047134978310e9fd07cf Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 11:56:02 +0000 Subject: [PATCH 06/21] updated qwen --- .../configs/Qwen2-57B-A14-Instruct-FP8.yaml | 10 +++++----- .../configs/Qwen2-57B-A14-Instruct.yaml | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml index 45d5efc8860f5..38cf8218c8443 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml @@ -1,11 +1,11 @@ -# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 -model_name: "Qwen/Qwen2-57B-A14B-Instruct" +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-57B-A14B-Instruct-FP8-KV -b "auto" -l 250 -f 5 -t 4 +model_name: "nm-testing/Qwen2-57B-A14B-Instruct-FP8-KV" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.792 + value: 0.823 - name: "exact_match,flexible-extract" - value: 0.824 -limit: 250 + value: 0.777 +limit: 1000 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml index 2e6eb0b98dae5..45d5efc8860f5 100644 --- a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml +++ b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct.yaml @@ -1,11 +1,11 @@ -# bash ./run-lm-eval-gsm-vllm-baseline.sh -m 0.777 -b "auto" -l 250 -f 5 -t 4 -model_name: "nm-testing/Qwen2-57B-A14B-Instruct-FP8-KV" +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2-57B-A14B-Instruct -b "auto" -l 250 -f 5 -t 4 +model_name: "Qwen/Qwen2-57B-A14B-Instruct" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.823 + value: 0.792 - name: "exact_match,flexible-extract" - value: 0.777 -limit: 1000 + value: 0.824 +limit: 250 num_fewshot: 5 From 80d3ecd62ff077ba558852bedf8eb19e250f8c09 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 13 Jul 2024 10:15:50 -0400 Subject: [PATCH 07/21] Update vllm/model_executor/layers/fused_moe/layer.py Co-authored-by: Michael Goin --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a600bd2847084..56defc2ddefc8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -160,7 +160,7 @@ def weight_loader(self, param: torch.nn.Parameter, idx = 0 if shard_id == 0 else 1 param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) - # * shard_id 1 == down_proj / w2 + # shard_id 1 == down_proj / w2 else: param_data[expert_id] = loaded_weight # Weights From 2ca7385a573c46dab17eac354c52bda634852f0d Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 13 Jul 2024 10:15:55 -0400 Subject: [PATCH 08/21] Update vllm/model_executor/layers/fused_moe/layer.py Co-authored-by: Michael Goin --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 56defc2ddefc8..6620e8850029b 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -152,8 +152,8 @@ def weight_loader(self, param: torch.nn.Parameter, # Weight scales elif "weight_scale" in weight_name: # If we are in merged column case (gate_up_proj) - # * shard_id 0 == gate_proj / w1 - # * shard_id 2 == up_proj / w3 + # shard_id 0 == gate_proj / w1 + # shard_id 2 == up_proj / w3 if shard_id == 0 or shard_id == 2: # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. From ce10c8ca12cb3e799964adf20330891cd7de28bb Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 16:29:18 +0000 Subject: [PATCH 09/21] stash --- .../layers/fused_moe/fused_moe.py | 26 +++++- vllm/model_executor/layers/fused_moe/layer.py | 25 +++++- .../model_executor/layers/quantization/fp8.py | 10 ++- vllm/model_executor/models/deepseek_v2.py | 82 ++++++------------- 4 files changed, 77 insertions(+), 66 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 3c62008fbfcc1..f978c6d1cf714 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -400,8 +400,11 @@ def grouped_topk( topk: int, renormalize: bool, num_expert_group: int = 0, - topk_group: int = 0, -): + topk_group: int = 0): + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + scores = torch.softmax(gating_output, dim=-1) num_token = scores.shape[0] group_scores = scores.view(num_token, num_expert_group, @@ -557,11 +560,15 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + use_grouped_topk: bool = False, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -579,8 +586,13 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -592,8 +604,14 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) + if use_grouped_topk: + topk_weights, topk_ids = grouped_topk( + hidden_states, gating_output, topk, renormalize, + num_expert_group, topk_group) + else: + topk_weights, topk_ids = fused_topk( + hidden_states, gating_output, topk, renormalize) + return fused_experts(hidden_states, w1, w2, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6620e8850029b..5e89faf0773a2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: raise NotImplementedError @@ -63,7 +66,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: return fused_moe(x, layer.w13_weight, @@ -71,7 +77,10 @@ def apply(self, router_logits, top_k, renormalize=renormalize, - inplace=True) + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) class FusedMoE(torch.nn.Module): @@ -104,6 +113,9 @@ def __init__( params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, ): @@ -119,6 +131,13 @@ def __init__( self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + else: + assert num_expert_group is None and topk_group is None if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 0c2d2bd3fabe5..5c916c9b4d7e4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -377,7 +377,10 @@ def apply(self, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True) -> torch.Tensor: + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: return fused_moe(x, layer.w13_weight, @@ -390,7 +393,10 @@ def apply(self, w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale) + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) class Fp8KVCacheMethod(QuantizeMethodBase): diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index fb4097fd1e9b3..b5bd6895f5644 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -29,9 +29,9 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, +from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_experts, grouped_topk from vllm.model_executor.layers.layernorm import RMSNorm @@ -91,32 +91,30 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() - self.n_routed_experts = config.n_routed_experts - self.top_k = config.num_experts_per_tok self.routed_scaling_factor = config.routed_scaling_factor - if self.tp_size > self.n_routed_experts: + self.n_shared_experts = config.n_shared_experts + if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.n_routed_experts}.") - - self.experts = nn.ModuleList([ - DeepseekV2MLP(hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False) - for idx in range(self.n_routed_experts) - ]) - self.pack_params() + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.experts = FusedMoE(num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config) self.gate = ReplicatedLinear(config.hidden_size, - self.n_routed_experts, + config.n_routed_experts, bias=False, quant_config=None) - if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) @@ -128,50 +126,20 @@ def __init__( reduce_results=False, ) - def pack_params(self): - w1 = [] - w2 = [] - for expert in self.experts: - w1.append(expert.gate_up_proj.weight) - w2.append(expert.down_proj.weight) - self.w1 = torch._utils._flatten_dense_tensors(w1) - w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) - for data, param in zip(w1s, w1): - param.data = data - self.w1 = self.w1.view(len(w1), *w1s[0].shape) - - self.w2 = torch._utils._flatten_dense_tensors(w2) - w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) - for data, param in zip(w2s, w2): - param.data = data - - self.w2 = self.w2.view(len(w2), *w2s[0].shape) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - if self.config.n_shared_experts is not None: + if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - topk_weights, topk_ids = grouped_topk( - hidden_states, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, - num_expert_group=self.config.n_group, - topk_group=self.config.topk_group) - final_hidden_states = fused_experts( - hidden_states, - self.w1, - self.w2, - topk_weights, - topk_ids, - inplace=True) * self.routed_scaling_factor - if self.config.n_shared_experts is not None: + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) From f7c6d24a288467d10e96e2206665abd89bb39a0f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 16:58:53 +0000 Subject: [PATCH 10/21] formatted --- .../layers/fused_moe/fused_moe.py | 29 +++++++++---------- vllm/model_executor/models/deepseek_v2.py | 3 +- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f978c6d1cf714..2a8bb956c0104 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -394,13 +394,12 @@ def fused_topk( # This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0): +def grouped_topk(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") @@ -560,15 +559,14 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, - use_grouped_topk: bool = False, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -605,13 +603,14 @@ def fused_moe( assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" if use_grouped_topk: - topk_weights, topk_ids = grouped_topk( - hidden_states, gating_output, topk, renormalize, - num_expert_group, topk_group) + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) else: - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize) - + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + return fused_experts(hidden_states, w1, w2, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index b5bd6895f5644..e4989fd87132f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -31,9 +31,8 @@ from vllm.config import CacheConfig from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_experts, grouped_topk +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, From a6dd8c35b4ee622aeda918e4f3d3c9e96a6fe817 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 17:34:13 +0000 Subject: [PATCH 11/21] its working! --- vllm/model_executor/layers/fused_moe/layer.py | 7 ++- vllm/model_executor/models/deepseek_v2.py | 62 ++++++++++++++----- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5e89faf0773a2..ffc63febb01a3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -138,6 +138,8 @@ def __init__( self.topk_group = topk_group else: assert num_expert_group is None and topk_group is None + self.num_expert_group = None + self.topk_group = None if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -213,7 +215,10 @@ def forward(self, hidden_states: torch.Tensor, x=hidden_states, router_logits=router_logits, top_k=self.top_k, - renormalize=self.renormalize) + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e4989fd87132f..16537b9a55459 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -108,7 +108,10 @@ def __init__( intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, - quant_config=quant_config) + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group) self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, @@ -471,34 +474,61 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] + expert_params_mapping = [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] + else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) + for expert_id in range(self.config.n_routed_experts) for shard_id, + weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) + and name not in params_dict): + continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if (("mlp.experts." in name or "mlp.shared_experts." in name) - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 480b8a1b18ee0e88e42bb15a54fc03795329128c Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 17:53:08 +0000 Subject: [PATCH 12/21] added --- .../configs/DeepSeek-V2-Lite-Chat.yaml | 11 +++++++++++ .buildkite/lm-eval-harness/configs/models-large.txt | 1 + .../lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh | 2 +- 3 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 .buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml new file mode 100644 index 0000000000000..15268395ec68b --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml @@ -0,0 +1,11 @@ +# bash ./run-lm-eval-gsm-vllm-baseline.sh -m deepseek-ai/DeepSeek-V2-Lite-Chat -b "auto" -l 1000 -f 5 -t 2 +model_name: "deepseek-ai/DeepSeek-V2-Lite-Chat" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.671 + - name: "exact_match,flexible-extract" + value: 0.664 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 2007dd2e1cfa1..94b15a87235b9 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -1,3 +1,4 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml +DeepSeek-V2-Lite-Chat.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 1bddbd89e4ab1..dbb21be4f86e4 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true \ --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE From d9e4477ab7b1939ea32d6359dca436242f535280 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 17:58:01 +0000 Subject: [PATCH 13/21] formatting --- vllm/model_executor/layers/fused_moe/layer.py | 8 ++------ vllm/model_executor/models/deepseek_v2.py | 3 +-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ffc63febb01a3..2385c1ee30c09 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -134,12 +134,8 @@ def __init__( self.use_grouped_topk = use_grouped_topk if self.use_grouped_topk: assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - else: - assert num_expert_group is None and topk_group is None - self.num_expert_group = None - self.topk_group = None + self.num_expert_group = num_expert_group + self.topk_group = topk_group if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 16537b9a55459..1a73d83f41356 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -498,8 +498,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) - and name not in params_dict): + if (("mlp.experts." in name) and name not in params_dict): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. From c45ac7cc19a05727f7595f543f3743609e3fedb9 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 18:30:24 +0000 Subject: [PATCH 14/21] factor out expert_params_mapping --- .../layers/fused_moe/fused_moe.py | 1 - vllm/model_executor/layers/fused_moe/layer.py | 37 ++++++++++++++++++- vllm/model_executor/models/deepseek_v2.py | 16 ++++---- vllm/model_executor/models/mixtral.py | 32 ++++------------ vllm/model_executor/models/qwen2_moe.py | 34 ++++------------- 5 files changed, 57 insertions(+), 63 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2a8bb956c0104..413c0b6d0924e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -590,7 +590,6 @@ def fused_moe( note: Deepseekv2 model uses grouped_topk - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2385c1ee30c09..d6e5cbc8a7ddb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, List, Tuple, Iterator import torch @@ -221,3 +221,38 @@ def forward(self, hidden_states: torch.Tensor, final_hidden_states) return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int + ) -> List[Tuple[str, str, int, int]]: + + gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] + gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name] + + return [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_scale" if weight_name in gate_up else "experts.w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, + shard_id) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_weight" if weight_name in gate_up else "experts.w2_weight", + f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, + weight_name in enumerate(gate_down_up) + ] + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ("experts.a13_scale" if weight_name in gate_up else "experts.a2_scale", + f"experts.{expert_id}.{weight_name}.input_scale", expert_id, + shard_id) for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1a73d83f41356..64616957c3466 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -474,15 +474,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = [ - # These are the weights for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] - else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.n_routed_experts) for shard_id, - weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e5bd58a9e97b0..0c456ada61230 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -372,31 +372,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" - if weight_name in ["w1", "w3"] else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ("experts.w13_weight" - if weight_name in ["w1", "w3"] else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ("experts.a13_scale" - if weight_name in ["w1", "w3"] else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id) for expert_id in range(self.config.num_local_experts) - for shard_id, weight_name in enumerate(["w1", "w2", "w3"]) - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 356bd0a0901b2..dad7f9c2f454c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -407,33 +407,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" if weight_name in ["gate_proj", "up_proj" - ] else "experts.w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, - shard_id) for expert_id in range(self.config.num_experts) - for shard_id, weight_name in enumerate( - ["gate_proj", "down_proj", "up_proj"]) - ] + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"] - else "experts.w2_weight", - f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(self.config.num_experts) for shard_id, - weight_name in enumerate(["gate_proj", "down_proj", "up_proj"]) - ] + [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id, shard_id) - ("experts.a13_scale" if weight_name in ["gate_proj", "up_proj" - ] else "experts.a2_scale", - f"experts.{expert_id}.{weight_name}.input_scale", expert_id, - shard_id) for expert_id in range(self.config.num_experts) - for shard_id, weight_name in enumerate( - ["gate_proj", "down_proj", "up_proj"]) - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_local_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: From 0d553441e09c301d64608aac5fcba52c60d22308 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 13 Jul 2024 14:31:37 -0400 Subject: [PATCH 15/21] Delete .buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml --- .../configs/Qwen2-57B-A14-Instruct-FP8.yaml | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 .buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml deleted file mode 100644 index 38cf8218c8443..0000000000000 --- a/.buildkite/lm-eval-harness/configs/Qwen2-57B-A14-Instruct-FP8.yaml +++ /dev/null @@ -1,11 +0,0 @@ -# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-57B-A14B-Instruct-FP8-KV -b "auto" -l 250 -f 5 -t 4 -model_name: "nm-testing/Qwen2-57B-A14B-Instruct-FP8-KV" -tasks: -- name: "gsm8k" - metrics: - - name: "exact_match,strict-match" - value: 0.823 - - name: "exact_match,flexible-extract" - value: 0.777 -limit: 1000 -num_fewshot: 5 From a954c5acf72ea88037ac100a581f4f7ff15ce00d Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 13 Jul 2024 14:31:46 -0400 Subject: [PATCH 16/21] Delete .buildkite/lm-eval-harness/configs/models-large-fp8.txt --- .buildkite/lm-eval-harness/configs/models-large-fp8.txt | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .buildkite/lm-eval-harness/configs/models-large-fp8.txt diff --git a/.buildkite/lm-eval-harness/configs/models-large-fp8.txt b/.buildkite/lm-eval-harness/configs/models-large-fp8.txt deleted file mode 100644 index aa2963322da2e..0000000000000 --- a/.buildkite/lm-eval-harness/configs/models-large-fp8.txt +++ /dev/null @@ -1,3 +0,0 @@ -Mixtral-8x7B-Instruct-v0.1-FP8.yaml -Qwen2-57B-A14-Instruct-FP8.yaml - From 8817127926825286b5a033de7270e1d46ef1f394 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 13 Jul 2024 14:32:08 -0400 Subject: [PATCH 17/21] Update test-pipeline.yaml --- .buildkite/test-pipeline.yaml | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a1542b606e06b..c8f53224b1dcf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -260,15 +260,7 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 -- label: LM Eval Large Models - L4 Fp8 - num_gpus: 4 - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" - commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large-fp8.txt -t 4 - -- label: LM Eval Large Models - A100 +- label: LM Eval Large Models gpu: a100 num_gpus: 4 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" From 1cad213c20ab8e870ac3a4677712016898270f68 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 18:34:23 +0000 Subject: [PATCH 18/21] fixes --- vllm/model_executor/models/deepseek_v2.py | 2 +- vllm/model_executor/models/qwen2_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 64616957c3466..03b76a7eaae06 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -480,7 +480,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_local_experts) + num_experts=self.config.n_routed_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index dad7f9c2f454c..2cc2f1440d147 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -413,7 +413,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.num_local_experts) + num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: From 6c8544534d1e5d44d4e5038a848211a1d6236f98 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 18:40:51 +0000 Subject: [PATCH 19/21] added routing scaling factor --- vllm/model_executor/models/deepseek_v2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 03b76a7eaae06..2d12ceb7f3dbf 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -93,6 +93,7 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -135,8 +136,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: From da4bf83f9b77bfc5e6c8016ed26dfd8ddfe9f7eb Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 18:41:45 +0000 Subject: [PATCH 20/21] format --- vllm/model_executor/layers/fused_moe/layer.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d6e5cbc8a7ddb..45d7137a47d3d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional, List, Tuple, Iterator +from typing import Optional, List, Tuple import torch @@ -224,34 +224,36 @@ def forward(self, hidden_states: torch.Tensor, @classmethod def make_expert_params_mapping( - cls, - ckpt_gate_proj_name: str, - ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int - ) -> List[Tuple[str, str, int, int]]: + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, int]]: gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] - gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name] + gate_down_up = [ + ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name + ] return [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_scale" if weight_name in gate_up else "experts.w2_scale", + ("experts.w13_scale" + if weight_name in gate_up else "experts.w2_scale", f"experts.{expert_id}.{weight_name}.weight_scale", expert_id, shard_id) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_weight" if weight_name in gate_up else "experts.w2_weight", + ("experts.w13_weight" + if weight_name in gate_up else "experts.w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id) - for expert_id in range(num_experts) for shard_id, - weight_name in enumerate(gate_down_up) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) ] + [ # These are the weight scales for the experts # (param_name, weight_name, expert_id, shard_id) - ("experts.a13_scale" if weight_name in gate_up else "experts.a2_scale", + ("experts.a13_scale" + if weight_name in gate_up else "experts.a2_scale", f"experts.{expert_id}.{weight_name}.input_scale", expert_id, shard_id) for expert_id in range(num_experts) for shard_id, weight_name in enumerate(gate_down_up) From 2ff2b355ee087fe94af52414c18387e7ff74bdcd Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sat, 13 Jul 2024 18:41:54 +0000 Subject: [PATCH 21/21] format --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 45d7137a47d3d..3904f3e3d0e76 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple import torch