diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 046f11d957bdd..2356b9ec18b0d 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype): for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) - vllm_moe.ws[i][:] = torch.cat(weights, dim=0) - vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data + vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 9ff9ba298588a..efa4de7516212 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -78,6 +78,8 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config + # FIXME(pcmoritz): Make this more general to support different # quantization schemes self.use_fp8 = isinstance(quant_config, Fp8Config) @@ -86,55 +88,79 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, params_dtype=self.params_dtype, quant_config=None) - self.ws = nn.Parameter( + if self.use_fp8: + params_dtype = torch.float8_e4m3fn + + self.w13_weight = nn.Parameter( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - dtype=self.params_dtype)) - self.w2s = nn.Parameter( + dtype=params_dtype)) + self.w2_weight = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - dtype=self.params_dtype)) + dtype=params_dtype)) - set_weight_attrs(self.ws, { + set_weight_attrs(self.w13_weight, { "weight_loader": self.weight_loader, }) - set_weight_attrs(self.w2s, { + set_weight_attrs(self.w2_weight, { "weight_loader": self.weight_loader, }) - # Scaling factors for FP8 weights - self.ws_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None - self.w2s_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None - - # Scaling factors for FP8 activations - need_act_scales = (self.use_fp8 - and quant_config.activation_scheme == "static") - self.as_scale = nn.Parameter( - torch.zeros(1, dtype=torch.float32), - requires_grad=False) if need_act_scales else None - self.a2s_scale = nn.Parameter( - torch.zeros(1, dtype=torch.float32), - requires_grad=False) if need_act_scales else None - - if need_act_scales: - set_weight_attrs(self.as_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.a2s_scale, { - "weight_loader": self.weight_loader, - }) + # Used for fp8. + self.w13_scale = None + self.w2_scale = None + self.a13_scale = None + self.a2_scale = None + + if self.use_fp8: + # WEIGHT_SCALE (for fp8) + self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(self.w13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_scale, { + "weight_loader": self.weight_loader, + }) + + # ACT_SCALE (for fp8) + if quant_config.activation_scheme == "static": + if not quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + self.a13_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + self.a2_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + + set_weight_attrs(self.a13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2_scale, { + "weight_loader": self.weight_loader, + }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): @@ -149,20 +175,49 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name: - param_data[:] = param_data[:].max(loaded_weight) + if "act_scale" in weight_name or "weight_scale" in weight_name: + param_data[expert_id] = loaded_weight def process_weights_after_loading(self): - if self.use_fp8: - ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) - w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) + # Fp8 is the only case where we need to process after loading. + if not self.use_fp8: + return + + # If checkpoint is fp16, quantize here. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(self.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(self.w2_weight.data, + dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant( - self.ws.data[expert, :, :]) - w2s[expert, :, :], self.w2s_scale[ - expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :]) - self.ws = nn.Parameter(ws, requires_grad=False) - self.w2s = nn.Parameter(w2s, requires_grad=False) + w13_weight[expert, :, :], self.w13_scale[ + expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], self.w2_scale[ + expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :]) + self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + + # If checkpoint is fp8 + static, cleanup act_scales. + # Since state_dict has an act_scale per expert but our kernels + # are passed one act_scale shared across all experts. + elif self.quant_config.activation_scheme == "static": + if self.a13_scale is None or self.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + + if (not all_close_1d(self.a13_scale) + or not all_close_1d(self.a2_scale)): + print_warning_once( + "Found act_scales that are not equal for fp8 MoE layer. " + "Using the maximum across experts for each layer. ") + + self.a13_scale = nn.Parameter(self.a13_scale.max(), + requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), + requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -170,17 +225,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, + self.w13_weight, + self.w2_weight, router_logits, self.top_k, renormalize=True, inplace=True, use_fp8=self.use_fp8, - w1_scale=self.ws_scale, - w2_scale=self.w2s_scale, - a1_scale=self.as_scale, - a2_scale=self.a2s_scale) + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -222,7 +277,9 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - if isinstance(quant_config, Fp8Config): + if isinstance( + quant_config, + Fp8Config) and not quant_config.is_checkpoint_fp8_serialized: print_warning_once( "For Mixtral FP8 quantization, we currently do not quantize " "the attention layers until their FP8 performance is improved." @@ -461,16 +518,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] ] + [ # These are the activation scales for the experts # (param_name, weight_name, expert_id) - ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale", + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", f"experts.{expert_id}.{weight_name}.act_scale", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] @@ -512,3 +576,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))