Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static sca…
Browse files Browse the repository at this point in the history
…les (vllm-project#4343)

Co-authored-by: Woosuk Kwon <[email protected]>
  • Loading branch information
2 people authored and robertgshaw2-neuralmagic committed May 6, 2024
1 parent 192c704 commit 1e88172
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 18 deletions.
7 changes: 6 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void dynamic_scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
Expand Down
3 changes: 2 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
25 changes: 24 additions & 1 deletion csrc/quantization/fp8/fp8_cuda_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(

} // namespace vllm

void scaled_fp8_quant(
void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
dim3 grid(num_tokens);
dim3 block(1024);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"scaled_fp8_quant_kernel",
[&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(),
input.data_ptr<scalar_t>(),
scale.data_ptr<float>(),
num_elems);
});
}

void dynamic_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
Expand Down
12 changes: 9 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,16 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,


# fp8
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
vllm_ops.scaled_fp8_quant(output, input, scale)
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
else:
vllm_ops.static_scaled_fp8_quant(output, input, scale)
return output, scale


Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ def moe_align_block_size(


def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B_scale: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
Expand All @@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
assert sorted_token_ids.stride(0) == 1

if not use_fp8:
A_scale = None
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A)
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
Expand Down Expand Up @@ -318,6 +319,8 @@ def fused_moe(
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -434,6 +437,7 @@ def fused_moe(
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
Expand All @@ -451,6 +455,7 @@ def fused_moe(
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""

def __init__(
self,
activation_scheme: str = "dynamic",
) -> None:
self.activation_scheme = activation_scheme

@classmethod
def get_name(cls) -> str:
return "fp8"
Expand All @@ -35,7 +41,8 @@ def get_config_filenames(cls) -> List[str]:

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
return cls()
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
Expand Down
44 changes: 37 additions & 7 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ def __init__(
device="cuda",
dtype=self.params_dtype))

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})

# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(
Expand All @@ -115,12 +122,23 @@ def __init__(
self.num_total_experts, device="cuda", dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
# Scaling factors for FP8 activations
need_act_scales = (self.use_fp8
and quant_config.activation_scheme == "static")
self.as_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
requires_grad=False) if need_act_scales else None
self.a2s_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
requires_grad=False) if need_act_scales else None

if need_act_scales:
set_weight_attrs(self.as_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2s_scale, {
"weight_loader": self.weight_loader,
})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
Expand All @@ -135,6 +153,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name:
param_data[:] = param_data[:].max(loaded_weight)

def process_weights_after_loading(self):
if self.use_fp8:
Expand Down Expand Up @@ -162,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.ws_scale,
w2_scale=self.w2s_scale)
w2_scale=self.w2s_scale,
a1_scale=self.as_scale,
a2_scale=self.a2s_scale)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down Expand Up @@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]

expert_params_mapping = [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]

params_dict = dict(self.named_parameters())
Expand Down

0 comments on commit 1e88172

Please sign in to comment.