Several works such as Hydra and MambaMixer have formulated bidirectionality through qusiseperable matrices. I highly recommend reading both of these papers to understand how bi-directionality can be done with Mamba. Unfortunately, neither implementation has an optimized kernel, which often increases the training and inference time by more than 2x.
To overcome this issue, I wrote the following GPU kernel which both reduces the memory overhead and the latency. It does so by fusing kernels together to minimize the number of loads and stores from global memory.
The idea of bi-directionality is to formulate the "Attention Matrix" as a quasiseperable matrix, meaning that the matrix can be decomposed into two semiseperable matrices and a diagonal matrix. The formulation is still subquadratic, as both semiseperable matrices and the diagonal matrix can be computed linearly. Hydra formulates the quasiseperable matrix in the following format:
This kernel formulates the quasiseperable matrix as follows:
Why?: The main reasoning is simplicity. The shift operation adds a lot of complexity to the kernel, and furthermore, shifting in SRAM is not currently supported by Triton. As I don't want to rewrite the entire kernel in CUDA, I compromise with the above formulation.
To access the kernels, run:
pip install -e .
You can access the normal ssd
kernels through ssd.uni
. You can access the bidirectional kernels through ssd.bi
.
Coming soon.
There will be both a functional and layerwise access to the bi-directional kernel. I have outlined both below:
Currently, the functional access the bi-directional kernel can be done using the following import:
from ssd.bi.ssd_combined import bimamba_chunk_scan_combined
# The Doc string of bimamba_chunk_scan_combined
def bimamba_chunk_scan_combined(...) -> torch.Tensor:
"""
Argument:
x: (batch, seqlen, nheads, headdim)
dt: (batch, seqlen, nheads)
A: (nheads)
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
chunk_size: int
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
dt_bias: (nheads,)
dt_softplus: Whether to apply softplus to dt
Return:
out: (batch, seqlen, nheads, headdim)
"""
...
Note Currently using seq_idx
like in Mamba2 causal is unsupported. Additionally passing init_hidden_states
is also unsupported.
Alternatively, you can also access it through a Module API, which is similar to a Mamba2 Layer:
Bi-Directional Kernel
import torch
from ssd.modules import BiMamba2
batch, length, dim = 2, 64, 32
x = torch.randn(batch, length, dim).to("cuda")
model = BiMamba2(
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=7, # Local non-causal convolution width
expand=2, # Block expansion factor
use_mem_eff_path=False, # Use memory efficient path is not allowed yet
).to("cuda")
y = model(x)
assert y.shape == x.shape
g = torch.randn_like(y)
y.backward(g)
- Write FWD Implementation
- Debug and Test FWD implementation
- Write BWD Implementation
- Debug and Test BWD Implementation
- Create PyPi Package
- Add more benchmarks
The benchmarking code can be found in the benchmark
folder. It can be run by using the following command:
python benchmark/benchmark_fwd_all.py
To find additonal benchmarks, please checkout BENCHMARKS.md.
Bi-Mamba2 is almost ~3x-4x times faster then naively flipping and accumulating the
Here is a comparisson of the fwd pass of Bi-Mamba2 v. Naively Flipping Mamba2 v. Causal Mamba2.
Here is a comparisson of the bwd pass of Bi-Mamba2 v. Naively Flipping Mamba2 v. Causal Mamba2.
Memory benchmarks coming soon.
I created a fairly through test suite to ensure that Bi-Mamba2 is correct. To run a test, simply use pytest along with the specific test file. For example, to run a test for the fwd pass of the kernel, use:
python -m pytest -x -s -v tests/test_fwd_scan.py::TestFwd
If you find a bug please tell me, and I'll fix it as fast as I can.
If you find this kernel useful please cite Mamba, Hydra, and MambaMixer (They are amazing works!).
Give this repo a star also :)
This library use's Mamba2's Triton kernel as a starting ground. The kernel's change a significant amount to support bi-directionality, however, the underlying algorithm and idea is still Albert Gu's and Tri Dao's.