Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mamba (minimal) #918

Closed
wants to merge 27 commits into from
Closed

Add Mamba (minimal) #918

wants to merge 27 commits into from

Conversation

swfsql
Copy link
Contributor

@swfsql swfsql commented Feb 2, 2024

Ports a minimal (non-optimized) implementation of Mamba (submitted on 2023-12), highly related to S4 (submitted on 2021-10).

In short and simple terms, Mamba is an alternative, with trade-offs, to the attention mechanism. Mamba can be used in RNNs that steps over a single sequence point at a time (instead of requiring to observe multiple sequence points at the same time, but it needs to carry the previous state over), and so it's memory and time requirements are fixed for each sequence point.

Implementation references:

This pr requires others (some of which are drafts or are useful for an app using this Module):

The commits specific to this Mamba pr are:

Tasks

  • Add an example to generate token prediction from loading the hf/state-spaces models, in the similar vein to the candle example.
  • Add initialization (how to determine the default initial values for a random model).
  • Be able to train (forward_mut and backpropagation). Note: This stateless interface is appropriate for training only, not for inference.
  • Add an extra stateful version for the calls, requiring the state cache alongside the usual input.
    • This requires a constant memory and time for each new token prediction. In my home computer, the mamba-130m f32 model generates ~23 token/s on the cpu.
    • Add for forward_mut for training. Note: This stateful interface is appropriate for inference only, not for training.
  • Find a way to avoid the Vec conversion near the end of selective_scan for the stateless version.
    • Add a softplus tensor operation that takes the threshold into account to avoid loss in precision. This could be an inline function, not needed to be an actual new operator.
  • Support Cuda (indirectly through the layers used, not intended to be an optimized/fused Cuda kernel).
  • Add tests.
  • Test an optimization.

Youtube Videos

S4
Mamba

rainiwu and others added 14 commits January 26, 2024 00:29
- Makes the safetensors module private.
  - Doesn't get exported on the preamble, avoiding a naming clash with the safetensors external crate.
- Change how and when the period is inserted.
  - This should make it closer to how the fields are accessed in the code.
- Add the try_normalize_rms related functions.
- Add the `LayerRMSNorm1D` module.
- Add `TrySplitShapeAlong` and `TrySplitTensorAlong`.
- Minor linting and docs fix.

TODO
- Check if the tape should be returned. If not, it can be removed from the interface.
- Add cuda kernel.
- Consider a different interface, where it could get split in more than two tensors - possibly stated on a vec.
  In this way it could get closer to the pytorch interface (chunks).
- Also added `from_fn` for Arrays.

Note: the interface currently requires two passes for construction, one for creating a list of tensors with NoneTape and another for putting tapes into those tensors.
@swfsql swfsql force-pushed the mamba-minimal branch 2 times, most recently from cadf65c to 9a2cf25 Compare February 7, 2024 22:02
This alternative method:
- Requires load/read to decide whether it should skip missing tensors;
- Requires load/read/save/write to decide how should keys be mapped.
@swfsql swfsql force-pushed the mamba-minimal branch 2 times, most recently from 1207867 to ce6d624 Compare February 9, 2024 17:27
@swfsql swfsql force-pushed the mamba-minimal branch 2 times, most recently from 3f392a6 to 165abc9 Compare February 9, 2024 17:55
- Add stateless forward impl.
  - Efficient for training (but training is not yet implemented).
  - Input requires the entire sequence, and requires no state cache.
  - Generates one output for each input sequence.
- Add stateful forward impl.
  - Efficient for inference.
  - Input requires the last single sequence point, and requires the last state cache.
  - Generates a single output referring to the last input.
@swfsql
Copy link
Contributor Author

swfsql commented Mar 1, 2024

I'll prioritize moving this experiment to a separate crate, but feel free to ping in case anyone have some question or suggestion.
Edit: moved to here.

@swfsql swfsql closed this Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants