-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding backward kernel for repkv on llama3
branch (cudamode-irl)
#764
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- [ ] WIP: CPU kernel - [ ] Cuda kernel
- [ ] WIP cuda version
insop
changed the title
DRAFT: Adding backward kernel for repkv on
DRAFT: Adding backward kernel for repkv on Sep 22, 2024
llama3
branch (cuda-mode-irl)llama3
branch (cudamode-irl)
- kernel 1 is tested - build ``` make repkv_backward /usr/local/cuda/bin/nvcc -O3 --use_fast_math --generate-code arch=compute_80,code=[compute_80,sm_80] -lcublas -lcublasLt -std=c++17 repkv_backward.cu -o repkv_backward ``` - test run on A30 ``` Using kernel 1 Checking block size 32. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 64. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 128. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 256. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 512. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 Checking block size 1024. 0.531524 0.531524 0.600285 0.600285 0.458787 0.458787 0.296680 0.296680 -0.911627 -0.911627 All results match. Starting benchmarks. block_size 32 time 3.2461 ms block_size 64 time 1.7509 ms block_size 128 time 1.7374 ms block_size 256 time 1.7441 ms block_size 512 time 1.8092 ms block_size 1024 time 2.0443 ms ```
insop
changed the title
DRAFT: Adding backward kernel for repkv on
Adding backward kernel for repkv on Sep 26, 2024
llama3
branch (cudamode-irl)llama3
branch (cudamode-irl)
@gordicaleksa , @ngc92, @ademeure, @karpathy PTAL, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PTAL,
repkv_backward
is updated and tested.I will update
repkv.cuh
once this PR is merged.CC: @karpathy
This is an WIP
repkv backward kernel
, started as a cudamode-irl project.Once the following work is done, will remove draft sign.
This work was supported by ALEKSA (@gordicaleksa) , Eric (@ngc92), ARUN (@ademeure) during the irl event.
pytorch backward test code
CPU kernel
Cuda kernel
build