-
Notifications
You must be signed in to change notification settings - Fork 8
/
matmul.py
127 lines (94 loc) · 4.8 KB
/
matmul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from neuronxcc import nki
import neuronxcc.nki.language as nl
import neuronxcc.nki.isa as ni
import numpy as np
def matmul(A_DRAM, B_DRAM, Z_DRAM, TILES_IN_BLOCK_K=8, TILES_IN_BLOCK_M=4, TILES_IN_BLOCK_N=4):
"""
Optimized matrix multiplication kernel
Args:
A_DRAM: an input tensor of shape [K, M], where K is a multiple of 1024
and M is a multiple of 512. It is the left-hand-side argument of the
matrix multiplication, delivered transposed for optimal performance.
B_DRAM: an input tensor of shape [K, N], where K is a multiple of 1024
and N is a multiple of 2048. It is the right-hand-side argument of
the matrix multiplication.
Z_DRAM: the resulting output tensor of shape [M, N]
"""
K, M = A_DRAM.shape
_, N = B_DRAM.shape
TILE_K = nl.tile_size.pmax
TILE_M = nl.tile_size.gemm_stationary_fmax
TILE_N = nl.tile_size.gemm_moving_fmax
NUM_BLOCK_K = K // (TILES_IN_BLOCK_K * TILE_K)
NUM_BLOCK_M = M // (TILES_IN_BLOCK_M * TILE_M)
NUM_BLOCK_N = N // (TILES_IN_BLOCK_N * TILE_N)
assert NUM_BLOCK_K * TILES_IN_BLOCK_K * TILE_K == K
assert NUM_BLOCK_M * TILES_IN_BLOCK_M * TILE_M == M
assert NUM_BLOCK_N * TILES_IN_BLOCK_N * TILE_N == N
for n2 in nl.affine_range(NUM_BLOCK_N):
for m2 in nl.affine_range(NUM_BLOCK_M):
# Partition Z and then ensure that we are Z-block stationary
# This way, no matter how large K, M, and N are, Z is never spilled/loaded
# We only need to store once
Z_SBUF = nl.zeros((TILES_IN_BLOCK_M, nl.par_dim(TILE_M), TILES_IN_BLOCK_N * TILE_N), dtype=Z_DRAM.dtype, buffer=nl.sbuf)
for k2 in nl.affine_range(NUM_BLOCK_K):
A_SBUF = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), TILES_IN_BLOCK_M * TILE_M), dtype=A_DRAM.dtype, buffer=nl.sbuf)
B_SBUF = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), TILES_IN_BLOCK_N * TILE_N), dtype=B_DRAM.dtype, buffer=nl.sbuf)
# Load in a block of A and a block of B
for k1 in nl.affine_range(TILES_IN_BLOCK_K):
k_start = k2 * TILES_IN_BLOCK_K * TILE_K + k1 * TILE_K
k_end = k_start + TILE_K
m_start = m2 * TILES_IN_BLOCK_M * TILE_M
m_end = m_start + TILES_IN_BLOCK_M * TILE_M
n_start = n2 * TILES_IN_BLOCK_N * TILE_N
n_end = n_start + TILES_IN_BLOCK_N * TILE_N
# We coalesce memory accesses by loading TILES_IN_BLOCK_M * TILE_M
# values of A at a time. We cannot coalesce across K because K gets
# split across the partition dimension
A_SBUF[k1] = nl.load(A_DRAM[k_start:k_end, m_start:m_end])
# We coalesce memory accesses by loading TILES_IN_BLOCK_N * TILE_N
# values of B at a time. We cannot coalesce across K because K gets
# split across the partition dimension
B_SBUF[k1] = nl.load(B_DRAM[k_start:k_end, n_start:n_end])
for m1 in nl.affine_range(TILES_IN_BLOCK_M):
for n1 in nl.affine_range(TILES_IN_BLOCK_N):
# Keep the tile of Z stationary in the PSUM buffer to minimize the
# number of calls to nl.loop_reduce
Z_PSUM = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum)
m_start = m1 * TILE_M
m_end = m_start + TILE_M
n_start = n1 * TILE_N
n_end = n_start + TILE_N
for k1 in nl.affine_range(TILES_IN_BLOCK_K):
Z_PSUM += ni.nc_matmul(A_SBUF[k1, :, m_start:m_end], B_SBUF[k1, :, n_start:n_end])
Z_SBUF[m1, :, n_start:n_end] = nl.loop_reduce(Z_PSUM, op=np.add, loop_indices=[k2], dtype=Z_DRAM.dtype)
for m1 in nl.affine_range(TILES_IN_BLOCK_M):
m_start = m2 * TILES_IN_BLOCK_M * TILE_M + m1 * TILE_M
m_end = m_start + TILE_M
n_start = n2 * TILES_IN_BLOCK_N * TILE_N
n_end = n_start + TILES_IN_BLOCK_N * TILE_N
# We coalesce memory accesses by storing TILES_IN_BLOCK_N * TILE_N
# values of Z at a time. We cannot coalesce across M because M gets
# split across the partition dimension
nl.store(Z_DRAM[m_start:m_end, n_start:n_end], value=Z_SBUF[m1])
def check_correct():
K, M, N = 1024, 4096, 2048
A = np.random.random_sample([K, M]).astype(np.float16)
B = np.random.random_sample([K, N]).astype(np.float16)
Z = np.ndarray(shape=[M, N], dtype=np.float16)
baremetal_func = nki.baremetal()(matmul)
baremetal_func(A, B, Z)
Z_corr = A.T @ B
print("Is close?", np.all(np.isclose(Z, Z_corr, atol=1e-4, rtol=1e-2)))
def benchmark_kernel():
K, M, N = 8192, 4096, 8192
A = np.random.random_sample([K, M]).astype(np.float16)
B = np.random.random_sample([K, N]).astype(np.float16)
Z = np.ndarray(shape=[M, N], dtype=np.float16)
benchmark_func = nki.benchmark(warmup=5, iters=10)(matmul)
benchmark_func(A, B, Z)
def main():
check_correct()
benchmark_kernel()
if __name__ == "__main__":
main()