Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
trholding committed Aug 5, 2023
2 parents 44f2e96 + d1a59a9 commit f0e5f61
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 45 deletions.
22 changes: 14 additions & 8 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
return freqs_cis.view(shape)

def apply_rotary_emb(
xq: torch.Tensor,
Expand All @@ -59,8 +59,8 @@ def apply_rotary_emb(
) -> Tuple[torch.Tensor, torch.Tensor]:

# reshape xq and xk to match the complex representation
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

# reshape freqs_cos and freqs_sin for broadcasting
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
Expand Down Expand Up @@ -142,10 +142,11 @@ def forward(

# flash implementation
if self.flash:
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
else:
# manual implementation
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
assert hasattr(self, 'mask')
scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
Expand Down Expand Up @@ -198,6 +199,8 @@ def forward(self, x, freqs_cos, freqs_sin):


class Transformer(nn.Module):
last_loss: Optional[torch.Tensor]

def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
Expand Down Expand Up @@ -227,6 +230,9 @@ def __init__(self, params: ModelArgs):
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))

# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
self.last_loss = None

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
Expand All @@ -235,7 +241,7 @@ def _init_weights(self, module):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

def forward(self, tokens, targets=None):
def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
Expand All @@ -249,13 +255,13 @@ def forward(self, tokens, targets=None):
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.output(h)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the output on the very last position
logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
self.last_loss = None

return logits, loss
return logits

def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
Expand Down
34 changes: 14 additions & 20 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ void malloc_run_state(RunState* s, Config* p) {
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
printf("malloc failed!\n");
exit(1);
exit(EXIT_FAILURE);
}
}

Expand Down Expand Up @@ -310,24 +310,18 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);

// apply RoPE rotation to the q and k vectors for each head
for (int h = 0; h < p->n_heads; h++) {
// get the q and k vectors for this head
float* q = s->q + h * head_size;
float* k = s->k + h * head_size;
// rotate q and k by the freq_cis_real and freq_cis_imag
for (int i = 0; i < head_size; i+=2) {
float q0 = q[i];
float q1 = q[i+1];
float k0 = k[i];
float k1 = k[i+1];
float fcr = freq_cis_real_row[i/2];
float fci = freq_cis_imag_row[i/2];
q[i] = q0 * fcr - q1 * fci;
q[i+1] = q0 * fci + q1 * fcr;
k[i] = k0 * fcr - k1 * fci;
k[i+1] = k0 * fci + k1 * fcr;
}
// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
for (int i = 0; i < dim; i+=2) {
float q0 = s->q[i];
float q1 = s->q[i+1];
float k0 = s->k[i];
float k1 = s->k[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
s->q[i] = q0 * fcr - q1 * fci;
s->q[i+1] = q0 * fci + q1 * fcr;
s->k[i] = k0 * fcr - k1 * fci;
s->k[i+1] = k0 * fci + k1 * fcr;
}

// save key,value at this time step (pos) to our kv cache
Expand Down Expand Up @@ -440,7 +434,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
for (char *c = text; *c != '\0'; c++) {
sprintf(str_buffer, "%c", *c);
int id = str_lookup(str_buffer, vocab, vocab_size);
if (id == -1) { printf("not good\n"); exit(1);}
if (id == -1) { printf("not good\n"); exit(EXIT_FAILURE); }
tokens[*n_tokens] = id;
(*n_tokens)++;
}
Expand Down
66 changes: 66 additions & 0 deletions save_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python
"""Saves the model as a TorchScript.
Usage examples:
./save_torchscript.py
./save_torchscript.py --dim=300
./save_torchscript.py --gzip_output=True --zero_params=True
The resulting file can be loaded in C++ code and then used for training or
inference with:
#include <torch/script.h>
torch::jit::Module module = torch::jit::load("model.pt")
Note that the serialized model includes the initial parameters and with the default
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
the model parameters separately you can zero out the parameters before saving it and
it will gzip down to 780K.
"""
import gzip
import os
import shutil
from inspect import signature

import torch

from model import ModelArgs, Transformer

# Model args config
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = n_heads
multiple_of = 32
max_seq_len = 256
dropout = 0.0
vocab_size = 32000
norm_eps = 1e-5
# Save config
model_path = "model.pt"
zero_params = False
gzip_output = False
# Allow config overrides
exec(open("configurator.py").read())


def main() -> None:
model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
model = Transformer(ModelArgs(**model_args))

# If requested zero params before saving the model. This is useful in
# conjunction with gzip_output.
if zero_params:
for p in model.parameters():
p.detach().zero_()

torch.jit.save(torch.jit.script(model), model_path)

if gzip_output:
with open(model_path, "rb") as f_in:
with gzip.open(f"{model_path}.gz", "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(model_path)


if __name__ == "__main__":
main()
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def estimate_loss():
for k in range(eval_iters):
X, Y = next(batch_iter)
with ctx:
logits, loss = model(X, Y)
logits = model(X, Y)
loss = model.last_loss
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
Expand Down Expand Up @@ -294,7 +295,8 @@ def get_lr(it):
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
with ctx:
logits, loss = model(X, Y)
logits = model(X, Y)
loss = model.last_loss
loss = loss / gradient_accumulation_steps
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = next(train_batch_iter)
Expand Down
28 changes: 14 additions & 14 deletions win.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
#define FILE_MAP_EXECUTE 0x0020
#endif /* FILE_MAP_EXECUTE */

static int __map_mman_error(const DWORD err, const int deferr)
static int __map_mman_error(const uint32_t err, const int deferr)
{
if (err == 0)
return 0;
//TODO: implement
return err;
}

static DWORD __map_mmap_prot_page(const int prot)
static uint32_t __map_mmap_prot_page(const int prot)
{
DWORD protect = 0;
uint32_t protect = 0;

if (prot == PROT_NONE)
return protect;
Expand All @@ -35,9 +35,9 @@ static DWORD __map_mmap_prot_page(const int prot)
return protect;
}

static DWORD __map_mmap_prot_file(const int prot)
static uint32_t __map_mmap_prot_file(const int prot)
{
DWORD desiredAccess = 0;
uint32_t desiredAccess = 0;

if (prot == PROT_NONE)
return desiredAccess;
Expand All @@ -62,15 +62,15 @@ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off)
#pragma warning(disable: 4293)
#endif

const DWORD dwFileOffsetLow = (DWORD)(off & 0xFFFFFFFFL);
const DWORD dwFileOffsetHigh = (DWORD)((off >> 32) & 0xFFFFFFFFL);
const DWORD protect = __map_mmap_prot_page(prot);
const DWORD desiredAccess = __map_mmap_prot_file(prot);
const uint32_t dwFileOffsetLow = (uint32_t)(off & 0xFFFFFFFFL);
const uint32_t dwFileOffsetHigh = (uint32_t)((off >> 32) & 0xFFFFFFFFL);
const uint32_t protect = __map_mmap_prot_page(prot);
const uint32_t desiredAccess = __map_mmap_prot_file(prot);

const ssize_t maxSize = off + (ssize_t)len;

const DWORD dwMaxSizeLow = (DWORD)(maxSize & 0xFFFFFFFFL);
const DWORD dwMaxSizeHigh = (DWORD)((maxSize >> 32) & 0xFFFFFFFFL);
const uint32_t dwMaxSizeLow = (uint32_t)(maxSize & 0xFFFFFFFFL);
const uint32_t dwMaxSizeHigh = (uint32_t)((maxSize >> 32) & 0xFFFFFFFFL);

#ifdef _MSC_VER
#pragma warning(pop)
Expand Down Expand Up @@ -130,8 +130,8 @@ int munmap(void *addr, size_t len)

int mprotect(void *addr, size_t len, int prot)
{
DWORD newProtect = __map_mmap_prot_page(prot);
DWORD oldProtect = 0;
uint32_t newProtect = __map_mmap_prot_page(prot);
uint32_t oldProtect = 0;

if (VirtualProtect(addr, len, newProtect, &oldProtect))
return 0;
Expand Down Expand Up @@ -173,7 +173,7 @@ int munlock(const void *addr, size_t len)

// Portable clock_gettime function for Windows
int clock_gettime(int clk_id, struct timespec *tp) {
DWORD ticks = GetTickCount();
uint32_t ticks = GetTickCount();
tp->tv_sec = ticks / 1000;
tp->tv_nsec = (ticks % 1000) * 1000000;
return 0;
Expand Down
3 changes: 2 additions & 1 deletion win.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
#define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
#include <windows.h>
#include <time.h>
#include <stdint.h>

#define ssize_t __int64
#define ssize_t int64_t
#define ftell _ftelli64

// Below code is originally from mman-win32
Expand Down

0 comments on commit f0e5f61

Please sign in to comment.