From 59b101274baf5ce1819465ad6443c7dbd2da36ca Mon Sep 17 00:00:00 2001 From: guangyusong <15316444+guangyusong@users.noreply.github.com> Date: Wed, 17 Jan 2024 23:32:48 -0500 Subject: [PATCH] Add files to make cuda work --- zoology/mixers/rwkv.py | 7 +- zoology/mixers/rwkv/v4/wkv_cuda.cu | 125 +++++++++++++++++++++++++++++ zoology/mixers/rwkv/v4/wkv_op.cpp | 21 +++++ 3 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 zoology/mixers/rwkv/v4/wkv_cuda.cu create mode 100644 zoology/mixers/rwkv/v4/wkv_op.cpp diff --git a/zoology/mixers/rwkv.py b/zoology/mixers/rwkv.py index ec2429e..b4d7c70 100644 --- a/zoology/mixers/rwkv.py +++ b/zoology/mixers/rwkv.py @@ -38,9 +38,10 @@ def backward(ctx, grad_output): # # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice from torch.utils.cpp_extension import load -wkv_cuda = load(name="wkv", sources= - ["/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_op.cpp", - "/var/cr05_data/sim_data/code/petting-zoo/src/models/mixers/cuda/wkv_cuda.cu"], +dir_path = os.path.dirname(os.path.realpath(__file__)) +wkv_cuda = load(name="wkv", sources=[ + os.path.join(dir_path, "./rwkv/v4/wkv_op.cpp"), + os.path.join(dir_path, "./rwkv/v4/wkv_cuda.cu")], verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) class WKV(torch.autograd.Function): diff --git a/zoology/mixers/rwkv/v4/wkv_cuda.cu b/zoology/mixers/rwkv/v4/wkv_cuda.cu new file mode 100644 index 0000000..a4522cb --- /dev/null +++ b/zoology/mixers/rwkv/v4/wkv_cuda.cu @@ -0,0 +1,125 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + F p = 0, q = 0, o = MIN_VALUE; + // p and q are running sums divided by exp(o) (to avoid overflows) + for (int i = 0; i < T; i++) { + const int ii = i * C; + + F no = max(o, u + k[ii]); + F A = exp(o - no); + F B = exp(u + k[ii] - no); + y[ii] = (A * p + B * v[ii]) / (A * q + B); + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const gy = _gy + _offset; + + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F y[Tmax], z[Tmax], zexp[Tmax]; + + F gw = 0, gu = 0; + F p = 0, q = 0; + F dpdw = 0, dqdw = 0; + F o = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + F no = max(o, k[ii] + u); + F A = exp(o - no); + F B = exp(k[ii] + u - no); + + F num = A * p + B * v[ii]; + F iden = 1 / (A * q + B); + + y[i] = num * iden; + z[i] = iden; + zexp[i] = k[ii] + u - no; + + gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; + gu += gy[ii] * (v[ii] - y[i]) * B * iden; + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + dpdw = A * (p + dpdw); + dqdw = A * (q + dqdw); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } + + F gp = 0, gq = 0; + o = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + F A = gy[ii] * z[i] * exp(zexp[i]); + F B = exp(k[ii] + o); + gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); + gv[ii] = A + B * gp; + + F no = max(w + o, zexp[i] - k[ii] - u); + A = exp(w + o - no); + B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); + gp = A * gp + B; + gq = A * gq - B * y[i]; + o = no; + } + + // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] += gw * _w[_c]; + _gu[_offsetBC] += gu; +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); +} \ No newline at end of file diff --git a/zoology/mixers/rwkv/v4/wkv_op.cpp b/zoology/mixers/rwkv/v4/wkv_op.cpp new file mode 100644 index 0000000..e59a515 --- /dev/null +++ b/zoology/mixers/rwkv/v4/wkv_op.cpp @@ -0,0 +1,21 @@ +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} \ No newline at end of file