-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama_sampling_sample with default args is more naively usable #6519
llama_sampling_sample with default args is more naively usable #6519
Conversation
llama.cpp
Outdated
// non-trivial case scan for the last output. | ||
for (int32_t i = ctx->output_ids.size() - 1; i >= 0; --i) { | ||
const int32_t candidate = ctx->output_ids[i]; | ||
if (candidate >= 0) { | ||
return i; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure there's a faster and simpler way to do this. If ctx->n_outputs
is set to the total number of outputs after a batch is processed, the last logits would be available at ctx->logits + (ctx->n_outputs - 1)*ctx->model.hparams.n_vocab
.
(this could be a special case of llama_get_logits_ith
instead, to avoid trying to find the last logits index, which is still non-trivial)
In llama_decode_internal
, the number of outputs is already calculated as n_outputs
, so a simple lctx.n_outputs = n_outputs
near the end of the function (way after the compute graph and its inputs are built, outside of the ubatch
loop) would do the trick.
(BTW, I worked on adding ctx->output_ids
, ctx->n_outputs
, and other related stuff in #6122)
(click to expand) A possible way to implement what I'm trying to explain (diffed from `master` (d4f220a at the time of writing))
diff --git a/common/sampling.h b/common/sampling.h
index 56ed991b..639b819a 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -129,7 +129,7 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
- int idx = 0);
+ int idx = -1);
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_prepare(
diff --git a/llama.cpp b/llama.cpp
index 21772618..adfb55fa 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2177,7 +2177,7 @@ struct llama_context {
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
- int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
+ int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
bool logits_all = false;
@@ -10411,6 +10411,9 @@ static int llama_decode_internal(
n_outputs_prev += lctx.n_outputs;
}
+ // set to total number of outputs in the batch, for use in llama_get_logits_ith
+ lctx.n_outputs = n_outputs;
+
// wait for the computation to finish (automatically done when obtaining the model output)
//llama_synchronize(&lctx);
@@ -15511,23 +15514,29 @@ float * llama_get_logits(struct llama_context * ctx) {
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
+ int32_t j;
+
llama_synchronize(ctx);
try {
if (ctx->logits == nullptr) {
throw std::runtime_error("no logits");
}
- if ((size_t) i >= ctx->output_ids.size()) {
+
+ if (i < 0) {
+ j = ctx->n_outputs + i;
+ } else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ } else {
+ j = ctx->output_ids[i];
}
- const int32_t j = ctx->output_ids[i];
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
- if ((size_t) j >= ctx->output_size) {
+ if (j >= ctx->n_outputs) {
// This should not happen
- throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->logits + j*ctx->model.hparams.n_vocab;
@@ -15547,23 +15556,29 @@ float * llama_get_embeddings(struct llama_context * ctx) {
}
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
+ int32_t j;
+
llama_synchronize(ctx);
try {
if (ctx->embd == nullptr) {
throw std::runtime_error("no embeddings");
}
- if ((size_t) i >= ctx->output_ids.size()) {
+
+ if (i < 0) {
+ j = ctx->n_outputs + i;
+ } else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ } else {
+ j = ctx->output_ids[i];
}
- const int32_t j = ctx->output_ids[i];
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
- if ((size_t) j >= ctx->output_size) {
+ if (j >= ctx->n_outputs) {
// This should not happen
- throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}
return ctx->embd + j*ctx->model.hparams.n_embd;
Note that this also makes llama_get_embeddings_ith
support negative indices, which may or may not be useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems reasonable to me!
I had been worried about changing the public api of llama_get_logits_ith
, but if that is viable, then I concur, it seems guaranteed optimal to do it there, rather than on the index.
I had been hopeful that the worst case was uncommon, as it seemed likely that the last index would be set for wanting logits, and therefore the for loop would, in practice, usually exit on the first loop.
Would you like me to update this PR to conform to the suggested code? I think the suggestion looks great, and even adds a cool piece of functionality of python-style negative indexing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had been worried about changing the public api of
llama_get_logits_ith
Since this seems like a backward-compatible change, there is no need to worry. API changes are quite common in llama.cpp
, there's even a "Recent API changes" section in the README.md
for this purpose.
I had been hopeful that the worst case was uncommon, as it seemed likely that the last index would be set for wanting logits, and therefore the for loop would, in practice, usually exit on the first loop.
You are right, the worst case isn't common (at least when using batch.logits
). Another way would have been to always return ctx->cparams.n_batch - 1
or ctx->output_ids.size() - 1
, and let llama_get_logits_ith
do its asserts to fail in unexpected cases, but this would not work with llama_batch_get_one
.
Thinking a bit more about this, the worst case is actually quite common in the sense that it happens all the time when using llama_batch_get_one
, because the "last valid index" in this case is 0.
Would you like me to update this PR to conform to the suggested code?
Yes. You had the idea of making this, so feel free to make commits in your name.
I hereby allow @TheFlipbook to use my code review suggestion from this comment thread and do absolutely anything with it, including but not limited to committing it (even with modifications) without having to give me attribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the PR with these suggestions. Apologies for the double-push, the first one had a formatting mistake.
Let me know if you'd suggest any further updates! Thanks for all the help!
* Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid. * Implemented as a special case of indexing as suggested by @compilade in ggerganov#6519
4d54281
to
0d574fc
Compare
* Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid. * Implemented as a special case of indexing as suggested by @compilade in ggerganov#6519
0d574fc
to
95bf5f7
Compare
* cited in macOS CI tests * Missed in original updates based on PR feedback in ggerganov#6519
It looks like I'm trying to read the logs, and I see:
If I'm following the chain for that error correctly it's because Perhaps the CI error is a transient error? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The macOS workflow has been failing occasionally recently - last time I looked into this, it seemed the runner can end up randomly without a GPU device for some reason. Anyway, it's not relevant for this change
== Relevant log messages from source repo: commit b73e564b16086845a8b4fffd26e22685d3e0c3db Author: Georgi Gerganov <[email protected]> Date: Mon Apr 8 16:23:01 2024 +0300 quantize : fix precedence of cli args (#6541) commit e3c337d87ca650972105a51c6ce302dd236c07ad Author: Rick G <[email protected]> Date: Mon Apr 8 06:02:30 2024 -0700 llama : support negative ith in llama_get_ API (#6519) * llama_sampling_sample with default args is more naively usable * Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid. * Implemented as a special case of indexing as suggested by @compilade in ggerganov/llama.cpp#6519 * Fixed mismatch type errors * cited in macOS CI tests * Missed in original updates based on PR feedback in ggerganov/llama.cpp#6519 commit beea6e1b16e783a0886e78dec01002a8c00db24d Author: Jan Boon <[email protected]> Date: Mon Apr 8 20:43:30 2024 +0800 llama : save and restore kv cache for single seq id (#6341) * llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
* llama_sampling_sample with default args is more naively usable * Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid. * Implemented as a special case of indexing as suggested by @compilade in ggerganov#6519 * Fixed mismatch type errors * cited in macOS CI tests * Missed in original updates based on PR feedback in ggerganov#6519
* llama_sampling_sample with default args is more naively usable * Batches populated by either llama_batch_get_one or llama_batch_add work with default args * Previously get_one could use the default argument * Previously add should usually have used the last index where logits[idx] == true * This hopefully encourages the use of llama_batch_add * By giving expected results when using default arguments. * Adds "negative indexing" feature to llama_get_logits_ith and llama_get_embeddings_ith * Believed to work with any currently well behaved program * Default arg now works for both cases (previously would give strange results for add case) * Any non-negative number is unaffected and behaves as previously * Negative arguments were previously invalid. * Implemented as a special case of indexing as suggested by @compilade in ggerganov/llama.cpp#6519 * Fixed mismatch type errors * cited in macOS CI tests * Missed in original updates based on PR feedback in ggerganov/llama.cpp#6519
llama_batch_get_one
orllama_batch_add
work with default argsget_one
could use the default argumentadd
should usually have used the last index wherebatch->logits[idx] == true
llama_batch_add