Skip to content

Commit

Permalink
Matrix multiplication tutorial block pointer variant (triton-lang#1)
Browse files Browse the repository at this point in the history
Adds a `USE_BLOCK_POINTER` flag to the matmul_kernel so we can get IR for pointers-to-tensors instead of tensors-of-pointers.
  • Loading branch information
rolfmorel authored and Devjiu committed Nov 13, 2024
1 parent 7d16b21 commit eb8fbdf
Showing 1 changed file with 66 additions and 21 deletions.
87 changes: 66 additions & 21 deletions python/tutorials/03-matrix-multiplication-cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
BLOCK_SIZE_K = 32
GROUP_SIZE_M = 8
USE_GPU = False
USE_BLOCK_POINTERS = False


@triton.jit
Expand All @@ -176,6 +177,7 @@ def matmul_kernel(
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
USE_BLOCK_POINTERS: tl.constexpr, #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Expand All @@ -193,6 +195,9 @@ def matmul_kernel(
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
if USE_BLOCK_POINTERS:
block_offset_m = pid_m * BLOCK_SIZE_M
block_offset_n = pid_n * BLOCK_SIZE_N

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
Expand All @@ -201,11 +206,29 @@ def matmul_kernel(
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
if USE_BLOCK_POINTERS:
a_tile_ptr = tl.make_block_ptr(
base=a_ptr,
shape=(M, K),
strides=(stride_am, stride_ak),
offsets=(block_offset_m, 0),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
order=(1, 0)
)
b_tile_ptr = tl.make_block_ptr(
base=b_ptr,
shape=(K, N),
strides=(stride_bk, stride_bn),
offsets=(0, block_offset_n),
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
order=(1, 0)
)
else:
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
Expand All @@ -217,30 +240,51 @@ def matmul_kernel(
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.

# TODO: Currently masked load is not supported yet.
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
if USE_BLOCK_POINTERS:
# TODO: Currently masked load is not supported yet.
a = tl.load(a_tile_ptr, boundary_check=(0, 1))
b = tl.load(b_tile_ptr, boundary_check=(0, 1))
else:
# TODO: Currently masked load is not supported yet.
# a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if USE_BLOCK_POINTERS:
a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K])
b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0])
else:
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Convert the accumulator to the output matrix C's type if needed.
c = accumulator

# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
if USE_BLOCK_POINTERS:
# TODO: masking
c_block_ptr = tl.make_block_ptr(
base=c_ptr,
shape=(M, N),
strides=(stride_cm, stride_cn),
offsets=(block_offset_m, block_offset_n),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(1, 0)
)
tl.store(c_block_ptr, c, boundary_check=(0, 1))
else:
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]

# TODO: Currently masked load is not supported yet.
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)
# TODO: Currently masked load is not supported yet.
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
# tl.store(c_ptrs, c, mask=c_mask)
tl.store(c_ptrs, c)


# %%
Expand Down Expand Up @@ -273,6 +317,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0):
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, #
GROUP_SIZE_M=GROUP_SIZE_M, #
num_threads=num_threads, #
USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, #
)
return c

Expand Down

0 comments on commit eb8fbdf

Please sign in to comment.