Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused top-K softmax kernel for MoE #2769

Merged
merged 42 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ad66935
Add CUTLASS as a submodule
WoosukKwon Jan 31, 2024
396e537
Port CUTLASS extensions
WoosukKwon Jan 31, 2024
0cd9436
Port MoE kernels
WoosukKwon Jan 31, 2024
cb4524c
Move moe_kernels
WoosukKwon Jan 31, 2024
c191207
Port MoE GEMM
WoosukKwon Jan 31, 2024
cfa4554
Port CUTLASS kernels
WoosukKwon Jan 31, 2024
90ccdfa
Remove MoE gemm
WoosukKwon Jan 31, 2024
3e90c1a
Merge branch 'main' into cutlass-moe
WoosukKwon Feb 2, 2024
77a5c8d
Remove unused CUTLASS kernels
WoosukKwon Feb 2, 2024
f1583de
Minor
WoosukKwon Feb 2, 2024
de7a749
Add topk_softmax kernels
WoosukKwon Feb 2, 2024
e5c62e8
Remove unnecessary headers
WoosukKwon Feb 2, 2024
e127d9b
Add MoE namespace
WoosukKwon Feb 2, 2024
c3096a0
Minor
WoosukKwon Feb 2, 2024
9a561cc
Add permute_kernels
WoosukKwon Feb 2, 2024
ba07256
Remove unused
WoosukKwon Feb 2, 2024
def2ccd
Move
WoosukKwon Feb 2, 2024
72256cc
Move
WoosukKwon Feb 2, 2024
e86fd06
Remove
WoosukKwon Feb 4, 2024
612f961
Add MoE MLP
WoosukKwon Feb 4, 2024
0bf8fb9
Add cudaUtils
WoosukKwon Feb 4, 2024
c09179d
Fix headers
WoosukKwon Feb 4, 2024
2ab65df
Enable BF16
WoosukKwon Feb 4, 2024
c74fc79
Err msg
WoosukKwon Feb 4, 2024
6320de4
Add unpermute_and_reduce
WoosukKwon Feb 4, 2024
9b57e39
Add renormalize
WoosukKwon Feb 5, 2024
55fae45
Add FusedMoE
WoosukKwon Feb 5, 2024
5dcf104
Remove dependency on cutlass
WoosukKwon Feb 5, 2024
92ac8dd
Remove CUTLASS
WoosukKwon Feb 5, 2024
cf559dc
Fix Mixtral & DeepSeek
WoosukKwon Feb 5, 2024
3246328
Minor
WoosukKwon Feb 5, 2024
51940a4
Add .cuda
WoosukKwon Feb 5, 2024
26d327f
Fix setup
WoosukKwon Feb 5, 2024
1cc84e1
Fix MoE test
WoosukKwon Feb 5, 2024
67f04c3
Fix copyright header
WoosukKwon Feb 5, 2024
186901b
yapf
WoosukKwon Feb 5, 2024
26ef5a0
Minor fix
WoosukKwon Feb 5, 2024
6f33c73
Add minor comment
WoosukKwon Feb 5, 2024
a94dd8c
Address review on test_moe
WoosukKwon Feb 5, 2024
fe8c108
Fix docstring
WoosukKwon Feb 5, 2024
9a5d9d8
Add assert statements
WoosukKwon Feb 5, 2024
7cd63e9
Merge branch 'main' into topk-softmax
WoosukKwon Feb 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions csrc/moe/moe_ops.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include "moe_ops.h"

#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
}
9 changes: 9 additions & 0 deletions csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <torch/extension.h>

void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
499 changes: 499 additions & 0 deletions csrc/moe/topk_softmax_kernels.cu

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");

#ifndef USE_ROCM
// Quantization ops
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif
Expand Down
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,17 @@ def get_torch_arch_list() -> Set[str]:
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension_sources.append("csrc/custom_all_reduce.cu")

# Add MoE kernels.
ext_modules.append(
CUDAExtension(
name="vllm._moe_C",
sources=glob("csrc/moe/*.cu") + glob("csrc/moe/*.cpp"),
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
))

if not _is_neuron():
vllm_extension = CUDAExtension(
name="vllm._C",
Expand Down
26 changes: 10 additions & 16 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

Run `pytest tests/kernels/test_moe.py`.
"""

import pytest
import torch

from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

Expand All @@ -14,22 +12,21 @@
from vllm.model_executor.models.mixtral import MixtralMoE


def torch_moe(a, w1, w2, topk_weight, topk_ids):
def torch_moe(a, w1, w2, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D)
out = torch.zeros(B * topk_ids.shape[1],
w2.shape[1],
dtype=a.dtype,
device=a.device)
topk_ids = topk_ids.view(-1)
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1)).sum(dim=1)
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


@pytest.mark.parametrize("m", [512, 222, 33, 1])
Expand All @@ -51,11 +48,8 @@ def test_fused_moe(
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.softmax(score, dim=-1)
topk_weight, topk_ids = torch.topk(score, topk)

triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)


Expand All @@ -75,7 +69,7 @@ def test_mixtral_moe(dtype: torch.dtype):
intermediate_size=config.intermediate_size,
params_dtype=dtype,
tp_size=1,
)
).cuda()

# Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
Expand Down
58 changes: 48 additions & 10 deletions vllm/model_executor/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import triton.language as tl

from vllm._C import ops
from vllm.utils import is_hip


@triton.jit
Expand Down Expand Up @@ -177,7 +178,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict):

assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

Expand Down Expand Up @@ -210,28 +210,35 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)


def fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace=False):
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
) -> torch.Tensor:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_weights (torch.Tensor): The weights for the top-k selected experts.
- topk_ids (torch.Tensor): The indices of the top-k selected experts.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions"
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
Expand All @@ -241,6 +248,37 @@ def fused_moe(hidden_states: torch.Tensor,
M, _ = hidden_states.shape
E, N, _ = w1.shape

if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
Expand Down
15 changes: 3 additions & 12 deletions vllm/model_executor/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import torch
from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.model_executor.input_metadata import InputMetadata
Expand Down Expand Up @@ -155,20 +154,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)

if self.config.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
routing_weights,
selected_experts,
router_logits,
self.top_k,
renormalize=self.config.norm_topk_prob,
inplace=True)

if self.config.n_shared_experts is not None:
Expand Down
14 changes: 3 additions & 11 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from typing import List, Optional, Tuple

import torch
import torch.nn.functional as F

from torch import nn
from transformers import MixtralConfig

Expand Down Expand Up @@ -128,18 +126,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
routing_weights,
selected_experts,
router_logits,
self.top_k,
renormalize=True,
inplace=True)

if self.tp_size > 1:
Expand Down
Loading