diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 6165225d2d819..f22bcf97365d2 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -81,11 +81,13 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] - inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") + hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") + # vLLM uses 1D query [num_tokens, hidden_dim] + vllm_inputs = hf_inputs.flatten(0, 1) # Run forward passes for both MoE blocks - hf_states, _ = hf_moe.forward(inputs) - vllm_states = vllm_moe.forward(inputs) + hf_states, _ = hf_moe.forward(hf_inputs) + vllm_states = vllm_moe.forward(vllm_inputs) mixtral_moe_tol = { torch.float32: 1e-3, @@ -93,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype): torch.bfloat16: 1e-2, } - assert torch.allclose(hf_states, + assert torch.allclose(hf_states.flatten(0, 1), vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index eff93e706f5dc..08c851f85c17b 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -150,11 +150,11 @@ def pack_params(self): self.w2 = self.w2.view(len(w2), *w2s[0].shape) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.config.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) - # router_logits: (batch * sequence_length, n_experts) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = fused_moe(hidden_states, self.w1, @@ -169,8 +169,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - return final_hidden_states.view(batch_size, sequence_length, - hidden_dim) + return final_hidden_states.view(num_tokens, hidden_dim) class DeepseekAttention(nn.Module): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 68a3a298444ae..f0138b6f9b1db 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -124,9 +124,9 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_data[expert_id, :, :] = loaded_weight[:, shard] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_size = hidden_states.shape + num_tokens, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) - # router_logits: (batch * sequence_length, n_experts) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = fused_moe(hidden_states, self.ws, @@ -140,8 +140,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) - return final_hidden_states.view(batch_size, sequence_length, - hidden_size) + return final_hidden_states.view(num_tokens, hidden_size) class MixtralAttention(nn.Module): diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index b4dfc439d50e9..b8d6b45a36dd6 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -132,9 +132,9 @@ def __init__( linear_method=None) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape + num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) @@ -158,7 +158,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states.add_(current_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states).view( - batch_size, sequence_length, hidden_dim) + num_tokens, hidden_dim) class MixtralAttention(nn.Module):