diff --git a/README.md b/README.md index bc067b1..9dd6181 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,23 @@ Implementation of MoE Mamba from the paper: "MoE-Mamba: Efficient Selective Stat pip install moe-mamba ``` -# Usage +## Usage + +### `MoEMambaBlock` ```python -print("hello world") +import torch +from moe_mamba import MoEMambaBlock + +x = torch.randn(1, 10, 512) +model = MoEMambaBlock( + dim=512, + depth=6, + d_state=128, + expand=4, + num_experts=4, +) +out = model(x) +print(out) ``` diff --git a/moe_mamba/model.py b/moe_mamba/model.py index 6000691..de0de35 100644 --- a/moe_mamba/model.py +++ b/moe_mamba/model.py @@ -4,6 +4,22 @@ from swarms_torch import SwitchMoE class MoEMambaBlock(nn.Module): + """ + MoEMambaBlock is a module that combines MambaBlock and SwitchMoE layers. + + Args: + dim (int): The input dimension. + depth (int): The number of MambaBlock layers. + d_state (int): The dimension of the state. + causal (bool, optional): Whether to use causal attention. Defaults to True. + dropout (float, optional): The dropout rate. Defaults to 0.1. + shared_qk (bool, optional): Whether to share the query and key projections. Defaults to True. + exact_window_size (bool, optional): Whether to use exact window size for attention. Defaults to False. + heads (int, optional): The number of attention heads. Defaults to None. + dim_head (int, optional): The dimension of each attention head. Defaults to None. + m_expand (int, optional): The expansion factor for the hidden dimension. Defaults to 4. + num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 4. + """ def __init__( self, dim, @@ -42,7 +58,6 @@ def __init__( dim=dim, depth=depth, d_state=d_state, - expand=m_expand, *args, **kwargs, ) @@ -54,26 +69,25 @@ def __init__( hidden_dim=self.hidden_dim, output_dim=dim, num_experts=num_experts, - mult=m_expand, ) ) def forward(self, x): - for attn, moe in zip(self.layers, self.ffn_layers): - x, _ = moe(x) - x = attn(x, x, x) + x + """ + Forward pass of the MoEMambaBlock module. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + for mamba, moe in zip(self.layers, self.ffn_layers): + x = mamba(x) x, _ = moe(x) return x -x = torch.randn(1, 10, 512) -model = MoEMambaBlock( - dim=512, - depth=6, - d_state=128, - expand=4, - num_experts=4, -) -model(x).shape +