Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
trholding committed Aug 1, 2023
2 parents b77335d + a8f3e1c commit d0ecdab
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 44 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,31 @@ jobs:
id: build_msvc
run: |
.\build_msvc.bat
windows-latest-mingw:
runs-on: windows-latest

defaults:
run:
shell: msys2 {0}

strategy:
matrix:
include:
- { sys: mingw64, env: x86_64 }

steps:
- name: Checkout
id: checkout
uses: actions/checkout@v3

- uses: msys2/setup-msys2@v2
id: setup-msys2
with:
msystem: ${{ matrix.sys }}
install: mingw-w64-${{matrix.env}}-gcc make

- name: Build ${{ matrix.sys }} ${{ matrix.env }}
id: build_mingw
run: |
make win64
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ runomp: run.c

.PHONY: win64
win64:
x86_64-w64-mingw32-gcc-win32 -Ofast -D_WIN32 -o run.exe -I. run.c win.c
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c

# compiles with gnu99 standard flags for amazon linux, coreos, etc. compatibility
.PHONY: rungnu
Expand Down
2 changes: 1 addition & 1 deletion build_msvc.bat
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cl.exe /Ox /openmp /I. run.c win.c
cl.exe /fp:fast /Ox /openmp /I. run.c win.c
12 changes: 6 additions & 6 deletions export_meta_llama_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def serialize(key):

# final rmsnorm
serialize('norm.weight')
# freqs_cis
freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
serialize('freqs_cis.real')
serialize('freqs_cis.imag')
# freqs_cos, freqs_sin
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
serialize('freqs_cos')
serialize('freqs_sin')

# finally write the output weights
serialize('output.weight')
Expand Down
57 changes: 37 additions & 20 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
Expand All @@ -51,17 +51,31 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

# reshape xq and xk to match the complex representation
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)

# reshape freqs_cos and freqs_sin for broadcasting
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

# apply rotation using real numbers
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

# flatten last two dimensions
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

return xq_out.type_as(xq), xk_out.type_as(xk)

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -103,7 +117,8 @@ def __init__(self, args: ModelArgs):
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
bsz, seqlen, _ = x.shape

Expand All @@ -114,7 +129,7 @@ def forward(
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

# RoPE relative positional embeddings
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

# grouped multiquery attention: expand out keys and values
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
Expand Down Expand Up @@ -176,8 +191,8 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x, freqs_cis):
h = x + self.attention.forward(self.attention_norm(x), freqs_cis)
def forward(self, x, freqs_cos, freqs_sin):
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out

Expand All @@ -201,8 +216,9 @@ def __init__(self, params: ModelArgs):
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying

# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

# init all weights
self.apply(self._init_weights)
Expand All @@ -223,10 +239,11 @@ def forward(self, tokens, targets=None):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
h = self.dropout(h)
freqs_cis = self.freqs_cis[:seqlen]
freqs_cos = self.freqs_cos[:seqlen]
freqs_sin = self.freqs_sin[:seqlen]

for layer in self.layers:
h = layer(h, freqs_cis)
h = layer(h, freqs_cos, freqs_sin)
h = self.norm(h)

if targets is not None:
Expand Down Expand Up @@ -359,8 +376,8 @@ def serialize(t):
serialize(self.norm.weight)
# note: no need to write final classifier weights due to weight sharing
# freqs_cis
serialize(self.freqs_cis.real[:p.max_seq_len])
serialize(self.freqs_cis.imag[:p.max_seq_len])
serialize(self.freqs_cos[:p.max_seq_len])
serialize(self.freqs_sin[:p.max_seq_len])

# write to binary file
f.close()
Expand Down
2 changes: 1 addition & 1 deletion run.c
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ int main(int argc, char *argv[]) {
TransformerWeights weights;
int fd = 0; // file descriptor for memory mapping
float* data = NULL; // memory mapped data pointer
long file_size; // size of the checkpoint file in bytes
ssize_t file_size; // size of the checkpoint file in bytes
{
FILE *file = fopen(checkpoint, "rb");
if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; }
Expand Down
22 changes: 8 additions & 14 deletions win.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include <errno.h>
#include <io.h>


#ifndef FILE_MAP_EXECUTE
#define FILE_MAP_EXECUTE 0x0020
#endif /* FILE_MAP_EXECUTE */
Expand Down Expand Up @@ -53,30 +52,25 @@ static DWORD __map_mmap_prot_file(const int prot)
return desiredAccess;
}

void* mmap(void *addr, size_t len, int prot, int flags, int fildes, off_t off)
void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off)
{
HANDLE fm, h;

void * map = MAP_FAILED;

#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable: 4293)
#endif

const DWORD dwFileOffsetLow = (sizeof(off_t) <= sizeof(DWORD)) ?
(DWORD)off : (DWORD)(off & 0xFFFFFFFFL);
const DWORD dwFileOffsetHigh = (sizeof(off_t) <= sizeof(DWORD)) ?
(DWORD)0 : (DWORD)((off >> 32) & 0xFFFFFFFFL);
const DWORD dwFileOffsetLow = (DWORD)(off & 0xFFFFFFFFL);
const DWORD dwFileOffsetHigh = (DWORD)((off >> 32) & 0xFFFFFFFFL);
const DWORD protect = __map_mmap_prot_page(prot);
const DWORD desiredAccess = __map_mmap_prot_file(prot);

const off_t maxSize = off + (off_t)len;
const ssize_t maxSize = off + (ssize_t)len;

const DWORD dwMaxSizeLow = (sizeof(off_t) <= sizeof(DWORD)) ?
(DWORD)maxSize : (DWORD)(maxSize & 0xFFFFFFFFL);
const DWORD dwMaxSizeHigh = (sizeof(off_t) <= sizeof(DWORD)) ?
(DWORD)0 : (DWORD)((maxSize >> 32) & 0xFFFFFFFFL);
const DWORD dwMaxSizeLow = (DWORD)(maxSize & 0xFFFFFFFFL);
const DWORD dwMaxSizeHigh = (DWORD)((maxSize >> 32) & 0xFFFFFFFFL);

#ifdef _MSC_VER
#pragma warning(pop)
Expand Down Expand Up @@ -110,11 +104,11 @@ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, off_t off)
errno = __map_mman_error(GetLastError(), EPERM);
return MAP_FAILED;
}

map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len);

CloseHandle(fm);

if (map == NULL)
{
errno = __map_mman_error(GetLastError(), EPERM);
Expand Down
4 changes: 3 additions & 1 deletion win.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <windows.h>
#include <time.h>

#define ssize_t __int64
#define ftell _ftelli64

// Below code is originally from mman-win32
//
Expand Down Expand Up @@ -51,7 +53,7 @@ extern "C" {
/* Flags for portable clock_gettime call. */
#define CLOCK_REALTIME 0

void* mmap(void *addr, size_t len, int prot, int flags, int fildes, off_t off);
void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off);
int munmap(void *addr, size_t len);
int mprotect(void *addr, size_t len, int prot);
int msync(void *addr, size_t len, int flags);
Expand Down

0 comments on commit d0ecdab

Please sign in to comment.