-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from guangyusong/main
Added wkv_cuda.cu and wkv_op.cpp files for RWKV
- Loading branch information
Showing
3 changed files
with
150 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#include <stdio.h> | ||
#include <assert.h> | ||
|
||
#define MIN_VALUE (-1e38) | ||
|
||
template <typename F> | ||
__global__ void kernel_forward(const int B, const int T, const int C, | ||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, | ||
F *__restrict__ const _y) { | ||
const int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
const int _b = idx / C; | ||
const int _c = idx % C; | ||
const int _offset = _b * T * C + _c; | ||
|
||
F u = _u[_c]; | ||
F w = _w[_c]; | ||
const F *__restrict__ const k = _k + _offset; | ||
const F *__restrict__ const v = _v + _offset; | ||
F *__restrict__ const y = _y + _offset; | ||
|
||
F p = 0, q = 0, o = MIN_VALUE; | ||
// p and q are running sums divided by exp(o) (to avoid overflows) | ||
for (int i = 0; i < T; i++) { | ||
const int ii = i * C; | ||
|
||
F no = max(o, u + k[ii]); | ||
F A = exp(o - no); | ||
F B = exp(u + k[ii] - no); | ||
y[ii] = (A * p + B * v[ii]) / (A * q + B); | ||
|
||
no = max(w + o, k[ii]); | ||
A = exp(w + o - no); | ||
B = exp(k[ii] - no); | ||
p = A * p + B * v[ii]; | ||
q = A * q + B; | ||
o = no; | ||
} | ||
} | ||
|
||
template <typename F> | ||
__global__ void kernel_backward(const int B, const int T, const int C, | ||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, | ||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { | ||
const int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
const int _b = idx / C; | ||
const int _c = idx % C; | ||
const int _offset = _b * T * C + _c; | ||
|
||
F u = _u[_c]; | ||
F w = _w[_c]; | ||
const F *__restrict__ const k = _k + _offset; | ||
const F *__restrict__ const v = _v + _offset; | ||
const F *__restrict__ const gy = _gy + _offset; | ||
|
||
F *__restrict__ const gk = _gk + _offset; | ||
F *__restrict__ const gv = _gv + _offset; | ||
|
||
F y[Tmax], z[Tmax], zexp[Tmax]; | ||
|
||
F gw = 0, gu = 0; | ||
F p = 0, q = 0; | ||
F dpdw = 0, dqdw = 0; | ||
F o = MIN_VALUE; | ||
for (int i = 0; i < T; i++) { | ||
const int ii = i * C; | ||
F no = max(o, k[ii] + u); | ||
F A = exp(o - no); | ||
F B = exp(k[ii] + u - no); | ||
|
||
F num = A * p + B * v[ii]; | ||
F iden = 1 / (A * q + B); | ||
|
||
y[i] = num * iden; | ||
z[i] = iden; | ||
zexp[i] = k[ii] + u - no; | ||
|
||
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; | ||
gu += gy[ii] * (v[ii] - y[i]) * B * iden; | ||
|
||
no = max(w + o, k[ii]); | ||
A = exp(w + o - no); | ||
B = exp(k[ii] - no); | ||
dpdw = A * (p + dpdw); | ||
dqdw = A * (q + dqdw); | ||
p = A * p + B * v[ii]; | ||
q = A * q + B; | ||
o = no; | ||
} | ||
|
||
F gp = 0, gq = 0; | ||
o = MIN_VALUE; | ||
for (int i = T - 1; i >= 0; i--) { | ||
const int ii = i * C; | ||
F A = gy[ii] * z[i] * exp(zexp[i]); | ||
F B = exp(k[ii] + o); | ||
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); | ||
gv[ii] = A + B * gp; | ||
|
||
F no = max(w + o, zexp[i] - k[ii] - u); | ||
A = exp(w + o - no); | ||
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); | ||
gp = A * gp + B; | ||
gq = A * gq - B * y[i]; | ||
o = no; | ||
} | ||
|
||
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass | ||
const int _offsetBC = _b * C + _c; | ||
_gw[_offsetBC] += gw * _w[_c]; | ||
_gu[_offsetBC] += gu; | ||
} | ||
|
||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { | ||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance | ||
assert(B * C % threadsPerBlock.x == 0); | ||
dim3 numBlocks(B * C / threadsPerBlock.x); | ||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); | ||
} | ||
|
||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { | ||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance | ||
assert(B * C % threadsPerBlock.x == 0); | ||
dim3 numBlocks(B * C / threadsPerBlock.x); | ||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#include <torch/extension.h> | ||
|
||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); | ||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); | ||
|
||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { | ||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); | ||
} | ||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { | ||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", &forward, "wkv forward"); | ||
m.def("backward", &backward, "wkv backward"); | ||
} | ||
|
||
TORCH_LIBRARY(wkv, m) { | ||
m.def("forward", forward); | ||
m.def("backward", backward); | ||
} |