-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : add GPU support for Mamba models #6758
Comments
I tried to add these both operations (in a very naive form, just the boilerplate and copy/paste of the CPU implementation, replacing memcpy by for loops, and using 1 block with WARP_SIZE==32 threads) here: While it works in the sense that the implementation generates plausible-looking output, it differs from the CPU version. The initial prompt seems to be processed incorrectly, with the first generated token already looking off, and the generation not influenced by the content of the prompt. Furthermore, changing batch size (the -b parameter) affects the generation (it shouldn't). So maybe something else needs to be improved besides the new ops implementation. Sample correct output from CPU version:
Sample incorrect output from the GPU version:
The GPU output looks especially mangled in the example above, but with a sinlge-token prompt consisting of a space it almost looks "normal" for this model, which makes me suspect that maybe it's something unrelated to the new CUDA kernels:
(The same also happens when I reduce the CUDA kernels to use just a single thread rather than 32.) |
Another observation: with one offloaded layer (-ngl 1) I get the same output as for -ngl 0 (= CPU). But as the number of offloaded layers is increased, the output changes each time. |
You can add tests for the ops in |
Thanks. I added a test for each of the new ops here: 35f2f86 The CPU vs. GPU comparison ("test-backend-ops test -o SSM_CONV", "test-backend-ops test -o SSM_SCAN") passes. Maybe because of non-representative input. I set "sq", whatever it means, to zero because initializing it to any other value triggered an assertion on the CPU backend. (I should mention that I have little clue about how the Mamba algorithm works in theory or practice; I just attempted a blind port FWIW.) |
@jploski Thank you for doing this!
Zero is correct, But note that this is being changed in #7531. Lines 16349 to 16418 in 61200ef
Hmm. If both result in the same output with random input (for all tensors apart from I wonder if using diff --git a/llama.cpp b/llama.cpp
index 841be1de..0c710f8d 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6099,7 +6099,7 @@ static bool llm_load_tensors(
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
- layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
+ layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
@@ -6108,7 +6108,7 @@ static bool llm_load_tensors(
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
// no "weight" suffix for these
- layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
+ layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
// out_proj I have no idea what exactly this would do, but I guess it's worth a try? |
Replacing ctx_split by ctx_layer did not change anything. Maybe it's good to wait with further tests until #7531 is merged (there's slight hope that the observed issue might go away...) |
So I was impatient, and applied my changes to #7531 here: The result is as before. The CPU output is the same as before (meaning that #7531 did not break anything.) The CPU vs. GPU backend tests for ssm_conv and ssm_scan with random data pass, but the actual generation derails (even faster than before, namely already with -ngl 1). The differences feel like there's some unintended "latent space exploration" going on, but with more GPU layers they get worse. CPU: Sara and Ben are playing in the snow. They make a big snowman with a hat, a scarf and a carrot nose. They put on their warm clothes and go inside to get some cold chocolate. GPU (-ngl 1): Sara and Ben are playing in the snow. They make a big snowman with a hat, a scarf and a carrot nose. They put on their warm clothes and go inside to have some hot chocolate. |
|
Btw, on LLAMA_CUDA=1 make -j && ./perplexity -m models/mamba-130m/ggml-model-f16.gguf -f build-cublas/wikitext-2-raw/wiki.test.raw -ngl 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
llm_load_tensors: ggml ctx size = 0.12 MiB
llm_load_tensors: offloading 0 repeating layers to GPU
llm_load_tensors: offloaded 0/25 layers to GPU
llm_load_tensors: CPU buffer size = 256.96 MiB
.................................................
llama_new_context_with_model: n_ctx = 2048
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA_Host KV buffer size = 10.69 MiB
llama_new_context_with_model: KV self size = 10.69 MiB, K (f32): 1.69 MiB, V (f32): 9.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0.77 MiB
llama_new_context_with_model: CUDA0 compute buffer size = 184.98 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 10.07 MiB
llama_new_context_with_model: graph nodes = 896
llama_new_context_with_model: graph splits = 292
system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
perplexity: tokenizing the input ..
perplexity: tokenization took 807.662 ms
perplexity: calculating perplexity over 650 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 1.21 seconds per pass - ETA 3.27 minutes
[1]54658547752546633926049792.0000,[2]179461831369242333449027584.0000,[3]31554836485675804180611072.0000,[4]6706279437817913437847552.0000 If I build without CUDA, it is OK: make -j && ./perplexity -m models/mamba-130m/ggml-model-f16.gguf -f build-cublas/wikitext-2-raw/wiki.test.raw
llm_load_tensors: ggml ctx size = 0.12 MiB
llm_load_tensors: CPU buffer size = 256.96 MiB
.................................................
llama_new_context_with_model: n_ctx = 2048
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CPU KV buffer size = 10.69 MiB
llama_new_context_with_model: KV self size = 10.69 MiB, K (f32): 1.69 MiB, V (f32): 9.00 MiB
llama_new_context_with_model: CPU output buffer size = 0.77 MiB
llama_new_context_with_model: CPU compute buffer size = 99.71 MiB
llama_new_context_with_model: graph nodes = 896
llama_new_context_with_model: graph splits = 1
system_info: n_threads = 16 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
perplexity: tokenizing the input ..
perplexity: tokenization took 805.807 ms
perplexity: calculating perplexity over 650 chunks, n_ctx=512, batch_size=2048, n_seq=4
perplexity: 1.43 seconds per pass - ETA 3.85 minutes
[1]17.9926,[2]26.3160,[3]26.9249,[4]27.3857,[5]25.7235,[6]24.7161,[7]24.1991,[8]24.1345 |
Unfortunately, since the model output deteriorates to the point of not generating eos token, I am pretty sure it's not a slight difference in this case (even without inspecting the perplexity metric; as @ggerganov pointed out, there seems to be something wrong with the perplexity calculation for CUDA version as well). |
@jploski Try making the tests process more tokens and more sequences at a time. Patch for the tests (click to expand)diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index e902a72e..ecfcdbc6 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1562,18 +1562,26 @@ struct test_leaky_relu : public test_case {
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
+ const int64_t d_conv;
+ const int64_t d_inner;
+ const int64_t n_seq_tokens;
+ const int64_t n_seqs;
std::string vars() override {
- return VARS_TO_STR4(type, 3, 1536, 4);
+ return VARS_TO_STR5(type, d_conv, d_inner, n_seq_tokens, n_seqs);
}
- test_ssm_conv(ggml_type type = GGML_TYPE_F32)
- : type(type) {}
+ test_ssm_conv(ggml_type type = GGML_TYPE_F32,
+ int64_t d_conv = 4,
+ int64_t d_inner = 1536,
+ int64_t n_seq_tokens = 7,
+ int64_t n_seqs = 2)
+ : type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 3, 1536, 1);
- ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 1);
- ggml_tensor * c = ggml_new_tensor_2d(ctx, type, 4, 1536);
+ ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_conv - 1, d_inner, n_seqs);
+ ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
+ ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner);
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c);
return out;
}
@@ -1582,21 +1590,29 @@ struct test_ssm_conv : public test_case {
// GGML_OP_SSM_SCAN
struct test_ssm_scan : public test_case {
const ggml_type type;
+ const int64_t d_state;
+ const int64_t d_inner;
+ const int64_t n_seq_tokens;
+ const int64_t n_seqs;
std::string vars() override {
- return VARS_TO_STR4(type, 16, 1536, 2);
+ return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
}
- test_ssm_scan(ggml_type type = GGML_TYPE_F32)
- : type(type) {}
+ test_ssm_scan(ggml_type type = GGML_TYPE_F32,
+ int64_t d_state = 16,
+ int64_t d_inner = 1536,
+ int64_t n_seq_tokens = 7,
+ int64_t n_seqs = 2)
+ : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 16, 1536, 1);
- ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 2);
- ggml_tensor * dt = ggml_new_tensor_2d(ctx, type, 1536, 2);
- ggml_tensor * A = ggml_new_tensor_2d(ctx, type, 16, 1536);
- ggml_tensor * B = ggml_new_tensor_2d(ctx, type, 16, 2);
- ggml_tensor * C = ggml_new_tensor_2d(ctx, type, 16, 2);
+ ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_state, d_inner, n_seqs);
+ ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
+ ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
+ ggml_tensor * A = ggml_new_tensor_2d(ctx, type, d_state, d_inner);
+ ggml_tensor * B = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
+ ggml_tensor * C = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
return out;
} Looking at the tests, it doesn't really compare all outputs, since in #7531 the state inputs are also written back. I think I should change the output of the SSM operators back to concatenated tensors, just to make the output comparisons easier, and to make the graph dependencies more representative of the actual data dependencies.
@ggerganov My hypothesis in this case is that it's probably related to where the KV cache is located.
vs
But maybe |
|
Thanks, I applied your patch, and the test for ssm_conv fails now, which sounds like good progress!
|
This should fix the CUDA ppl. Not sure if both conts are actually needed. diff --git a/llama.cpp b/llama.cpp
index 841be1de..b311467a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -10499,6 +10499,9 @@ struct llm_build_context {
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
+ x = ggml_cont(ctx0, x);
+ z = ggml_cont(ctx0, z);
+
// conv
{
// Custom operator which is needed only to ease simultaneous sequence processing. |
I tried applying the patch to my PR #7531 branch around here: https://github.com/jploski/llama.cpp/blob/mamba_cuda_pr7531/llama.cpp#L8694 - and as a result perplexity calculation on CPU appeared ok. But with this fix the GPU-based generation fails as follows (so did not commit it):
I got the failing ssm_conv test working now - even if I increase n_tokens from 7 to 1024 and n_seqs to 8. The fail was due to a misplaced parenthesis (jploski@982a22e), which was a mistake I made today. However, even with the test working, the output qualitatively degrades with an increasing number of GPU layers (as was the case with the original non-PR-7531 version, which did not suffer from that parenthesis bug). So I'd say the overall state of the GPU port is now as bad as yesterday, but not worse - and we have more representative tests for ssm_conv and ssm_scan in place thanks to @compilade. |
I managed to do the equivalent of Lines 8739 to 8751 in 8fb57ac
But performance is measurably worse for prompt processing (around I'm not sure if I should remove @jploski note that I also changed |
I updated my branch to catch up, but the non-ssm_conv implementation now fails on the GPU because apparently MUL_MAT is not completely implemented (or maybe you could work around with some transpose, didn't examine it further):
I imagine that a version with a more coarse-granular op might offer more potential for optimization on the GPU, keyword "fused kernel" (not claiming that I know how to do it for CUDA specifically, but generally knowing "what comes next" in a computation pipeline aids optimization). |
@jploski Thank you for trying this. I did not expect it would fail like this. It seems this is also a problem on SYCL and Vulkan, which also expect A transpose can't work around this. The first tensor in that I guess it will likely be simpler (and faster) with I'm not sure if I should or not fuse the concat into it. The implementation is much simpler when it's separate (the shift doesn't need to move the state, it's only +1 on a pointer, no need to shift a temporary buffer), but it's also slighly slower ( |
The op can be extended to support broadcast in all backends. If it's not too much hassle, try to keep the |
I think it's also possible to replace that (around 95% |
I compared the tensor contents of CPU and GPU versions to find the source of first discrepancy. While small rounding errors which might account for confusing hot chocolate with cold chocolate are evident from the get go, the first GPU tensor with a really alarming difference was the one resulting from ggml_silu(ctx, z). I wrapped it with ggml_cont to fix: With that (and using ssm_conv for reasons explained previously) the tensor looks almost identical. Although the facts about Sara's and Ben's encounter with the abominable snowman still do not match between CPU and GPU, I no longer get strange tokens, and it remains on topic regardless of the CPU/GPU layer split. |
I pushed a new branch https://github.com/jploski/llama.cpp/tree/falcon_mamba_cuda, which is based on the recent master of llama.cpp with reapplied patches from my original attempt in June (https://github.com/jploski/llama.cpp/tree/mamba_cuda_pr7531) With one small additional fix (fae826f) this implementation is now working. I tested it with https://huggingface.co/tiiuae/falcon-mamba-7b-instruct. It produces coherent output in f16, q8_0 and q5_k_m quantizations. Maybe someone with CUDA experience could have a look at the ggml/src/ggml-cuda/ssm_scan.cu and ggml/src/ggml-cuda/ssm_conv.cu regarding grid configuration and memory access patterns in those kernels (remember I just copy-pasted the CPU version without regard for what is good for CUDA's parallel execution; so it could probably be optimized for performance). |
Maybe it's just a small promotion, we have implemented initial SYCL support and CUDA support for RWKV, everyone is welcome to use and improve! #10133 |
Can someone tell me the typical sizes for input x (B, L, N) and state (B, D, N)? The values of D and N. I am not familiar with the MAMBA model, but I would like to try writing the CUDA kernel for ssm_conv and ssm_scan. |
Those kernels are already implemented in PR #9186 |
I have looked at the code of PR9186 and I think the CUDA code in the scan section is more like CPU code than GPU code. This part of CUDA uses one warp and one block to loop within the block. Perhaps we can consider dividing dimension D into multiple blocks and multiple warp, and calculate them simultaneously. In addition, I saw in the MAMBA paper that placing state on shared memory can also reduce the read and write of state, so I believe there are some possibilities for optimization. I tried running mamba-130M, which has a D of 1536 and N of 16, and mamba-370M, which has a D of 2048 (if I remember correctly, the running data is not with me now), so I think splitting D is a reasonable optimization solution. |
Recently, initial Mamba support (CPU-only) has been introduced in #5328 by @compilade
In order to support running these models efficiently on the GPU, we seem to be lacking kernel implementations for the following 2 ops:
GGML_OP_SSM_CONV
GGML_OP_SSM_SCAN
Creating this issue to keep track of this and give more visibility of this feature. Help with implementing the missing kernels for CUDA and Metal (and other backends potentially) is welcome. We can also discuss if anything else is required to better support this architecture in
llama.cpp
The text was updated successfully, but these errors were encountered: