Skip to content

Commit

Permalink
Extend llama_kv_cache_seq_rm to allow matching any sequence (#3843)
Browse files Browse the repository at this point in the history
* Extend llama_kv_cache_seq_rm to allow matichng any sequence

* Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear

Use llama_kv_cache_clear for cache clearing

Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
  • Loading branch information
KerfuffleV2 authored Oct 29, 2023
1 parent 2046eb4 commit 6e08281
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 32 deletions.
2 changes: 1 addition & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par

std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_kv_cache_clear(lctx);
llama_reset_timings(lctx);
}

Expand Down
2 changes: 1 addition & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ int main(int argc, char ** argv) {

const auto t_pp_start = ggml_time_us();

llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
Expand Down
4 changes: 2 additions & 2 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) {

test t(inst, lmodel, ctx);

llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

// warmup run
if (t.n_prompt > 0) {
Expand All @@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) {
}

for (int i = 0; i < params.reps; i++) {
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

uint64_t t_start = get_time_ns();
if (t.n_prompt > 0) {
Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
}

// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
}

LOGLN(
Expand Down
6 changes: 3 additions & 3 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();

// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down Expand Up @@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();

// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
Expand Down Expand Up @@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}

// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);

auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ struct llama_server_context

void kv_cache_clear() {
// clear the entire KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
llama_kv_cache_clear(ctx);
clean_kv_cache = false;
}

Expand Down
29 changes: 15 additions & 14 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
return 0;
}

static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
if (c0 < 0) c0 = 0;
if (c1 < 0) c1 = cache.size;

for (int32_t i = c0; i < c1; ++i) {
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
for (int32_t i = 0; i < cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}

// Searching for a free slot can start here since we know it will be empty.
cache.head = uint32_t(c0);
cache.head = 0;
}

static void llama_kv_cache_seq_rm(
Expand All @@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
cache.cells[i].seq_id.clear();
} else if (cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].seq_id.erase(seq_id);
} else {
continue;
}
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
Expand Down Expand Up @@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head;
}

void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
void llama_kv_cache_clear(struct llama_context * ctx) {
llama_kv_cache_clear(ctx->kv_self);
}

void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
Expand Down Expand Up @@ -9654,7 +9655,7 @@ int llama_eval(
llama_token * tokens,
int32_t n_tokens,
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);

const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
if (ret < 0) {
Expand All @@ -9669,7 +9670,7 @@ int llama_eval_embd(
float * embd,
int32_t n_tokens,
int n_past) {
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);

llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };

Expand Down
15 changes: 6 additions & 9 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,17 +334,14 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future, instead - count the tokens in user code");

// Remove all tokens data of cells in [c0, c1)
// c0 < 0 : [0, c1]
// c1 < 0 : [c0, inf)
LLAMA_API void llama_kv_cache_tokens_rm(
struct llama_context * ctx,
int32_t c0,
int32_t c1);
// Clear the KV cache
LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx);

// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,
Expand Down

0 comments on commit 6e08281

Please sign in to comment.