Simple, minimal implementation of Mamba in one file of PyTorch.
Featuring:
- Equivalent numerical output as official implementation for both forward and backward pass
- Simplified, readable, annotated code
Does NOT include:
- Speed. The official implementation is heavily optimized, and these optimizations are core contributions of the Mamba paper. I kept most implementations simple for readability.
- Proper parameter initialization (though this could be added without sacrificing readability)
See demo.ipynb for examples of prompt completions.
from model import Mamba
from transformers import AutoTokenizer
model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
generate(model, tokenizer, 'Mamba is the')
Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)
150 meters... 🫢 scary!
The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.
The official implementation is here: https://github.com/state-spaces/mamba/tree/main