From a8b8b8543db408763604b2018f7af2789d8a1fce Mon Sep 17 00:00:00 2001 From: Vulcan <93451215+trholding@users.noreply.github.com> Date: Wed, 2 Aug 2023 22:03:30 +0530 Subject: [PATCH] run.c - Clean up, and fix for garbled output Increased MAX_BUFFER_SIZE 2048 to MAX_BUFFER_SIZE 4096 This fixes buffer overflow and garbled output when output is larger than buffer in rare cases such as when token sequences are repeated during inference. Eg: ./run -c stories110M.bin -t 0 -s 0 -p "A big dog" -b 256 (./run stories110M.bin -0 0 "A big dog" 256) --- run.c | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/run.c b/run.c index 985ef99d..af3d9172 100644 --- a/run.c +++ b/run.c @@ -544,28 +544,6 @@ int main(int argc, char *argv[]) { printf("Error: checkpoint file (model) not set. \nSet with %s -c \n",argv[0]); exit(EXIT_FAILURE); } - /* - if (argc < 2) { - printf("Usage: %s [temperature] [steps] [prompt] [buffer_tokens]\n", argv[0]); - return 1; - } - if (argc >= 2) { - checkpoint = argv[1]; - } - if (argc >= 3) { - // optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline - temperature = atof(argv[2]); - } - if (argc >= 4) { - steps = atoi(argv[3]); - } - if (argc >= 5) { - prompt = argv[4]; - } - if (argc >= 6) { - buffertokens = atoi(argv[5]); - } - */ #endif // seed rng with time. if you want deterministic behavior use temperature 0.0 @@ -642,9 +620,8 @@ int main(int argc, char *argv[]) { int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer int pos = 0; // position in the sequence int bufferflush = 1; // buffer flush after token counter - #define MAX_BUFFER_SIZE 2048 + #define MAX_BUFFER_SIZE 4096 // max buffer size char outbuff[MAX_BUFFER_SIZE]=""; // used for output buffering - outbuff[0]='\0'; memset( outbuff, '\0', sizeof( outbuff )); // clear buffer area printf("\n"); // explicit print the initial BOS token for stylistic symmetry reasons setvbuf(stdout, outbuff, _IOFBF, MAX_BUFFER_SIZE); // setup output buffering @@ -673,12 +650,6 @@ int main(int argc, char *argv[]) { // following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89) char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - /*char* token_str; - if (token == 1 && vocab[next][0] == ' ') { - token_str = vocab[next] + 1; - } else { - token_str = vocab[next]; - }*/ printf("%s", token_str); if (bufferflush==pos && strlen(outbuff)<=MAX_BUFFER_SIZE) { fflush(stdout); bufferflush+=buffertokens; } // flush after every n tokens