Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate alternative ggml_compute_forward_mul_mat_q_f32() implementation #909

Closed
ggerganov opened this issue Apr 12, 2023 · 7 comments · Fixed by #951
Closed

Investigate alternative ggml_compute_forward_mul_mat_q_f32() implementation #909

ggerganov opened this issue Apr 12, 2023 · 7 comments · Fixed by #951
Assignees
Labels
help wanted Extra attention is needed performance Speed related topics research 🔬

Comments

@ggerganov
Copy link
Owner

ggerganov commented Apr 12, 2023

This is the most computationally significant call in the entire transformer evaluation, so we have to be sure that it is running optimally.

It computes the matrix multiplication: z = x * y

  • x is quantized
  • y is F32
  • z is F32

Currently, it runs in 2 modes, depending on the tensor shapes:

  • (A) for bigger tensors, if BLAS is available, x is dequantized to F32 and we use sgemm to perform the matrix multiplication
  • (B) for smaller tensors, or if BLAS is not available, y is quantized to 4-bits on-the-fly and we use integer-based dot products to perform the matrix multiplication

The former method is much more accurate than the latter. This can be clearly observed during perplexity computations.
However, during text generation (i.e. batch = 1), it is not feasible to use it - my experience is that there is significant overhead of calling BLAS for smaller tensor shapes, typical for single-token inference calls.

There are at least two alternative modes of operation that can be explored:

  • (C) for smaller tensors, or if BLAS is not available, x is dequantized to F32 and we use ggml_vec_dot_f32() to perform the multiplication
  • (D) for smaller tensors, or if BLAS is not available, x is dequantized to F16, y is converted to F16 and we use ggml_vec_dot_f16() to perform the multiplication
  • (E) for smaller tensors, or if BLAS is not available, y is quantized on-the-fly to 8-bits and we use a new ggml dot-product call that operates on 4-bit x and 8-bit y. This call will still unpack x into 8-bits as usual and perform the 8-bit dot-product as in the existing routines, but in contrast to (B), y will already be unpacked to 8-bits and the precision loss will be significantly slower

To me it is not immediately clear if (C) or (D) would be significantly slower compared to (B), but they should be much more accurate compared to (B) and probably as accurate as (A).

I think, one has to be careful and choose the respective mode based on the tensor shapes, trying to find a good balance between speed and accuracy. Ideally, I am hoping after this investigation that we will achieve noticeable perplexity gain without using BLAS at the cost of a slightly slower single-token (i.e. batch = 1) computation.

Edit: after the analysis and discussion in #896 I added a new mode (E) which I think is very important to be explored. Unless I am missing something, I believe this mode can be exactly as efficient as (B), but with significantly higher accuracy. Much higher than what can be achieved via improving the quantization RMS.
So I believe we have to investigate this with very high priority.

@ggerganov ggerganov added performance Speed related topics research 🔬 labels Apr 12, 2023
@ggerganov ggerganov pinned this issue Apr 12, 2023
@ggerganov ggerganov added the help wanted Extra attention is needed label Apr 12, 2023
@sw
Copy link
Contributor

sw commented Apr 12, 2023

(C) for smaller tensors, or if BLAS is not available, x is dequantized to F32 and we use ggml_vec_dot_f32() to perform the multiplication

I guess you could simply have the vec_dot_q functions accept float for y? master...sw:llama.cpp:vec-dot-hybrid (quick&dirty, only AVX2 and scalar will even build...)

This is at least not catastrophically worse in terms of performance for text generation. But perplexity ETA went from 12 hours to 48 hours on my slowpoke Intel Core i3.

@F286
Copy link

F286 commented Apr 13, 2023

Are you using all the cores, and is the code fully vectorized?

I'm familiar with AVX but not NEON, but using AVX at least it likely would be possible to unpack the quantized values to floats in a vectorized way using broadcasting, and shuffling operations.

Depending on if the bottleneck is compute or memory bandwidth, it possibly could be worth it to do the compute operation in int16 space etc..

@ggerganov
Copy link
Owner Author

@sw and all

I just added a new potential mode (E) for investigation - I have a very good feeling about it, but not sure if I will get to playing with this idea soon. Feel free to investigate if you have the time.

@sw
Copy link
Contributor

sw commented Apr 13, 2023

  • E) for smaller tensors, or if BLAS is not available, y is quantized on-the-fly to 8-bits and we use a new ggml dot-product call that operates on 4-bit x and 8-bit y

Trying it out, seems like a no-brainer (except for the higher memory use for wdata)

master...sw:llama.cpp:mulmat-q8

(AVX2/AVX/scalar only)

If I squint with both eyes, I can even see a tiny speedup for text generation...

Perplexity:

master
59.43 seconds per pass - ETA 10.81 hours
[1]4.5741,[2]5.0601,[3]5.9543,[4]6.6776,[5]6.7585,[6]6.7574,[7]6.9444,[8]7.0527,^C

mulmat-q8
58.04 seconds per pass - ETA 10.56 hours
[1]4.3843,[2]4.9599,[3]5.8288,[4]6.4722,[5]6.5473,[6]6.5418,[7]6.7172,[8]6.8064,^C

@ggerganov
Copy link
Owner Author

ggerganov commented Apr 13, 2023

Yes, (E) is the way 🦙 !

I just implemented a SIMD ARM_NEON and initial perplexity gains are similar to your observations.
The good thing is that the speed is the same.

See #951

Edit: interestingly, the perplexity calculation does become faster for some reason:

master:

23.48 seconds per pass - ETA 4.27 hours
[1]4.6633,[2]5.2204,[3]6.0946,[4]6.74

q8_0:

19.45 seconds per pass - ETA 3.54 hours
[1]4.3858,[2]4.9662,[3]5.8382,[4]6.4833,^C

@ggerganov
Copy link
Owner Author

Btw, now I am very curious how the 2-bit quantized model will behave using the 8-bit intermediate data.
If it turns out it is not totally "drunk" as before, we can build a WASM page that runs 2-bit LLaMA inference. I think it will just fit in memory.

@glinscott
Copy link
Collaborator

Wow, that's a huge win. Appears to capture nearly all the benefit of running with f32/BLAS! Very exciting :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed performance Speed related topics research 🔬
Development

Successfully merging a pull request may close this issue.

4 participants