Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for different data types (float16, float32) #93

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ This still runs at interactive rates and samples more coherent and diverse stori
As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format. For this we use the `export_meta_llama_bin.py` file, e.g. for 7B model:

```bash
python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin
python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin float32
```

The export will take ~10 minutes or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts, the 13B export currently doesn't work for unknown reaons (accepting PRs for fix). We can run the model as normal:
The export will take ~1 minute or so and generate a 26GB file (the weights of the 7B model in float32) called `llama2_7b.bin` in the current directory. It has been [reported](https://github.com/karpathy/llama2.c/pull/85) that despite efforts, the 13B export currently doesn't work for unknown reaons (accepting PRs for fix). We can run the model as normal:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reaons -> reasons


```bash
./run llama2_7b.bin
Expand Down Expand Up @@ -120,6 +120,21 @@ $ pytest

Currently you will need two files to test or sample: the [model.bin](https://drive.google.com/file/d/1aTimLdx3JktDXxcHySNrZJOOk8Vb1qBR/view?usp=share_link) file and the [model.ckpt](https://drive.google.com/file/d/1SM0rMxzy7babB-v4MfTg1GFqOCgWar5w/view?usp=share_link) file from PyTorch training I ran earlier. I have to think through running the tests without having to download 200MB of data.

## data types

There are different data types models can be stored in, for example: float16, float32.

```bash
gcc -O3 -o run run.c -lm -DDTYPE=float16 # float 16
gcc -O3 -o run run.c -lm -DDTYPE=float # float 32
```

In order to run float16 version the model has to be exported to float16.

```
python export_meta_llama_bin.py path/to/llama/model/7B llama2_7b.bin float16
```

## performance

*(NOTE: this guide is not great because I personally spend a lot of my time in Python land and don't have an amazing understanding of a lot of these features and flags. If someone does and is willing to help document and briefly describe some of these and their tradeoffs, I'd welcome a PR)*
Expand Down
18 changes: 10 additions & 8 deletions export_meta_llama_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,27 @@
from model import precompute_freqs_cis


def export(p, state_dict, filepath='model.bin'):
def export(p, state_dict, filepath, dtype):
"""export the model weights in fp32 into .bin file to be read from C"""
f = open(filepath, 'wb')

def serialize(key):
print(f"writing {key}...")
t = state_dict[key].contiguous().view(-1).type(torch.float32).numpy()
t = state_dict[key].contiguous().view(-1).type(dtype).numpy()
f.write(memoryview(t))
del state_dict[key]

# first write out the header
hidden_dim = state_dict['layers.0.feed_forward.w1.weight'].shape[0]
p['vocab_size'] = 32000
p['max_seq_len'] = 2048
p['dtype'] = dtype == torch.float32

n_kv_heads = p.get('n_kv_heads') or p['n_heads']
header = struct.pack(
'iiiiiii',
'iiiiiiii',
p['dim'], hidden_dim, p['n_layers'], p['n_heads'],
n_kv_heads, -p['vocab_size'], p['max_seq_len']
n_kv_heads, -p['vocab_size'], p['max_seq_len'], p['dtype']
)
# NOTE ABOVE: -ve vocab_size is indicating that the classifier weights are present
# in the checkpoint and should be loaded.
Expand Down Expand Up @@ -88,7 +89,7 @@ def concat_weights(models):
return state_dict


def load_and_export(model_path, output_path):
def load_and_export(model_path, output_path, dtype):
with open(model_path + 'params.json') as f:
params = json.load(f)
print(params)
Expand All @@ -101,14 +102,15 @@ def load_and_export(model_path, output_path):

state_dict = concat_weights(models)
del models
export(params, state_dict, output_path)
export(params, state_dict, output_path, dtype)


if __name__ == '__main__':
if len(sys.argv) == 1:
print('[Llama model folder path] [output path]')
print('[Llama model folder path] [output path] [dtype (float16, float32)]')
exit()

model_path = sys.argv[1]
output_path = sys.argv[2]
load_and_export(model_path, output_path)
dtype = torch.float16 if sys.argv[3] == "float16" else torch.float32
load_and_export(model_path, output_path, dtype)
100 changes: 63 additions & 37 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Then run with:
#include <fcntl.h>
#include <sys/mman.h>

#ifndef DTYPE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to typedef dtype rather than spreading DTYPE macro all over.

Suggested change
#ifndef DTYPE
#ifdef DTYPE
typedef DTYPE dtype;
else
typedef float dtype;
#else

#define DTYPE float
#endif

// ----------------------------------------------------------------------------
// Transformer and RunState structs, and related memory management

Expand All @@ -28,30 +32,31 @@ typedef struct {
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
int dtype; // 0: float16, 1:float32
} Config;

typedef struct {
// token embedding table
float* token_embedding_table; // (vocab_size, dim)
DTYPE* token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
DTYPE* rms_att_weight; // (layer, dim) rmsnorm weights
DTYPE* rms_ffn_weight; // (layer, dim)
// weights for matmuls
float* wq; // (layer, dim, dim)
float* wk; // (layer, dim, dim)
float* wv; // (layer, dim, dim)
float* wo; // (layer, dim, dim)
DTYPE* wq; // (layer, dim, dim)
DTYPE* wk; // (layer, dim, dim)
DTYPE* wv; // (layer, dim, dim)
DTYPE* wo; // (layer, dim, dim)
// weights for ffn
float* w1; // (layer, hidden_dim, dim)
float* w2; // (layer, dim, hidden_dim)
float* w3; // (layer, hidden_dim, dim)
DTYPE* w1; // (layer, hidden_dim, dim)
DTYPE* w2; // (layer, dim, hidden_dim)
DTYPE* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
DTYPE* rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
float* freq_cis_real; // (seq_len, dim/2)
float* freq_cis_imag; // (seq_len, dim/2)
DTYPE* freq_cis_real; // (seq_len, dim/2)
DTYPE* freq_cis_imag; // (seq_len, dim/2)
// (optional) classifier weights for the logits, on the last layer
float* wcls;
DTYPE* wcls;
} TransformerWeights;

typedef struct {
Expand Down Expand Up @@ -86,8 +91,8 @@ void malloc_run_state(RunState* s, Config* p) {
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
printf("malloc failed!\n");
exit(1);
Expand All @@ -112,8 +117,8 @@ void free_run_state(RunState* s) {
// ----------------------------------------------------------------------------
// initialization: read from checkpoint

void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
float* ptr = f;
void checkpoint_init_weights(TransformerWeights *w, Config* p, DTYPE* f, int shared_weights) {
DTYPE* ptr = f;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
Expand Down Expand Up @@ -153,7 +158,7 @@ void accum(float *a, float *b, int size) {
}
}

void rmsnorm(float* o, float* x, float* weight, int size) {
void rmsnorm(float* o, float* x, DTYPE* weight, int size) {
// calculate sum of squares
float ss = 0.0f;
for (int j = 0; j < size; j++) {
Expand Down Expand Up @@ -188,7 +193,7 @@ void softmax(float* x, int size) {
}
}

void matmul(float* xout, float* x, float* w, int n, int d) {
void matmul(float* xout, float* x, DTYPE* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
#pragma omp parallel for
for (int i = 0; i < d; i++) {
Expand All @@ -201,24 +206,26 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
}

void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {

// a few convenience variables
float *x = s->x;
int dim = p->dim;
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;

// copy the token embedding into x
float* content_row = &(w->token_embedding_table[token * dim]);
memcpy(x, content_row, dim*sizeof(*x));
DTYPE* content_row = &(w->token_embedding_table[token * dim]);
for (int i = 0; i < dim; i++) {
x[i] = content_row[i];
}

// pluck out the "pos" row of freq_cis_real and freq_cis_imag
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;
DTYPE* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
DTYPE* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;

// forward all the layers
for(int l = 0; l < p->n_layers; l++) {

// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);

Expand Down Expand Up @@ -253,7 +260,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
float* value_cache_row = s->value_cache + loff + pos * dim;
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));

// multihead attention. iterate over all heads
#pragma omp parallel for
for (int h = 0; h < p->n_heads; h++) {
Expand All @@ -277,7 +284,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*

// softmax the scores to get attention weights, from 0..pos inclusively
softmax(att, pos + 1);

// weighted sum of the values, store back into xb
for (int i = 0; i < head_size; i++) {
float val = 0.0f;
Expand All @@ -301,12 +308,12 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// first calculate self.w1(x) and self.w3(x)
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);

// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for (int i = 0; i < hidden_dim; i++) {
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i])));
}

// elementwise multiply with w3(x)
for (int i = 0; i < hidden_dim; i++) {
s->hb[i] = s->hb[i] * s->hb2[i];
Expand All @@ -318,7 +325,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// residual connection
accum(x, s->xb, dim);
}

// final rmsnorm
rmsnorm(x, x, w->rms_final_weight, dim);

Expand Down Expand Up @@ -388,13 +395,13 @@ int main(int argc, char *argv[]) {
}

// seed rng with time. if you want deterministic behavior use temperature 0.0
srand((unsigned int)time(NULL));
srand((unsigned int)time(NULL));

// read in the model.bin file
Config config;
TransformerWeights weights;
int fd = 0;
float* data = NULL;
DTYPE* data = NULL;
long file_size;
{
FILE *file = fopen(checkpoint, "rb");
Expand All @@ -416,8 +423,27 @@ int main(int argc, char *argv[]) {
if (fd == -1) { printf("open failed!\n"); return 1; }
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
float* weights_ptr = data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
enum dtype {
float16 = 0,
float32 = 1
};
switch (config.dtype) {
default:
printf("dtype not supported!\n");
return 1;

case float16:
if (sizeof(DTYPE) != sizeof(_Float16)) { printf("dtype doesn't match!\n"); return 1; }
DTYPE* weights_ptr_float16 = data + sizeof(Config)/sizeof(DTYPE);
checkpoint_init_weights(&weights, &config, weights_ptr_float16, shared_weights);
break;

case float32:
if (sizeof(DTYPE) != sizeof(float)) { printf("dtype doesn't match!\n"); return 1; }
DTYPE* weights_ptr_float32 = data + sizeof(Config)/sizeof(DTYPE);
checkpoint_init_weights(&weights, &config, weights_ptr_float32, shared_weights);
break;
}
}
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
Expand All @@ -444,7 +470,7 @@ int main(int argc, char *argv[]) {
// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);

// the current position we are in
long start = time_in_ms();
int next;
Expand Down