diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 2ace29240b9b..09f23c92b4b2 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -159,6 +159,7 @@ BLOCK_SIZE_K = 32 GROUP_SIZE_M = 8 USE_GPU = False +USE_BLOCK_POINTERS = False @triton.jit @@ -176,6 +177,7 @@ def matmul_kernel( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # + USE_BLOCK_POINTERS: tl.constexpr, # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -193,6 +195,9 @@ def matmul_kernel( group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m + if USE_BLOCK_POINTERS: + block_offset_m = pid_m * BLOCK_SIZE_M + block_offset_n = pid_n * BLOCK_SIZE_N # ---------------------------------------------------------- # Create pointers for the first blocks of A and B. @@ -201,11 +206,29 @@ def matmul_kernel( # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers # See above `Pointer Arithmetic` section for details - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if USE_BLOCK_POINTERS: + a_tile_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(1, 0) + ) + b_tile_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(1, 0) + ) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -217,37 +240,56 @@ def matmul_kernel( # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. - # TODO: Currently masked load is not supported yet. - # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) + if USE_BLOCK_POINTERS: + # TODO: Currently masked load is not supported yet. + a = tl.load(a_tile_ptr, boundary_check=(0, 1)) + b = tl.load(b_tile_ptr, boundary_check=(0, 1)) + else: + # TODO: Currently masked load is not supported yet. + # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) # We accumulate along the K dimension. accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk + if USE_BLOCK_POINTERS: + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) + else: + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk # Convert the accumulator to the output matrix C's type if needed. c = accumulator - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + if USE_BLOCK_POINTERS: + # TODO: masking + c_block_ptr = tl.make_block_ptr( + base=c_ptr, + shape=(M, N), + strides=(stride_cm, stride_cn), + offsets=(block_offset_m, block_offset_n), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0) + ) + tl.store(c_block_ptr, c, boundary_check=(0, 1)) + else: + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - # TODO: Currently masked load is not supported yet. - # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - # tl.store(c_ptrs, c, mask=c_mask) - tl.store(c_ptrs, c) + # TODO: Currently masked load is not supported yet. + # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + # tl.store(c_ptrs, c, mask=c_mask) + tl.store(c_ptrs, c) # %% # We can now create a convenience wrapper function that only takes two input tensors, # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. - - def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" @@ -272,6 +314,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): c.stride(0), c.stride(1), # BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # + USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, # ) return c