-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extra/gemm/max_matmul: start of custom kernels for GEMM #6926
base: master
Are you sure you want to change the base?
Conversation
added an unoptimized FP16 input/FP16 acc MMA example |
So I tested this with Triton once, what does Triton get? IIRC it was well over 200. We should at least match that with handcoded stuff. |
52feba8
to
7234150
Compare
This branch currently is behind tinygrad/master. The line count difference bot is disabled. |
Added a 3-staged pipeline with swizzled SMEM inputs example for FP16 acc that does 270TF (still less than the 330TF in cutlass, but not as bad as before). note this PR depends on #6956 being landed first. |
5b44292
to
9465409
Compare
happy to remove these extra stages/variations and just keep the "max" variations, but figure it might be useful for comparison as future kernel rendering features are incrementally added. |
As a prerequisite to implementing full speed GEMM for NV, here's the handwritten versions of GEMM that shows the incremental progress needed to get there and associated speed improvements.
The command for hcopt is:
PYTHONPATH=. CUDA=1 GEMM_VARIATION="hcopt" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=float CNT=1024 INPUT=RAND python3 ./extra/gemm/max_matmul.py
.The command for the rest of the FP32 acc is:
PYTHONPATH=. CUDA=1 GEMM_VARIATION="$VARIATION" DTYPE_IN=half DTYPE_OUT=float DTYPE_ACC=float CNT=1024 INPUT=RAND python3 ./extra/gemm/max_matmul.py
.The command for the rest of the FP16 acc is:
PYTHONPATH=. CUDA=1 GEMM_VARIATION="$VARIATION" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=1024 INPUT=ONES python3 ./extra/gemm/max_matmul.py
.