Skip to content

Commit

Permalink
[Bugfix] Support 2D input shape in MoE layer (vllm-project#6287)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jul 10, 2024
1 parent 8a924d2 commit e72ae80
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ def __init__(self,
tp_size=tp_size)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(num_tokens, hidden_size)
return final_hidden_states.view(orig_shape)


class MixtralAttention(nn.Module):
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def __init__(
bias=False)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
if self.shared_expert is not None:
Expand All @@ -145,7 +147,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states.view(num_tokens, hidden_dim)
return final_hidden_states.view(orig_shape)


class Qwen2MoeAttention(nn.Module):
Expand Down

0 comments on commit e72ae80

Please sign in to comment.