Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speculative : PoC for speeding-up inference via speculative sampling #2926

Merged
merged 3 commits into from
Sep 3, 2023

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Aug 31, 2023

ref: #2030

Initial results with the following config indicate a factor of x2 speed-up:

  • target model: Code Llama 34B F16
  • draft model: Code Llama 7B Q4_10

Todo:

Usage:

# standard F16 sampling (using "main" tool)
./bin/main \
-m ../models/codellama-34b/ggml-model-f16.gguf \
-p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include" \
-e -ngl 1 -t 4 -n 256 -c 4096 -s 8 --top_k 1

# speculative F16 sampling with Q4_1 draft (using "speculative" tool)

# example 0
./bin/speculative \
-m ../models/codellama-34b/ggml-model-f16.gguf \
-md ../models/codellama-7b/ggml-model-q4_1.gguf \
-p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include" \
-e -ngl 1 -t 4 -n 256 -c 4096 -s 8 --top_k 1 --draft 16

# example 1
./bin/speculative \
-m ../models/codellama-34b/ggml-model-f16.gguf \
-md ../models/codellama-7b/ggml-model-q4_1.gguf \
-p "// Dijkstra algorithm in C++ (4 spaces indentation + detailed comments) + sample usage:\n\n" \
-e -ngl 1 -t 4 -n 4096 -c 4096 -s 20 --top_k 1 --draft 16

# example 2
./bin/speculative \
-m ../models/codellama-34b/ggml-model-f16.gguf \
-md ../models/codellama-7b/ggml-model-q4_1.gguf \
-p "# Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" \
-e -ngl 1 -t 4 -n 512 -c 4096 -s 20 --top_k 1 --draft 16

In some cases (e.g. low temperature code generation), this clocks at about ~25 t/s for a full-precision F16 34B model on M2 Ultra.

 // Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:

#include <stdio.h>
#include <stdlib.h>

// Swap two elements of an array
void swap(int *a, int i, int j) {
    int t = a[i];
    a[i] = a[j];
    a[j] = t;
}

// Partition the array around a pivot
int partition(int *a, int l, int r) {
    // Choose the rightmost element as pivot
    int x = a[r], i = l;
    for (int j = l; j <= r - 1; j++) {
        if (a[j] <= x) {
            swap(a, i, j);
            i++;
        }
    }
    swap(a, i, r);
    return i;
}

// Quicksort implementation
void quickSort(int *a, int l, int r) {
    if (l < r) {
        // Partition the array around a pivot and get index of pivot
        int p = partition(a, l, r);
        // Recursively sort elements before and after pivot
        quickSort(a, l, p - 1);
        quickSort(a, p + 1, r);
    }
}

// Print an array on a single line
void printArray(int *a, int n) {
    for (int i = 0; i < n; i++) {
        printf("%d ", a[i]);
    }
    printf("\n");
}

// Driver code
int main() {
    // Sample array:
    int arr[] = {12, 11, 13, 5, 6};
    int n = sizeof(arr) / sizeof(arr[0]);

    printf("Given array is \n");
    printArray(arr, n);

    quickSort(arr, 0, n - 1);

    printf("\nSorted array is \n");
    printArray(arr, n);
}


encoded   25 tokens in    0.319 seconds, speed:   78.298 t/s
decoded  471 tokens in   19.278 seconds, speed:   24.432 t/s

n_draft   = 16
n_predict = 471
n_drafted = 481
n_accept  = 398
accept    = 82.744%

draft:

llama_print_timings:        load time =   359.01 ms
llama_print_timings:      sample time =   913.76 ms /     1 runs   (  913.76 ms per token,     1.09 tokens per second)
llama_print_timings: prompt eval time =    47.18 ms /    25 tokens (    1.89 ms per token,   529.89 tokens per second)
llama_print_timings:        eval time =  6778.24 ms /   537 runs   (   12.62 ms per token,    79.22 tokens per second)
llama_print_timings:       total time = 19596.36 ms

target:

llama_print_timings:        load time =  3315.15 ms
llama_print_timings:      sample time =   330.96 ms /   471 runs   (    0.70 ms per token,  1423.13 tokens per second)
llama_print_timings: prompt eval time =  9864.56 ms /   563 tokens (   17.52 ms per token,    57.07 tokens per second)
llama_print_timings:        eval time =  1565.36 ms /    15 runs   (  104.36 ms per token,     9.58 tokens per second)
llama_print_timings:       total time = 19960.61 ms
ggml_metal_free: deallocating
ggml_metal_free: deallocating
  • Standard F16 34B Code Llama sampling: ~10 t/s
speculative-1.mp4
  • Speculative F16 34B Code Llama + Q4_1 7B Code Llama: ~20 t/s
speculative-2.mp4
speculative-0.mp4

@JohannesGaessler
Copy link
Collaborator

I tested this PR with 70b f16 and 7b q8_0. When using CPU only and the default sampling parameters the average t/s increases from 0.44 to 0.52:

 Llamas are animals that 90% of the world doesn’t know about, but they’re actually super important.
Llamas are part of a group of animals called Camelidae which have four legs and a tail, like cows or sheep do; their head is also shaped differently than other members in this family (the shape resembles an alpaca). They were domesticated around 4500 BCE near Peru’s coast.
Llamas are the most important animal to Andean life and culture, with a history that dates back to at least 1200 BC. Llama meat

generated 129 tokens in 248.340 seconds, speed: 0.519 t/s

n_draft   = 8
n_predict = 129
n_drafted = 132
n_accept  = 47
accept    = 35.606%

draft:

llama_print_timings:        load time = 42345.19 ms
llama_print_timings:      sample time =     0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings: prompt eval time =   363.60 ms /     8 tokens (   45.45 ms per token,    22.00 tokens per second)
llama_print_timings:        eval time = 27360.44 ms /   214 runs   (  127.85 ms per token,     7.82 tokens per second)
llama_print_timings:       total time = 290684.19 ms

target:

llama_print_timings:        load time = 44875.62 ms
llama_print_timings:      sample time =    85.21 ms /   129 runs   (    0.66 ms per token,  1513.96 tokens per second)
llama_print_timings: prompt eval time = 170542.94 ms /   196 tokens (  870.12 ms per token,     1.15 tokens per second)
llama_print_timings:        eval time = 57030.81 ms /    25 runs   ( 2281.23 ms per token,     0.44 tokens per second)
llama_print_timings:       total time = 295872.32 ms

@JohannesGaessler
Copy link
Collaborator

With 70b q6_K and 7b q8_0 on 3x P40 the performance it 3.63 t/s which is only ~half of what I get with regular inference. The problem is most likely that the CUDA code that I wrote has not been optimized for this use case. I would expect the performance to end up better given the right optimizations though.

@ggerganov
Copy link
Owner Author

Yes, so far I observe that this strategy is most effective for code generation with ~2x speedup for 34B/7B and ~1.5x speedup for 13B/7B pairs using --top_k 1 sampling. In such scenarios, the acceptance rate is above 75% which helps a lot.

If you try to generate free-form text, then the acceptance rate drops significantly and the method does not offer any benefit. I'm still tweaking, but my gut feeling is that this might be very efficient for cases where we have a very constrained grammar.

@JohannesGaessler
Copy link
Collaborator

Even for free-form text I would expect there to be quite a large speedup if you have a weak GPU and the CLI allows you to set the GPU layers for the draft and the target model separately. If you can fully offload the draft model it's essentially being evaluated instantaneously compared to the larger model on the CPU so even an acceptance rate of only 33% should translate to +50% t/s.

my gut feeling is that this might be very efficient for cases where we have a very constrained grammar.

I would expect this technique to also work very well for cases where you have a lot of unconventional terms that consist of multiple tokens: in those situations the first token of such a term is almost always followed by the other tokens of the term. So I would expect large performance gains for program code and non-English languages.

@JohannesGaessler
Copy link
Collaborator

JohannesGaessler commented Aug 31, 2023

Out of curiosity I tested LLaMA 2 7b q8_0 with itself as a draft model. With free-form text and the default sampling parameters it had an acceptance rate of 37%. Meanwhile, when I used 7b q8_0 as the draft model for 70b f16 with --top-k 1 I got an acceptance rate of 69%. This suggests to me that sampling is more important than the quality of the draft model.

If I understand the current implementation correctly, the draft model always chooses the token with the highest probability for creating the draft. But maybe you could get a higher acceptance rate by sampling from the draft and the target model in the exact same way (including the same RNG seed)? @charliexchen did you investigate this?

@ggerganov
Copy link
Owner Author

ggerganov commented Aug 31, 2023

If I understand the current implementation correctly, the draft model always chooses the token with the highest probability for creating the draft.

Yes, that is the case now. There is room for experimentation, although it makes more sense to me to always draft the best token.

Btw, when we add batched inference support, we should be able to implement Staged Speculative Decoding which might give some extra boost. Basically, instead of sampling 1 draft sequence, we sample N and then the target sampling can accept from either one of them. Would be an interesting experiment

@zhisbug
Copy link

zhisbug commented Aug 31, 2023

@ggerganov I think an earlier and better paper from CMU called SpecInfer first studied the idea of using multiple models to speculate and tree-like verification.

They have an implementation in FlexFlow https://github.com/flexflow/FlexFlow/tree/inference

Worth looking at.

If I understand the current implementation correctly, the draft model always chooses the token with the highest probability for creating the draft.

Yes, that is the case now. There is room for experimentation, although it makes more sense to me to always draft the best token.

Btw, when we add batched inference support, we should be able to implement Staged Speculative Decoding which might give some extra boost. Basically, instead of sampling 1 draft sequence, we sample N and then the target sampling can accept from either one of them. Would be an interesting experiment

@charliexchen
Copy link

charliexchen commented Sep 1, 2023

JohannesGaessler The random seed doesn't actually matter. However you should definitely apply the same kind of sampling to both models (temp + top-k) along with the modified rejection scheme from the paper.

(EDIT: To clarify, for both vanilla Speculative Sampling or SpecInfer, there is a stochastic resampling algorithm. This should have a higher acceptance rate than greedily sampling the draft)

@goliaro
Copy link

goliaro commented Sep 1, 2023

@zhisbug Thanks for mentioning our work!

@ggerganov I'm one of the authors of the SpecInfer paper (https://arxiv.org/abs/2305.09781) and a lead contributor of FlexFlow Serve, a distributed framework for LLM inference. I'm really glad to see so much interested in speculative decoding techniques from the community, both in terms new ArXiv paper uploads, and integrations with existing open-source projects.

If you or someone else wants to take a look at how we implemented the key ideas in our paper, this is the file to look at in our repo: request_manager.cc.

Overall, FlexFlow Serve, which is also implemented in C++, is currently 1.3-2.4× faster than existing distributed LLM inference systems and by 2.6-3.5× faster than offloading-based inference frameworks

@kalomaze
Copy link
Contributor

kalomaze commented Sep 1, 2023

Pardon my ignorance, but, you need two whole models to be loaded in for this to work yea? And I assume the second 'draft' model can't be fully outside of VRAM if it's gonna provide decent speed ups...
Really cool work if I'm understanding it correctly either way, though.

@KerfuffleV2
Copy link
Collaborator

you need two whole models to be loaded in for this to work yea?

The point is that one of the models is much smaller than the main model and can be used to avoid running a full generation on the large model. The more tokens you can skip running the big model for, the bigger the speedup.

@JohannesGaessler
Copy link
Collaborator

Unless I'm misunderstanding something you don't actually skip any tokens for the large model. Instead you first write a draft with the small model one token at a time. Then you pass all of those tokens at once to the larger model to validate the draft and use as many tokens from the draft as were correctly predicted.

@KerfuffleV2
Copy link
Collaborator

My explanation wasn't the best but the overall effect is that you don't have to run a full generation from the large model per token, like you would without the speculative sampling. The effects are pretty much the same as skipping running the evaluation of the big model for some of the tokens.

@am-randombit
Copy link

am-randombit commented Sep 3, 2023

May i suggest a wild idea?

How feasible is it, to train the speculative little model with the answers of the large model, ON THE FLY and the weight differences are cached on disk after every use.
Sidequest: How about adjusting inference parameters of the speculative mode ON THE FLY as well. Or perhaps some kind of "training" or optimization of the inference parameters PRIOR TO RUN on a test/init sample ?

Say, you run a 7B float16 (does it fit in 24GB?) on a 4090 and a 70B 8bit on CPU.

Unrelated to the above, question :
Is it possible to estimate, the speculative execution performance boost? On Mac hardware and on PC hardware. On mac there is a big resources conflict, but on PCs, different GPUs or the CPU have separate resources.
In essence, is it possible for a 5-10x speed up on PCs? with 70B+13B model.

I'm really excited for this feature, it will bridge the gap for us 'GPU-poor', and it's something that will set apart this project in performance and capacity for bigger models on the same system versus exllama that's hard capped by GPU VRAM capacities. Speculative execution i think is more important on PC side, where a 3060 12GB might really boost a 34B or even 70B model into usable speeds on the CPU.

Top notch work on llamacpp guys.

@ggerganov ggerganov changed the base branch from master to build-metal-default September 3, 2023 10:29
@ggerganov ggerganov changed the base branch from build-metal-default to master September 3, 2023 10:29
@ggerganov
Copy link
Owner Author

ggerganov commented Sep 3, 2023

@ejones Adding grammar support to this example almost works, but we are missing a way to restore the grammar state to a previous state.

To clarify this, we need 2 grammar contexts - one for the small "draft" model and one for the big "target" model.
Let's say we sample 16 draft tokens using the "draft" grammar context and then we accept let's say 5 of them using the "target" grammar context. We now discard the rest 11 draft tokens and need to sample a new batch of 16 "draft" tokens, but to do that, the "draft" grammar context has to "go back" to the state after sampling the first 5 accepted tokens.

I think it's an easy fix - I can add a struct llama_grammar * llama_grammar_copy(struct llama_grammar * grammar); that makes a deep copy of the grammar and make a copy after each drafted token. Let me know if this sounds good and if there is some easier way to do it

Is this code correct, or am I misunderstanding:

struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
    llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };

    // redirect elements in stacks to point to new rules
    for (size_t is = 0; is < result->stacks.size(); is++) {
        for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
            for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
                for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
                    if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
                        result->stacks[is][ie] = &result->rules[ir0][ir1];
                    }
                }
            }
        }
    }

    return result;
}

@ggerganov ggerganov merged commit 47068e5 into master Sep 3, 2023
27 checks passed
@YangWang92
Copy link

add my results on M2 Ultra GPU

./build-metal/bin/main \
-m /opt/codellama/CodeLlama-34b/ggml-model-f16.bin \
-p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include" \
-e -ngl 120 -t 24 -n 256 -c 4096 -s 8 --top_k 1
image
./build-metal/bin/speculative \
-m /opt/codellama/CodeLlama-34b/ggml-model-f16.bin \
-md /opt/codellama/CodeLlama-7b/ggml-model-q4_1.bin \
-p "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include" \
-e -ngl 120 -t 24 -n 256 -c 4096 -s 8 --top_k 1 --draft 16
image

@JianbangZ
Copy link

@ggerganov i came across this repo https://github.com/FasterDecoding/Medusa, talking about their approaches vs speculative decoding

@sorasoras
Copy link

Quick question:
draft model vocab must closely match target model to use speculation but target vocab size 152064 does not match draft vocab size 151936 - difference 128, max allowed 100
I encounter this when I try to run 13B model with a 1.8B model.
Why is there a limit on difference of vocab size?

@ggerganov
Copy link
Owner Author

This is to prevent using drastically incompatible vocabs - you can increase the limit if you know what you are doing

@cermeng
Copy link

cermeng commented Jan 25, 2024

I'm wondering if this PR supports batched speculative decoding? What if each sequence in a batch has a different length of accepted draft tokens?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.