Skip to content

Commit

Permalink
[Kernel] reloading fused_moe config on the last chunk (vllm-project#6210
Browse files Browse the repository at this point in the history
)
  • Loading branch information
avshalomman authored and jimpang committed Jul 24, 2024
1 parent 82f9898 commit 160b60a
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,31 @@ def get_default_config(
return config


def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
):
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
configs = get_moe_configs(E, N, dtype)

if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
return config


def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down Expand Up @@ -428,22 +453,16 @@ def fused_experts(hidden_states: torch.Tensor,
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)

if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
)

if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1.shape[2],
topk_ids.shape[1],
"float8" if use_fp8 else None)
config = get_config_func(M)

intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
Expand Down Expand Up @@ -478,6 +497,8 @@ def fused_experts(hidden_states: torch.Tensor,
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
# reload config to get better performance on the last chunk
config = get_config_func(tokens_in_chunk)

curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
Expand Down

0 comments on commit 160b60a

Please sign in to comment.