-
Notifications
You must be signed in to change notification settings - Fork 1
/
wind_rwkv7.cpp
34 lines (27 loc) · 2.1 KB
/
wind_rwkv7.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include <torch/extension.h>
#include <cuda_bf16.h>
using bf = __nv_bfloat16;
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*s0, bf*y, bf*s, bf*sT);
void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &s0, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sT) {
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)s0.data_ptr(), (bf*)y.data_ptr(), (bf*)s.data_ptr(), (bf*)sT.data_ptr());
}
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, bf*s, bf*dsT, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da, bf*ds0);
void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
torch::Tensor &s, torch::Tensor &dsT, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da, torch::Tensor &ds0) {
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
(bf*)s.data_ptr(), (bf*)dsT.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr(), (bf*)ds0.data_ptr());
}
/*TORCH_LIBRARY(wind, m) {
m.def("forward", forward);
m.def("backward", backward);
}*/
TORCH_LIBRARY(wind, m) {
m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor s0, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sT) -> ()");
m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor dsT, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da, Tensor(g!) ds0) -> ()");
}
TORCH_LIBRARY_IMPL(wind, CUDA, m) {
m.impl("forward", &forward);
m.impl("backward", &backward);
}