From 7f551dbfd73e9d36342fa95d88ca1833f89ca2e6 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 19 Aug 2023 18:25:20 +0000 Subject: [PATCH 01/21] new model export: versions 0 (legacy) and 1 --- export.py | 243 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ model.py | 52 ------------ train.py | 3 +- 3 files changed, 245 insertions(+), 53 deletions(-) create mode 100644 export.py diff --git a/export.py b/export.py new file mode 100644 index 00000000..47106490 --- /dev/null +++ b/export.py @@ -0,0 +1,243 @@ +""" +This script has functions and utilties for model export. +Basically, we have a bunch of versions of the model, and we +want to export them to .bin files to be read from and inferenced in C. + +Among the "input" versions of PyTorch files/models: +- Official Llama 2 weights released by Meta +- Huggingface weights available on the hub +- llama2.c (this repo) trained models + +Among the "output" versions of .bin files: +- v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED) +- v1-vN: Improved .bin files with a proper header, cache alignment, etc. + +This script aspires to provide all of these conversions. +""" +import struct +import argparse +import torch +import numpy as np + +from model import ModelArgs, Transformer + +# ----------------------------------------------------------------------------- +# common utilities + +def serialize_fp32(file, tensor): + """ writes one fp32 tensor to file that is open in wb mode """ + d = tensor.detach().cpu().view(-1).numpy().astype(np.float32) + b = struct.pack(f'{len(d)}f', *d) + file.write(b) + +def serialize_int8(file, tensor): + """ writes one int8 tensor to file that is open in wb mode """ + d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) + b = struct.pack(f'{len(d)}b', *d) + file.write(b) + +def quantize_q80(w, group_size): + """ + takes a tensor and returns the Q8_0 quantized version + i.e. symmetric quantization into int8, range [-127,127] + """ + assert w.numel() % group_size == 0 + ori_shape = w.shape + w = w.float() # convert to float32 + w = w.reshape(-1, group_size) + # find the max in each group + wmax = torch.abs(w).max(dim=1).values + # calculate the scaling factor such that float = quant * scale + scale = wmax / 127.0 + # scale into range [-127, 127] + quant = w / scale[:,None] + # round to nearest integer + int8val = torch.round(quant).to(torch.int8) + # dequantize by rescaling + fp32val = (int8val.float() * scale[:,None]).view(-1) + fp32valr = fp32val.reshape(-1, group_size) + # calculate the max error in each group + err = torch.abs(fp32valr - w).max(dim=1).values + # find the max error across all groups + maxerr = err.max().item() + return int8val, scale, maxerr + +# ----------------------------------------------------------------------------- +# legacy + +def legacy_export(model, filepath): + """ Original export of llama2.c bin files, i.e. version v0 """ + out_file = open(filepath, 'wb') + + # first write out the header + hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] + p = model.params + n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads + header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, + n_kv_heads, p.vocab_size, p.max_seq_len) + out_file.write(header) + + # next write out the embedding weights + serialize_fp32(out_file, model.tok_embeddings.weight) + + # now all the layers + # attention weights + for layer in model.layers: + serialize_fp32(out_file, layer.attention_norm.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wq.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wk.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wv.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.attention.wo.weight) + # ffn weights + for layer in model.layers: + serialize_fp32(out_file, layer.ffn_norm.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.feed_forward.w1.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.feed_forward.w2.weight) + for layer in model.layers: + serialize_fp32(out_file, layer.feed_forward.w3.weight) + # final rmsnorm + serialize_fp32(out_file, model.norm.weight) + # note: no need to write final classifier weights due to weight sharing + # freqs_cis + serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) + serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) + + # write to binary file + out_file.close() + print(f"wrote {filepath}") + +# ----------------------------------------------------------------------------- +# new version + +def version1_export(model, filepath, group_size=64): + """ + Export the model weights in Q8_0 into .bin file to be read from C. + That is: + - quantize all weights to symmetric int8, in range [-127, 127] + - all other tensors (the rmsnorm params) are kept and exported in fp32 + - quantization is done in groups of group_size to reduce the effects of any outliers + """ + version = 1 + + # let's first do some validation for this export type + while model.params.dim % group_size != 0: + group_size //= 2 + print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim") + weights = [ + model.tok_embeddings.weight, + *[layer.attention.wq.weight for layer in model.layers], + *[layer.attention.wk.weight for layer in model.layers], + *[layer.attention.wv.weight for layer in model.layers], + *[layer.attention.wo.weight for layer in model.layers], + *[layer.feed_forward.w1.weight for layer in model.layers], + *[layer.feed_forward.w2.weight for layer in model.layers], + *[layer.feed_forward.w3.weight for layer in model.layers], + ] + for w in weights: + assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" + + # write + out_file = open(filepath, 'wb') + # first write out the header. the header will be 256 bytes + nbytes = 0 + # 1) write magic, which will be uint32 of "ak42" in ASCII + out_file.write(struct.pack('I', 0x616b3432)) + nbytes += 4 + # 2) write version, which will be int + out_file.write(struct.pack('i', version)) + nbytes += 4 + # 3) write the params, which will be 7 ints + p = model.params + hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] + n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads + header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, + n_kv_heads, p.vocab_size, p.max_seq_len) + out_file.write(header) + nbytes += 7*4 + # 4) write some other flags + shared_classifier = 1 # we do share a classifier, write flag as a byte + out_file.write(struct.pack('B', shared_classifier)) + nbytes += 1 + out_file.write(struct.pack('i', group_size)) # group size used for quantization + nbytes += 4 + pad = 256 - nbytes # pad the rest with zeros + assert pad >= 0 + out_file.write(b'\0' * pad) + # now that the header is done, let's write out the model + + # first let's write out all the params that we are keeping in fp32: the norms + for layer in model.layers: # attention norms + serialize_fp32(out_file, layer.attention_norm.weight) + for layer in model.layers: # MLP norms + serialize_fp32(out_file, layer.ffn_norm.weight) + serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm + + # now let's write out all the params that we are quantizing to Q8_0 + # note we skip classifier weights, which are shared with the embedding + ew = [] + scales = [] + for i, w in enumerate(weights): + # quantize this weight + q, s, err = quantize_q80(w, group_size) + # save the int8 weights to file + serialize_int8(out_file, q) # save the tensor in int8 + scales.append(s) # we'll do all the scales after all the qs + # logging + ew.append((err, w.shape)) + print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}") + + # save the scaling factors in fp32 here + # this is done to keep all the weights contiquous, making pointer arithmetic easier in C + for s in scales: + serialize_fp32(out_file, s) + + # print the highest error across all weights, should be very small, e.g. O(~0.001) + ew.sort(reverse=True) + print(f"max quantization group error across all weights: {ew[0][0]}") + + # write to binary file + out_file.close() + print(f"wrote {filepath}") + +# ----------------------------------------------------------------------------- +# API entrypoint + +def model_export(model, filepath, version): + if version == 0: + legacy_export(model, filepath) + elif version == 1: + version1_export(model, filepath) + else: + raise ValueError(f"unknown version {version}") + +# ----------------------------------------------------------------------------- +# CLI entrypoint + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("filepath", type=str, help="the output filepath") + parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file") + parser.add_argument("--version", default=0, type=int, help="the version to export with") + args = parser.parse_args() + + # load the provided model checkpoint + checkpoint_dict = torch.load(args.checkpoint, map_location='cpu') + gptconf = ModelArgs(**checkpoint_dict['model_args']) + model = Transformer(gptconf) + state_dict = checkpoint_dict['model'] + unwanted_prefix = '_orig_mod.' + for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + model.load_state_dict(state_dict, strict=False) + model.eval() + + # export + model_export(model, args.filepath, args.version) diff --git a/model.py b/model.py index c8c82a9d..044712f2 100644 --- a/model.py +++ b/model.py @@ -338,55 +338,3 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): idx = torch.cat((idx, idx_next), dim=1) return idx - - def export(self, filepath='model.bin'): - """export the model weights in fp32 into .bin file to be read from C""" - f = open(filepath, 'wb') - - def serialize(t): - d = t.detach().cpu().view(-1).numpy().astype(np.float32) - b = struct.pack(f'{len(d)}f', *d) - f.write(b) - - # first write out the header - hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0] - p = self.params - n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, - n_kv_heads, p.vocab_size, p.max_seq_len) - f.write(header) - - # next write out the embedding weights - serialize(self.tok_embeddings.weight) - - # now all the layers - # attention weights - for layer in self.layers: - serialize(layer.attention_norm.weight) - for layer in self.layers: - serialize(layer.attention.wq.weight) - for layer in self.layers: - serialize(layer.attention.wk.weight) - for layer in self.layers: - serialize(layer.attention.wv.weight) - for layer in self.layers: - serialize(layer.attention.wo.weight) - # ffn weights - for layer in self.layers: - serialize(layer.ffn_norm.weight) - for layer in self.layers: - serialize(layer.feed_forward.w1.weight) - for layer in self.layers: - serialize(layer.feed_forward.w2.weight) - for layer in self.layers: - serialize(layer.feed_forward.w3.weight) - # final rmsnorm - serialize(self.norm.weight) - # note: no need to write final classifier weights due to weight sharing - # freqs_cis - serialize(self.freqs_cos[:p.max_seq_len]) - serialize(self.freqs_sin[:p.max_seq_len]) - - # write to binary file - f.close() - print(f"wrote {filepath}") diff --git a/train.py b/train.py index b1972dc5..e9585383 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from tinystories import Task +from export import model_export # ----------------------------------------------------------------------------- # I/O @@ -287,7 +288,7 @@ def get_lr(it): } print(f"saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt")) - raw_model.export(os.path.join(out_dir, "model.bin")) + model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0) if iter_num == 0 and eval_only: break From 4212bd6d4343ac8a13efaced5609af268e7f4730 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 19 Aug 2023 18:34:49 +0000 Subject: [PATCH 02/21] oops fix double indent on quantize def --- export.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/export.py b/export.py index 47106490..db874b06 100644 --- a/export.py +++ b/export.py @@ -37,30 +37,30 @@ def serialize_int8(file, tensor): file.write(b) def quantize_q80(w, group_size): - """ - takes a tensor and returns the Q8_0 quantized version - i.e. symmetric quantization into int8, range [-127,127] - """ - assert w.numel() % group_size == 0 - ori_shape = w.shape - w = w.float() # convert to float32 - w = w.reshape(-1, group_size) - # find the max in each group - wmax = torch.abs(w).max(dim=1).values - # calculate the scaling factor such that float = quant * scale - scale = wmax / 127.0 - # scale into range [-127, 127] - quant = w / scale[:,None] - # round to nearest integer - int8val = torch.round(quant).to(torch.int8) - # dequantize by rescaling - fp32val = (int8val.float() * scale[:,None]).view(-1) - fp32valr = fp32val.reshape(-1, group_size) - # calculate the max error in each group - err = torch.abs(fp32valr - w).max(dim=1).values - # find the max error across all groups - maxerr = err.max().item() - return int8val, scale, maxerr + """ + takes a tensor and returns the Q8_0 quantized version + i.e. symmetric quantization into int8, range [-127,127] + """ + assert w.numel() % group_size == 0 + ori_shape = w.shape + w = w.float() # convert to float32 + w = w.reshape(-1, group_size) + # find the max in each group + wmax = torch.abs(w).max(dim=1).values + # calculate the scaling factor such that float = quant * scale + scale = wmax / 127.0 + # scale into range [-127, 127] + quant = w / scale[:,None] + # round to nearest integer + int8val = torch.round(quant).to(torch.int8) + # dequantize by rescaling + fp32val = (int8val.float() * scale[:,None]).view(-1) + fp32valr = fp32val.reshape(-1, group_size) + # calculate the max error in each group + err = torch.abs(fp32valr - w).max(dim=1).values + # find the max error across all groups + maxerr = err.max().item() + return int8val, scale, maxerr # ----------------------------------------------------------------------------- # legacy From 4df5e2e939e214855fdccfa838dd419b37347ce6 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 19 Aug 2023 18:51:32 +0000 Subject: [PATCH 03/21] make version 1 be the legacy export but with new header. version 2 will be Q8_0 export --- export.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index db874b06..8d431568 100644 --- a/export.py +++ b/export.py @@ -115,7 +115,60 @@ def legacy_export(model, filepath): # ----------------------------------------------------------------------------- # new version -def version1_export(model, filepath, group_size=64): +def version1_export(model, filepath): + """ + Export the model weights in full float32 .bin file to be read from C. + This is same as legacy_export, but with a proper header. + """ + version = 1 + + out_file = open(filepath, 'wb') + # first write out the header. the header will be 256 bytes + nbytes = 0 + # 1) write magic, which will be uint32 of "ak42" in ASCII + out_file.write(struct.pack('I', 0x616b3432)) + nbytes += 4 + # 2) write version, which will be int + out_file.write(struct.pack('i', version)) + nbytes += 4 + # 3) write the params, which will be 7 ints + p = model.params + hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] + n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads + header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, + n_kv_heads, p.vocab_size, p.max_seq_len) + out_file.write(header) + nbytes += 7*4 + # 4) write some other flags + shared_classifier = 1 # we do share a classifier, write flag as a byte + out_file.write(struct.pack('B', shared_classifier)) + nbytes += 1 + pad = 256 - nbytes # pad the rest with zeros + assert pad >= 0 + out_file.write(b'\0' * pad) + + # now let's write out all the params + weights = [ + *[layer.attention_norm.weight for layer in model.layers], + *[layer.ffn_norm.weight for layer in model.layers], + model.norm.weight, + model.tok_embeddings.weight, + *[layer.attention.wq.weight for layer in model.layers], + *[layer.attention.wk.weight for layer in model.layers], + *[layer.attention.wv.weight for layer in model.layers], + *[layer.attention.wo.weight for layer in model.layers], + *[layer.feed_forward.w1.weight for layer in model.layers], + *[layer.feed_forward.w2.weight for layer in model.layers], + *[layer.feed_forward.w3.weight for layer in model.layers], + ] + for w in weights: + serialize_fp32(out_file, w) + + # write to binary file + out_file.close() + print(f"wrote {filepath}") + +def version2_export(model, filepath, group_size=64): """ Export the model weights in Q8_0 into .bin file to be read from C. That is: @@ -123,7 +176,7 @@ def version1_export(model, filepath, group_size=64): - all other tensors (the rmsnorm params) are kept and exported in fp32 - quantization is done in groups of group_size to reduce the effects of any outliers """ - version = 1 + version = 2 # let's first do some validation for this export type while model.params.dim % group_size != 0: @@ -213,6 +266,8 @@ def model_export(model, filepath, version): legacy_export(model, filepath) elif version == 1: version1_export(model, filepath) + elif version == 2: + version2_export(model, filepath) else: raise ValueError(f"unknown version {version}") From fa8dfd854ebdd911e43bc6238217e343d5754796 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 19 Aug 2023 19:21:12 +0000 Subject: [PATCH 04/21] isolate read_checkpoint, because i'd like to now make it support both version 0 and version 1 --- run.c | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/run.c b/run.c index 10d468b3..59b8b295 100644 --- a/run.c +++ b/run.c @@ -148,6 +148,28 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int s w->wcls = shared_weights ? w->token_embedding_table : ptr; } +void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights, + int* fd, float** data, ssize_t* file_size) { + FILE *file = fopen(checkpoint, "rb"); + if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); } + // read in the config header + if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); } + // negative vocab size is hacky way of signaling unshared weights. bit yikes. + int shared_weights = config->vocab_size > 0 ? 1 : 0; + config->vocab_size = abs(config->vocab_size); + // figure out the file size + fseek(file, 0, SEEK_END); // move file pointer to end of file + *file_size = ftell(file); // get the file size, in bytes + fclose(file); + // memory map the Transformer weights into the data pointer + *fd = open(checkpoint, O_RDONLY); // open in read only mode + if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); } + *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); + if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } + float* weights_ptr = *data + sizeof(Config)/sizeof(float); + checkpoint_init_weights(weights, config, weights_ptr, shared_weights); +} + // ---------------------------------------------------------------------------- // neural net blocks @@ -604,27 +626,9 @@ int main(int argc, char *argv[]) { TransformerWeights weights; int fd = 0; // file descriptor for memory mapping float* data = NULL; // memory mapped data pointer - ssize_t file_size; // size of the checkpoint file in bytes - { - FILE *file = fopen(checkpoint, "rb"); - if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; } - // read in the config header - if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; } - // negative vocab size is hacky way of signaling unshared weights. bit yikes. - int shared_weights = config.vocab_size > 0 ? 1 : 0; - config.vocab_size = abs(config.vocab_size); - // figure out the file size - fseek(file, 0, SEEK_END); // move file pointer to end of file - file_size = ftell(file); // get the file size, in bytes - fclose(file); - // memory map the Transformer weights into the data pointer - fd = open(checkpoint, O_RDONLY); // open in read only mode - if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; } - data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0); - if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; } - float* weights_ptr = data + sizeof(Config)/sizeof(float); - checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights); - } + ssize_t file_size; // size of the checkpoint file in bytes + read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size); + // right now we cannot run for more than config.seq_len steps if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } From f3db92a2dc8fc928f7877f7dac4c4c5d98a9f7ff Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 20 Aug 2023 16:51:35 +0000 Subject: [PATCH 05/21] use out_file.tell() instead of nbytes += arithmetic --- export.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/export.py b/export.py index 8d431568..ffcb5064 100644 --- a/export.py +++ b/export.py @@ -124,13 +124,10 @@ def version1_export(model, filepath): out_file = open(filepath, 'wb') # first write out the header. the header will be 256 bytes - nbytes = 0 # 1) write magic, which will be uint32 of "ak42" in ASCII out_file.write(struct.pack('I', 0x616b3432)) - nbytes += 4 # 2) write version, which will be int out_file.write(struct.pack('i', version)) - nbytes += 4 # 3) write the params, which will be 7 ints p = model.params hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] @@ -138,12 +135,10 @@ def version1_export(model, filepath): header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) - nbytes += 7*4 # 4) write some other flags shared_classifier = 1 # we do share a classifier, write flag as a byte out_file.write(struct.pack('B', shared_classifier)) - nbytes += 1 - pad = 256 - nbytes # pad the rest with zeros + pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 out_file.write(b'\0' * pad) @@ -198,13 +193,10 @@ def version2_export(model, filepath, group_size=64): # write out_file = open(filepath, 'wb') # first write out the header. the header will be 256 bytes - nbytes = 0 # 1) write magic, which will be uint32 of "ak42" in ASCII out_file.write(struct.pack('I', 0x616b3432)) - nbytes += 4 # 2) write version, which will be int out_file.write(struct.pack('i', version)) - nbytes += 4 # 3) write the params, which will be 7 ints p = model.params hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] @@ -212,14 +204,11 @@ def version2_export(model, filepath, group_size=64): header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) - nbytes += 7*4 # 4) write some other flags shared_classifier = 1 # we do share a classifier, write flag as a byte out_file.write(struct.pack('B', shared_classifier)) - nbytes += 1 out_file.write(struct.pack('i', group_size)) # group size used for quantization - nbytes += 4 - pad = 256 - nbytes # pad the rest with zeros + pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 out_file.write(b'\0' * pad) # now that the header is done, let's write out the model From 13dcee493a727d60a8c4053d162ef085c08c2348 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 20 Aug 2023 17:02:22 +0000 Subject: [PATCH 06/21] todos update --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8c36285d..3d7ba217 100644 --- a/README.md +++ b/README.md @@ -308,11 +308,11 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg ## unsorted todos -- make it easier to add a new dataset with not too much pain -- should calculate freq_cis online in the script run.c instead of loading them -- int4/8 quantization -- export the model in a more sensible output format with a proper header, etc. +- delete the export_meta_llama_bin.py and export_meta_llama_hf_bin.py files. instead, import both of these into a proper model.py Transformer instance, and then export using the export script as usual. +- migrate the code to work with the new versions export and deprecate the original .bin files - support Llama 2 7B Chat models and tune run.c to Chat UI/UX +- make it easier to add a new dataset with not too much pain +- int8 quantization - llama2.cu investigate and merge - (LoRA) finetuning and export of Llama 2 models From c0511de61716325b3dadc86b39ce69c12f2d8b22 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 20 Aug 2023 17:18:06 +0000 Subject: [PATCH 07/21] probindex should never have been part of RunState. i apologize for this failure of abstraction --- run.c | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/run.c b/run.c index 59b8b295..614f18b6 100644 --- a/run.c +++ b/run.c @@ -50,11 +50,6 @@ typedef struct { float* wcls; } TransformerWeights; -typedef struct { - float prob; - int index; -} ProbIndex; // struct used when sorting probabilities during top-p sampling - typedef struct { // current wave of activations float *x; // activation at current time stamp (dim,) @@ -67,7 +62,6 @@ typedef struct { float *v; // value (dim,) float *att; // buffer for scores/attention values (n_heads, seq_len) float *logits; // output logits - ProbIndex *probindex; // buffer used in top-p sampling // kv cache float* key_cache; // (layer, seq_len, dim) float* value_cache; // (layer, seq_len, dim) @@ -86,13 +80,12 @@ void malloc_run_state(RunState* s, Config* p) { s->v = calloc(kv_dim, sizeof(float)); s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); s->logits = calloc(p->vocab_size, sizeof(float)); - s->probindex = calloc(p->vocab_size, sizeof(ProbIndex)); s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); // ensure all mallocs went fine if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->k || !s->v || !s->att || !s->logits || !s->key_cache - || !s->value_cache || !s->probindex) { + || !s->value_cache) { fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE); } @@ -109,7 +102,6 @@ void free_run_state(RunState* s) { free(s->v); free(s->att); free(s->logits); - free(s->probindex); free(s->key_cache); free(s->value_cache); } @@ -499,6 +491,11 @@ float random_f32() { // random float32 in [0,1) // ---------------------------------------------------------------------------- // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling +typedef struct { + float prob; + int index; +} ProbIndex; // struct used when sorting probabilities during top-p sampling + int argmax(float* probabilities, int n) { // return the index that has the highest probability int max_i = 0; @@ -654,6 +651,7 @@ int main(int argc, char *argv[]) { // create and init the application RunState RunState state; malloc_run_state(&state, &config); + ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling // process the prompt, if any int *prompt_tokens = NULL; @@ -693,7 +691,7 @@ int main(int argc, char *argv[]) { next = sample(state.logits, config.vocab_size); } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(state.logits, config.vocab_size, topp, state.probindex); + next = sample_topp(state.logits, config.vocab_size, topp, probindex); } } } @@ -734,6 +732,7 @@ int main(int argc, char *argv[]) { // memory and file handles cleanup free_run_state(&state); + free(probindex); for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); } free(vocab); free(vocab_scores); From 1e335a41cfc1b34c37aed6ff5074ae6533dd7084 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 20 Aug 2023 17:26:43 +0000 Subject: [PATCH 08/21] remove freq_cis fields as they are not used anymore --- run.c | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/run.c b/run.c index 614f18b6..b8a1839f 100644 --- a/run.c +++ b/run.c @@ -43,9 +43,6 @@ typedef struct { float* w3; // (layer, hidden_dim, dim) // final rmsnorm float* rms_final_weight; // (dim,) - // freq_cis for RoPE relatively positional embeddings (not used anymore) - float* freq_cis_real; // (seq_len, head_size/2) - float* freq_cis_imag; // (seq_len, head_size/2) // (optional) classifier weights for the logits, on the last layer float* wcls; } TransformerWeights; @@ -133,10 +130,8 @@ void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int s ptr += p->n_layers * p->dim * p->hidden_dim; w->rms_final_weight = ptr; ptr += p->dim; - w->freq_cis_real = ptr; - ptr += p->seq_len * head_size / 2; - w->freq_cis_imag = ptr; - ptr += p->seq_len * head_size / 2; + ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE) + ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE) w->wcls = shared_weights ? w->token_embedding_table : ptr; } From c74456f3f084c73a2865f758e341f4cfe5b54a87 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 20 Aug 2023 18:18:23 +0000 Subject: [PATCH 09/21] refactor step 1. the tokenizer, and all the other abstractions, are a total mess, refactoring things a bit --- run.c | 136 ++++++++++++++++++++++++++++++++++------------------------ 1 file changed, 81 insertions(+), 55 deletions(-) diff --git a/run.c b/run.c index b8a1839f..1c14563c 100644 --- a/run.c +++ b/run.c @@ -341,7 +341,62 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* } // ---------------------------------------------------------------------------- -// byte pair encoding (BPE) tokenizer, encodes strings into tokens so we can prompt +// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens + +typedef struct { + char** vocab; + float* vocab_scores; + int vocab_size; + unsigned int max_token_length; + char byte_piece[2]; +} Tokenizer; + +void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { + // i should have written the vocab_size into the tokenizer file... sigh + t->vocab_size = vocab_size; + // malloc space to hold the scores and the strings + t->vocab = (char**)malloc(vocab_size * sizeof(char*)); + t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); + t->byte_piece[1] = '\0'; // null terminate the byte_piece string + // read in the file + FILE *file = fopen(tokenizer, "rb"); + if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); } + if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + int len; + for (int i = 0; i < vocab_size; i++) { + if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);} + if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + t->vocab[i] = (char *)malloc(len + 1); + if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } + t->vocab[i][len] = '\0'; // add the string terminating token + } + fclose(file); +} + +void free_tokenizer(Tokenizer* t) { + for (int i = 0; i < t->vocab_size; i++) { + free(t->vocab[i]); + } + free(t->vocab); + free(t->vocab_scores); +} + +char* get_piece(Tokenizer* t, int prev_token, int token) { + char *piece = t->vocab[token]; + // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) + if (prev_token == 1 && piece[0] == ' ') { piece++; } + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + unsigned char byte_val; + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { + // ok this token is a raw byte token, careful to only print printable chars or whitespace + // some of the other bytes can be various control codes, backspace, etc. => skip + if (isprint(byte_val) || isspace(byte_val)) { + t->byte_piece[0] = byte_val; + piece = &t->byte_piece[0]; + } + } + return piece; +} typedef struct { char *str; @@ -359,22 +414,23 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { return res != NULL ? res->id : -1; } -void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, unsigned int max_token_length, int *tokens, int *n_tokens) { +void bpe_encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { + // encode the string text (input) into an upper-bound preallocated tokens[] array // sort vocabulary - TokenIndex *sorted_vocab = malloc(vocab_size * sizeof(TokenIndex)); - for (int i = 0; i < vocab_size; i++) { - sorted_vocab[i].str = vocab[i]; + TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); + for (int i = 0; i < t->vocab_size; i++) { + sorted_vocab[i].str = t->vocab[i]; sorted_vocab[i].id = i; } - qsort(sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); // create a temporary buffer that will store merge candidates of always two consecutive tokens - char* str_buffer = malloc((max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) + char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) size_t str_len = 0; // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", sorted_vocab, vocab_size); + tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size); *n_tokens = 1; // the number of tokens // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: @@ -410,7 +466,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u } // ok c+1 is not a continuation byte, so we've read in a full codepoint - int id = str_lookup(str_buffer, sorted_vocab, vocab_size); + int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); if (id != -1) { // we found this codepoint in vocab, add it as a token @@ -434,11 +490,11 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) - sprintf(str_buffer, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, sorted_vocab, vocab_size); - if (id != -1 && vocab_scores[id] > best_score) { + sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); + int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); + if (id != -1 && t->vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position - best_score = vocab_scores[id]; + best_score = t->vocab_scores[id]; best_id = id; best_idx = i; } @@ -587,8 +643,8 @@ void error_usage() { int main(int argc, char *argv[]) { // default inits - char *checkpoint = NULL; // e.g. out/model.bin - char *tokenizer = "tokenizer.bin"; + char *checkpoint_path = NULL; // e.g. out/model.bin + char *tokenizer_path = "tokenizer.bin"; float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower rng_seed = 0; // seed rng with time by default @@ -596,7 +652,7 @@ int main(int argc, char *argv[]) { char *prompt = NULL; // prompt string // poor man's C argparse so we can override the defaults above from the command line - if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); } + if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); } for (int i = 2; i < argc; i+=2) { // do some basic validation if (i + 1 >= argc) { error_usage(); } // must have arg after flag @@ -608,7 +664,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); } else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } - else if (argv[i][1] == 'z') { tokenizer = argv[i + 1]; } + else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } else { error_usage(); } } if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} @@ -619,29 +675,14 @@ int main(int argc, char *argv[]) { int fd = 0; // file descriptor for memory mapping float* data = NULL; // memory mapped data pointer ssize_t file_size; // size of the checkpoint file in bytes - read_checkpoint(checkpoint, &config, &weights, &fd, &data, &file_size); + read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size); // right now we cannot run for more than config.seq_len steps if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } // read in the tokenizer .bin file - char** vocab = (char**)malloc(config.vocab_size * sizeof(char*)); - float* vocab_scores = (float*)malloc(config.vocab_size * sizeof(float)); - unsigned int max_token_length; - { - FILE *file = fopen(tokenizer, "rb"); - if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); return 1; } - if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; } - int len; - for (int i = 0; i < config.vocab_size; i++) { - if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;} - if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; } - vocab[i] = (char *)malloc(len + 1); - if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; } - vocab[i][len] = '\0'; // add the string terminating token - } - fclose(file); - } + Tokenizer tokenizer; + build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size); // create and init the application RunState RunState state; @@ -653,7 +694,7 @@ int main(int argc, char *argv[]) { int num_prompt_tokens = 0; if (prompt != NULL) { prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); - bpe_encode(prompt, vocab, vocab_scores, config.vocab_size, max_token_length, prompt_tokens, &num_prompt_tokens); + bpe_encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens); } // start the main loop @@ -695,22 +736,9 @@ int main(int argc, char *argv[]) { // data-dependent terminating condition: the BOS (1) token delimits sequences if (next == 1) { break; } - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; - // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' - unsigned char byte_val; - if (sscanf(token_str, "<0x%02hhX>", &byte_val) == 1) { - // ok this token is a raw byte token, carefuly to only print printable chars or whitespace - // some of the other bytes can be various control codes, backspace, etc. => skip - if (isprint(byte_val) || isspace(byte_val)) { - char byte_piece[2]; - byte_piece[0] = byte_val; - byte_piece[1] = '\0'; - printf("%s", byte_piece); - } - } else { - printf("%s", token_str); - } + // print the token as string, decode it with the Tokenizer object + char* piece = get_piece(&tokenizer, token, next); + printf("%s", piece); fflush(stdout); token = next; @@ -728,9 +756,7 @@ int main(int argc, char *argv[]) { // memory and file handles cleanup free_run_state(&state); free(probindex); - for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); } - free(vocab); - free(vocab_scores); + free_tokenizer(&tokenizer); if (prompt_tokens != NULL) free(prompt_tokens); if (data != MAP_FAILED) munmap(data, file_size); if (fd != -1) close(fd); From a72b3b0206de4a05b483bc67d3a5149cb5d2fa00 Mon Sep 17 00:00:00 2001 From: Harry Gifford Date: Sun, 20 Aug 2023 15:01:33 -0700 Subject: [PATCH 10/21] Update readme with suggestion on number of threads to use Update the documentation to make suggestions on the number of threads. The performance difference can be very large. Also linked to the PyTorch docs which are relevant here. --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2fd355d1..7d3393b9 100644 --- a/README.md +++ b/README.md @@ -217,7 +217,8 @@ When you run inference make sure to use OpenMP flags to set the number of thread OMP_NUM_THREADS=4 ./run out/model.bin ``` -Depending on your system resources you may want to tweak these hyperparameters and use more threads. But more is not always better, usually this is a bit U shaped. +Depending on your system resources you may want to tweak these hyperparameters and use more threads. But more is not always better, usually this is a bit U shaped. In particular, if your CPU has SMT (multithreading), try setting the number of threads to the number of physical cores rather than logical cores. The performance difference can be large due to cache thrashing and communication overhead. The PyTorch documentation [CPU specific optimizations +](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#cpu-specific-optimizations) has some good information that applies here too. ## platforms From 09db52c69ec7dd485618f899f854b8183eb0e87b Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 02:53:50 +0300 Subject: [PATCH 11/21] Added huggingface model loader to export.py --- export.py | 106 +++++++++++++++++++++++++++++++++++++++-------- model.py | 3 +- requirements.txt | 1 + 3 files changed, 91 insertions(+), 19 deletions(-) diff --git a/export.py b/export.py index ffcb5064..6fff7f5d 100644 --- a/export.py +++ b/export.py @@ -16,8 +16,9 @@ """ import struct import argparse -import torch import numpy as np +import torch +from torch import nn from model import ModelArgs, Transformer @@ -72,6 +73,10 @@ def legacy_export(model, filepath): # first write out the header hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] p = model.params + shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) + # legacy format uses negative/positive vocab size as a shared classifier flag + if not shared_classifier: + p.vocab_size = -p.vocab_size n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, n_kv_heads, p.vocab_size, p.max_seq_len) @@ -103,11 +108,14 @@ def legacy_export(model, filepath): serialize_fp32(out_file, layer.feed_forward.w3.weight) # final rmsnorm serialize_fp32(out_file, model.norm.weight) - # note: no need to write final classifier weights due to weight sharing # freqs_cis serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) + # final classifier weights + if not shared_classifier: + serialize_fp32(out_file, model.output.weight) + # write to binary file out_file.close() print(f"wrote {filepath}") @@ -136,8 +144,8 @@ def version1_export(model, filepath): n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) # 4) write some other flags - shared_classifier = 1 # we do share a classifier, write flag as a byte - out_file.write(struct.pack('B', shared_classifier)) + shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) + out_file.write(struct.pack('B', int(shared_classifier))) pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 out_file.write(b'\0' * pad) @@ -156,6 +164,8 @@ def version1_export(model, filepath): *[layer.feed_forward.w2.weight for layer in model.layers], *[layer.feed_forward.w3.weight for layer in model.layers], ] + if not shared_classifier: + weights.append(model.output.weight) for w in weights: serialize_fp32(out_file, w) @@ -187,6 +197,9 @@ def version2_export(model, filepath, group_size=64): *[layer.feed_forward.w2.weight for layer in model.layers], *[layer.feed_forward.w3.weight for layer in model.layers], ] + shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) + if not shared_classifier: + weights.append(model.output.weight) for w in weights: assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" @@ -205,8 +218,7 @@ def version2_export(model, filepath, group_size=64): n_kv_heads, p.vocab_size, p.max_seq_len) out_file.write(header) # 4) write some other flags - shared_classifier = 1 # we do share a classifier, write flag as a byte - out_file.write(struct.pack('B', shared_classifier)) + out_file.write(struct.pack('B', int(shared_classifier))) out_file.write(struct.pack('i', group_size)) # group size used for quantization pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 @@ -247,6 +259,68 @@ def version2_export(model, filepath, group_size=64): out_file.close() print(f"wrote {filepath}") + +# ----------------------------------------------------------------------------- +# Load / import functions + +def load_checkpoint(checkpoint): + + # load the provided model checkpoint + checkpoint_dict = torch.load(checkpoint, map_location='cpu') + gptconf = ModelArgs(**checkpoint_dict['model_args']) + model = Transformer(gptconf) + state_dict = checkpoint_dict['model'] + unwanted_prefix = '_orig_mod.' + for k,v in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + model.load_state_dict(state_dict, strict=False) + model.eval() + return model + +def load_hf_model(model_path): + + from transformers import AutoModelForCausalLM + + # load HF model + hf_model = AutoModelForCausalLM.from_pretrained(model_path) + hf_dict = hf_model.state_dict() + + # convert LlamaConfig to ModelArgs + config = ModelArgs() + config.dim = hf_model.config.hidden_size + config.n_layers = hf_model.config.num_hidden_layers + config.n_heads = hf_model.config.num_attention_heads + config.n_kv_heads = hf_model.config.num_attention_heads + config.vocab_size = hf_model.config.vocab_size + config.hidden_dim = hf_model.config.intermediate_size + config.norm_eps = hf_model.config.rms_norm_eps + config.max_seq_len = hf_model.config.max_position_embeddings + + # create a new Transformer object and set weights + model = Transformer(config) + + model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight']) + model.norm.weight = nn.Parameter(hf_dict['model.norm.weight']) + + for layer in model.layers: + i = layer.layer_id + layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight']) + layer.attention.wq.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']) + layer.attention.wk.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']) + layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight']) + layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight']) + layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight']) + layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight']) + layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight']) + layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight']) + + # final classifier + model.output.weight = nn.Parameter(hf_dict['lm_head.weight']) + model.eval() + return model + + # ----------------------------------------------------------------------------- # API entrypoint @@ -267,21 +341,17 @@ def model_export(model, filepath, version): parser = argparse.ArgumentParser() parser.add_argument("filepath", type=str, help="the output filepath") - parser.add_argument("--checkpoint", default="", type=str, help="model checkpoint, .pt file") + parser.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") + parser.add_argument("--hf", type=str, help="huggingface model") parser.add_argument("--version", default=0, type=int, help="the version to export with") args = parser.parse_args() - # load the provided model checkpoint - checkpoint_dict = torch.load(args.checkpoint, map_location='cpu') - gptconf = ModelArgs(**checkpoint_dict['model_args']) - model = Transformer(gptconf) - state_dict = checkpoint_dict['model'] - unwanted_prefix = '_orig_mod.' - for k,v in list(state_dict.items()): - if k.startswith(unwanted_prefix): - state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) - model.load_state_dict(state_dict, strict=False) - model.eval() + if args.checkpoint: + model = load_checkpoint(args.checkpoint) + elif args.hf: + model = load_hf_model(args.hf) + else: + parser.error("Input model missing: --checkpoint or --hf is required") # export model_export(model, args.filepath, args.version) diff --git a/model.py b/model.py index 044712f2..09e6aa50 100644 --- a/model.py +++ b/model.py @@ -17,6 +17,7 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = 32000 + hidden_dim: int = (4 * 4096) multiple_of: int = 256 # MLP hidden layer size will be multiple of norm_eps: float = 1e-5 max_seq_len: int = 2048 @@ -186,7 +187,7 @@ def __init__(self, layer_id: int, args: ModelArgs): self.attention = Attention(args) self.feed_forward = FeedForward( dim=args.dim, - hidden_dim=4 * args.dim, + hidden_dim=args.hidden_dim, multiple_of=args.multiple_of, dropout=args.dropout, ) diff --git a/requirements.txt b/requirements.txt index 7187a737..b4054e14 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ sentencepiece==0.1.99 torch==2.0.1 tqdm==4.64.1 wandb==0.15.5 +transformers==4.31.0 From d7704bdeaa7142300edeabf179e6b1e41c637608 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 03:40:34 +0300 Subject: [PATCH 12/21] mark ModelArgs.hidden_dim as optional and calculate as previously if not provided --- model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/model.py b/model.py index 09e6aa50..9e4ce220 100644 --- a/model.py +++ b/model.py @@ -17,7 +17,7 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = 32000 - hidden_dim: int = (4 * 4096) + hidden_dim: Optional[int] = None multiple_of: int = 256 # MLP hidden layer size will be multiple of norm_eps: float = 1e-5 max_seq_len: int = 2048 @@ -167,8 +167,10 @@ def forward( class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + if hidden_dim is None: + hidden_dim = 4 * dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) From 155475a5235a3ad492015f856f57f0f7a61f8686 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 05:16:11 +0300 Subject: [PATCH 13/21] Fix WQ and WK permutation in huggingface models --- export.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index 6fff7f5d..7fb43669 100644 --- a/export.py +++ b/export.py @@ -303,11 +303,15 @@ def load_hf_model(model_path): model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight']) model.norm.weight = nn.Parameter(hf_dict['model.norm.weight']) + # huggingface permutes WQ and WK, this function reverses it + def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim): + return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + for layer in model.layers: i = layer.layer_id layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight']) - layer.attention.wq.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']) - layer.attention.wk.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']) + layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight'])) + layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'])) layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight']) layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight']) layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight']) From 0dd82158f6b058409a6001647f88942d7a89e7b2 Mon Sep 17 00:00:00 2001 From: atamyrat Date: Mon, 21 Aug 2023 06:07:29 +0300 Subject: [PATCH 14/21] removed transformers from requirements.txt, added error message --- export.py | 10 +++++++++- requirements.txt | 1 - 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/export.py b/export.py index 7fb43669..d909c9f5 100644 --- a/export.py +++ b/export.py @@ -280,7 +280,12 @@ def load_checkpoint(checkpoint): def load_hf_model(model_path): - from transformers import AutoModelForCausalLM + try: + from transformers import AutoModelForCausalLM + except ImportError: + print("Error: transformers package is required to load huggingface models") + print("Please run `pip install transformers` to install it") + return None # load HF model hf_model = AutoModelForCausalLM.from_pretrained(model_path) @@ -357,5 +362,8 @@ def model_export(model, filepath, version): else: parser.error("Input model missing: --checkpoint or --hf is required") + if model is None: + parser.error("Can't load input model!") + # export model_export(model, args.filepath, args.version) diff --git a/requirements.txt b/requirements.txt index b4054e14..7187a737 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,3 @@ sentencepiece==0.1.99 torch==2.0.1 tqdm==4.64.1 wandb==0.15.5 -transformers==4.31.0 From ae2e4f8d88366f3c01f66d553311ff23718500ef Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 03:11:54 +0000 Subject: [PATCH 15/21] name the tokenizer methods cleaner: encode and decode --- run.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/run.c b/run.c index 1c14563c..314c7001 100644 --- a/run.c +++ b/run.c @@ -381,7 +381,7 @@ void free_tokenizer(Tokenizer* t) { free(t->vocab_scores); } -char* get_piece(Tokenizer* t, int prev_token, int token) { +char* decode(Tokenizer* t, int prev_token, int token) { char *piece = t->vocab[token]; // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) if (prev_token == 1 && piece[0] == ' ') { piece++; } @@ -414,7 +414,7 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { return res != NULL ? res->id : -1; } -void bpe_encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { +void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { // encode the string text (input) into an upper-bound preallocated tokens[] array // sort vocabulary @@ -694,7 +694,7 @@ int main(int argc, char *argv[]) { int num_prompt_tokens = 0; if (prompt != NULL) { prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); - bpe_encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens); + encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens); } // start the main loop @@ -737,7 +737,7 @@ int main(int argc, char *argv[]) { if (next == 1) { break; } // print the token as string, decode it with the Tokenizer object - char* piece = get_piece(&tokenizer, token, next); + char* piece = decode(&tokenizer, token, next); printf("%s", piece); fflush(stdout); token = next; From 8a377a1d3110875ce3d6fdeda31a86489303b12a Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 03:55:12 +0000 Subject: [PATCH 16/21] refactor the Transformer (Config, Weights, RunState) into a single object, with build and free too --- run.c | 95 +++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 54 insertions(+), 41 deletions(-) diff --git a/run.c b/run.c index 314c7001..12425968 100644 --- a/run.c +++ b/run.c @@ -14,7 +14,7 @@ #include #endif // ---------------------------------------------------------------------------- -// Transformer and RunState structs, and related memory management +// Transformer model typedef struct { int dim; // transformer dimension @@ -64,6 +64,16 @@ typedef struct { float* value_cache; // (layer, seq_len, dim) } RunState; +typedef struct { + Config config; // the hyperparameters of the architecture (the blueprint) + TransformerWeights weights; // the weights of the model + RunState state; // buffers for the "wave" of activations in the forward pass + // some more state needed to properly clean up the memory mapping (sigh) + int fd; // file descriptor for memory mapping + float* data; // memory mapped data pointer + ssize_t file_size; // size of the checkpoint file in bytes +} Transformer; + void malloc_run_state(RunState* s, Config* p) { // we calloc instead of malloc to keep valgrind happy int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; @@ -103,10 +113,7 @@ void free_run_state(RunState* s) { free(s->value_cache); } -// ---------------------------------------------------------------------------- -// initialization: read from checkpoint - -void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { +void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) { int head_size = p->dim / p->n_heads; w->token_embedding_table = ptr; ptr += p->vocab_size * p->dim; @@ -154,11 +161,26 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0); if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); } float* weights_ptr = *data + sizeof(Config)/sizeof(float); - checkpoint_init_weights(weights, config, weights_ptr, shared_weights); + memory_map_weights(weights, config, weights_ptr, shared_weights); +} + +void build_transformer(char* checkpoint_path, Transformer *t) { + // read in the Config and the Weights from the checkpoint + read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); + // allocate the RunState buffers + malloc_run_state(&t->state, &t->config); +} + +void free_transformer(Transformer* t) { + // close the memory mapping + if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); } + if (t->fd != -1) { close(t->fd); } + // free the RunState buffers + free_run_state(&t->state); } // ---------------------------------------------------------------------------- -// neural net blocks +// neural net blocks; the dynamics of the Transformer void rmsnorm(float* o, float* x, float* weight, int size) { // calculate sum of squares @@ -209,9 +231,12 @@ void matmul(float* xout, float* x, float* w, int n, int d) { } } -void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) { +float* forward(Transformer* transformer, int token, int pos) { // a few convenience variables + Config* p = &transformer->config; + TransformerWeights* w = &transformer->weights; + RunState* s = &transformer->state; float *x = s->x; int dim = p->dim; int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads; @@ -338,6 +363,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* // classifier into logits matmul(s->logits, x, w->wcls, p->dim, p->vocab_size); + return s->logits; } // ---------------------------------------------------------------------------- @@ -351,7 +377,7 @@ typedef struct { char byte_piece[2]; } Tokenizer; -void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { +void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) { // i should have written the vocab_size into the tokenizer file... sigh t->vocab_size = vocab_size; // malloc space to hold the scores and the strings @@ -359,8 +385,8 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { t->vocab_scores = (float*)malloc(vocab_size * sizeof(float)); t->byte_piece[1] = '\0'; // null terminate the byte_piece string // read in the file - FILE *file = fopen(tokenizer, "rb"); - if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); } + FILE *file = fopen(tokenizer_path, "rb"); + if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); } if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } int len; for (int i = 0; i < vocab_size; i++) { @@ -374,9 +400,7 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) { } void free_tokenizer(Tokenizer* t) { - for (int i = 0; i < t->vocab_size; i++) { - free(t->vocab[i]); - } + for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } free(t->vocab); free(t->vocab_scores); } @@ -667,28 +691,19 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; } else { error_usage(); } } - if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} - - // read in the model.bin file - Config config; - TransformerWeights weights; - int fd = 0; // file descriptor for memory mapping - float* data = NULL; // memory mapped data pointer - ssize_t file_size; // size of the checkpoint file in bytes - read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size); + if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);} - // right now we cannot run for more than config.seq_len steps - if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; } + // build the Transformer via the model .bin file + Transformer transformer; + build_transformer(checkpoint_path, &transformer); + int vocab_size = transformer.config.vocab_size; // convenience copy - // read in the tokenizer .bin file + // build the Tokenizer via the tokenizer .bin file Tokenizer tokenizer; - build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size); + build_tokenizer(tokenizer_path, &tokenizer, vocab_size); // create and init the application RunState - RunState state; - malloc_run_state(&state, &config); - ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling - + ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling // process the prompt, if any int *prompt_tokens = NULL; int num_prompt_tokens = 0; @@ -705,7 +720,7 @@ int main(int argc, char *argv[]) { while (pos < steps) { // forward the transformer to get logits for the next token - transformer(token, pos, &config, &state, &weights); + float* logits = forward(&transformer, token, pos); // advance the state state machine if(pos < num_prompt_tokens) { @@ -715,19 +730,19 @@ int main(int argc, char *argv[]) { // sample the next token if (temperature == 0.0f) { // greedy argmax sampling: take the token with the highest probability - next = argmax(state.logits, config.vocab_size); + next = argmax(logits, vocab_size); } else { // apply the temperature to the logits - for (int q=0; q= 1) { // simply sample from the predicted probability distribution - next = sample(state.logits, config.vocab_size); + next = sample(logits, vocab_size); } else { // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(state.logits, config.vocab_size, topp, probindex); + next = sample_topp(logits, vocab_size, topp, probindex); } } } @@ -754,11 +769,9 @@ int main(int argc, char *argv[]) { } // memory and file handles cleanup - free_run_state(&state); free(probindex); + if (prompt_tokens != NULL) { free(prompt_tokens); } free_tokenizer(&tokenizer); - if (prompt_tokens != NULL) free(prompt_tokens); - if (data != MAP_FAILED) munmap(data, file_size); - if (fd != -1) close(fd); + free_transformer(&transformer); return 0; } From 3868f732a43aed3290dc855fddea31f6d0e43ec1 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 04:23:02 +0000 Subject: [PATCH 17/21] and finally refactor the Sampler. things are starting to look a lot cleaner I think --- run.c | 114 +++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 70 insertions(+), 44 deletions(-) diff --git a/run.c b/run.c index 12425968..b7b61836 100644 --- a/run.c +++ b/run.c @@ -164,7 +164,7 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh memory_map_weights(weights, config, weights_ptr, shared_weights); } -void build_transformer(char* checkpoint_path, Transformer *t) { +void build_transformer(Transformer *t, char* checkpoint_path) { // read in the Config and the Weights from the checkpoint read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); // allocate the RunState buffers @@ -377,7 +377,7 @@ typedef struct { char byte_piece[2]; } Tokenizer; -void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) { +void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { // i should have written the vocab_size into the tokenizer file... sigh t->vocab_size = vocab_size; // malloc space to hold the scores and the strings @@ -542,15 +542,21 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } // ---------------------------------------------------------------------------- -// utilities: time / rng +// The Sampler, which takes logits and returns a sampled token +// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling -long time_in_ms() { - // return time in milliseconds, for benchmarking the model speed - struct timespec time; - clock_gettime(CLOCK_REALTIME, &time); - return time.tv_sec * 1000 + time.tv_nsec / 1000000; -} +typedef struct { + float prob; + int index; +} ProbIndex; // struct used when sorting probabilities during top-p sampling +typedef struct { + int vocab_size; + ProbIndex* probindex; // buffer used in top-p sampling +} Sampler; + +// rng should technically be a state variable of the Sampler +// leaving it global here for now for convenience, maybe move later unsigned long long rng_seed; unsigned int random_u32() { // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A @@ -563,15 +569,7 @@ float random_f32() { // random float32 in [0,1) return (random_u32() >> 8) / 16777216.0f; } -// ---------------------------------------------------------------------------- -// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling - -typedef struct { - float prob; - int index; -} ProbIndex; // struct used when sorting probabilities during top-p sampling - -int argmax(float* probabilities, int n) { +int sample_argmax(float* probabilities, int n) { // return the index that has the highest probability int max_i = 0; float max_p = probabilities[0]; @@ -584,7 +582,7 @@ int argmax(float* probabilities, int n) { return max_i; } -int sample(float* probabilities, int n) { +int sample_mult(float* probabilities, int n) { // sample index from probabilities (they must sum to 1!) float r = random_f32(); float cdf = 0.0f; @@ -647,6 +645,48 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) { return probindex[last_idx].index; // in case of rounding errors } +void build_sampler(Sampler* sampler, int vocab_size) { + sampler->vocab_size = vocab_size; + // probindex might not be needed, but it's a ~small buffer so we'll just malloc it + sampler->probindex = malloc(vocab_size * sizeof(ProbIndex)); +} + +void free_sampler(Sampler* sampler) { + free(sampler->probindex); +} + +int sample(Sampler* sampler, float* logits, float temperature, float topp) { + // sample the token given the logits and some hyperparameters + int next; + if (temperature == 0.0f) { + // greedy argmax sampling: take the token with the highest probability + next = sample_argmax(logits, sampler->vocab_size); + } else { + // apply the temperature to the logits + for (int q=0; qvocab_size; q++) { logits[q] /= temperature; } + // apply softmax to the logits to get the probabilities for next token + softmax(logits, sampler->vocab_size); + // we sample from this distribution to get the next token + if (topp <= 0 || topp >= 1) { + // simply sample from the predicted probability distribution + next = sample_mult(logits, sampler->vocab_size); + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + next = sample_topp(logits, sampler->vocab_size, topp, sampler->probindex); + } + } + return next; +} + +// ---------------------------------------------------------------------------- +// utilities: time + +long time_in_ms() { + // return time in milliseconds, for benchmarking the model speed + struct timespec time; + clock_gettime(CLOCK_REALTIME, &time); + return time.tv_sec * 1000 + time.tv_nsec / 1000000; +} // ---------------------------------------------------------------------------- // int main @@ -695,16 +735,18 @@ int main(int argc, char *argv[]) { // build the Transformer via the model .bin file Transformer transformer; - build_transformer(checkpoint_path, &transformer); + build_transformer(&transformer, checkpoint_path); int vocab_size = transformer.config.vocab_size; // convenience copy // build the Tokenizer via the tokenizer .bin file Tokenizer tokenizer; - build_tokenizer(tokenizer_path, &tokenizer, vocab_size); + build_tokenizer(&tokenizer, tokenizer_path, vocab_size); + + // build the Sampler + Sampler sampler; + build_sampler(&sampler, vocab_size); - // create and init the application RunState - ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling - // process the prompt, if any + // encode the (string) prompt into tokens sequence, if any is given int *prompt_tokens = NULL; int num_prompt_tokens = 0; if (prompt != NULL) { @@ -723,28 +765,12 @@ int main(int argc, char *argv[]) { float* logits = forward(&transformer, token, pos); // advance the state state machine - if(pos < num_prompt_tokens) { + if (pos < num_prompt_tokens) { // if we are still processing the input prompt, force the next prompt token next = prompt_tokens[pos]; } else { - // sample the next token - if (temperature == 0.0f) { - // greedy argmax sampling: take the token with the highest probability - next = argmax(logits, vocab_size); - } else { - // apply the temperature to the logits - for (int q=0; q= 1) { - // simply sample from the predicted probability distribution - next = sample(logits, vocab_size); - } else { - // top-p (nucleus) sampling, clamping the least likely tokens to zero - next = sample_topp(logits, vocab_size, topp, probindex); - } - } + // otherwise sample the next token from the logits + next = sample(&sampler, logits, temperature, topp); } pos++; @@ -769,8 +795,8 @@ int main(int argc, char *argv[]) { } // memory and file handles cleanup - free(probindex); if (prompt_tokens != NULL) { free(prompt_tokens); } + free_sampler(&sampler); free_tokenizer(&tokenizer); free_transformer(&transformer); return 0; From 14275bd623df8ebb9ef8628df460db624b7940fa Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 04:43:24 +0000 Subject: [PATCH 18/21] minor clean. i think a lot of chaos has been reduced for today. we shall now rest. --- run.c | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/run.c b/run.c index b7b61836..b1a278da 100644 --- a/run.c +++ b/run.c @@ -736,19 +736,18 @@ int main(int argc, char *argv[]) { // build the Transformer via the model .bin file Transformer transformer; build_transformer(&transformer, checkpoint_path); - int vocab_size = transformer.config.vocab_size; // convenience copy // build the Tokenizer via the tokenizer .bin file Tokenizer tokenizer; - build_tokenizer(&tokenizer, tokenizer_path, vocab_size); + build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size); // build the Sampler Sampler sampler; - build_sampler(&sampler, vocab_size); + build_sampler(&sampler, transformer.config.vocab_size); // encode the (string) prompt into tokens sequence, if any is given - int *prompt_tokens = NULL; - int num_prompt_tokens = 0; + int *prompt_tokens = NULL; // the sequence of prompt tokens + int num_prompt_tokens = 0; // the total number of prompt tokens if (prompt != NULL) { prompt_tokens = (int*)malloc((strlen(prompt)+1) * sizeof(int)); encode(&tokenizer, prompt, prompt_tokens, &num_prompt_tokens); From 288b3cec09ef7ed8e7d728c9632cd6bc3d62ae1e Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 04:47:49 +0000 Subject: [PATCH 19/21] remove dagger in the eyeball --- run.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run.c b/run.c index b1a278da..fb11362f 100644 --- a/run.c +++ b/run.c @@ -245,7 +245,7 @@ float* forward(Transformer* transformer, int token, int pos) { int head_size = dim / p->n_heads; // copy the token embedding into x - float* content_row = &(w->token_embedding_table[token * dim]); + float* content_row = w->token_embedding_table + token * dim; memcpy(x, content_row, dim*sizeof(*x)); // forward all the layers From ea44f535682658f3c586719ef52fc985240461fe Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 04:58:19 +0000 Subject: [PATCH 20/21] now that the export.py HF functionality is in master, we can delete this file, and update the readme --- README.md | 6 +- export_meta_llama_hf_bin.py | 113 ------------------------------------ 2 files changed, 5 insertions(+), 114 deletions(-) delete mode 100644 export_meta_llama_hf_bin.py diff --git a/README.md b/README.md index 7d3393b9..ff150054 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ For this we need to install the python dependencies (`pip install -r requirement python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin ``` -The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts, the 13B export currently doesn't work for unknown reasons (accepting PRs for fix). We can run the model as normal: +The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts. I would not attempt to run anything above 7B right now for two reasons: first, 13B+ currently doesn't work because of integer flow in pointer arithmetic, which is yet to be fixed, and second, even if it were fixed, this repo is doing float32 inference right now, so it would be fairly unusably slow. Once the export is done, we can run it: ```bash ./run llama2_7b.bin @@ -83,6 +83,10 @@ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo! +## hugginface models + +We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file. + ## models For the sake of examples of smaller, from-scratch models, I trained a small model series on TinyStories. All of these trained in a few hours on my training setup (4X A100 40GB GPUs). The 110M took around 24 hours. I am hosting them on huggingface hub [tinyllamas](https://huggingface.co/karpathy/tinyllamas), both in the original PyTorch .pt, and also in the llama2.c format .bin: diff --git a/export_meta_llama_hf_bin.py b/export_meta_llama_hf_bin.py deleted file mode 100644 index e3a8c73b..00000000 --- a/export_meta_llama_hf_bin.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -This script exports the Llama 2 weights in llama2c.bin format. -""" -import os -import sys -import struct -from pathlib import Path -import json - -import torch - -from model import precompute_freqs_cis - - -def export(p, state_dict, filepath='model.bin'): - """export the model weights in fp32 into .bin file to be read from C""" - f = open(filepath, 'wb') - - def serialize(key): - print(f"writing {key}...") - t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy() - f.write(memoryview(t)) - del state_dict[key] - - # first write out the header - hidden_dim = state_dict['model.layers.0.mlp.gate_proj.weight'].shape[0] - p['vocab_size'] = 32000 - p['max_seq_len'] = 2048 - - n_kv_heads = p.get('n_kv_heads') or p['n_heads'] - header = struct.pack( - 'iiiiiii', - p['dim'], hidden_dim, p['n_layers'], p['n_heads'], - n_kv_heads, -p['vocab_size'], p['max_seq_len'] - ) - # NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present - # in the checkpoint and should be loaded. - f.write(header) - - # next write out the embedding weights - print("writing tok_embeddings...") - serialize('model.embed_tokens.weight') - - # now all the layers - # attention weights - for i in range(p['n_layers']): serialize(f'model.layers.{i}.input_layernorm.weight') - for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.q_proj.weight') - for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.k_proj.weight') - for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.v_proj.weight') - for i in range(p['n_layers']): serialize(f'model.layers.{i}.self_attn.o_proj.weight') - # ffn weights - for i in range(p['n_layers']): serialize(f'model.layers.{i}.post_attention_layernorm.weight') - for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.gate_proj.weight') - for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.down_proj.weight') - for i in range(p['n_layers']): serialize(f'model.layers.{i}.mlp.up_proj.weight') - - # final rmsnorm - serialize('model.norm.weight') - # freqs_cos, freqs_sin - freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2) - state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']] - state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']] - # check if this requires addtional conversion - serialize('freqs_cos') - serialize('freqs_sin') - - # finally write the output weights - serialize('lm_head.weight') - - f.close() - print(f"wrote {filepath}") - - -def concat_weights(models): - state_dict = {} - for name in list(models[0]): - tensors = [model[name] for model in models] - if len(tensors) == 1 or len(tensors[0].shape) == 1: - state_dict[name] = tensors[0] - continue - is_axis_1 = ( - name.startswith('model.embed_tokens.weight') - or name.endswith('.self_attn.o_proj.weight') - or name.endswith('.mlp.down_proj.weight') - ) - axis = 1 if is_axis_1 else 0 - state_dict[name] = torch.cat(tensors, dim=axis) - for model in models: - del model[name] - return state_dict - - -def load_and_export(model_path, output_path): - params_path = os.path.join(model_path, 'params.json') - with open(params_path) as f: - params = json.load(f) - print(params) - - model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) - models = [torch.load(p, map_location='cpu') for p in model_paths] - state_dict = concat_weights(models) - del models - export(params, state_dict, output_path) - - -if __name__ == '__main__': - if len(sys.argv) == 1: - print('[Llama model folder path] [output path]') - exit() - - model_path = sys.argv[1] - output_path = sys.argv[2] - load_and_export(model_path, output_path) From dd61b13e578ef237c775ad05d8280d2a836c774b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 21 Aug 2023 05:09:06 +0000 Subject: [PATCH 21/21] delete the save_torchscript export file, but copy its content to the new export.py for the future maybe --- export.py | 31 +++++++++++++++++++++ save_torchscript.py | 66 --------------------------------------------- 2 files changed, 31 insertions(+), 66 deletions(-) delete mode 100755 save_torchscript.py diff --git a/export.py b/export.py index d909c9f5..e486a815 100644 --- a/export.py +++ b/export.py @@ -14,6 +14,9 @@ This script aspires to provide all of these conversions. """ +import os +import gzip +import shutil import struct import argparse import numpy as np @@ -343,6 +346,34 @@ def model_export(model, filepath, version): else: raise ValueError(f"unknown version {version}") +def torchscript_export(model, filepath, zero_params=False, gzip_output=False): + """ + (This was submitted via a PR earlier. Leaving it here, but "orphaned" for now) + Saves the model as a TorchScript. + The resulting file can be loaded in C++ code and then used for training or + inference with: + #include + torch::jit::Module module = torch::jit::load("model.pt") + Note that the serialized model includes the initial parameters and with the default + ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute + the model parameters separately you can zero out the parameters before saving it and + it will gzip down to 780K. + """ + + # If requested zero params before saving the model. This is useful in + # conjunction with gzip_output. + if zero_params: + for p in model.parameters(): + p.detach().zero_() + + torch.jit.save(torch.jit.script(model), filepath) + + if gzip_output: + with open(filepath, "rb") as f_in: + with gzip.open(f"{filepath}.gz", "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + os.unlink(filepath) + # ----------------------------------------------------------------------------- # CLI entrypoint diff --git a/save_torchscript.py b/save_torchscript.py deleted file mode 100755 index af3a2995..00000000 --- a/save_torchscript.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python -"""Saves the model as a TorchScript. - -Usage examples: - ./save_torchscript.py - ./save_torchscript.py --dim=300 - ./save_torchscript.py --gzip_output=True --zero_params=True - -The resulting file can be loaded in C++ code and then used for training or -inference with: - #include - torch::jit::Module module = torch::jit::load("model.pt") - -Note that the serialized model includes the initial parameters and with the default -ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute -the model parameters separately you can zero out the parameters before saving it and -it will gzip down to 780K. -""" -import gzip -import os -import shutil -from inspect import signature - -import torch - -from model import ModelArgs, Transformer - -# Model args config -dim = 288 -n_layers = 6 -n_heads = 6 -n_kv_heads = n_heads -multiple_of = 32 -max_seq_len = 256 -dropout = 0.0 -vocab_size = 32000 -norm_eps = 1e-5 -# Save config -model_path = "model.pt" -zero_params = False -gzip_output = False -# Allow config overrides -exec(open("configurator.py").read()) - - -def main() -> None: - model_args = {k: globals()[k] for k in signature(ModelArgs).parameters} - model = Transformer(ModelArgs(**model_args)) - - # If requested zero params before saving the model. This is useful in - # conjunction with gzip_output. - if zero_params: - for p in model.parameters(): - p.detach().zero_() - - torch.jit.save(torch.jit.script(model), model_path) - - if gzip_output: - with open(model_path, "rb") as f_in: - with gzip.open(f"{model_path}.gz", "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - os.unlink(model_path) - - -if __name__ == "__main__": - main()