diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index f33432645cc3..5fc42392c337 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -651,6 +651,33 @@ TVM_DLL const Op& ptx_cp_async(); TVM_DLL const Op& ptx_commit_group(); TVM_DLL const Op& ptx_wait_group(); +/*! + * \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer. + * For example, if each thread in a warp of size 32 has 4 elements from the result of + * m16xn8xk16 MMA in its registers, this intrinsic can be used to store the result in a + * 16x8 region in shared or global memory. + * + * There is no real PTX instruction that does that, but we want to hide details of + * complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g. + * LowerWarpMemory). + * + * void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var dst_stride); + */ +TVM_DLL const Op& mma_store(); + +/*! + * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor. + * For example, if each thread in a warp of size 32 has 8 elements from the A matrix in + * m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its + * 4 accumulation registers. + * + * There is no real PTX instruction that does that, but we introduce this intrinsic for the + * same reason as mma_store above. + * + * void mma_fill(IntImm local_size, Var local_ptr, Expr offset); + */ +TVM_DLL const Op& mma_fill(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 4115c3b90070..a3b47ff6d5d7 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -20,3 +20,4 @@ from .arm_cpu import * from .dot_product_common import * from .rocm import * +from .cuda import * diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py new file mode 100644 index 000000000000..853a37735486 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -0,0 +1,469 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Intrinsics for tensorization on NVIDIA GPU.""" +from tvm.script import tir as T +from .. import IntImm, Cast +from ..._ffi import register_func +from ...runtime import convert +from .. import TensorIntrin + + +def shared_16x16_to_ldmatrix_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) + + +def shared_16x32_to_ldmatrix_32x16_layout(i, j): + thread_id = 4 * (i % 8) + (j % 16) // 4 + return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4 + + +def shared_32x16_to_ldmatrix_32x16_layout(i, j): + thread_id = (i % 4) + 4 * (j % 8) + return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 + + +@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") +def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): + i, j = ind[0], ind[1] + thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j) + return convert([thread_id, local_id]) + + +lift = convert + +M_DIM = 16 +N_DIM = 16 +WARP_SIZE = 32 +HALF_WARP = WARP_SIZE // 2 +HALF_WARP_expr = lift(HALF_WARP) + + +def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): + local_size = (M_DIM * k_dim) // WARP_SIZE + shared_offset = None + index_map = None + + if transposed: + assert is_b, "Transposed A matrix not supported" + + ldmatrix_col_major = is_b and not transposed + + if k_dim == 16: + assert dtype == "float16" + + index_map = shared_16x16_to_ldmatrix_32x8_layout + + if transposed: + shared_offset = ( + lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr) + + stride * (tx % 8) + + 8 * ((tx % HALF_WARP_expr) // 8) + ) + else: + shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * ( + tx // HALF_WARP_expr + ) + else: + assert ( + k_dim == 32 and dtype == "int8" + ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" + + if ldmatrix_col_major: + index_map = shared_32x16_to_ldmatrix_32x16_layout + # A dummy offset, ldmatrix cannot be used for int8 + trans case. + # We still use the ldmatrix intrinsic, but lower it to a manual loop in the codegen. + # Only the stride information is required. + shared_offset = lambda _, stride: stride + elif is_b and transposed: + index_map = shared_16x32_to_ldmatrix_32x16_layout + shared_offset = ( + lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr) + + (tx % 8) * stride + + 16 * ((tx % HALF_WARP_expr) // 8) + ) + else: + index_map = shared_16x32_to_ldmatrix_32x16_layout + shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16) + + assert index_map and shared_offset + + if is_b and not transposed: + row_dim = k_dim + col_dim = M_DIM + else: + row_dim = M_DIM + col_dim = k_dim + + shmem_shape = (row_dim, col_dim) + + @T.prim_func + def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: + shared = T.match_buffer( + shared_handle, shmem_shape, dtype, align=128, offset_factor=16, scope="shared" + ) + warp = T.match_buffer( + warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads(shared[0:row_dim, 0:col_dim]) + T.writes(warp[0:WARP_SIZE, 0:local_size]) + + for ax0, ax1 in T.grid(row_dim, col_dim): + with T.block("shared_warp"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(shared[v0, v1]) + + thread_id, local_id = index_map(v0, v1) + T.writes(warp[thread_id, local_id]) + warp[thread_id, local_id] = shared[v0, v1] + + @T.prim_func + def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + shared = T.match_buffer( + shared_handle, + shmem_shape, + dtype, + align=128, + offset_factor=16, + scope="shared", + strides=[s0, s1], + ) + warp = T.match_buffer( + warp_handle, (WARP_SIZE, local_size), dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads(shared[0:row_dim, 0:col_dim]) + T.writes(warp[0:WARP_SIZE, 0:local_size]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.ptx_ldmatrix( + ldmatrix_col_major, + 4, # Always load 4 matrices + ".b16", + warp.data, + warp.elem_offset + lift(local_size) * tx, + shared.access_ptr("r"), + shared_offset(tx, s0), + dtype=dtype, + ) + ) + + return ldmatrix_desc, ldmatrix_impl + + +def get_mma_intrin(k_dim, out_dtype, b_transposed): + local_size = (M_DIM * k_dim) // WARP_SIZE + local_size_out = (M_DIM * N_DIM) // 32 + + index_map_C = shared_16x16_to_ldmatrix_32x8_layout + + if k_dim == 16: + index_map_A = shared_16x16_to_ldmatrix_32x8_layout + index_map_B = shared_16x16_to_ldmatrix_32x8_layout + mma_prefix = "m16n8k16" + elif k_dim == 32 and b_transposed: + index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout + mma_prefix = "m16n8k32" + elif k_dim == 32 and not b_transposed: + index_map_A = shared_16x32_to_ldmatrix_32x16_layout + index_map_B = shared_32x16_to_ldmatrix_32x16_layout + mma_prefix = "m16n8k32" + else: + assert False + + out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype] + + if out_dtype in ["float16", "float32"]: + in_dtype = "float16" + in_dtype_abbrv = "fp16" + else: + in_dtype = "int8" + in_dtype_abbrv = "int8" + + def maybe_cast(v): + if out_dtype in ["float32", "int32"]: + return Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + @T.prim_func + def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + B = T.match_buffer( + b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + + for i, j, k in T.grid(M_DIM, N_DIM, k_dim): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + b_row_ind, b_col_ind = maybe_swap(k, j) + + thread_id_C, local_id_C = index_map_C(i, j) + thread_id_A, local_id_A = index_map_A(i, k) + thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind) + + T.reads( + C[thread_id_C, local_id_C], + A[thread_id_A, local_id_A], + B[thread_id_B, local_id_B], + ) + T.writes(C[thread_id_C, local_id_C]) + + C[thread_id_C, local_id_C] += maybe_cast( + A[thread_id_A, local_id_A] + ) * maybe_cast(B[thread_id_B, local_id_B]) + + @T.prim_func + def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer( + a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + B = T.match_buffer( + b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" + ) + C = T.match_buffer( + c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" + ) + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size), + C.data, + C.elem_offset + tx * lift(local_size_out), + False, + dtype=out_dtype, + ) + ) + + T.evaluate( + T.ptx_mma( + mma_prefix, + "row", + "col", + in_dtype_abbrv, + in_dtype_abbrv, + out_dtype_abbrv, + A.data, + A.elem_offset + tx * lift(local_size), + B.data, + B.elem_offset + tx * lift(local_size) + lift(local_size) // 2, + C.data, + C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2, + False, + dtype=out_dtype, + ) + ) + + return mma_sync_desc, mma_sync_impl + + +def get_mma_fill_intrin(dtype, local_size): + zero = IntImm("int32", 0).astype(dtype) + + # Assume M = N = 16 + index_map = shared_16x16_to_ldmatrix_32x8_layout + + @T.prim_func + def mma_fill_desc(a: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + i, j = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = index_map(i, j) + T.reads() + T.writes(C_warp[thread_id, local_id]) + C_warp[thread_id, local_id] = zero + + @T.prim_func + def mma_fill_impl(a: T.handle) -> None: + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype)) + + return mma_fill_desc, mma_fill_impl + + +def get_mma_store_intrin(dtype, local_size, scope="global"): + # Assume M = N = 16 + index_map = shared_16x16_to_ldmatrix_32x8_layout + + @T.prim_func + def mma_store_desc(a: T.handle, c: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = index_map(v0, v1) + T.reads(C_warp[thread_id, local_id]) + T.writes(C[v0, v1]) + C[v0, v1] = C_warp[thread_id, local_id] + + @T.prim_func + def mma_store_impl(a: T.handle, c: T.handle) -> None: + s0 = T.var("int32") + s1 = T.var("int32") + + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + C = T.match_buffer( + c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1] + ) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.mma_store( + M_DIM, + N_DIM, + C.access_ptr("w"), + C_warp.data, + C_warp.elem_offset, + s0, + dtype=dtype, + ) + ) + + return mma_store_desc, mma_store_impl + + +LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a" +TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False)) + +LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b" +TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, "float16", True, False)) + +LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans" +TensorIntrin.register( + LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, True) +) + +LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a" +TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32, "int8", False, False)) + +LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b" +TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32, "int8", True, False)) + +LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans" +TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True)) + +MMA_f16f16f32_INTRIN = "mma_f16f16f32" +TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False)) + +MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans" +TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, "float32", True)) + +MMA_f16f16f16_INTRIN = "mma_f16f16f16" +TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False)) + +MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans" +TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, "float16", True)) + +MMA_i8i8i32_INTRIN = "mma_i8i8i32" +TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False)) + +MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans" +TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True)) + +MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32" +TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 8)) + +MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16" +TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, *get_mma_fill_intrin("float16", 8)) + +MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32" +TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 8)) + +MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_" +TensorIntrin.register( + MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, "global") +) + +MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_" +TensorIntrin.register( + MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, "global") +) + +MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_" +TensorIntrin.register( + MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global") +) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 7459d4c250ba..616e75f2e776 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -818,9 +819,78 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string local_ptr = this->PrintExpr(op->args[3]); std::string local_elem_offset = this->PrintExpr(op->args[4]); std::string smem_ptr = this->PrintExpr(op->args[5]); - std::string smem_elem_offset = this->PrintExpr(op->args[6]); - this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, - smem_ptr, smem_elem_offset); + if (trans && op->dtype.bits() == 8) { + // Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an + // int8 matrix. + std::string smem_stride = this->PrintExpr(op->args[6]); + ICHECK(num == 4); + os << "for (int i = 0; i < 16; ++i) {\n"; + os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr + << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * " + smem_stride + + "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8) * 8];\n"; + os << "}\n"; + } else { + std::string smem_elem_offset = this->PrintExpr(op->args[6]); + this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, + smem_ptr, smem_elem_offset); + } + } else if (op->op.same_as(builtin::mma_store())) { + int m = Downcast(op->args[0])->value; + int n = Downcast(op->args[1])->value; + std::string dst = this->PrintExpr(op->args[2]); + std::string src = this->PrintExpr(op->args[3]); + std::string src_offset = this->PrintExpr(op->args[4]); + PrimExpr stride = op->args[5]; + + ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; + + // Each thread in a warp holds a certain number of elements of an MMA output. + // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements + // in its registers. So conceptually, a warp memory is organized as a 32x8 block. + // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. + + // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this map + // to determine the output location for each 8 element. + + const auto* index_map_func = + runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); + ICHECK(index_map_func); + + auto inverse_index_map = + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}); + auto indices_16x16 = inverse_index_map->final_indices; + + // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine. + // FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually replace them + // to the plain ones here. + class LowerFloorDivMod : public ExprMutator { + public: + PrimExpr VisitExpr_(const FloorDivNode* op) { + return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + PrimExpr VisitExpr_(const FloorModNode* op) { + return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + }; + + auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); + + var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; + var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; + + os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "]" + << " = " << src << "[" << src_offset << " + local_id];\n"; + os << "}\n"; + + } else if (op->op.same_as(builtin::mma_fill())) { + std::string num_elem = this->PrintExpr(op->args[0]); + std::string dst = this->PrintExpr(op->args[1]); + std::string dst_offset = this->PrintExpr(op->args[2]); + + os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; + os << dst << "[" << dst_offset << " + i] = 0.0;"; + os << "}\n"; } else if (op->op.same_as(builtin::ptx_cp_async())) { std::string dst = this->PrintExpr(op->args[0]); std::string dst_offset = this->PrintExpr(op->args[1]); diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0415d1bbec9e..1871a3d7bf70 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -256,6 +256,12 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(mma_store).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(mma_fill).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 40971114d416..d8250cd09888 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -101,7 +101,7 @@ namespace tir { // Visitor to find m in pattern // store warp_mem[m * warp_index + (width * m) * y + x] -class WarpStoreCoeffFinder : private StmtVisitor { +class WarpStoreCoeffFinder : private StmtExprVisitor { public: WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} @@ -113,6 +113,18 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as() == buffer_) { + UpdatePattern(op->args[4]); + } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as() == buffer_) { + auto* local_size = op->args[0].as(); + ICHECK(local_size) << "Integer expected for the first argument of mma_fill"; + warp_coeff_ = local_size->value; + } + + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; } @@ -245,6 +257,37 @@ class WarpAccessRewriter : protected StmtExprMutator { } protected: + PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector& indices) { + Array new_args = op->args; + for (int i : indices) { + if (op->args[i].get() == buffer_) { + PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first; + new_args.Set(i + 1, local_index); + } + } + return Call(op->dtype, op->op, new_args); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::ptx_mma())) { + return RewriteIndicesAt(op, {6, 8, 10}); + } + + if (op->op.same_as(builtin::ptx_ldmatrix())) { + return RewriteIndicesAt(op, {3}); + } + + if (op->op.same_as(builtin::mma_store())) { + return RewriteIndicesAt(op, {3}); + } + + if (op->op.same_as(builtin::mma_fill())) { + return RewriteIndicesAt(op, {1}); + } + + return StmtExprMutator::VisitExpr_(op); + } + PrimExpr VisitExpr_(const VarNode* op) override { ICHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py new file mode 100644 index 000000000000..67e8ae0ad836 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -0,0 +1,422 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm +from tvm import te +from tvm.tir.tensor_intrin.cuda import ( + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_INTRIN, + LDMATRIX_16x16_B_TRANS_INTRIN, + LDMATRIX_16x32_A_INTRIN, + LDMATRIX_32x16_B_INTRIN, + LDMATRIX_16x32_B_TRANS_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_f16f16f32_TRANS_INTRIN, + MMA_f16f16f16_INTRIN, + MMA_f16f16f16_TRANS_INTRIN, + MMA_i8i8i32_INTRIN, + MMA_i8i8i32_TRANS_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + shared_16x16_to_ldmatrix_32x8_layout, + shared_32x16_to_ldmatrix_32x16_layout, + shared_16x32_to_ldmatrix_32x16_layout, +) +import tvm.testing +import numpy as np + + +M = 4096 +N = 4096 +K = 4096 +measure_perf = False +gflops = (N * M * K) * 2 / 1e9 + + +def matmul(m, n, k, in_dtype, out_dtype, b_transposed): + b_shape = (n, k) if b_transposed else (k, n) + a = te.placeholder((m, k), name="A", dtype=in_dtype) + b = te.placeholder(b_shape, name="B", dtype=in_dtype) + k = te.reduce_axis((0, k), name="k") + + def maybe_cast(v): + if in_dtype != out_dtype: + return tvm.tir.Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + c = te.compute( + (m, n), + lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]), + name="C", + ) + return (a, b, c) + + +def is_ampere_or_newer(): + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + return major >= 8 + + +def run_test( + k_inner, + in_dtype, + out_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, +): + workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)) + ir_module = tvm.IRModule({"main": workload}) + sch = tvm.tir.Schedule(ir_module) + + block = sch.get_block("C") + i, j, k = sch.get_loops(block) + i, i_tc = sch.split(i, factors=[None, 16]) + j, j_tc = sch.split(j, factors=[None, 16]) + k, k_tc = sch.split(k, factors=[None, k_inner]) + + sch.reorder(i, j, k, i_tc, j_tc, k_tc) + + block_inner = sch.blockize(i_tc) + block_outer, block_inner = block_inner, block + + num_ty = i_factors[2] * j_factors[2] + + i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors) + j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors) + k0, k1, k2 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared") + sch.compute_at(block_read, k0) + vector_size = 16 if in_dtype == "int8" else 8 + warp_size = 32 + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + offset = 8 if in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset) + + return block_read + + fetch_to_shared(block_outer, 0, 2) + fetch_to_shared(block_outer, 1, 2) + + A_warp = sch.cache_read(block_outer, 0, "warp") + B_warp = sch.cache_read(block_outer, 1, "warp") + + sch.compute_at(A_warp, k1) + sch.compute_at(B_warp, k1) + + C_warp = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(C_warp, thread_idy) + + ii, jj = sch.get_loops(C_warp)[-2:] + io, ii = sch.split(ii, factors=[None, 16]) + jo, ji = sch.split(jj, factors=[None, 16]) + sch.reorder(io, jo, ii, ji) + + sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3]) + block_init_c = sch.get_block("C_init") + + def tile_wmma_fragment(block_read, height, width): + i, j = sch.get_loops(block_read)[-2:] + i0, i1 = sch.split(i, factors=[None, height]) + j0, j1 = sch.split(j, factors=[None, width]) + sch.reorder(i0, j0, i1, j1) + return i1 + + loop_a = tile_wmma_fragment(A_warp, 16, k_inner) + + if b_transposed: + loop_b = tile_wmma_fragment(B_warp, 16, k_inner) + else: + loop_b = tile_wmma_fragment(B_warp, k_inner, 16) + + sch.transform_layout(A_warp, 0, "write", index_map_A) + sch.transform_layout(B_warp, 0, "write", index_map_B) + sch.transform_layout(C_warp, 0, "read", index_map_C) + + sch.tensorize(loop_a, ldmatrix_a_intrin) + sch.tensorize(loop_b, ldmatrix_b_intrin) + sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin) + sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin) + sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin) + + if not is_ampere_or_newer(): + return None + + f = tvm.build(sch.mod["main"], target="cuda", name="dense") + + dev = tvm.device("cuda", 0) + + if in_dtype == "float16": + a_np = np.random.uniform(size=(M, K)).astype("float16") + + if b_transposed: + b_np = np.random.uniform(size=(N, K)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + out_dtype + ) + else: + b_np = np.random.uniform(size=(K, N)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + else: + a_np = np.random.randint(-128, 128, (M, K)).astype("int8") + + if b_transposed: + b_np = np.random.randint(-128, 128, (N, K)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + "int32" + ) + else: + b_np = np.random.randint(-128, 128, (K, N)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") + + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev) + + f(a, b, c) + + if out_dtype != "float16": + # The numpy reference is computed with fp32 precision (otherwise too slow). + # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) + + return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) + + +@tvm.testing.requires_cuda +def test_f16f16f32_m16n16k16(): + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "float16" + out_dtype = "float32" + i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_TRANS_INTRIN, + MMA_f16f16f32_TRANS_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) + + +@tvm.testing.requires_cuda +def test_f16f16f16_m16n16k16(): + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "float16" + out_dtype = "float16" + i_factors, j_factors, k_factors = [16, 2, 1, 4, 2], [16, 2, 2, 1, 4], [128, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_INTRIN, + MMA_f16f16f16_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_16x16_A_INTRIN, + LDMATRIX_16x16_B_TRANS_INTRIN, + MMA_f16f16f16_TRANS_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) + + +@tvm.testing.requires_cuda +def test_i8i8i32_m16n16k32(): + def index_map_A(i, j): + return ( + i // 16, + j // 32, + *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), + ) + + def index_map_B(i, j): + return ( + i // 32, + j // 16, + *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 32 + in_dtype = "int8" + out_dtype = "int32" + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + LDMATRIX_16x32_A_INTRIN, + LDMATRIX_32x16_B_INTRIN, + MMA_i8i8i32_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + ) + + if measure_perf and timer: + print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_A, + index_map_C, + LDMATRIX_16x32_A_INTRIN, + LDMATRIX_16x32_B_TRANS_INTRIN, + MMA_i8i8i32_TRANS_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + ) + + if measure_perf and timer: + print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + + +if __name__ == "__main__": + test_f16f16f32_m16n16k16() + test_f16f16f16_m16n16k16() + test_i8i8i32_m16n16k32()