Skip to content

Commit

Permalink
[HGEMM] Update toy-hgemm library 0.1.0 (#152)
Browse files Browse the repository at this point in the history
* Update hgemm.py

* Update hgemm_cublas.cu

* Update hgemm_mma_stage_tn_cute.cu

* Update hgemm_mma_stage.cu

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Nov 28, 2024
1 parent edf80bb commit 37f1554
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 18 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,28 @@

<div id="contents"></div>

📚 **Modern CUDA Learn Notes with PyTorch** for Beginners: It includes **Tensor/CUDA Cores, TF32/F16/BF16/F8**, [📖150+ CUDA Kernels🔥🔥](#cuda-kernel) with PyTorch bindings, [📖30+ LLM/VLM🔥](#my-blogs-part-1), [📖40+ CV/C++...🔥](#my-blogs-part-2), [📖50+ CUDA/CuTe...🔥](#other-blogs) Blogs and [📖toy-hgemm library🔥🔥](./kernels/hgemm) which can achieve the performance of **cuBLAS**, check [📖HGEMM Supported Matrix👇](#hgemm-sgemm) for more details. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉
📚 **Modern CUDA Learn Notes with PyTorch** for Beginners: It includes **Tensor/CUDA Cores, TF32/F16/BF16/F8**, [📖150+ CUDA Kernels🔥🔥](#cuda-kernel) with PyTorch bindings, [📖30+ LLM/VLM🔥](#my-blogs-part-1), [📖40+ CV/C++...🔥](#my-blogs-part-2), [📖50+ CUDA/CuTe...🔥](#other-blogs) Blogs and [📖toy-hgemm library⚡️⚡️](./kernels/hgemm) which can achieve `98%~100%` performance of **cuBLAS**, check [📖HGEMM Supported Matrix👇](#hgemm-sgemm) for techs details. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉

<div id="hgemm-sgemm"></div>

<div align='center'>
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="150px" width="267px">
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="150px" width="267px">
<img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="150px" width="267px">
<img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="170px" width="270px">
<img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="170px" width="270px">
<img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="170px" width="270px">
</div>

Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA)` implemented in this repo (`blue`🔵) can achieve `99%~100%+` of its (`orange`🟠) performance. Please check [toy-hgemm library🔥🔥](./kernels/hgemm) for more details.
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm `CUBLAS_GEMM_DEFAULT_TENSOR_OP`, the `HGEMM (WMMA/MMA/CuTe)` implemented in this repo (`blue`🔵) can achieve `98%~100%` of its (`orange`🟠) performance. Please check [toy-hgemm library⚡️⚡️](./kernels/hgemm) for more details.

|CUDA Cores|Sliced K(Loop over K)|Tile Block|Tile Thread|
|CUDA Cores|Sliced K (Loop over K)|Tile Block (BMxBK)|Tile Thread (t 8x8)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
|WMMA(m16n16k16)|MMA(m16n8k16)|Pack LDST(128 bits)|SMEM Padding|
|WMMA (m16n16k16)|MMA (m16n8k16)|Pack LDST (128 bits)|SMEM Padding|
|✔️|✔️|✔️|✔️|
|Copy Async|Tile MMA(More Threads)|Tile Warp(More Values)|Multi Stages|
|Copy Async|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages (2/3/4)|
|✔️|✔️|✔️|✔️|
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle(CuTe)|
|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe)|
|✔️|✔️|✔️|✔️|
|Collective Store(Warp Shfl)|Row Major(NN)|Col Major(TN)|SGEMM F32/TF32|
|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
|✔️|✔️|✔️|✔️|

## ©️Citations🎉🎉
Expand Down
2 changes: 2 additions & 0 deletions kernels/hgemm/cublas/hgemm_cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ int main(int argc, char *argv[]) {
total_sec += this_sec;
}

// 1 TFLOPS = 10^12 FLOPS
// ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
double avg_sec = total_sec / outer_repeat;
double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;

Expand Down
2 changes: 2 additions & 0 deletions kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ int main() {
total_sec += this_sec;
}

// 1 TFLOPS = 10^12 FLOPS
// ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
double avg_sec = total_sec / outer_repeat;
double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;

Expand Down
17 changes: 10 additions & 7 deletions kernels/hgemm/hgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,20 @@ def run_benchmark(perf_func: callable,
torch.cuda.synchronize()

end = time.time()
total_time = (end - start) * 1000 # ms
mean_time = total_time / iters
total_time_secs = (end - start) # ms
mean_time_secs = total_time_secs / iters
out_info = f"{tag}"
out_flat = out.flatten()
out_val_first = out_flat[:2].detach().cpu().numpy().tolist()
out_val_last = out_flat[-2:].detach().cpu().numpy().tolist()
out_val = [out_val_first[0], out_val_last[-1]]
out_val = [round(v, 8) for v in out_val]
out_val = [f"{v:<12}"[:10] for v in out_val]
TFLOPS = (2 * M * N * K) * 1e-9 / (mean_time)
mean_time = str(f"{mean_time:<12}")[:8]
# 1 TFLOPS = 10^12 FLOPS
# ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
TFLOPS = (2 * M * N * K) * 1e-12 / (mean_time_secs)
mean_time_ms = mean_time_secs * 1000
mean_time_ms = str(f"{mean_time_ms:<12}")[:8] # ms
swizzle_stride = 'NOOP' if swizzle_stride == 1 else swizzle_stride

# caculate TFLOPS improved.
Expand All @@ -157,11 +160,11 @@ def run_benchmark(perf_func: callable,
else:
improve = 0
MAX_TFLOPS = TFLOPS
print(f"{out_info:>50}: {out_val}, time:{mean_time}ms, "
print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, "
f"swizzle<block>: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)")
else:
if not only_show_improved or "cublas" in tag:
print(f"{out_info:>50}: {out_val}, time:{mean_time}ms, "
print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, "
f"swizzle<block>: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}")
if show_matrix: print(out)
if args.plot_flops:
Expand All @@ -186,7 +189,7 @@ def run_benchmark(perf_func: callable,
gc.collect()
torch.cuda.empty_cache()
time.sleep(args.sleep_duration)
return out, mean_time
return out, mean_time_ms


def get_topk_tflops():
Expand Down
4 changes: 3 additions & 1 deletion kernels/hgemm/mma/hgemm_mma_stage.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2032,7 +2032,9 @@ int main() {
min_sec = min(min_sec, this_sec);
total_sec += this_sec;
}


// 1 TFLOPS = 10^12 FLOPS
// ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
double avg_sec = total_sec / outer_repeat;
double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;

Expand Down

0 comments on commit 37f1554

Please sign in to comment.