diff --git a/csrc/ops.h b/csrc/ops.h index ff7a3de1a0a8c..03bb1e24dc68e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -146,7 +146,12 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); -void scaled_fp8_quant( +void static_scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant( torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a5b16c5abc3ed..2250c7f69f0ab 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index c3337cede1282..2477051eb60d7 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel( } // namespace vllm -void scaled_fp8_quant( +void static_scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "scaled_fp8_quant_kernel", + [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + scale.data_ptr(), + num_elems); + }); +} + +void dynamic_scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] torch::Tensor& scale) // [1] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 508d35656eb00..5ba104bada7ac 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -168,10 +168,16 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, # fp8 -def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) - vllm_ops.scaled_fp8_quant(output, input, scale) + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + vllm_ops.static_scaled_fp8_quant(output, input, scale) return output, scale diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aed2c350bdd10..d37837a0b2ce8 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -220,8 +220,9 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - B_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, assert sorted_token_ids.stride(0) == 1 if not use_fp8: - A_scale = None + assert A_scale is None assert B_scale is None else: - A, A_scale = ops.scaled_fp8_quant(A) + A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ @@ -318,6 +319,8 @@ def fused_moe( use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -434,6 +437,7 @@ def fused_moe( invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, + a1_scale, w1_scale, topk_weights, topk_ids, @@ -451,6 +455,7 @@ def fused_moe( invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + a2_scale, w2_scale, topk_weights, topk_ids, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 39679834b545c..ba9f3149649c1 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -14,6 +14,12 @@ class Fp8Config(QuantizationConfig): """Config class for FP8.""" + def __init__( + self, + activation_scheme: str = "dynamic", + ) -> None: + self.activation_scheme = activation_scheme + @classmethod def get_name(cls) -> str: return "fp8" @@ -35,7 +41,8 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": - return cls() + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(activation_scheme) def get_quant_method( self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7847df735ab44..c5dd1a63e2f7a 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -105,6 +105,13 @@ def __init__( device="cuda", dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + # Scaling factors for FP8 weights self.ws_scale = nn.Parameter( torch.ones( @@ -115,12 +122,23 @@ def __init__( self.num_total_experts, device="cuda", dtype=torch.float32), requires_grad=False) if self.use_fp8 else None - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) + # Scaling factors for FP8 activations + need_act_scales = (self.use_fp8 + and quant_config.activation_scheme == "static") + self.as_scale = nn.Parameter( + torch.zeros(1, device="cuda", dtype=torch.float32), + requires_grad=False) if need_act_scales else None + self.a2s_scale = nn.Parameter( + torch.zeros(1, device="cuda", dtype=torch.float32), + requires_grad=False) if need_act_scales else None + + if need_act_scales: + set_weight_attrs(self.as_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2s_scale, { + "weight_loader": self.weight_loader, + }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): @@ -135,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] + if "act_scale" in weight_name: + param_data[:] = param_data[:].max(loaded_weight) def process_weights_after_loading(self): if self.use_fp8: @@ -162,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: inplace=True, use_fp8=self.use_fp8, w1_scale=self.ws_scale, - w2_scale=self.w2s_scale) + w2_scale=self.w2s_scale, + a1_scale=self.as_scale, + a2_scale=self.a2s_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weights for the experts # (param_name, weight_name, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale", + f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters())