Skip to content

Commit

Permalink
Add fused top-K softmax kernel for MoE (vllm-project#2769)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and jimpang committed Feb 20, 2024
1 parent 4a3a3a0 commit 2d2e6ee
Show file tree
Hide file tree
Showing 9 changed files with 591 additions and 50 deletions.
7 changes: 7 additions & 0 deletions csrc/moe/moe_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include "moe_ops.h"

#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
}
9 changes: 9 additions & 0 deletions csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <torch/extension.h>

void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
499 changes: 499 additions & 0 deletions csrc/moe/topk_softmax_kernels.cu

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");

#ifndef USE_ROCM
// Quantization ops
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
Expand Down
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,17 @@ def get_torch_arch_list() -> Set[str]:
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.append("csrc/custom_all_reduce.cu")

# Add MoE kernels.
ext_modules.append(
CUDAExtension(
name="vllm._moe_C",
sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
))

if not _is_neuron():
vllm_extension = CUDAExtension(
name="vllm._C",
Expand Down
26 changes: 10 additions & 16 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
Run `pytest tests/kernels/test_moe.py`.
"""

import pytest
import torch

from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

Expand All @@ -14,22 +12,21 @@
from vllm.model_executor.models.mixtral import MixtralMoE


def torch_moe(a, w1, w2, topk_weight, topk_ids):
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
out = torch.zeros(B * topk_ids.shape[1],
w2.shape[1],
dtype=a.dtype,
device=a.device)
topk_ids = topk_ids.view(-1)
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1)).sum(dim=1)
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


@pytest.mark.parametrize("m", [512, 222, 33, 1])
Expand All @@ -51,11 +48,8 @@ def test_fused_moe(
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.softmax(score, dim=-1)
topk_weight, topk_ids = torch.topk(score, topk)

triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)


Expand All @@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size=config.intermediate_size,
params_dtype=dtype,
tp_size=1,
)
).cuda()

# Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
Expand Down
58 changes: 48 additions & 10 deletions vllm/model_executor/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import triton.language as tl

from vllm._C import ops
from vllm.utils import is_hip


@triton.jit
Expand Down Expand Up @@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict):

assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

Expand Down Expand Up @@ -210,28 +210,35 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)


def fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=False):
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions"
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
Expand All @@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
M, _ = hidden_states.shape
E, N, _ = w1.shape

if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
Expand Down
15 changes: 3 additions & 12 deletions vllm/model_executor/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import torch
from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.model_executor.input_metadata import InputMetadata
Expand Down Expand Up @@ -155,20 +154,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)

if self.config.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
routing_weights,
selected_experts,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)

if self.config.n_shared_experts is not None:
Expand Down
14 changes: 3 additions & 11 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F

from torch import nn
from transformers import MixtralConfig

Expand Down Expand Up @@ -128,18 +126,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
routing_weights,
selected_experts,
router_logits,
self.top_k,
renormalize=True,
inplace=True)

if self.tp_size > 1:
Expand Down

0 comments on commit 2d2e6ee

Please sign in to comment.