Skip to content

Commit

Permalink
run.c - Clean up, and fix for garbled output
Browse files Browse the repository at this point in the history
Increased MAX_BUFFER_SIZE 2048 to MAX_BUFFER_SIZE 4096

This fixes buffer overflow and garbled output when output is larger than buffer in rare cases such as when token sequences are repeated during inference.

Eg:

./run -c stories110M.bin -t 0  -s 0  -p "A big dog" -b 256

(./run stories110M.bin -0  0  "A big dog" 256)
  • Loading branch information
trholding committed Aug 2, 2023
1 parent 341c309 commit a8b8b85
Showing 1 changed file with 1 addition and 30 deletions.
31 changes: 1 addition & 30 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -544,28 +544,6 @@ int main(int argc, char *argv[]) {
printf("Error: checkpoint file (model) not set. \nSet with %s -c <checkpoint_file>\n",argv[0]);
exit(EXIT_FAILURE);
}
/*
if (argc < 2) {
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt] [buffer_tokens]\n", argv[0]);
return 1;
}
if (argc >= 2) {
checkpoint = argv[1];
}
if (argc >= 3) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
temperature = atof(argv[2]);
}
if (argc >= 4) {
steps = atoi(argv[3]);
}
if (argc >= 5) {
prompt = argv[4];
}
if (argc >= 6) {
buffertokens = atoi(argv[5]);
}
*/
#endif

// seed rng with time. if you want deterministic behavior use temperature 0.0
Expand Down Expand Up @@ -642,9 +620,8 @@ int main(int argc, char *argv[]) {
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
int bufferflush = 1; // buffer flush after token counter
#define MAX_BUFFER_SIZE 2048
#define MAX_BUFFER_SIZE 4096 // max buffer size
char outbuff[MAX_BUFFER_SIZE]=""; // used for output buffering
outbuff[0]='\0';
memset( outbuff, '\0', sizeof( outbuff )); // clear buffer area
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
setvbuf(stdout, outbuff, _IOFBF, MAX_BUFFER_SIZE); // setup output buffering
Expand Down Expand Up @@ -673,12 +650,6 @@ int main(int argc, char *argv[]) {

// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
/*char* token_str;
if (token == 1 && vocab[next][0] == ' ') {
token_str = vocab[next] + 1;
} else {
token_str = vocab[next];
}*/

printf("%s", token_str);
if (bufferflush==pos && strlen(outbuff)<=MAX_BUFFER_SIZE) { fflush(stdout); bufferflush+=buffertokens; } // flush after every n tokens
Expand Down

0 comments on commit a8b8b85

Please sign in to comment.