-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Contribute higher utilization matrix multiplication (#23)
Contribute higher utilization matrix multiplication
- Loading branch information
Showing
1 changed file
with
127 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from neuronxcc import nki | ||
import neuronxcc.nki.language as nl | ||
import neuronxcc.nki.isa as ni | ||
import numpy as np | ||
|
||
def matmul(A_DRAM, B_DRAM, Z_DRAM, TILES_IN_BLOCK_K=8, TILES_IN_BLOCK_M=4, TILES_IN_BLOCK_N=4): | ||
""" | ||
Optimized matrix multiplication kernel | ||
Args: | ||
A_DRAM: an input tensor of shape [K, M], where K is a multiple of 1024 | ||
and M is a multiple of 512. It is the left-hand-side argument of the | ||
matrix multiplication, delivered transposed for optimal performance. | ||
B_DRAM: an input tensor of shape [K, N], where K is a multiple of 1024 | ||
and N is a multiple of 2048. It is the right-hand-side argument of | ||
the matrix multiplication. | ||
Z_DRAM: the resulting output tensor of shape [M, N] | ||
""" | ||
K, M = A_DRAM.shape | ||
_, N = B_DRAM.shape | ||
|
||
TILE_K = nl.tile_size.pmax | ||
TILE_M = nl.tile_size.gemm_stationary_fmax | ||
TILE_N = nl.tile_size.gemm_moving_fmax | ||
|
||
NUM_BLOCK_K = K // (TILES_IN_BLOCK_K * TILE_K) | ||
NUM_BLOCK_M = M // (TILES_IN_BLOCK_M * TILE_M) | ||
NUM_BLOCK_N = N // (TILES_IN_BLOCK_N * TILE_N) | ||
|
||
assert NUM_BLOCK_K * TILES_IN_BLOCK_K * TILE_K == K | ||
assert NUM_BLOCK_M * TILES_IN_BLOCK_M * TILE_M == M | ||
assert NUM_BLOCK_N * TILES_IN_BLOCK_N * TILE_N == N | ||
|
||
for n2 in nl.affine_range(NUM_BLOCK_N): | ||
for m2 in nl.affine_range(NUM_BLOCK_M): | ||
|
||
# Partition Z and then ensure that we are Z-block stationary | ||
# This way, no matter how large K, M, and N are, Z is never spilled/loaded | ||
# We only need to store once | ||
Z_SBUF = nl.zeros((TILES_IN_BLOCK_M, nl.par_dim(TILE_M), TILES_IN_BLOCK_N * TILE_N), dtype=Z_DRAM.dtype, buffer=nl.sbuf) | ||
|
||
for k2 in nl.affine_range(NUM_BLOCK_K): | ||
A_SBUF = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), TILES_IN_BLOCK_M * TILE_M), dtype=A_DRAM.dtype, buffer=nl.sbuf) | ||
B_SBUF = nl.ndarray((TILES_IN_BLOCK_K, nl.par_dim(TILE_K), TILES_IN_BLOCK_N * TILE_N), dtype=B_DRAM.dtype, buffer=nl.sbuf) | ||
|
||
# Load in a block of A and a block of B | ||
for k1 in nl.affine_range(TILES_IN_BLOCK_K): | ||
k_start = k2 * TILES_IN_BLOCK_K * TILE_K + k1 * TILE_K | ||
k_end = k_start + TILE_K | ||
|
||
m_start = m2 * TILES_IN_BLOCK_M * TILE_M | ||
m_end = m_start + TILES_IN_BLOCK_M * TILE_M | ||
|
||
n_start = n2 * TILES_IN_BLOCK_N * TILE_N | ||
n_end = n_start + TILES_IN_BLOCK_N * TILE_N | ||
|
||
# We coalesce memory accesses by loading TILES_IN_BLOCK_M * TILE_M | ||
# values of A at a time. We cannot coalesce across K because K gets | ||
# split across the partition dimension | ||
A_SBUF[k1] = nl.load(A_DRAM[k_start:k_end, m_start:m_end]) | ||
|
||
# We coalesce memory accesses by loading TILES_IN_BLOCK_N * TILE_N | ||
# values of B at a time. We cannot coalesce across K because K gets | ||
# split across the partition dimension | ||
B_SBUF[k1] = nl.load(B_DRAM[k_start:k_end, n_start:n_end]) | ||
|
||
for m1 in nl.affine_range(TILES_IN_BLOCK_M): | ||
for n1 in nl.affine_range(TILES_IN_BLOCK_N): | ||
# Keep the tile of Z stationary in the PSUM buffer to minimize the | ||
# number of calls to nl.loop_reduce | ||
Z_PSUM = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum) | ||
|
||
m_start = m1 * TILE_M | ||
m_end = m_start + TILE_M | ||
|
||
n_start = n1 * TILE_N | ||
n_end = n_start + TILE_N | ||
|
||
for k1 in nl.affine_range(TILES_IN_BLOCK_K): | ||
Z_PSUM += ni.nc_matmul(A_SBUF[k1, :, m_start:m_end], B_SBUF[k1, :, n_start:n_end]) | ||
|
||
Z_SBUF[m1, :, n_start:n_end] = nl.loop_reduce(Z_PSUM, op=np.add, loop_indices=[k2], dtype=Z_DRAM.dtype) | ||
|
||
for m1 in nl.affine_range(TILES_IN_BLOCK_M): | ||
m_start = m2 * TILES_IN_BLOCK_M * TILE_M + m1 * TILE_M | ||
m_end = m_start + TILE_M | ||
|
||
n_start = n2 * TILES_IN_BLOCK_N * TILE_N | ||
n_end = n_start + TILES_IN_BLOCK_N * TILE_N | ||
|
||
# We coalesce memory accesses by storing TILES_IN_BLOCK_N * TILE_N | ||
# values of Z at a time. We cannot coalesce across M because M gets | ||
# split across the partition dimension | ||
nl.store(Z_DRAM[m_start:m_end, n_start:n_end], value=Z_SBUF[m1]) | ||
|
||
def check_correct(): | ||
K, M, N = 1024, 4096, 2048 | ||
A = np.random.random_sample([K, M]).astype(np.float16) | ||
B = np.random.random_sample([K, N]).astype(np.float16) | ||
Z = np.ndarray(shape=[M, N], dtype=np.float16) | ||
|
||
baremetal_func = nki.baremetal()(matmul) | ||
baremetal_func(A, B, Z) | ||
|
||
Z_corr = A.T @ B | ||
|
||
print("Is close?", np.all(np.isclose(Z, Z_corr, atol=1e-4, rtol=1e-2))) | ||
|
||
def benchmark_kernel(): | ||
K, M, N = 8192, 4096, 8192 | ||
A = np.random.random_sample([K, M]).astype(np.float16) | ||
B = np.random.random_sample([K, N]).astype(np.float16) | ||
Z = np.ndarray(shape=[M, N], dtype=np.float16) | ||
|
||
benchmark_func = nki.benchmark(warmup=5, iters=10)(matmul) | ||
benchmark_func(A, B, Z) | ||
|
||
def main(): | ||
check_correct() | ||
benchmark_kernel() | ||
|
||
if __name__ == "__main__": | ||
main() |