diff --git a/README.md b/README.md index 197bc086..2b4755be 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin You can also prompt the model with a prefix: ```bash -./run stories42M.bin -t 1.0 -s 256 -p "A big dog" +./run stories42M.bin -t 0.8 -n 256 -i "A big dog" ``` When prompting, the temperature and steps parameters are needed since we use simple positional arguments. @@ -106,6 +106,7 @@ The original author trained a series of small models on TinyStories, which took The upstream project owner trained the llama2.c storyteller models on a 4X A100 40GB box provided by [Lambda labs](https://lambdalabs.com/service/gpu-cloud). +Quick note on sampling, the recommendation for good results is to use `-t 1.0 -p 0.9`, i.e. top-p sampling at 0.9 with temperature 1.0 (this is the default). To control the diversity of samples use either the temperature (i.e. vary `-t` between 0 and 1 and keep top-p off with `-p 0`) or the top-p value (i.e. vary `-p` between 0 and 1 and keep `-t 1`), but not both. Nice explainers on LLM sampling strategies include [this](https://peterchng.com/blog/2023/05/02/token-selection-strategies-top-k-top-p-and-temperature/), [this](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p) or [this](https://huggingface.co/blog/how-to-generate). ```bash ./run llama2_7b.bin @@ -128,7 +129,7 @@ Where -t is the *optional* temperature in the range `0.0` to `1.0` and a default of **0.9**.\ `0` makes outputs with same or no prompt reproducible. - -s is the *optional* number of steps in the range `1` to `256` and a default of **256**.\ + -n is the *optional* number of steps in the range `1` to `256` and a default of **256**.\ `0` sets it to context length of model.\ This option defines the number of tokens to infer and output. @@ -136,7 +137,7 @@ Where `0` sets it to context length of model.\ This increases the interactive performance. Use values such as `4`, `8`, `16`, `32`, `64`, `128` ... YMMV! - -p is the *optional* prompt such as `"A car"` to pass on to guide inference.\ + -i is the *optional* input prompt such as `"A car"` to pass on to guide inference.\ If omitted the model will infer on its own. **Minimal Usage** diff --git a/run.c b/run.c index 4958927b..2986cfad 100644 --- a/run.c +++ b/run.c @@ -123,6 +123,11 @@ typedef struct { float* wcls; } TransformerWeights; +typedef struct { + float prob; + int index; +} ProbIndex; // struct used when sorting probabilities during top-p sampling + typedef struct { // current wave of activations float *x; // activation at current time stamp (dim,) @@ -135,6 +140,7 @@ typedef struct { float *v; // value (dim,) float *att; // buffer for scores/attention values (n_heads, seq_len) float *logits; // output logits + ProbIndex *probindex; // buffer used in top-p sampling // kv cache float* key_cache; // (layer, seq_len, dim) float* value_cache; // (layer, seq_len, dim) @@ -152,12 +158,13 @@ void malloc_run_state(RunState* s, Config* p) { s->v = calloc(p->dim, sizeof(float)); s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); s->logits = calloc(p->vocab_size, sizeof(float)); + s->probindex = calloc(p->vocab_size, sizeof(ProbIndex)); s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float)); // ensure all mallocs went fine if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->k || !s->v || !s->att || !s->logits || !s->key_cache - || !s->value_cache) { + || !s->value_cache || !s->probindex) { printf("malloc failed!\n"); exit(EXIT_FAILURE); } @@ -174,6 +181,7 @@ void free_run_state(RunState* s) { free(s->v); free(s->att); free(s->logits); + free(s->probindex); free(s->key_cache); free(s->value_cache); } @@ -476,7 +484,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } // ---------------------------------------------------------------------------- -// utilities +// utilities: time / rng long time_in_ms() { // return time in milliseconds, for benchmarking the model speed @@ -497,8 +505,24 @@ float random_f32() { // random float32 in [0,1) return (random_u32() >> 8) / 16777216.0f; } +// ---------------------------------------------------------------------------- +// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling + +int argmax(float* probabilities, int n) { + // return the index that has the highest probability + int max_i = 0; + float max_p = probabilities[0]; + for (int i = 1; i < n; i++) { + if (probabilities[i] > max_p) { + max_i = i; + max_p = probabilities[i]; + } + } + return max_i; +} + int sample(float* probabilities, int n) { - // sample index from probabilities, they must sum to 1 + // sample index from probabilities (they must sum to 1!) float r = random_f32(); float cdf = 0.0f; for (int i = 0; i < n; i++) { @@ -510,25 +534,59 @@ int sample(float* probabilities, int n) { return n - 1; // in case of rounding errors } -int argmax(float* v, int n) { - // return argmax of v in elements 0..n - int max_i = 0; - float max_p = v[0]; - for (int i = 1; i < n; i++) { - if (v[i] > max_p) { - max_i = i; - max_p = v[i]; +int compare(const void* a, const void* b) { + ProbIndex* a_ = (ProbIndex*) a; + ProbIndex* b_ = (ProbIndex*) b; + if (a_->prob > b_->prob) return -1; + if (a_->prob < b_->prob) return 1; + return 0; +} + +int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { + // top-p sampling (or "nucleus sampling") samples from the smallest set of + // tokens that exceed probability topp. This way we never sample tokens that + // have very low probabilities and are less likely to go "off the rails". + + // quicksort indices in descending order of probabilities + for (int i = 0; i < n; i++) { + probindex[i].index = i; + probindex[i].prob = probabilities[i]; + } + qsort(probindex, n, sizeof(ProbIndex), compare); + + // truncate the list where cumulative probability exceeds topp + float cumulative_prob = 0.0f; + int last_idx = 0; + for (int i = 0; i < n; i++) { + cumulative_prob += probindex[i].prob; + if (cumulative_prob > topp) { + last_idx = i; + break; // we've exceeded topp by including last_idx } } - return max_i; + + // sample from the truncated list + float r = random_f32() * cumulative_prob; + float cdf = 0.0f; + for (int i = 0; i <= last_idx; i++) { + cdf += probindex[i].prob; + if (r < cdf) { + return probindex[i].index; + } + } + return probindex[last_idx].index; // in case of rounding errors } + + // ---------------------------------------------------------------------------- +// int main int main(int argc, char *argv[]) { // default inits char *checkpoint = NULL; // e.g. out/model.bin - float temperature = 0.9f; // 0.0 = greedy & deterministic, 1.0 = max uncertainty + float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher + float topp = 0.9f; // top-p in nucleus sampling rng_seed = (unsigned int)time(NULL); // seed rng with time by default int steps = 256; // number of steps to run for char *prompt = NULL; // prompt string @@ -552,21 +610,22 @@ int main(int argc, char *argv[]) { } if (argc >= 2) { checkpoint = argv[1]; } for (int i = 2; i < argc; i++) { - // do some basic validation + // do some basic validation - add rng_seed and other checks switch (argv[i][0]) { case '-': switch (argv[i][1]) { // optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline case 't': if (i + 1 < argc) { temperature = atof(argv[++i]); } break; + case 'p': if (i + 1 < argc) { topp = atof(argv[++i]); } break; case 's': if (i + 1 < argc) { rng_seed = atoi(argv[++i]); } break; case 'n': if (i + 1 < argc) { steps = atoi(argv[++i]); } break; case 'b': if (i + 1 < argc) { buffertokens = atoi(argv[++i]); } break; - case 'p': if (i + 1 < argc) { prompt = argv[++i]; } break; + case 'i': if (i + 1 < argc) { prompt = argv[++i]; } break; default: printf("Invalid option: %s\n", argv[i]); exit(EXIT_FAILURE); } break; default: - printf("Usage: %s -t [temperature] -s [seed] -n [steps] -b [buffertokens] -p [prompt] \n", argv[0]); + printf("Usage: %s -t [temperature] -p [top-p] -s [seed] -n [steps] -b [buffertokens] -p [prompt] \n", argv[0]); exit(EXIT_FAILURE); } } @@ -670,7 +729,13 @@ int main(int argc, char *argv[]) { // apply softmax to the logits to get the probabilities for next token softmax(state.logits, config.vocab_size); // we sample from this distribution to get the next token - next = sample(state.logits, config.vocab_size); + if (topp <= 0) { + // simply sample from the predicted probability distribution + next = sample(state.logits, config.vocab_size); + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + next = sample_topp(state.logits, config.vocab_size, topp, state.probindex); + } } } diff --git a/train.py b/train.py index 811dd8ac..dbf0b240 100644 --- a/train.py +++ b/train.py @@ -212,7 +212,7 @@ def estimate_loss(): X, Y = next(batch_iter) with ctx: logits = model(X, Y) - loss = model.last_loss + loss = raw_model.last_loss losses[k] = loss.item() out[split] = losses.mean() model.train() @@ -296,7 +296,7 @@ def get_lr(it): model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 with ctx: logits = model(X, Y) - loss = model.last_loss + loss = raw_model.last_loss loss = loss / gradient_accumulation_steps # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = next(train_batch_iter)