forked from mesolitica/vllm-whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize Triton MoE Kernel (vllm-project#2979)
Co-authored-by: Cade Daniel <[email protected]>
- Loading branch information
1 parent
f783298
commit 75a7d51
Showing
7 changed files
with
297 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import json | ||
import os | ||
import sys | ||
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
|
||
from vllm.model_executor.layers.fused_moe import fused_moe | ||
import torch | ||
import torch.nn.functional as F | ||
import triton | ||
|
||
|
||
def main(): | ||
method = fused_moe | ||
for bs in [ | ||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, | ||
2048, 3072, 4096 | ||
]: | ||
run_grid(bs, method=method) | ||
|
||
|
||
def run_grid(bs, method): | ||
d_model = 4096 | ||
num_total_experts = 8 | ||
top_k = 2 | ||
tp_size = 2 | ||
model_intermediate_size = 14336 | ||
num_layers = 32 | ||
num_calls = 100 | ||
|
||
num_warmup_trials = 1 | ||
num_trials = 1 | ||
|
||
configs = [] | ||
if bs <= 16: | ||
BLOCK_SIZES_M = [16] | ||
elif bs <= 32: | ||
BLOCK_SIZES_M = [16, 32] | ||
elif bs <= 64: | ||
BLOCK_SIZES_M = [16, 32, 64] | ||
elif bs <= 128: | ||
BLOCK_SIZES_M = [16, 32, 64, 128] | ||
else: | ||
BLOCK_SIZES_M = [16, 32, 64, 128, 256] | ||
|
||
for block_size_n in [32, 64, 128, 256]: | ||
for block_size_m in BLOCK_SIZES_M: | ||
for block_size_k in [64, 128, 256]: | ||
for group_size_m in [1, 16, 32, 64]: | ||
for num_warps in [4, 8]: | ||
configs.append({ | ||
"BLOCK_SIZE_M": block_size_m, | ||
"BLOCK_SIZE_N": block_size_n, | ||
"BLOCK_SIZE_K": block_size_k, | ||
"GROUP_SIZE_M": group_size_m, | ||
"num_warps": num_warps, | ||
"num_stages": 4, | ||
}) | ||
|
||
best_config = None | ||
best_time_us = 1e20 | ||
|
||
for config in configs: | ||
print(f'{tp_size=} {bs=}') | ||
print(f'{config}') | ||
# warmup | ||
print(f'warming up') | ||
try: | ||
for _ in range(num_warmup_trials): | ||
run_timing( | ||
num_calls=num_calls, | ||
bs=bs, | ||
d_model=d_model, | ||
num_total_experts=num_total_experts, | ||
top_k=top_k, | ||
tp_size=tp_size, | ||
model_intermediate_size=model_intermediate_size, | ||
method=method, | ||
config=config, | ||
) | ||
except triton.runtime.autotuner.OutOfResources: | ||
continue | ||
|
||
# trial | ||
print(f'benchmarking') | ||
for _ in range(num_trials): | ||
kernel_dur_ms = run_timing( | ||
num_calls=num_calls, | ||
bs=bs, | ||
d_model=d_model, | ||
num_total_experts=num_total_experts, | ||
top_k=top_k, | ||
tp_size=tp_size, | ||
model_intermediate_size=model_intermediate_size, | ||
method=method, | ||
config=config, | ||
) | ||
|
||
kernel_dur_us = 1000 * kernel_dur_ms | ||
model_dur_ms = kernel_dur_ms * num_layers | ||
|
||
if kernel_dur_us < best_time_us: | ||
best_config = config | ||
best_time_us = kernel_dur_us | ||
|
||
print( | ||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}' | ||
) | ||
|
||
print("best_time_us", best_time_us) | ||
print("best_config", best_config) | ||
|
||
filename = "/tmp/config.jsonl" | ||
print(f"writing config to file {filename}") | ||
with open(filename, "a") as f: | ||
f.write(json.dumps({str(bs): best_config}) + "\n") | ||
|
||
|
||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, | ||
top_k: int, tp_size: int, model_intermediate_size: int, method, | ||
config) -> float: | ||
shard_intermediate_size = model_intermediate_size // tp_size | ||
|
||
hidden_states = torch.rand( | ||
(bs, d_model), | ||
device="cuda:0", | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
ws = torch.rand( | ||
(num_total_experts, 2 * shard_intermediate_size, d_model), | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype, | ||
) | ||
|
||
w2s = torch.rand( | ||
(num_total_experts, d_model, shard_intermediate_size), | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype, | ||
) | ||
|
||
gating_output = F.softmax(torch.rand( | ||
(num_calls, bs, num_total_experts), | ||
device=hidden_states.device, | ||
dtype=torch.float32, | ||
), | ||
dim=-1) | ||
|
||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
|
||
start_event.record() | ||
for i in range(num_calls): | ||
hidden_states = method( | ||
hidden_states=hidden_states, | ||
w1=ws, | ||
w2=w2s, | ||
gating_output=gating_output[i], | ||
topk=2, | ||
renormalize=True, | ||
inplace=True, | ||
override_config=config, | ||
) | ||
end_event.record() | ||
end_event.synchronize() | ||
|
||
dur_ms = start_event.elapsed_time(end_event) / num_calls | ||
return dur_ms | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe | ||
|
||
__all__ = [ | ||
"fused_moe", | ||
] |
20 changes: 20 additions & 0 deletions
20
...model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{ | ||
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, | ||
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, | ||
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, | ||
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, | ||
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, | ||
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, | ||
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, | ||
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}, | ||
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4}, | ||
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}, | ||
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, | ||
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4}, | ||
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4} | ||
} |
24 changes: 24 additions & 0 deletions
24
...model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, | ||
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4}, | ||
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, | ||
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, | ||
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, | ||
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, | ||
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, | ||
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, | ||
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, | ||
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
This directory contains tuned configurations for different settings of the fused_moe kernel. | ||
For different settings of | ||
- E (number of experts) | ||
- N (intermediate size) | ||
- device_name (torch.cuda.get_device_name()) | ||
the JSON file contains a mapping from M (batch size) to the chosen configuration. | ||
|
||
The example configurations provided are for the Mixtral model for TP2 on H100 | ||
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have | ||
N = 7168 and for TP4 we have N = 3584. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters