Skip to content

Commit

Permalink
[FEAT][MoEMambaBlock]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 22, 2024
1 parent 3a99c01 commit ec5673a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,23 @@ Implementation of MoE Mamba from the paper: "MoE-Mamba: Efficient Selective Stat
pip install moe-mamba
```

# Usage
## Usage

### `MoEMambaBlock`
```python
print("hello world")
import torch
from moe_mamba import MoEMambaBlock

x = torch.randn(1, 10, 512)
model = MoEMambaBlock(
dim=512,
depth=6,
d_state=128,
expand=4,
num_experts=4,
)
out = model(x)
print(out)

```

Expand Down
42 changes: 28 additions & 14 deletions moe_mamba/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
from swarms_torch import SwitchMoE

class MoEMambaBlock(nn.Module):
"""
MoEMambaBlock is a module that combines MambaBlock and SwitchMoE layers.
Args:
dim (int): The input dimension.
depth (int): The number of MambaBlock layers.
d_state (int): The dimension of the state.
causal (bool, optional): Whether to use causal attention. Defaults to True.
dropout (float, optional): The dropout rate. Defaults to 0.1.
shared_qk (bool, optional): Whether to share the query and key projections. Defaults to True.
exact_window_size (bool, optional): Whether to use exact window size for attention. Defaults to False.
heads (int, optional): The number of attention heads. Defaults to None.
dim_head (int, optional): The dimension of each attention head. Defaults to None.
m_expand (int, optional): The expansion factor for the hidden dimension. Defaults to 4.
num_experts (int, optional): The number of experts in the SwitchMoE layer. Defaults to 4.
"""
def __init__(
self,
dim,
Expand Down Expand Up @@ -42,7 +58,6 @@ def __init__(
dim=dim,
depth=depth,
d_state=d_state,
expand=m_expand,
*args,
**kwargs,
)
Expand All @@ -54,26 +69,25 @@ def __init__(
hidden_dim=self.hidden_dim,
output_dim=dim,
num_experts=num_experts,
mult=m_expand,
)
)

def forward(self, x):
for attn, moe in zip(self.layers, self.ffn_layers):
x, _ = moe(x)
x = attn(x, x, x) + x
"""
Forward pass of the MoEMambaBlock module.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
for mamba, moe in zip(self.layers, self.ffn_layers):
x = mamba(x)
x, _ = moe(x)
return x

x = torch.randn(1, 10, 512)
model = MoEMambaBlock(
dim=512,
depth=6,
d_state=128,
expand=4,
num_experts=4,
)
model(x).shape




Expand Down

0 comments on commit ec5673a

Please sign in to comment.