Skip to content

Commit

Permalink
talk-llama : use llama_decode instead of llama_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 8, 2024
1 parent 8e409d1 commit 2f5a5a6
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions examples/talk-llama/talk-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ int main(int argc, char ** argv) {

prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);

llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1);

// init session
std::string path_session = params.path_session;
std::vector<llama_token> session_tokens;
Expand Down Expand Up @@ -426,8 +428,21 @@ int main(int argc, char ** argv) {
printf("\n");
printf("%s : initializing - please wait ...\n", __func__);

if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
// prepare batch
{
batch.n_tokens = embd_inp.size();

for (int i = 0; i < batch.n_tokens; i++) {
batch.token[i] = embd_inp[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = i == batch.n_tokens - 1;
}
}

if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to decode\n", __func__);
return 1;
}

Expand Down Expand Up @@ -647,8 +662,21 @@ int main(int argc, char ** argv) {
n_session_consumed = session_tokens.size();
}

if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
// prepare batch
{
batch.n_tokens = embd.size();

for (int i = 0; i < batch.n_tokens; i++) {
batch.token[i] = embd[i];
batch.pos[i] = n_past + i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = i == batch.n_tokens - 1;
}
}

if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to decode\n", __func__);
return 1;
}
}
Expand Down

0 comments on commit 2f5a5a6

Please sign in to comment.