Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extra/gemm/max_matmul: start of custom kernels for GEMM #6926

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

flammit
Copy link
Contributor

@flammit flammit commented Oct 7, 2024

As a prerequisite to implementing full speed GEMM for NV, here's the handwritten versions of GEMM that shows the incremental progress needed to get there and associated speed improvements.

Acc Variation Performance
FP32 hcopt 1301.50 us, would be 105600.10 GFLOPS matmul, 77.34 GB/s
FP32 flat_smem_input 1309.41 us, would be 104962.67 GFLOPS matmul, 76.88 GB/s
FP32 swizzled_smem_input 882.69 us, would be 155705.02 GFLOPS matmul, 114.04 GB/s
FP32 2_stage_swizzled_smem_input 831.49 us, would be 165292.77 GFLOPS matmul, 121.06 GB/s
FP32 max 826.37 us, would be 166316.89 GFLOPS matmul, 121.81 GB/s
FP16 3_stage_swizzled 505.66 us, would be 271798.97 GFLOPS matmul, 199.07 GB/s
FP16 max 404.48 us, would be 339791.71 GFLOPS matmul, 248.87 GB/s

The command for hcopt is: PYTHONPATH=. CUDA=1 GEMM_VARIATION="hcopt" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=float CNT=1024 INPUT=RAND python3 ./extra/gemm/max_matmul.py .

The command for the rest of the FP32 acc is: PYTHONPATH=. CUDA=1 GEMM_VARIATION="$VARIATION" DTYPE_IN=half DTYPE_OUT=float DTYPE_ACC=float CNT=1024 INPUT=RAND python3 ./extra/gemm/max_matmul.py.

The command for the rest of the FP16 acc is: PYTHONPATH=. CUDA=1 GEMM_VARIATION="$VARIATION" DTYPE_IN=half DTYPE_OUT=half DTYPE_ACC=half CNT=1024 INPUT=ONES python3 ./extra/gemm/max_matmul.py.

@flammit
Copy link
Contributor Author

flammit commented Oct 7, 2024

added an unoptimized FP16 input/FP16 acc MMA example

@geohot
Copy link
Collaborator

geohot commented Oct 8, 2024

So I tested this with Triton once, what does Triton get? IIRC it was well over 200. We should at least match that with handcoded stuff.

@flammit flammit marked this pull request as draft October 8, 2024 18:08
@flammit flammit force-pushed the max_matmul branch 2 times, most recently from 52feba8 to 7234150 Compare October 9, 2024 23:14
Copy link
Contributor

This branch currently is behind tinygrad/master. The line count difference bot is disabled.

@flammit
Copy link
Contributor Author

flammit commented Oct 11, 2024

Added a 3-staged pipeline with swizzled SMEM inputs example for FP16 acc that does 270TF (still less than the 330TF in cutlass, but not as bad as before).

note this PR depends on #6956 being landed first.

@flammit flammit force-pushed the max_matmul branch 2 times, most recently from 5b44292 to 9465409 Compare October 17, 2024 01:22
@flammit flammit marked this pull request as ready for review October 17, 2024 01:27
@flammit
Copy link
Contributor Author

flammit commented Oct 17, 2024

happy to remove these extra stages/variations and just keep the "max" variations, but figure it might be useful for comparison as future kernel rendering features are incrementally added.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants