Skip to content

Commit

Permalink
Merge pull request #13 from guangyusong/main
Browse files Browse the repository at this point in the history
Added wkv_cuda.cu and wkv_op.cpp files for RWKV
  • Loading branch information
simran-arora authored Jan 18, 2024
2 parents 23519aa + 59b1012 commit e324d66
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 3 deletions.
7 changes: 4 additions & 3 deletions zoology/mixers/rwkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ def backward(ctx, grad_output):
# # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice

from torch.utils.cpp_extension import load
wkv_cuda = load(name="wkv", sources=
["/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_op.cpp",
"/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_cuda.cu"],
dir_path = os.path.dirname(os.path.realpath(__file__))
wkv_cuda = load(name="wkv", sources=[
os.path.join(dir_path, "./rwkv/v4/wkv_op.cpp"),
os.path.join(dir_path, "./rwkv/v4/wkv_cuda.cu")],
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])

class WKV(torch.autograd.Function):
Expand Down
125 changes: 125 additions & 0 deletions zoology/mixers/rwkv/v4/wkv_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include <stdio.h>
#include <assert.h>

#define MIN_VALUE (-1e38)

template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
F *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;

F p = 0, q = 0, o = MIN_VALUE;
// p and q are running sums divided by exp(o) (to avoid overflows)
for (int i = 0; i < T; i++) {
const int ii = i * C;

F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}
}

template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

F u = _u[_c];
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const gy = _gy + _offset;

F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;

F y[Tmax], z[Tmax], zexp[Tmax];

F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);

F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);

y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;

gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
}

F gp = 0, gq = 0;
o = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;

F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
}

// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}

void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}

void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
}
21 changes: 21 additions & 0 deletions zoology/mixers/rwkv/v4/wkv_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <torch/extension.h>

void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);

void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}

TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}

0 comments on commit e324d66

Please sign in to comment.