Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding backward kernel for repkv on llama3 branch (cudamode-irl) #764

Merged
merged 14 commits into from
Sep 27, 2024

Conversation

insop
Copy link

@insop insop commented Sep 22, 2024

PTAL, repkv_backward is updated and tested.
I will update repkv.cuh once this PR is merged.

CC: @karpathy

This is an WIP repkv backward kernel, started as a cudamode-irl project.
Once the following work is done, will remove draft sign.

This work was supported by ALEKSA (@gordicaleksa) , Eric (@ngc92), ARUN (@ademeure) during the irl event.

  • pytorch backward test code

  • CPU kernel

  • Cuda kernel

  • build

make repkv_backward
/usr/local/cuda/bin/nvcc -O3 --use_fast_math --generate-code arch=compute_80,code=[compute_80,sm_80] -lcublas -lcublasLt -std=c++17 repkv_backward.cu -o repkv_backward
  • test run on A30
Using kernel 1
Checking block size 32.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 64.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 128.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 256.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 512.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 1024.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
All results match. Starting benchmarks.

block_size   32 time 3.2461 ms
block_size   64 time 1.7509 ms
block_size  128 time 1.7374 ms
block_size  256 time 1.7441 ms
block_size  512 time 1.8092 ms
block_size 1024 time 2.0443 ms

- [ ] WIP: CPU kernel
- [ ] Cuda kernel
@insop insop marked this pull request as draft September 22, 2024 03:46
@insop insop changed the title DRAFT: Adding backward kernel for repkv on llama3 branch (cuda-mode-irl) DRAFT: Adding backward kernel for repkv on llama3 branch (cudamode-irl) Sep 22, 2024
- kernel 1 is tested

- build
```
make repkv_backward
/usr/local/cuda/bin/nvcc -O3 --use_fast_math --generate-code arch=compute_80,code=[compute_80,sm_80] -lcublas -lcublasLt -std=c++17 repkv_backward.cu -o repkv_backward
```

- test run on A30
```
Using kernel 1
Checking block size 32.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 64.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 128.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 256.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 512.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
Checking block size 1024.
0.531524 0.531524
0.600285 0.600285
0.458787 0.458787
0.296680 0.296680
-0.911627 -0.911627
All results match. Starting benchmarks.

block_size   32 time 3.2461 ms
block_size   64 time 1.7509 ms
block_size  128 time 1.7374 ms
block_size  256 time 1.7441 ms
block_size  512 time 1.8092 ms
block_size 1024 time 2.0443 ms
```
@insop insop changed the title DRAFT: Adding backward kernel for repkv on llama3 branch (cudamode-irl) Adding backward kernel for repkv on llama3 branch (cudamode-irl) Sep 26, 2024
@insop
Copy link
Author

insop commented Sep 26, 2024

@gordicaleksa , @ngc92, @ademeure, @karpathy

PTAL, repkv_backward cpu and cuda kernels are updated and tested.
I will update repkv.cuh once this PR is merged.

@insop insop marked this pull request as ready for review September 26, 2024 00:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants