From 3231e04e82503918c34431aa224e05e0b73a1483 Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 21 Jan 2024 22:22:08 -0500 Subject: [PATCH] [FEAT][SwitchMixtureOfExperts] --- example.py | 6 +-- moe_mamba/block.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++ moe_mamba/model.py | 32 ++++++++------- 3 files changed, 117 insertions(+), 18 deletions(-) create mode 100644 moe_mamba/block.py diff --git a/example.py b/example.py index 6665b03..233fa36 100644 --- a/example.py +++ b/example.py @@ -1,5 +1,5 @@ -import torch -from moe_mamba.model import MoEMamba +import torch +from moe_mamba.model import MoEMamba # Create a tensor of shape (1, 1024, 512) @@ -23,4 +23,4 @@ out = model(x) # Print the shape of the output tensor -print(out) \ No newline at end of file +print(out) diff --git a/moe_mamba/block.py b/moe_mamba/block.py new file mode 100644 index 0000000..670450a --- /dev/null +++ b/moe_mamba/block.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FeedForward(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super(FeedForward, self).__init__() + self.network = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, x): + return self.network(x) + + +class SwitchMixtureOfExperts(nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + expert_output_dim, + num_experts, + top_k=1, + ): + super(SwitchMixtureOfExperts, self).__init__() + self.num_experts = num_experts + self.top_k = top_k + + # Router: MLP to generate logits for expert selection + self.router = nn.Linear(input_dim, num_experts) + + # Experts: a list of FeedForward networks + self.experts = nn.ModuleList( + [ + FeedForward(input_dim, hidden_dim, expert_output_dim) + for _ in range(num_experts) + ] + ) + + def forward(self, x): + batch_size, seq_len, input_dim = x.shape + x_flat = x.view(-1, input_dim) # Flatten to [B*SEQLEN, dim] + + # Routing tokens to experts + router_logits = self.router(x_flat) + topk_logits, topk_indices = router_logits.topk( + self.top_k, dim=1 + ) + topk_gates = F.softmax( + topk_logits, dim=1 + ) # Normalizing the top-k logits + + # Initializing the output + output_flat = torch.zeros( + batch_size * seq_len, + self.experts[0].network[-1].out_features, + device=x.device, + ) + + # Distributing tokens to the experts and aggregating the results + for i in range(self.top_k): + expert_index = topk_indices[:, i] + gate_value = topk_gates[:, i].unsqueeze(1) + + expert_output = torch.stack( + [ + self.experts[idx](x_flat[n]) + for n, idx in enumerate(expert_index) + ] + ) + + output_flat += gate_value * expert_output + + # Reshape the output to the original input shape [B, SEQLEN, expert_output_dim] + output = output_flat.view(batch_size, seq_len, -1) + return output + + +# Example Usage +batch_size = 32 +seq_len = 10 +input_dim = 512 +hidden_dim = 2048 +expert_output_dim = 1024 +num_experts = 4 +top_k = 1 + +moe = SwitchMixtureOfExperts( + input_dim, hidden_dim, expert_output_dim, num_experts, top_k +) +x = torch.rand(batch_size, seq_len, input_dim) # Example input tensor +output = moe(x) +print(output) +print(output.shape) \ No newline at end of file diff --git a/moe_mamba/model.py b/moe_mamba/model.py index e266f4e..10b24d0 100644 --- a/moe_mamba/model.py +++ b/moe_mamba/model.py @@ -51,13 +51,17 @@ def forward(self, x: Tensor, use_aux_loss=False): top_k_scores, top_k_indices = gate_scores.topk(1, dim=-1) # Mask to enforce sparsity - mask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1) + mask = torch.zeros_like(gate_scores).scatter_( + 1, top_k_indices, 1 + ) # Combine gating scores with the mask masked_gate_scores = gate_scores * mask # Denominators - denominators = masked_gate_scores.sum(0, keepdim=True) + self.epsilon + denominators = ( + masked_gate_scores.sum(0, keepdim=True) + self.epsilon + ) # Norm gate scores to sum to the capacity gate_scores = (masked_gate_scores / denominators) * capacity @@ -146,7 +150,9 @@ def forward(self, x: Tensor): """ # (batch_size, seq_len, num_experts) - gate_scores, loss = self.gate(x, use_aux_loss=self.use_aux_loss) + gate_scores, loss = self.gate( + x, use_aux_loss=self.use_aux_loss + ) # Dispatch to experts expert_outputs = [expert(x) for expert in self.experts] @@ -161,7 +167,9 @@ def forward(self, x: Tensor): expert_outputs, dim=-1 ) # (batch_size, seq_len, output_dim, num_experts) if torch.isnan(stacked_expert_outputs).any(): - stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0 + stacked_expert_outputs[ + torch.isnan(stacked_expert_outputs) + ] = 0 # Combine expert outputs and gating scores moe_output = torch.sum( @@ -174,7 +182,7 @@ def forward(self, x: Tensor): class MoEMambaBlock(nn.Module): """ MoEMambaBlock is a module that combines MambaBlock and SwitchMoE layers. - + Args: dim (int): The input dimension. depth (int): The number of MambaBlock layers. @@ -188,6 +196,7 @@ class MoEMambaBlock(nn.Module): m_expand (int, optional): The expansion factor for the hidden dimension. Defaults to 4. num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 4. """ + def __init__( self, dim, @@ -213,7 +222,7 @@ def __init__( self.dim_head = dim_head self.m_expand = m_expand self.num_experts = num_experts - + self.layers = nn.ModuleList([]) self.ffn_layers = nn.ModuleList([]) self.hidden_dim = dim * m_expand @@ -241,10 +250,10 @@ def __init__( def forward(self, x): """ Forward pass of the MoEMambaBlock module. - + Args: x (torch.Tensor): The input tensor. - + Returns: torch.Tensor: The output tensor. """ @@ -254,13 +263,6 @@ def forward(self, x): return x - - - - - - - class MoEMamba(nn.Module): """ MoEMamba is a PyTorch module that implements the MoE-Mamba model.