Skip to content

Commit

Permalink
[FEAT][SwitchMixtureOfExperts]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 22, 2024
1 parent e7241b9 commit 3231e04
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 18 deletions.
6 changes: 3 additions & 3 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from moe_mamba.model import MoEMamba
import torch
from moe_mamba.model import MoEMamba


# Create a tensor of shape (1, 1024, 512)
Expand All @@ -23,4 +23,4 @@
out = model(x)

# Print the shape of the output tensor
print(out)
print(out)
97 changes: 97 additions & 0 deletions moe_mamba/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class FeedForward(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(FeedForward, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)

def forward(self, x):
return self.network(x)


class SwitchMixtureOfExperts(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
expert_output_dim,
num_experts,
top_k=1,
):
super(SwitchMixtureOfExperts, self).__init__()
self.num_experts = num_experts
self.top_k = top_k

# Router: MLP to generate logits for expert selection
self.router = nn.Linear(input_dim, num_experts)

# Experts: a list of FeedForward networks
self.experts = nn.ModuleList(
[
FeedForward(input_dim, hidden_dim, expert_output_dim)
for _ in range(num_experts)
]
)

def forward(self, x):
batch_size, seq_len, input_dim = x.shape
x_flat = x.view(-1, input_dim) # Flatten to [B*SEQLEN, dim]

# Routing tokens to experts
router_logits = self.router(x_flat)
topk_logits, topk_indices = router_logits.topk(
self.top_k, dim=1
)
topk_gates = F.softmax(
topk_logits, dim=1
) # Normalizing the top-k logits

# Initializing the output
output_flat = torch.zeros(
batch_size * seq_len,
self.experts[0].network[-1].out_features,
device=x.device,
)

# Distributing tokens to the experts and aggregating the results
for i in range(self.top_k):
expert_index = topk_indices[:, i]
gate_value = topk_gates[:, i].unsqueeze(1)

expert_output = torch.stack(
[
self.experts[idx](x_flat[n])
for n, idx in enumerate(expert_index)
]
)

output_flat += gate_value * expert_output

# Reshape the output to the original input shape [B, SEQLEN, expert_output_dim]
output = output_flat.view(batch_size, seq_len, -1)
return output


# Example Usage
batch_size = 32
seq_len = 10
input_dim = 512
hidden_dim = 2048
expert_output_dim = 1024
num_experts = 4
top_k = 1

moe = SwitchMixtureOfExperts(
input_dim, hidden_dim, expert_output_dim, num_experts, top_k
)
x = torch.rand(batch_size, seq_len, input_dim) # Example input tensor
output = moe(x)
print(output)
print(output.shape)
32 changes: 17 additions & 15 deletions moe_mamba/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,17 @@ def forward(self, x: Tensor, use_aux_loss=False):
top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1)

# Mask to enforce sparsity
mask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1)
mask = torch.zeros_like(gate_scores).scatter_(
1, top_k_indices, 1
)

# Combine gating scores with the mask
masked_gate_scores = gate_scores * mask

# Denominators
denominators = masked_gate_scores.sum(0, keepdim=True) + self.epsilon
denominators = (
masked_gate_scores.sum(0, keepdim=True) + self.epsilon
)

# Norm gate scores to sum to the capacity
gate_scores = (masked_gate_scores / denominators) * capacity
Expand Down Expand Up @@ -146,7 +150,9 @@ def forward(self, x: Tensor):
"""
# (batch_size, seq_len, num_experts)
gate_scores, loss = self.gate(x, use_aux_loss=self.use_aux_loss)
gate_scores, loss = self.gate(
x, use_aux_loss=self.use_aux_loss
)

# Dispatch to experts
expert_outputs = [expert(x) for expert in self.experts]
Expand All @@ -161,7 +167,9 @@ def forward(self, x: Tensor):
expert_outputs, dim=-1
) # (batch_size, seq_len, output_dim, num_experts)
if torch.isnan(stacked_expert_outputs).any():
stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0
stacked_expert_outputs[
torch.isnan(stacked_expert_outputs)
] = 0

# Combine expert outputs and gating scores
moe_output = torch.sum(
Expand All @@ -174,7 +182,7 @@ def forward(self, x: Tensor):
class MoEMambaBlock(nn.Module):
"""
MoEMambaBlock is a module that combines MambaBlock and SwitchMoE layers.
Args:
dim (int): The input dimension.
depth (int): The number of MambaBlock layers.
Expand All @@ -188,6 +196,7 @@ class MoEMambaBlock(nn.Module):
m_expand (int, optional): The expansion factor for the hidden dimension. Defaults to 4.
num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 4.
"""

def __init__(
self,
dim,
Expand All @@ -213,7 +222,7 @@ def __init__(
self.dim_head = dim_head
self.m_expand = m_expand
self.num_experts = num_experts

self.layers = nn.ModuleList([])
self.ffn_layers = nn.ModuleList([])
self.hidden_dim = dim * m_expand
Expand Down Expand Up @@ -241,10 +250,10 @@ def __init__(
def forward(self, x):
"""
Forward pass of the MoEMambaBlock module.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
Expand All @@ -254,13 +263,6 @@ def forward(self, x):
return x









class MoEMamba(nn.Module):
"""
MoEMamba is a PyTorch module that implements the MoE-Mamba model.
Expand Down

0 comments on commit 3231e04

Please sign in to comment.