Skip to content

Commit

Permalink
[Minor] Fix type annotation in fused moe (vllm-project#3045)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Feb 27, 2024
1 parent 67dd7f0 commit f7382f6
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import json
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import torch
import triton
Expand Down Expand Up @@ -137,7 +137,7 @@ def fused_moe_kernel(

def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Expand Down Expand Up @@ -185,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict):
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

Expand Down

0 comments on commit f7382f6

Please sign in to comment.