Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamk v0.2 #646

Open
wants to merge 24 commits into
base: main_perf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
83f298a
add Lixun's occ.sh to here, will move to proper location later
xiaohuguo2023 Aug 29, 2024
b927376
the v0.2 streamk kernel and tuning script
xiaohuguo2023 Aug 29, 2024
50fc4f6
update readme
xiaohuguo2023 Aug 29, 2024
5711d6e
fix manual k-loop peeling bug and experiment with cache modifier
xiaohuguo2023 Sep 11, 2024
556aa1a
add missing dependencies
xiaohuguo2023 Sep 13, 2024
ec58261
update tune_streamk script
xiaohuguo2023 Sep 13, 2024
64535f7
add wrapper for peopel to test
xiaohuguo2023 Sep 16, 2024
d686de8
tl.load() now doesn't result in a race cond.
neoblizz Sep 18, 2024
b29688f
reduce register usage using tl.multiple_of before tl.load
xiaohuguo2023 Sep 20, 2024
71c5793
change num_sms to compiletime constant
xiaohuguo2023 Sep 25, 2024
72fc9f3
add num_sms to tuning space
xiaohuguo2023 Sep 25, 2024
1838c8f
Merge branch 'main_perf' into streamk_v0.2
xiaohuguo2023 Sep 25, 2024
cca9c8e
add unit test for streamk kernel
xiaohuguo2023 Sep 25, 2024
b7a99c0
add CI tests for streamk kernel
xiaohuguo2023 Sep 25, 2024
7272ed6
fix format issues
xiaohuguo2023 Sep 25, 2024
1edd889
remove unused wrapper and occ.sh
xiaohuguo2023 Sep 25, 2024
e98f789
fix the format issues
xiaohuguo2023 Sep 25, 2024
3275cdd
update README
xiaohuguo2023 Sep 25, 2024
2755b72
fix git issue
xiaohuguo2023 Sep 25, 2024
4ff9faa
more change to make git work
xiaohuguo2023 Sep 25, 2024
8add102
need cache_modifer when load P_
xiaohuguo2023 Oct 4, 2024
bf8cfe8
add back matmul wrapper for now
xiaohuguo2023 Oct 4, 2024
c129a0a
fix wrapper
xiaohuguo2023 Oct 16, 2024
2ac73f6
Merge branch 'main_perf' into streamk_v0.2
xiaohuguo2023 Oct 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/amd_perf_kernel_Integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ jobs:
pytest -vvvv ./python/perf-kernels/softmax.py
pytest -vvv ./python/perf-kernels/rmsnorm.py
pytest -vvv ./python/perf-kernels/layernorm.py
sh ./python/perf-kernels/streamk/utils/unittest.sh
pytest -vvv ./python/perf-kernels/multreduce_matmul_kernel.py

- name: Run Perf Kernels Benchmark
run: |
python ./python/perf-kernels/flash-attention.py
Expand Down
216 changes: 216 additions & 0 deletions python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import torch
import triton
import random

from streamk_kernel import streamk_gemm
#from persistent_loop import streamk_gemm

torch.manual_seed(123)
random.seed(123)

total_sm = 304
print(f"total SMs: {total_sm}")


class matmul(torch.autograd.Function):

_debug = True

@staticmethod
def set_debug(debug: bool):
matmul._debug = debug

@staticmethod
def _call(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.Tensor, P: torch.Tensor,
locks: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, gsize_m: int,
two_tiles: bool, num_stages: int, num_warps: int, waves_per_eu: int, mfmaInstrSize: int, kpack: int):

# assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported"
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape

total_blocks_M = triton.cdiv(M, BLK_M)
total_blocks_N = triton.cdiv(N, BLK_N)
iters_per_tile = triton.cdiv(K, BLK_K)
total_tiles = total_blocks_M * total_blocks_N
even_k = K % BLK_K == 0

if total_programs_streamk > 0: # Stream-K
# last wave may occupy less than total_programs_streamk SMs
total_tiles_streamk = total_tiles % total_programs_streamk
# for two-tile Stream-K + data-parallel from original paper
# if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk:
# total_tiles_streamk += total_programs_streamk
# remaining tiles are computed using classical blocking
total_blocking_tiles = total_tiles - total_tiles_streamk
total_iters_streamk = total_tiles_streamk * iters_per_tile
# iterations related to full waves
total_full_tiles_streamk = total_iters_streamk // total_programs_streamk
# iterations related to last (partial) wave
total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk

else: # all tiles are computed using classical blocking
total_blocking_tiles = total_tiles
total_tiles_streamk = 0
total_full_tiles_streamk = 0
total_partial_tiles_streamk = 0
total_iters_streamk = 0

if matmul._debug:
print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}")
print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}")
print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}")
print(f"{total_programs_streamk=}")
print(f"{total_blocking_tiles=}")
print(f"{total_full_tiles_streamk=}")
print(f"{iters_per_tile=}")
print(f"{total_iters_streamk=}")
print("total_remainder_iters_streamk=", total_partial_tiles_streamk)
use_bias = False
# compute grid (work to do per SM on the first wave)
grids = total_programs_streamk
stride_bias = bias.stride(0) if use_bias else 0
# P=P*0.0
# locks=locks*0
kk = streamk_gemm[(grids, )](
a,
b,
c,
bias,
P,
locks,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
stride_bias,
BLOCK_SIZE_M=BLK_M,
BLOCK_SIZE_N=BLK_N,
BLOCK_SIZE_K=BLK_K,
GROUP_SIZE_M=gsize_m,
NUM_SMS=total_programs_streamk,
BIAS=use_bias,
EVEN_K=even_k,
num_stages=num_stages,
num_warps=num_warps,
waves_per_eu=waves_per_eu,
matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack,
)
if matmul._debug:
print(f"{kk.n_regs} registers used, {kk.n_spills} spills")

# print(kk.asm['ttgir'])
# print(kk.asm['amdgcn'])

return c

@staticmethod
def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.Tensor, P: torch.Tensor,
locks: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, gsize_m=1, two_tiles=True, num_stages=3,
num_warps=4, waves_per_eu=2, mfmaInstrSize=16, kpack=1):
matmul._call(a=a, b=b, c=c, bias=bias, P=P, locks=locks, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N,
BLK_K=BLK_K, gsize_m=gsize_m, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages,
waves_per_eu=waves_per_eu, mfmaInstrSize=mfmaInstrSize, kpack=kpack)
return c


# ---------------------------------------------------------------------------
# Example and Benchmark
# ---------------------------------------------------------------------------

perf = lambda ms: 2 * m * n * k * 1e-12 / (ms * 1e-3)

#m, n, k = 4864, 4096, 8256 # some problem size to test
#m, n, k = 4096, 4096, 8192 # some problem size to test
#m, n, k = 1, 1024, 256
#m, n, k = 8133, 8132, 8172 # some problem size to test
#m, n, k = 8192, 8192, 8192 # some problem size to test
#m, n, k = 8128, 6878, 7378 # some problem size to test
#m, n, k = 8192, 4864, 6878 # some problem size to test
#m, n, k = 512, 512, 512 # some problem size to test
#m, n, k = 6912, 768, 256 # some problem size to test
#m, n, k =4864, 8192, 4160 # some problem size to test
#m, n, k = 5632, 6656, 7936
m, n, k = 4864, 4096, 4300

A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(n, k, device="cuda", dtype=torch.float16).T
# allocates output
C = torch.zeros((m, n), device="cuda", dtype=A.dtype)
bias = torch.zeros((m, ), device="cuda", dtype=A.dtype)
#bias = None
BLK_M = 256
BLK_N = 256
BLK_K = 64
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
gsize_m = 8
two_tiles = 'True'
num_stages = 2
num_warps = 8
waves_per_eu = 0
mfmaInstrSize = 16
kpack = 2

##for total_sm in range(1, 305):
## print(f"{total_sm=}")
## matmul.set_debug(True)
## locks = torch.zeros((total_sm,), device = "cuda", dtype = torch.int32)
## P = torch.zeros((total_sm, BLK_M*BLK_N), device="cuda", dtype=torch.float32)
## C = matmul.apply(A, B, C, P, locks, total_sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack)
## #exit(0)
## matmul.set_debug(False)
## expected = A @ B
##
## assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"
## print("pass validation test")
## triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, C, P, locks, total_sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))
## print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

#for total_sm in range(1, 305):
print(f"{total_sm=}")
matmul.set_debug(True)
locks = torch.zeros((total_sm, ), device="cuda", dtype=torch.int32)
P = torch.zeros((total_sm, BLK_M * BLK_N), device="cuda", dtype=torch.float32)
C = matmul.apply(A, B, C, bias, P, locks, total_sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages, num_warps,
waves_per_eu, mfmaInstrSize, kpack)
#exit(0)
matmul.set_debug(False)
expected = A @ B

#assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}"
print("pass validation test")

# for debugging, uncomment the following line
#exit(0)

triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B))
print(f"PyTorch: {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

locks = torch.zeros((total_sm, ), device="cuda", dtype=torch.int32)
P = torch.zeros((total_sm, BLK_M * BLK_N), device="cuda", dtype=torch.float32)
triton_ms = triton.testing.do_bench(
lambda: matmul.apply(A, B, C, bias, P, locks, total_sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages,
num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"hybrid stream-k (grid={total_sm}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

locks = torch.zeros((total_sm * 2, ), device="cuda", dtype=torch.int32)
P = torch.zeros((total_sm * 2, BLK_M * BLK_N), device="cuda", dtype=torch.float32)
triton_ms = triton.testing.do_bench(
lambda: matmul.apply(A, B, C, bias, P, locks, total_sm * 2, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages,
num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"hybrid stream-k (grid={total_sm * 2}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")

triton_ms = triton.testing.do_bench(
lambda: matmul.apply(A, B, C, bias, P, locks, total_tiles, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages,
num_warps, waves_per_eu, mfmaInstrSize, kpack))
print(f"tile matmul (grid={total_tiles}): {triton_ms:.3f} ms {perf(triton_ms):.3f} tflops")
30 changes: 30 additions & 0 deletions python/perf-kernels/streamk/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
# streamk gemm script v0.2

### features added:

- new streamk tuning script to reduce compiling and profiling time

- use load/store cache modifier to reimplement spinning lock

- add CI test for streamk-kernel

- able to use streampipelineV2

### potential issues:

- there may be hanging issue when use random grid sizes
- large register spills when using tile size 256x256x64

### tuning command

```
TRITON_HIP_USE_NEW_STREAM_PIPELINE=1 python tune_streamk.py --gemm_size_file input_nn_size.yaml --ngpus 8 --jobs 24
```

### calculate occ

```
TRITON_HIP_USE_NEW_STREAM_PIPELINE=1 ../../occ.sh "python tune_streamk.py --gemm_size_file single_item.yaml --compare_wo_tuning"
```


# streamk gemm script v0.1

The plan is to use this version as the base version for the future triton streamk gemm development.
Expand Down
Loading
Loading