Skip to content

Commit

Permalink
reshuffle repkv a bit, i wrote it from scratch. the kernel is still c…
Browse files Browse the repository at this point in the history
…orrect. repkv backward looks correct. rope backward is trivial so i don't see how it's not correct, and i also checked it. basically i'm really confused right now
  • Loading branch information
karpathy committed Sep 27, 2024
1 parent 8d49062 commit 7d945e9
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 51 deletions.
52 changes: 29 additions & 23 deletions dev/cuda/repkv_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,46 @@ Block size 128 seems fastest on H100

// cpu reference code
void repkv_backward_cpu(float* dinp, const float* dout,
const int B, const int T, const int Cout,
const int hd, const int qh, const int kh, const int vh) {

assert(Cout == (hd * (3 * qh)));
int B, int T, int C,
int hd, int qh, int kh, int vh) {
// inp is (B, T, C)
// out is (B, T, 3, NH, HD)
// hd = head dimension
// qh, kh, vh = number of query, key, value heads
assert(C == hd * (qh + kh + vh));
assert(kh == vh);
int nrep = qh / kh; // number of times to replicate key/value vectors
int Cin = hd * (qh + kh + vh); // output channels
int Cout = hd * (qh * 3); // output channels

for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// seek to the input position dout[b,t,:]
const float* x = dout + b * T * Cout + t * Cout;
// seek to the input position inp[b,t,:]
float* dx = dinp + b * T * C + t * C;
// seek to the output position out[b,t,:]
float* y = dinp + b * T * Cin + t * Cin;
const float* dy = dout + b * T * Cout + t * Cout;
// copy all the query vectors, no changes
for (int i = 0; i < hd * qh; i++) { y[i] = x[i]; }
x += hd * qh; // advance input pointer
y += hd * qh; // advance output pointer
// copy key vectors, and replicate them nrep times
for (int i = 0; i < hd * qh; i++) { dx[i] = dy[i]; }
dx += hd * qh; // advance input pointer
dy += hd * qh; // advance output pointer
// gather gradients from the key vectors
for (int h = 0; h < kh; h++) {
// init the gradient to 0
for (int i = 0; i < hd; i++) { dx[i] = 0.0f; }
for (int n = 0; n < nrep; n++) {
for (int i = 0; i < hd; i++) { y[i] += x[i]; }
x += hd; // advance input pointer
for (int i = 0; i < hd; i++) { dx[i] += dy[i]; }
dy += hd; // advance output pointer
}
y += hd; // advance output pointer
dx += hd; // advance input pointer
}
// copy value vectors, and replicate them nrep times
// gather gradients from the value vectors
for (int h = 0; h < vh; h++) {
// init the gradient to 0
for (int i = 0; i < hd; i++) { dx[i] = 0.0f; }
for (int n = 0; n < nrep; n++) {
for (int i = 0; i < hd; i++) { y[i] += x[i]; }
x += hd; // advance input pointer
for (int i = 0; i < hd; i++) { dx[i] += dy[i]; }
dy += hd; // advance output pointer
}
y += hd; // advance output pointer
dx += hd; // advance input pointer
}
}
}
Expand Down Expand Up @@ -76,7 +83,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout,
dinp[dinp_idx] = __ldcs(&dout[dout_idx]);
} else if (c == 1) {
if (nh % replicate_factor == 0) {
float reduced_sum = 0;
float reduced_sum = 0.0f;
for (int i = 0; i < replicate_factor; i++) {
reduced_sum += __ldcs(&dout[dout_idx+HD*i]);
}
Expand All @@ -87,7 +94,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout,

} else {
if (nh % replicate_factor == 0) {
float reduced_sum = 0;
float reduced_sum = 0.0f;
for (int i = 0; i < replicate_factor; i++) {
reduced_sum += __ldcs(&dout[dout_idx+HD*i]);
}
Expand Down Expand Up @@ -141,7 +148,6 @@ int main(int argc, char **argv) {

// allocate (and fill) CPU memory
float* dinp = (float*)malloc(B * T * Cin * sizeof(float));
memset(dinp, 0, B * T * Cin * sizeof(float));
float* dout = make_random_float(B * T * Cout * sizeof(float));

// allocate GPU memory
Expand All @@ -160,7 +166,7 @@ int main(int argc, char **argv) {
printf("Using kernel %d\n", kernel_num);

// CPU reference calculate
repkv_backward_cpu(dinp, dout, B, T, Cout, hd, qh, kh, vh);
repkv_backward_cpu(dinp, dout, B, T, Cin, hd, qh, kh, vh);

// check the correctness of the kernel at all block sizes
int block_sizes[] = {32, 64, 128, 256, 512, 1024};
Expand Down
4 changes: 2 additions & 2 deletions llmc/repkv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout,
dinp[dinp_idx] = __ldcs(&dout[dout_idx]);
} else if (c == 1) {
if (nh % replicate_factor == 0) {
float reduced_sum = 0;
float reduced_sum = 0.0f;
for (int i = 0; i < replicate_factor; i++) {
reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]);
}
Expand All @@ -85,7 +85,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout,

} else {
if (nh % replicate_factor == 0) {
float reduced_sum = 0;
float reduced_sum = 0.0f;
for (int i = 0; i < replicate_factor; i++) {
reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]);
}
Expand Down
43 changes: 23 additions & 20 deletions train_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -897,10 +897,31 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
rmsnorm_backward(dresidual, dl_ln2w, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_rstd, B, T, C, main_stream);
matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream);

// <--- gradient here matches OK

#ifdef ENABLE_CUDNN
printf("cuDNN path TODO\n"); exit(0);
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream);
#else
floatX* l_att = acts.att + l * B * NH * T * T;
// we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory
floatX* buffer_a = l_atty;
floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
#endif
// backward rope (this can be done in-place)
rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream);
// backward repkv (use scratchX as gradient buffer here)
repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd);

// <--- here the gradients don't match

This comment has been minimized.

Copy link
@Jake-Song

Jake-Song Oct 1, 2024

@karpathy
i'm trying to reconstruct gradients so that I see if I can fix it, in my case dl_bt4c2 is all zeroes. is that correct(?) values? I'm not sure I reconstruct correctly.

// so there is an issue with one of attention, rope, or repkv, or how they are called

// ------------------------------------------------------------------------
// DEBUGGING: we only work until this point right now, so exit here
// transfer the first 32 elements to CPU and print them
float* output = (float*)dl_btc;
float* output = (float*)dl_bt4c2;
floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost));
for (int i = 0; i < 32; i++) {
Expand All @@ -909,7 +930,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// write to .bin file
// move output to cpu
// int sz = B*T*qkv_channels; //B*T*C;
int sz = B*T*C;
int sz = B*T*qkv_channels;
floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX));
cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost));
FILE* f = fopen("out.bin", "wb");
Expand All @@ -918,24 +939,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
exit(0);
// ------------------------------------------------------------------------

#ifdef ENABLE_CUDNN
printf("cuDNN path TODO\n"); exit(0);
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream);
#else
floatX* l_att = acts.att + l * B * NH * T * T;
// we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory
floatX* buffer_a = l_atty;
floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
#endif
// backward rope (this can be done in-place)
rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream);
// backward repkv (use scratchX as gradient buffer here)
repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd);

// <--- here the gradients don't match, so there is an issue in between

// backward QKV projection
if(model->recompute >= 2) {
rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream);
Expand Down
12 changes: 6 additions & 6 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)

DEBUG_POINT = qkv.detach()
DEBUG_POINT = DEBUG_POINT.requires_grad_(True)
self.DEBUG_POINT = DEBUG_POINT
qkv = DEBUG_POINT

q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1)
q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD)
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2
Expand Down Expand Up @@ -197,12 +203,6 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
att = F.softmax(scores.float(), dim=-1).type_as(q)
y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD)
y = y.transpose(1, 2).contiguous().view(B, T, C)

DEBUG_POINT = y.detach()
DEBUG_POINT = DEBUG_POINT.requires_grad_(True)
self.DEBUG_POINT = DEBUG_POINT
y = DEBUG_POINT

y = self.c_proj(y)
return y

Expand Down

0 comments on commit 7d945e9

Please sign in to comment.