Skip to content

Commit

Permalink
main : option to disable context shift (#9484)
Browse files Browse the repository at this point in the history
* added cli arg to disable context shift

* reverted precommit

* updated README.md for main

* white space

* allow disabling context shift in the server

* Update common/arg.cpp

no-context-shift only works for main example

Co-authored-by: Georgi Gerganov <[email protected]>

* added server example to --no-context-shift args

* removed server changes

* white space

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
VJHack and ggerganov authored Sep 16, 2024
1 parent c4965a6 commit 441b72b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
8 changes: 7 additions & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.n_keep = value;
}
));
add_opt(llama_arg(
{"--no-context-shift"},
format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](gpt_params & params) {
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--chunks"}, "N",
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
Expand Down Expand Up @@ -1985,4 +1992,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,

return ctx_arg;
}

1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ struct gpt_params {
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
Expand Down
2 changes: 2 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite

If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.

The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.

It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.

### Temperature
Expand Down
34 changes: 20 additions & 14 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,29 +559,35 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches

if (n_past + (int) embd.size() >= n_ctx) {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
}
} else {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}

const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;

LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);

n_past -= n_discard;
n_past -= n_discard;

LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("after swap: n_past = %d\n", n_past);

LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());

LOG_DBG("clear session path\n");
path_session.clear();
LOG_DBG("clear session path\n");
path_session.clear();
}
}
} else {
// context extension via Self-Extend
Expand Down

0 comments on commit 441b72b

Please sign in to comment.