From b257cd7859662ca77e2d64a4be30e37feb771ae7 Mon Sep 17 00:00:00 2001 From: Rye <85601223+Cunxiao2002@users.noreply.github.com> Date: Thu, 17 Oct 2024 02:12:58 +0800 Subject: [PATCH] [fix]add hip target (#2) --- cmake/modules/ROCM.cmake | 4 +- include/tvm/tir/builtin.h | 46 + python/tvm/_ffi/runtime_ctypes.py | 1 + python/tvm/contrib/hipcc.py | 111 ++ python/tvm/script/ir_builder/tir/ir.py | 8 + python/tvm/tir/__init__.py | 1 + python/tvm/tir/op.py | 247 +++++ python/tvm/tir/tensor_intrin/hip.py | 441 ++++++++ src/target/opt/build_rocm_nort.cc | 121 +++ src/target/opt/build_rocm_on.cc | 184 ++++ src/target/source/codegen_hip.cc | 1295 ++++++++++++++++++++++++ src/target/source/codegen_hip.h | 115 +++ src/target/target_kind.cc | 35 + src/tir/op/builtin.cc | 12 + 14 files changed, 2620 insertions(+), 1 deletion(-) create mode 100644 python/tvm/contrib/hipcc.py create mode 100644 python/tvm/tir/tensor_intrin/hip.py create mode 100644 src/target/opt/build_rocm_nort.cc create mode 100644 src/target/opt/build_rocm_on.cc create mode 100644 src/target/source/codegen_hip.cc create mode 100644 src/target/source/codegen_hip.h diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index 37fcd716464e..753dd1f2d6a6 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -33,6 +33,7 @@ if(USE_ROCM) message(STATUS "Build with ROCM support") tvm_file_glob(GLOB RUNTIME_ROCM_SRCS src/runtime/rocm/*.cc) list(APPEND RUNTIME_SRCS ${RUNTIME_ROCM_SRCS}) + list(APPEND COMPILER_SRCS src/target/opt/build_rocm_on.cc) list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPHCC_LIBRARY}) if (ROCM_HSA_LIBRARY) list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HSA_LIBRARY}) @@ -69,5 +70,6 @@ if(USE_ROCM) endif(USE_THRUST) else(USE_ROCM) - list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc) + #list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc) + list(APPEND COMPILER_SRCS src/target/opt/build_rocm_nort.cc) endif(USE_ROCM) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 10e5b462d1d1..30947516d1f1 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -616,6 +616,52 @@ TVM_DLL const Op& tvm_store_matrix_sync(); */ TVM_DLL const Op& ptx_mma(); +/*! + * \brief tvm intrinsic for amd matrix core mfma instructions. + * + * void tvm_mfma(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index); + */ +TVM_DLL const Op& tvm_mfma(); + +/*! + * \brief tvm intrinsic for storing the result of AMD MFMA into a destination pointer. + * + * There is no real instruction that does that, but we want to hide details of + * complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g. + * LowerWarpMemory) like cuda ptx backend does. + * + * void tvm_mfma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var + * dst_stride); + */ +TVM_DLL const Op& tvm_mfma_store(); + +/*! + * \brief tvm intrinsic for amd rdna matrix core instructions. + * + * void tvm_rdna_wmma(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index); + */ +TVM_DLL const Op& tvm_rdna_wmma(); + +/*! + * \brief tvm intrinsic for storing the result of AMD RDNA WMMA into a destination pointer. + * + * There is no real instruction that does that, but we want to hide details of + * complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g. + * LowerWarpMemory) like cuda ptx backend does. + * + * void tvm_rdna_wmma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var + * dst_stride); + */ +TVM_DLL const Op& tvm_rdna_wmma_store(); + /*! * \brief tvm intrinsic for ptx predicate load with 32-bit data type. * diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 099cbe972a4a..e25d391ce63c 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -309,6 +309,7 @@ class Device(ctypes.Structure): "metal": kDLMetal, "vpi": kDLVPI, "rocm": kDLROCM, + "hip": kDLROCM, "ext_dev": kDLExtDev, "hexagon": kDLHexagon, "webgpu": kDLWebGPU, diff --git a/python/tvm/contrib/hipcc.py b/python/tvm/contrib/hipcc.py new file mode 100644 index 000000000000..9feeb536f9e1 --- /dev/null +++ b/python/tvm/contrib/hipcc.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Utility to invoke hipcc compiler in the system""" +from __future__ import absolute_import as _abs + +import subprocess +import os +import warnings + +import tvm._ffi +from tvm.target import Target + +from . import utils +from .._ffi.base import py_str +from .rocm import get_rocm_arch, find_rocm_path + + +def compile_hip(code, target_format="hsaco", arch=None, options=None, path_target=None): + """Compile HIP code with hipcc. + + Parameters + ---------- + code : str + The HIP code. + + target_format : str + The target format of hipcc compiler. + + arch : str + The AMD GPU architecture. + + options : str or list of str + The additional options. + + path_target : str, optional + Output file. + + Return + ------ + hsaco : bytearray + The bytearray of the hsaco + """ + if arch is None: + rocm_path = find_rocm_path() + arch = get_rocm_arch(rocm_path) + + temp = utils.tempdir() + if target_format not in ["hsaco"]: + raise ValueError("target_format must be hsaco") + temp_code = temp.relpath("my_kernel.cc") + temp_target = temp.relpath("my_kernel.%s" % target_format) + + with open(temp_code, "w") as out_file: + out_file.write(code) + + file_target = path_target if path_target else temp_target + cmd = ["hipcc"] + cmd += ["-O3", '-c'] + if isinstance(arch, str): + cmd += [f"--offload-arch={arch}"] + if target_format == "hsaco": + cmd += ["--genco"] + if options: + if isinstance(options, str): + cmd += [options] + elif isinstance(options, list): + cmd += options + else: + raise ValueError("options must be str or list of str") + + cmd += ["-o", file_target] + cmd += [temp_code] + print(f"cmd: {cmd}") + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = code + msg += "\nCompilation error:\n" + msg += py_str(out) + raise RuntimeError(msg) + + with open(file_target, "rb") as f: + data = bytearray(f.read()) + if not data: + raise RuntimeError("Compilation error: empty result is generated") + return data + + +@tvm._ffi.register_func("tvm_callback_hip_compile") +def tvm_callback_hip_compile(code): + """use hipcc to generate fatbin code for better optimization""" + hsaco = compile_hip(code, target_format="hsaco") + return hsaco \ No newline at end of file diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 70cb6d801b04..3991a2f82ea0 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1914,6 +1914,10 @@ def wrapped(*args, **kwargs): vectorlow = _dtype_forward(_tir_op.vectorlow) vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) +tvm_mfma = _dtype_forward(_tir_op.tvm_mfma) +tvm_mfma_store = _dtype_forward(_tir_op.tvm_mfma_store) +tvm_rdna_wmma = _dtype_forward(_tir_op.tvm_rdna_wmma) +tvm_rdna_wmma_store = _dtype_forward(_tir_op.tvm_rdna_wmma_store) broadcast = Broadcast @@ -2170,6 +2174,10 @@ def wrapped(*args, **kwargs): "vectorlow", "vectorhigh", "vectorcombine", + "tvm_mfma", + "tvm_mfma_store", + "tvm_rdna_wmma", + "tvm_rdna_wmma_store", "assume", "undef", "tvm_call_packed", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 4d6ad44e106f..757f3471d6a9 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -73,6 +73,7 @@ ptx_wait_barrier, create_barriers, ) +from .op import tvm_mfma, tvm_mfma_store, tvm_rdna_wmma, tvm_rdna_wmma_store from .op import vectorlow, vectorhigh, vectorcombine from .op import infinity, reinterpret from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 6b72e63f2990..ee2ce40f389f 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1457,6 +1457,253 @@ def ptx_wait_group(num): """ return call_intrin("", "tir.ptx_wait_group", num) +def tvm_mfma( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, +): + """TVM intrinsic for amd matrix core mfma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of accumulator fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment A. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + "tir.tvm_mfma", + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + ) + +def tvm_mfma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): + """TVM intrinsic for storing the result of PTX MMA into a destination pointer + + Parameters + ---------- + dtype : str + The data type of the result. + + m : IntImm + The shape of mma fragment. + + n : IntImm + The shape of mma fragment. + + dst_ptr : Var + The destination pointer variable. + + src_ptr : Var + The source pointer variable. + + src_offset : Expr + The source offset. + + dst_stride : Var + The destination stride. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + "tir.tvm_mfma_store", + m, + n, + dst_ptr, + src_ptr, + src_offset, + dst_stride, + ) + +def tvm_rdna_wmma( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, +): + """TVM intrinsic for amd matrix core mfma instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of accumulator fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment A. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + "tir.tvm_rdna_wmma", + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + ) + + +def tvm_rdna_wmma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): + """TVM intrinsic for storing the result of PTX MMA into a destination pointer + + Parameters + ---------- + dtype : str + The data type of the result. + + m : IntImm + The shape of mma fragment. + + n : IntImm + The shape of mma fragment. + + dst_ptr : Var + The destination pointer variable. + + src_ptr : Var + The source pointer variable. + + src_offset : Expr + The source offset. + + dst_stride : Var + The destination stride. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + "tir.tvm_rdna_wmma_store", + m, + n, + dst_ptr, + src_ptr, + src_offset, + dst_stride, + ) + def ptx_cp_async_barrier(barrier_id): """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive diff --git a/python/tvm/tir/tensor_intrin/hip.py b/python/tvm/tir/tensor_intrin/hip.py new file mode 100644 index 000000000000..18f36677118e --- /dev/null +++ b/python/tvm/tir/tensor_intrin/hip.py @@ -0,0 +1,441 @@ +import tvm.tir +from tvm.runtime import convert +from tvm.tir.expr import Cast, IntImm +from tvm.tir.function import TensorIntrin +from tvm.script import tir as T + +lift = convert + +WARP_SIZE = 64 +M_DIM = 16 +N_DIM = 16 + + +def shared_16x4_to_local_64x1_layout_A(i, j): + thread_id = j * 16 + i + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id): + i = thread_id % 16 + j = thread_id // 16 + return i, j + + +def shared_4x16_to_local_64x1_layout_B(i, j): + thread_id = i * 16 + j + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id): + i = thread_id // 16 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_C(i, j): + thread_id = j + (i // 4) * 16 + local = i % 4 + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): + i = thread_id % 16 + j = (thread_id // 16) * 4 + local_id + return i, j + + +def shared_16x16_to_local_64x4_layout_A(i, j): + thread_id = i + 16 * (j // 4) + local = j % 4 + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_B(i, j): + thread_id = j + (i // 4) * 16 + local = i % 4 + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id): + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j + + +def get_mma_fill_intrin(dtype, local_size): + zero = IntImm("int32", 0).astype(dtype) + + # Assume M = N = 16 + index_map = shared_16x16_to_local_64x4_layout_C + + @T.prim_func + def mma_fill_desc(a: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + i, j = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = T.meta_var(index_map(i, j)) + T.reads() + T.writes(C_warp[thread_id, local_id]) + C_warp[thread_id, local_id] = zero + + @T.prim_func + def mma_fill_impl(a: T.handle) -> None: + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + + with T.block("root"): + T.reads() + T.writes(C_warp[0:WARP_SIZE, 0:local_size]) + for tx in T.thread_binding(WARP_SIZE, "threadIdx.x"): + for local_id in T.serial(0, local_size): + C_warp[tx, local_id] = zero + + return mma_fill_desc, mma_fill_impl + + +def get_mfma_load_intrin( + k_dim=4, + dtype="float32", + scope="shared", + is_b=False, + transposed=False, +): + local_size = (M_DIM * k_dim) // WARP_SIZE if not is_b else (N_DIM * k_dim) // WARP_SIZE + memory_shape = (M_DIM, k_dim) + if is_b: + memory_shape = (N_DIM, k_dim) if transposed else (k_dim, N_DIM) + + row_dim, col_dim = memory_shape + + if k_dim == 4: + index_map = shared_16x4_to_local_64x1_layout_A + reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A + if is_b: + index_map = ( + shared_16x4_to_local_64x1_layout_A + if transposed + else shared_4x16_to_local_64x1_layout_B + ) + reverse_index_map = ( + thread_id_shared_access_64x1_to_16x4_layout_A + if transposed + else thread_id_shared_access_64x1_to_4x16_layout_B + ) + elif k_dim == 16: + index_map = shared_16x16_to_local_64x4_layout_A + reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A + + if is_b: + index_map = ( + shared_16x16_to_local_64x4_layout_A + if transposed + else shared_16x16_to_local_64x4_layout_B + ) + reverse_index_map = ( + thread_id_shared_access_64x4_to_16x16_layout_A + if transposed + else thread_id_shared_access_64x4_to_16x16_layout_B + ) + else: + raise ValueError("k_dim must be 4 or 16 currently") + + @T.prim_func + def mfma_load_desc(reg_handle: T.handle, memory_handle: T.handle) -> None: + memory = T.match_buffer( + memory_handle, + memory_shape, + dtype, + offset_factor=1, + scope=scope, + ) + reg = T.match_buffer( + reg_handle, (WARP_SIZE, local_size), dtype, offset_factor=1, scope="warp" + ) + + with T.block("root"): + T.reads(memory[0:row_dim, 0:col_dim]) + T.writes(reg[0:WARP_SIZE, 0:local_size]) + + for ax0, ax1 in T.grid(row_dim, col_dim): + with T.block("memory_reg"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(memory[v0, v1]) + + thread_id, local_id = T.meta_var(index_map(v0, v1)) + T.writes(reg[thread_id, local_id]) + reg[thread_id, local_id] = memory[v0, v1] + + @T.prim_func + def mfma_load_impl(reg_handle: T.handle, memory_handle: T.handle) -> None: + s0 = T.int32() + s1 = T.int32() + + memory = T.match_buffer( + memory_handle, + memory_shape, + dtype, + align=64, + offset_factor=1, + scope=scope, + strides=[s0, s1], + ) + reg = T.match_buffer( + reg_handle, (WARP_SIZE, local_size), dtype, align=64, offset_factor=1, scope="warp" + ) + + with T.block("root"): + T.reads(memory[0:row_dim, 0:col_dim]) + T.writes(reg[0:WARP_SIZE, 0:local_size]) + for tx in T.thread_binding(WARP_SIZE, "threadIdx.x"): + for local_id in T.serial(0, local_size): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + reg[tx, local_id] = memory[row, col] + + return mfma_load_desc, mfma_load_impl + + +def get_mfma_intrin(k_dim, in_dtype="float32", out_dtype="float32", b_transposed=False): + local_size = (M_DIM * k_dim) // WARP_SIZE + local_size_out = (M_DIM * N_DIM) // WARP_SIZE + compute_in_dtype = in_dtype if local_size == 1 else f"{in_dtype}x{local_size}" + compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" + + if k_dim == 4: + index_map_A = shared_16x4_to_local_64x1_layout_A + index_map_B = shared_4x16_to_local_64x1_layout_B + index_map_C = shared_16x16_to_local_64x4_layout_C + elif k_dim == 16: + index_map_A = shared_16x16_to_local_64x4_layout_A + index_map_B = shared_16x16_to_local_64x4_layout_B + index_map_C = shared_16x16_to_local_64x4_layout_C + else: + raise ValueError("k_dim must be 4 or 16 currently") + + out_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[out_dtype] + + in_dtype_abbrv = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}[in_dtype] + + mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" + + def maybe_cast(v): + if out_dtype != in_dtype: + return Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + @T.prim_func + def mfma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=1, scope="warp") + + with T.block("root"): + T.reads( + C[0:WARP_SIZE, 0:local_size_out], + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + + for i, j, k in T.grid(M_DIM, N_DIM, k_dim): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i, j, k]) + b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j)) + + thread_id_C, local_id_C = T.meta_var(index_map_C(i, j)) + thread_id_A, local_id_A = T.meta_var(index_map_A(i, k)) + thread_id_B, local_id_B = T.meta_var(index_map_B(b_row_ind, b_col_ind)) + + T.reads( + C[thread_id_C, local_id_C], + A[thread_id_A, local_id_A], + B[thread_id_B, local_id_B], + ) + T.writes(C[thread_id_C, local_id_C]) + + C[thread_id_C, local_id_C] += maybe_cast( + A[thread_id_A, local_id_A] + ) * maybe_cast(B[thread_id_B, local_id_B]) + + @T.prim_func + def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=1, scope="warp") + + with T.block("root"): + T.reads( + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + C[0:WARP_SIZE, 0:local_size_out], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + T.evaluate(T.tvm_mfma( + mfma_suffix, + "row", + "row", + compute_in_dtype, + compute_in_dtype, + compute_out_dtype, + A.data, + A.elem_offset, + B.data, + B.elem_offset, + C.data, + C.elem_offset // (WARP_SIZE * local_size_out), + dtype=compute_out_dtype, + )) + + @T.prim_func + def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, offset_factor=1, scope="warp") + C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, offset_factor=1, scope="warp") + + with T.block("root"): + T.reads( + A[0:WARP_SIZE, 0:local_size], + B[0:WARP_SIZE, 0:local_size], + C[0:WARP_SIZE, 0:local_size_out], + ) + T.writes(C[0:WARP_SIZE, 0:local_size_out]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + + T.evaluate( + T.tvm_mfma( + mfma_suffix, + "row", + "row", + compute_in_dtype, + compute_in_dtype, + compute_out_dtype, + T.call_intrin("int32", "tir.reinterpret", A.data), + A.elem_offset, + T.call_intrin("int32", "tir.reinterpret", B.data), + B.elem_offset, + C.data, + C.elem_offset // (WARP_SIZE * local_size_out), + dtype=compute_out_dtype, + ) + ) + + return ( + (mfma_sync_desc, mfma_sync_impl_integer) + if in_dtype == "int8" + else (mfma_sync_desc, mfma_sync_impl_float) + ) + + +def get_mfma_store_intrin(local_size=4, dtype="float32", scope="global"): + index_map = shared_16x16_to_local_64x4_layout_C + + @T.prim_func + def mfma_store_desc(a: T.handle, c: T.handle) -> None: + C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, scope="warp") + C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + for i0, i1 in T.grid(M_DIM, N_DIM): + with T.block("C_warp"): + v0, v1 = T.axis.remap("SS", [i0, i1]) + thread_id, local_id = T.meta_var(index_map(v0, v1)) + T.reads(C_warp[thread_id, local_id]) + T.writes(C[v0, v1]) + C[v0, v1] = C_warp[thread_id, local_id] + + @T.prim_func + def mfma_store_impl(a: T.handle, c: T.handle) -> None: + s0 = T.int32() + s1 = T.int32() + + C_warp = T.match_buffer( + a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 + ) + C = T.match_buffer( + c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] + ) + + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + tx = T.env_thread("threadIdx.x") + T.launch_thread(tx, WARP_SIZE) + for i in range(local_size): + C[((tx // 16) * 4) + i, (tx % 16)] = C_warp[tx, i] + + return mfma_store_desc, mfma_store_impl + + +HIP_MFMA_fill_16x16_f32_INTRIN = "HIP_mfma_fill_16x16_f32" +TensorIntrin.register(HIP_MFMA_fill_16x16_f32_INTRIN, *get_mma_fill_intrin("float32", 4)) + +HIP_MFMA_fill_16x16_i32_INTRIN = "HIP_mfma_fill_16x16_i32" +TensorIntrin.register(HIP_MFMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int", 4)) + +HIP_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN = "hip_mfma_load_16x16_a_shared_s8" +TensorIntrin.register( + HIP_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN, *get_mfma_load_intrin(16, "int8", "shared") +) +HIP_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN = "hip_mfma_load_b_16x16_shared_s8" +TensorIntrin.register( + HIP_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN, *get_mfma_load_intrin(16, "int8", "shared", is_b=True) +) + +HIP_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN = "hip_mfma_load_16x16_a_shared_f16" +TensorIntrin.register( + HIP_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN, *get_mfma_load_intrin(16, "float16", "shared") +) +HIP_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN = "hip_mfma_load_b_16x16_shared_f16" +TensorIntrin.register( + HIP_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN, + *get_mfma_load_intrin(16, "float16", "shared", is_b=True), +) + +HIP_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN = "hip_mfma_load_16x4_a_shared_f32" +TensorIntrin.register( + HIP_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN, *get_mfma_load_intrin(4, "float32", "shared") +) +HIP_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN = "hip_mfma_load_b_16x4_shared_f32" +TensorIntrin.register( + HIP_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN, *get_mfma_load_intrin(4, "float32", "shared", is_b=True) +) + + +HIP_MFMA_f32f32f32_INTRIN = "hip_mfma_f32f32f32" +TensorIntrin.register(HIP_MFMA_f32f32f32_INTRIN, *get_mfma_intrin(4, "float32", "float32")) + +HIP_MFMA_f16f16f32_INTRIN = "hip_mfma_f16f16f32" +TensorIntrin.register(HIP_MFMA_f16f16f32_INTRIN, *get_mfma_intrin(16, "float16", "float32")) + +HIP_MFMA_s8s8s32_INTRIN = "hip_mfma_s8s8s32" +TensorIntrin.register(HIP_MFMA_s8s8s32_INTRIN, *get_mfma_intrin(16, "int8", "int32")) + +HIP_MFMA_STORE_16x16_s32_INTRIN = "hip_mfma_store_16x16_s32" +TensorIntrin.register(HIP_MFMA_STORE_16x16_s32_INTRIN, *get_mfma_store_intrin(4, "int32", "global")) + +HIP_MFMA_STORE_16x16_f32_INTRIN = "hip_mfma_store_16x16_f32" +TensorIntrin.register( + HIP_MFMA_STORE_16x16_f32_INTRIN, *get_mfma_store_intrin(4, "float32", "global") +) diff --git a/src/target/opt/build_rocm_nort.cc b/src/target/opt/build_rocm_nort.cc new file mode 100644 index 000000000000..5a62cbf3a509 --- /dev/null +++ b/src/target/opt/build_rocm_nort.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Optional module when build rocm is switched to off + */ + +#include +#include + +#include "../../runtime/rocm/rocm_module.h" +#include "../build_common.h" +#include "../source/codegen_hip.h" +#include "../source/codegen_source_base.h" + +namespace tvm { +namespace runtime { + +class ROCMModuleNode : public runtime::ModuleNode { + public: + explicit ROCMModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, + std::string hip_source, std::string assembly) + : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) {} + + const char* type_key() const final { return "hip"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + ICHECK(0) << "Not implemented when rocm is not enabled in TVM."; + return PackedFunc(); + }; + + std::string GetSource(const std::string& format) final { + if (format == fmt_) { + return data_; + } + if (format == "llvm" || format == "") { + return hip_source_; + } + if (format == "asm") { + return assembly_; + } + return ""; + } + + + private: + // the binary data + std::string data_; + // The format + std::string fmt_; + // function information table. + std::unordered_map fmap_; + // The hip source. + std::string hip_source_; + // The gcn asm. + std::string assembly_; + // internal mutex when updating the module + std::mutex mutex_; +}; + +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string hip_source, + std::string assembly) { + auto n = make_object(data, fmt, fmap, hip_source, assembly); + return Module(n); +} + +} // namespace runtime +} // namespace tvm +namespace tvm { +namespace codegen { +using tvm::runtime::Registry; +runtime::Module BuildHIP(IRModule mod, Target target) { + using tvm::runtime::Registry; + bool output_ssa = false; + CodeGenHIP cg; + cg.Init(output_ssa); + + for (auto kv : mod->functions) { + ICHECK(kv.second->IsInstance()) << "CodeGenHIP: Can only take PrimFunc"; + auto f = Downcast(kv.second); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenHIP: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + cg.AddFunction(f); + } + + std::string code = cg.Finish(); + + if (const auto* f = Registry::Get("tvm_callback_hip_postproc")) { + code = (*f)(code).operator std::string(); + } + std::string fmt = "ptx"; + std::string ptx; + const auto* f_enter = Registry::Get("target.TargetEnterScope"); + (*f_enter)(target); + const auto* f_exit = Registry::Get("target.TargetExitScope"); + (*f_exit)(target); + return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); +} + +TVM_REGISTER_GLOBAL("target.build.hip").set_body_typed(BuildHIP); +} // namespace codegen +} // namespace tvm diff --git a/src/target/opt/build_rocm_on.cc b/src/target/opt/build_rocm_on.cc new file mode 100644 index 000000000000..d46e7256295a --- /dev/null +++ b/src/target/opt/build_rocm_on.cc @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Optional module when build rocm is switched to on + */ + +#if defined(__linux__) +#include +#endif + +#include +#include + +#include +#include + +#include "../../runtime/rocm/rocm_module.h" +#include "../build_common.h" +#include "../source/codegen_hip.h" +#include "../source/codegen_source_base.h" + +namespace tvm { +namespace codegen { + +#define HIPRTC_CALL(x) \ + \ + { \ + \ + hiprtcResult result = x; \ + \ + if (result != HIPRTC_SUCCESS) { \ + \ + LOG(FATAL) \ + << "HiprtcError: " #x " failed with error: " << hiprtcGetErrorString(result); \ + \ + \ + } \ + \ + \ + } + +std::string FindHIPIncludePath() { +#if defined(_WIN32) + const std::string delimiter = "\\"; +#else + const std::string delimiter = "/"; +#endif + std::string hip_include_path; + const char* hip_path_env = std::getenv("HIP_PATH"); + if (hip_path_env != nullptr) { + hip_include_path += hip_path_env; + hip_include_path += delimiter + "include"; + return hip_include_path; + } + +#if defined(__linux__) + struct stat st; + hip_include_path = "/opt/rocm/hip/include"; + if (stat(hip_include_path.c_str(), &st) == 0) { + return hip_include_path; + } + + if (stat("/usr/include/hip/hip_runtime.h", &st) == 0) { + return "/usr/include/hip"; + } +#endif + LOG(FATAL) << "Cannot find HIP include path." + << "HIP_PATH is not set or ROCm is not installed in the default installation path." + << "In other than linux, it is necessary to set HIP_PATH."; + return hip_include_path; +} + +std::string HIPRTCCompile(const std::string& code, bool include_path = false) { + std::vector compile_params; + std::vector param_cstrings{}; + hiprtcProgram prog; + std::string cc = "gfx900"; // 默认目标架构(可以根据需要更改) + int major, minor; + hipError_t e1 = hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, 0); + hipError_t e2 = hipDeviceGetAttribute(&minor, hipDeviceAttributeComputeCapabilityMinor, 0); + + if (e1 == hipSuccess && e2 == hipSuccess) { + cc = "gfx" + std::to_string(major * 100 + minor * 10); + } else { + LOG(WARNING) << "cannot detect compute capability from your device, " + << "fall back to gfx900."; + } + + compile_params.push_back("--gpu-architecture=" + cc); + + if (include_path) { + std::string include_option = "--include-path=" + FindHIPIncludePath(); + compile_params.push_back(include_option); + } + + for (const auto& string : compile_params) { + param_cstrings.push_back(string.c_str()); + } + HIPRTC_CALL(hiprtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); + hiprtcResult compile_res = + hiprtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); + + size_t log_size; + HIPRTC_CALL(hiprtcGetProgramLogSize(prog, &log_size)); + std::string log; + log.resize(log_size); + HIPRTC_CALL(hiprtcGetProgramLog(prog, &log[0])); + ICHECK_EQ(compile_res, HIPRTC_SUCCESS) << log; + size_t code_size; + HIPRTC_CALL(hiprtcGetCodeSize(prog, &code_size)); + + std::string code_out; + code_out.resize(code_size); + HIPRTC_CALL(hiprtcGetCode(prog, &code_out[0])); + HIPRTC_CALL(hiprtcDestroyProgram(&prog)); + + return code_out; +} + +runtime::Module BuildHIP(IRModule mod, Target target) { + using tvm::runtime::Registry; + bool output_ssa = false; + CodeGenHIP cg; + cg.Init(output_ssa); + + Map functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance()) << "CodeGenHIP: Can only take PrimFunc"; + auto prim_func = Downcast(base_func); + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenHIP: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + cg.AddFunction(gvar, prim_func); + } + + std::string code = cg.Finish(); + + if (const auto* f = Registry::Get("tvm_callback_hip_postproc")) { + code = (*f)(code).operator std::string(); + } + std::string fmt = "ptx"; + std::string ptx; + const auto* f_enter = Registry::Get("target.TargetEnterScope"); + (*f_enter)(target); + if (const auto* f = Registry::Get("tvm_callback_hip_compile")) { + ptx = (*f)(code).operator std::string(); + // Dirty matching to check PTX vs hsaco. + // TODO(leiwang1999) more reliable checks + if (ptx[0] != '/') fmt = "hsaco"; + } else { + ptx = HIPRTCCompile(code, cg.need_include_path()); + } + const auto* f_exit = Registry::Get("target.TargetExitScope"); + (*f_exit)(target); + return ROCMModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code, std::string()); +} + +TVM_REGISTER_GLOBAL("target.build.hip").set_body_typed(BuildHIP); +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/codegen_hip.cc b/src/target/source/codegen_hip.cc new file mode 100644 index 000000000000..cc8478e2448a --- /dev/null +++ b/src/target/source/codegen_hip.cc @@ -0,0 +1,1295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_hip.cc + */ + +#include "codegen_hip.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "ptx.h" + +namespace tvm { +namespace codegen { + +/*! + * \brief Replace patterns with replacement strings. + * \note should use std::format instead when codebase is ported to C++20. + */ +class Replacer { + public: + void register_rule(const std::string& pattern, const std::string& replacement) { + _rules.emplace_back(pattern, replacement); + } + std::string rewrite(std::string str) { + for (auto&& rule : _rules) { + auto [pattern, replacement] = rule; + size_t len = pattern.size(); + size_t new_len = replacement.size(); + size_t pos = str.find(pattern); + while (pos != std::string::npos) { + str = str.replace(pos, len, replacement); + pos = str.find(pattern, pos + new_len); + } + } + return str; + } + void empty_rules() { _rules.clear(); } + + private: + std::vector> _rules; +}; + +CodeGenHIP::CodeGenHIP() { restrict_keyword_ = "__restrict__"; } + +void CodeGenHIP::Init(bool output_ssa) { + CodeGenC::Init(output_ssa); + this->cuda_codegen_.Init(output_ssa); +} + +void CodeGenHIP::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ void"; } + +std::string CodeGenHIP::Finish() { + // hip must need a header file. + decl_stream << "#include \n"; + + if (enable_fp16_) { + decl_stream << "#include \n"; + decl_stream << "using float16_t = _Float16;\n"; + decl_stream << "using float16x2\n"; + decl_stream << " = __attribute__((__vector_size__(2 * sizeof(float16_t)))) float16_t;\n"; + decl_stream << "using float16x4\n"; + decl_stream << " = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t;\n"; + decl_stream << "using float16x8\n"; + decl_stream << " = __attribute__((__vector_size__(8 * sizeof(float16_t)))) float16_t;\n"; + decl_stream << "using float16x16\n"; + decl_stream << " = __attribute__((__vector_size__(16 * sizeof(float16_t)))) float16_t;\n"; + } + + if (need_math_constants_h_) { + decl_stream << "#include \n"; + } + + if (need_wmma_h_) { + decl_stream << "#include \n"; + } + decl_stream << "using int32x4\n"; + decl_stream << " = __attribute__((__vector_size__(4 * sizeof(int)))) int;\n"; + decl_stream << "using float32x4\n"; + decl_stream << " = __attribute__((__vector_size__(4 * sizeof(float)))) float;\n"; + decl_stream << "using float32x16\n"; + decl_stream << " = __attribute__((__vector_size__(16 * sizeof(float)))) float;\n"; + + return CodeGenC::Finish(); +} + +class ThreadIdxExtractor : public tir::StmtVisitor { + private: + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tir::attr::thread_extent) { + IterVar iv = Downcast(op->node); + if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") { + threadIdx_x_ext = op->value; + } + if (iv->var->name_hint == "threadIdx.y" || iv->thread_tag == "threadIdx.y") { + threadIdx_y_ext = op->value; + } + if (iv->var->name_hint == "threadIdx.z" || iv->thread_tag == "threadIdx.z") { + threadIdx_z_ext = op->value; + } + } + StmtVisitor::VisitStmt_(op); + } + + public: + PrimExpr threadIdx_x_ext = Integer(1); + PrimExpr threadIdx_y_ext = Integer(1); + PrimExpr threadIdx_z_ext = Integer(1); +}; + +/*void CodeGenHIP::PrintExtraAttrs(const PrimFunc& f) { + ThreadIdxExtractor extractor; + extractor(f->body); + arith::Analyzer analyzer; + PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * + extractor.threadIdx_z_ext); + if (const IntImmNode* const threadIdx_ext_int = threadIdx_ext.as()) { + if (threadIdx_ext_int->value == 1) { + // unable to extract the number of threads per block, hence directly return + return; + } + stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; + } +} +*/ + +void CodeGenHIP::VisitStmt_(const tir::ForNode* op) { + ICHECK(is_const_int(op->min, 0)); + if (op->kind == tir::ForKind::kUnrolled) { + PrintIndent(); + stream << "#pragma unroll\n"; + } + CodeGenC::VisitStmt_(op); +} + +void CodeGenHIP::BindThreadIndex(const IterVar& iv) { + ICHECK(!var_idmap_.count(iv->var.get())); + var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); +} + +void CodeGenHIP::PrintType(DataType t, std::ostream& os) { // NOLINT(*) + int lanes = t.lanes(); + if (t.is_handle()) { + ICHECK(t.is_scalar()) << "do not yet support vector types"; + os << "void*"; + return; + } + + if (t.is_void()) { + os << "void"; + return; + } + + bool fail = false; + if (t.is_float()) { + switch (t.bits()) { + case 16: + enable_fp16_ = true; + if (t.is_scalar()) { + os << "half"; + } else if (lanes <= 8) { + // Emit CUDA code to access fp16 vector elements. + // + // half4 is stored as uint2 + // + // h4.x is emitted as *(half2*)(&(u2.x)).x + // h4.y is emitted as *(half2*)(&(u2.x)).y + // h4.z is emitted as *(half2*)(&(u2.y)).x + // h4.w is emitted as *(half2*)(&(u2.y)).y + // + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else { + fail = true; + } + break; + case 32: + if (lanes <= 4) { + os << "float"; + } else if (lanes <= 8) { + // Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8. + // + // float8 is stored as ulonglong4 + // + // f8.v1 is emitted as *(float2*)(&(ul4.x)).x + // f8.v2 is emitted as *(float2*)(&(ul4.x)).y + // + ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4"; + os << "ulonglong" << lanes / 2; + } else { + fail = true; + } + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; + } + if (!fail && (t.is_scalar() || t.bits() == 16)) return; + if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return; + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } else if (t.is_bfloat16()) { + enable_bf16_ = true; + if (t.is_scalar()) { + os << "nv_bfloat16"; + } else if (lanes <= 8) { + ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; + os << "uint" << lanes / 2; + } else { + fail = true; + } + if (!fail) return; + } else if (t == DataType::Bool()) { + os << "bool"; + return; + } else if (t.is_vector_bool()) { + // CUDA does not support bool vectors. + // Use ushort vectors to represent instead. + int n = t.lanes(); + if (n <= 4) { + os << "ushort" << n; + return; + } + } else if (t.is_uint() || t.is_int()) { + if (t.is_uint()) { + os << "u"; + } + switch (t.bits()) { + case 1: { + if (t.is_scalar()) { + os << "int"; + return; + } else if (t.lanes() == 8) { + os << "int8_t"; + return; + } else if (t.lanes() == 16) { + os << "int16_t"; + return; + } else if (t.lanes() == 32) { + os << "int"; + return; + } else { + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + } + } + case 4: { + if (t.is_scalar()) { + os << "int"; + return; + } else if (t.lanes() == 4) { + os << "int16_t"; + return; + } else if (t.lanes() == 8) { + // directly 8 4-bit int in integer. + os << "int"; + return; + } else if (t.lanes() == 16) { + os << "int2"; + return; + } else if (t.lanes() == 32) { + os << "int4"; + return; + } else if (t.lanes() == 64) { + os << "int8"; + return; + } else { + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; + } + } + case 8: { + if (t.lanes() == 4) { + // directly 4 8 bit int in integer. + enable_int8_ = true; + + // We use int for int8x4 instead of char4 because using char4 is + // likely to produce extra instructions to pack four int8 elements + // into 32-bit data. + os << "int"; + return; + } else if (t.lanes() == 8) { + enable_int8_ = true; + os << "int2"; + return; + } else if (t.lanes() == 16) { + enable_int8_ = true; + os << "int4"; + return; + } else if (!t.is_uint() && t.is_scalar()) { + os << "signed char"; + break; + } else { + os << "char"; + break; + } + } + case 16: { + if (t.is_scalar()) { + os << "short"; + } else if (t.lanes() <= 4) { + os << "short" << lanes; + } else if (t.lanes() <= 8) { + // Emit CUDA code to access int16 vector elements. + // + // short4 is stored as int2 + // + // s4.x is emitted as *(short2*)(&(i2.x)).x + // s4.y is emitted as *(short2*)(&(i2.x)).y + // s4.z is emitted as *(short2*)(&(i2.y)).x + // s4.w is emitted as *(short2*)(&(i2.y)).y + // + ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4"; + os << "int" << t.lanes() / 2; + } else { + fail = true; + } + if (!fail) { + return; + } + break; + } + case 32: { + if (t.is_scalar()) { + os << "int"; + } else if (t.lanes() <= 4) { + os << "int" << t.lanes(); + } else if (t.lanes() <= 8) { + // Emit CUDA code to access int32 vector elements for 4 < lanes <= 8. + // + // int8 is stored as longlong4 + // + // i8.v1 is emitted as *(int2*)(&(l4.x)).x + // i8.v2 is emitted as *(int2*)(&(l4.x)).y + // + ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4"; + os << "longlong" << lanes / 2; + } else { + fail = true; + } + if (!fail) { + return; + } + break; + } + case 64: { + if (t.is_scalar()) { + os << "int64_t"; + } else if (t.lanes() == 2) { + os << "longlong2"; + } else if (t.lanes() == 3) { + os << "longlong3"; + } else if (t.lanes() == 4) { + os << "longlong4"; + } + return; + } + default: + fail = true; + break; + } + if (!fail && lanes == 1) { + return; + } + if (!fail && (lanes >= 2 && lanes <= 4)) { + os << lanes; + return; + } + } + LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; +} + +/*void CodeGenHIP::VisitStmt_(const RasterNode* op) { + ICHECK(is_positive_const(op->stage)); + stream << "\n"; + PrintIndent(); + stream << "const int MAX_BLOCK_N = " << op->stage << ";"; + stream << R"( + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; + const auto bz = blockIdx.z; + const dim3 blockIdx(bx, by, bz); + +)"; +}*/ + +void CodeGenHIP::PrintStorageSync(const CallNode* op) { + const std::string& sync = op->args[0].as()->value; + if (sync == "warp") { + // DO nothing. + } else if (sync == "shared" || sync == "shared.dyn") { + this->PrintIndent(); + this->stream << "__syncthreads();\n"; + } else if (sync == "global") { + if (!need_global_barrier_) { + need_global_barrier_ = true; + this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_ + << ";\n"; + } + // global synchronizer + std::string is_load = PrintExpr(op->args[1]); + std::string num_blocks = PrintExpr(op->args[2]); + this->PrintIndent(); + // In theory only threadfence is needed + // but we observed problems with only threadfence + this->stream << "__threadfence_system();\n"; + this->PrintIndent(); + this->stream << "if (" << is_load << ") {\n"; + int wb = this->BeginScope(); + this->PrintIndent(); + this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; + this->PrintIndent(); + std::string ptr = name_supply_->FreshName("pf"); + this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n"; + this->PrintIndent(); + this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; + this->PrintIndent(); + this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; + this->EndScope(wb); + this->PrintIndent(); + this->stream << "}\n"; + this->PrintIndent(); + this->stream << "__syncthreads();\n"; + } +} + +void CodeGenHIP::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) + ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " + "all global arrays as input instead"; + if (scope == "shared") { + os << "__shared__ "; + } else if (scope == "shared.dyn") { + os << "extern __shared__ "; + } +} + +void CodeGenHIP::VisitExpr_(const CastNode* op, std::ostream& os) { + DataType from_ty = op->value.dtype(); + DataType target_ty = op->dtype; + ICHECK_EQ(target_ty.lanes(), from_ty.lanes()); + + // Emit simple C-style type conversion. + if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); + + // We could emit make_float4 like calls, but the emitted code looks + // too compact to read. Emit this as vectorized unary ops. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(target_ty, stream); + stream << ' ' << sret << ";\n"; + { + std::string src = SSAGetID(PrintExpr(op->value), from_ty); + for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { + std::ostringstream val; + val << "("; + PrintType(target_ty.element_of(), val); + val << ")("; + PrintVecElemLoad(src, from_ty, i, val); + val << ")"; + PrintVecElemStore(sret, target_ty, i, val.str()); + } + } + os << sret; +} + +void CodeGenHIP::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == tir::attr::fragment_shape) { + const VarNode* buffer = op->node.as(); + const StringImmNode* shape_str = op->value.as(); + fragment_shapes[buffer] = shape_str->value; + } else if (op->attr_key == tir::attr::fragment_layout) { + const VarNode* buffer = op->node.as(); + const StringImmNode* layout_str = op->value.as(); + fragment_layouts[buffer] = layout_str->value; + } else if (op->attr_key == tir::attr::async_commit_queue_scope) { + const IntImmNode* queue_id = op->value.as(); + ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; + this->VisitStmt(op->body); + auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {}); + this->VisitExpr(commit_group, this->stream); + return; + } else if (op->attr_key == tir::attr::async_wait_queue_scope) { + auto wait_attrs = GetAsyncWaitAttributes(op); + auto queue_id = wait_attrs.first.as(); + ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; + auto wait_cnt = wait_attrs.second; + auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt}); + this->VisitExpr(wait_group, this->stream); + auto inner = op->body.as(); + ICHECK(inner); + this->VisitStmt(inner->body); + return; + } + CodeGenC::VisitStmt_(op); +} + +void CodeGenHIP::VisitStmt_(const EvaluateNode* op) { + if (is_const_int(op->value)) return; + CodeGenC::VisitStmt_(op); +} + +void CodeGenHIP::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) + int lanes = op->dtype.lanes(); + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && lanes == 4) { + // make_int8x4 + const int64_t* p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xFF; + v = (v << 24) | (v << 16) | (v << 8) | v; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + return; + } + + if (op->dtype.is_float16()) { + std::string v = PrintExpr(op->value); + PrintVecConstructor(op->dtype, os); + os << "make_"; + //PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_half2(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + + if (op->dtype.is_bfloat16()) { + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 2; ++i) { + if (i != 0) os << ", "; + os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; + } + os << ')'; + return; + } + + if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) { + bool fail = false; + const int64_t* p = as_const_int(op->value); + ICHECK(p); + int64_t v = *p & 0xF; + + if (lanes == 4) { + v = (v << 12) | (v << 8) | (v << 4) | v; + if (op->dtype.is_uint()) { + os << "(uint16_t)" << v; + } else { + os << "(int16_t)" << v; + } + } else { + v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v; + if (lanes == 8) { + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } else if (lanes == 16 || lanes == 32) { + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes / 8; ++i) { + if (i != 0) os << ", "; + if (op->dtype.is_uint()) { + os << "(uint)" << v; + } else { + os << "(int)" << v; + } + } + os << ')'; + } else { + fail = true; + } + } + + if (!fail) { + return; + } + } + + std::string v = PrintExpr(op->value); + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0; i < lanes; ++i) { + if (i != 0) os << ", "; + os << v; + } + os << ')'; +} + +void CodeGenHIP::VisitExpr_(const ShuffleNode* op, std::ostream& os) { + std::vector to_shuffle(op->vectors.size()); + for (int i = 0, e = op->vectors.size(); i < e; ++i) { + ICHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; + to_shuffle[i] = PrintExpr(op->vectors[i]); + } + os << "make_"; + PrintType(op->dtype, os); + os << '('; + for (int i = 0, e = op->indices.size(); i < e; ++i) { + const int64_t* val = as_const_int(op->indices[i]); + ICHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size()); + if (i != 0) os << ", "; + os << to_shuffle[*val]; + } + os << ')'; +} + +void CodeGenHIP::VisitExpr_(const SelectNode* op, std::ostream& os) { + // Non-vector cases. + /*if (!op->dtype.is_vector()) { + CodeGenC::VisitExpr_(op, os); + return; + }*/ + if (!op->dtype.is_fixed_length_vector()) { + CodeGenC::VisitExpr_(op, os); + return; + } + + // Codegen vector condition case by serializing the select op. + ICHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && + op->dtype.lanes() == op->condition.dtype().lanes()); + + std::string r_var = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(op->dtype, stream); + stream << ' ' << r_var << ";\n"; + { + std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); + std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); + std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); + + // The condition is stored as an ushort vector. + int lanes = op->dtype.lanes(); + DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes); + + for (int i = 0; i < lanes; ++i) { + std::ostringstream item; + item << "(bool("; + PrintVecElemLoad(c_var, memory_ty, i, item); + item << ")?"; + PrintVecElemLoad(t_var, op->dtype, i, item); + item << ':'; + PrintVecElemLoad(f_var, op->dtype, i, item); + item << ')'; + PrintVecElemStore(r_var, op->dtype, i, item.str()); + } + } + os << r_var; +} + +inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenHIP* p) { // NOLINT(*) + // Type code is kBFloat + if (op->dtype.is_bfloat16()) { + os << "__float2bfloat16_rn"; + os << '(' << std::scientific << op->value << 'f' << ')'; + return; + } + // Type code is kFloat + switch (op->dtype.bits()) { + case 64: + case 32: { + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); + p->need_math_constants_h_ = true; + } else if (std::isnan(op->value)) { + temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); + p->need_math_constants_h_ = true; + } else { + temp << std::scientific << op->value; + if (op->dtype.bits() == 32) temp << 'f'; + } + p->MarkConst(temp.str()); + os << temp.str(); + break; + } + case 16: { + os << "__float2half_rn" << '('; + FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); + PrintConst(const_f32.get(), os, p); + os << ')'; + break; + } + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + } +} + +void CodeGenHIP::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) + PrintConst(op, os, this); +} + +void CodeGenHIP::VisitExpr_(const CallNode* op, std::ostream& os) { + if (op->op.same_as(builtin::tvm_fill_fragment())) { + need_wmma_h_ = true; + ICHECK_EQ(op->args.size(), 6U); + os << "rocwmma::fill_fragment("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { + need_wmma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "rocwmma::load_matrix_sync("; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[6], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { + need_wmma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "rocwmma::store_matrix_sync("; + this->PrintExpr(op->args[5], os); + os << ", "; + this->PrintExpr(op->args[0], os); + os << "["; + this->PrintExpr(op->args[4], os); + os << "], "; + this->PrintExpr(op->args[6], os); + if (const StringImmNode* str = op->args[7].as()) { + os << ", rocwmma::mem_" << str->value; + } else { + LOG(FATAL) << "Invalid parameters"; + } + os << ")"; + } else if (op->op.same_as(builtin::tvm_mma_sync())) { + need_wmma_h_ = true; + ICHECK_EQ(op->args.size(), 8U); + os << "rocwmma::mma_sync("; + for (int i = 0; i < 4; ++i) { + this->PrintExpr(op->args[i * 2], os); + os << "["; + this->PrintExpr(op->args[i * 2 + 1], os); + os << "]" << ((i < 3) ? ", " : ")"); + } + } else if (op->op.same_as(builtin::tvm_mfma())) { + // arg 0: prefix: {otype}_16x16x16{itype} + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: float16, float32, ... + // arg 4: B precision: float16, float32, ... + // arg 5: C precision: float32, float64, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + + ICHECK(op->args.size() == 12U) << "Invalid number of arguments for tvm_mfma"; + std::string prefix = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_bias = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_bias = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_bias = this->PrintExpr(op->args[11]); + ICHECK(A_layout == "row" || B_layout == "row") << "Matrix core only support row major"; + // map for dtype -> float32x4 -> float4 + std::unordered_map dtype_map = { + {"int8", "char"}, + {"int32", "int"}, + {"int32x4", "int32x4"}, + {"float16", "half"}, + {"float32", "float"}, + {"float64", "double"}, + {"float16x4", "float16x4"}, + {"float32x4", "float32x4"}, + {"float32x16", "float32x16"} + }; + std::string call_mfma_code = R"({ + *((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}), + *((({B_dytpe}*){b_ref}) + {b_bias}), + *((({C_dytpe}*){c_ref}) + {c_bias}), 0, 0, 0); + })"; + std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix; + Replacer replacer; + replacer.register_rule("{mfma_buildin}", mfma_buildin); + replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]); + replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]); + replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]); + replacer.register_rule("{a_ref}", a_ref); + replacer.register_rule("{a_bias}", a_bias); + replacer.register_rule("{b_ref}", b_ref); + replacer.register_rule("{b_bias}", b_bias); + replacer.register_rule("{c_ref}", c_ref); + replacer.register_rule("{c_bias}", c_bias); + os << replacer.rewrite(call_mfma_code); + } else if (op->op.same_as(builtin::tvm_mfma_store())) { + int m = Downcast(op->args[0])->value; + int n = Downcast(op->args[1])->value; + std::string dst = this->PrintExpr(op->args[2]); + std::string src = this->PrintExpr(op->args[3]); + std::string src_offset = this->PrintExpr(op->args[4]); + PrimExpr stride = op->args[5]; + + ICHECK((m == 16 && n == 16) || (m == 32 && n==32)) << "Only m == 16 && n == 16 or m == 32 && n == 32 case supported for now"; + + if(m == 16){ + // Each thread in a warp holds a certain number of elements of an MMA output. + // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements + // in its registers. So conceptually, a warp memory is organized as a 32x8 block. + // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. + + // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this + // map to determine the output location for each 8 element. + const auto* index_map_func = + runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_64x4_layout"); + ICHECK(index_map_func); + + arith::Analyzer analyzer; + auto inverse_index_map = + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, &analyzer); + auto indices_16x16 = inverse_index_map->final_indices; + + // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are + // fine. FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually + // replace them to the plain ones here. + class LowerFloorDivMod : public ExprMutator { + public: + PrimExpr VisitExpr_(const FloorDivNode* op) { + return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + PrimExpr VisitExpr_(const FloorModNode* op) { + return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + }; + + auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); + + var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; + var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; + + os << "for (int local_id = 0; local_id < 4; ++local_id) {\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "]" + << " = " << src << "[" << src_offset << " + local_id];\n"; + os << "}\n"; + }else{ + // Each thread in a warp holds a certain number of elements of an MMA output. + // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements + // in its registers. So conceptually, a warp memory is organized as a 32x8 block. + // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. + + // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this + // map to determine the output location for each 8 element. + const auto* index_map_func = + runtime::Registry::Get("tir.index_map.shared_32x32_to_ldmatrix_64x16_layout"); + ICHECK(index_map_func); + + arith::Analyzer analyzer; + auto inverse_index_map = + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, &analyzer); + auto indices_32x32 = inverse_index_map->final_indices; + + // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are + // fine. FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually + // replace them to the plain ones here. + class LowerFloorDivMod : public ExprMutator { + public: + PrimExpr VisitExpr_(const FloorDivNode* op) { + return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + PrimExpr VisitExpr_(const FloorModNode* op) { + return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + }; + + auto dst_ind = LowerFloorDivMod()(indices_32x32[0] * stride + indices_32x32[1]); + + var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; + var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; + + os << "for (int local_id = 0; local_id < 16; ++local_id) {\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "]" + << " = " << src << "[" << src_offset << " + local_id];\n"; + os << "}\n"; + } + + } else if (op->op.same_as(builtin::tvm_rdna_wmma())) { + // arg 0: prefix: {otype}_16x16x16{itype} + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: float16, float32, ... + // arg 4: B precision: float16, float32, ... + // arg 5: C precision: float32, float64, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + + ICHECK(op->args.size() == 12U) << "Invalid number of arguments for tvm_mfma"; + std::string prefix = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_bias = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_bias = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_bias = this->PrintExpr(op->args[11]); + ICHECK(A_layout == "row" || B_layout == "row") << "Matrix core only support row major"; + // map for dtype -> float32x4 -> float4 + std::unordered_map dtype_map = { + {"int8", "char"}, + {"int32", "int"}, + {"int32x4", "int32x4"}, + {"float16", "half"}, + {"float32", "float"}, + {"float64", "double"}, + {"float16x4", "float16x4"}, + {"float16x8", "float16x16"}, + {"float16x16", "float16x16"}, + {"float32x4", "float32x4"}, + {"float32x16", "float32x16"} + }; + std::string call_mfma_code = R"({ + *((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}), + *((({B_dytpe}*){b_ref}) + {b_bias}), + *((({C_dytpe}*){c_ref}) + {c_bias}), false); + })"; + std::string mfma_buildin = "__builtin_amdgcn_wmma_" + prefix; + Replacer replacer; + replacer.register_rule("{mfma_buildin}", mfma_buildin); + replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]); + replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]); + replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]); + replacer.register_rule("{a_ref}", a_ref); + replacer.register_rule("{a_bias}", a_bias); + replacer.register_rule("{b_ref}", b_ref); + replacer.register_rule("{b_bias}", b_bias); + replacer.register_rule("{c_ref}", c_ref); + replacer.register_rule("{c_bias}", c_bias); + os << replacer.rewrite(call_mfma_code); + } else if (op->op.same_as(builtin::tvm_rdna_wmma_store())) { + int m = Downcast(op->args[0])->value; + int n = Downcast(op->args[1])->value; + std::string dst = this->PrintExpr(op->args[2]); + std::string src = this->PrintExpr(op->args[3]); + std::string src_offset = this->PrintExpr(op->args[4]); + PrimExpr stride = op->args[5]; + + ICHECK((m == 16 && n == 16) || (m == 32 && n==32)) << "Only m == 16 && n == 16 or m == 32 && n == 32 case supported for now"; + + if(m == 16){ + // Each thread in a warp holds a certain number of elements of an MMA output. + // For example, if we compute a 16x16 tile using MMA, each thread holds 8 elements + // in its registers. So conceptually, a warp memory is organized as a 32x8 block. + // A map from a 16x16 tile to a 32x8 block of memory is specified by the index map below. + + // To store the 32x8 output back to a 16x16 tile in shared or global memory, we invert this + // map to determine the output location for each 8 element. + const auto* index_map_func = + runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_16x16_layout"); + ICHECK(index_map_func); + + auto inverse_index_map = + IndexMap::FromFunc(2, *index_map_func); + auto indices_16x16 = inverse_index_map->final_indices; + + // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are + // fine. FloorDiv/Mod are supposed to be lowered before they reach codegen, so manually + // replace them to the plain ones here. + class LowerFloorDivMod : public ExprMutator { + public: + PrimExpr VisitExpr_(const FloorDivNode* op) { + return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + PrimExpr VisitExpr_(const FloorModNode* op) { + return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); + } + }; + + auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + indices_16x16[1]); + + var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; + var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; + + os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; + os << dst << "[" + this->PrintExpr(dst_ind) + "]" + << " = " << src << "[" << src_offset << " + local_id * 2];\n"; + os << "}\n"; + } + } else { + CodeGenC::VisitExpr_(op, os); + } +} + +void CodeGenHIP::VisitStmt_(const AllocateNode* op) { + ICHECK(!is_zero(op->condition)); + std::string vid = AllocVarID(op->buffer_var.get()); + + this->PrintIndent(); + std::string scope = GetPtrStorageScope(op->buffer_var); + const VarNode* buffer = op->buffer_var.as(); + if (scope.find("wmma.") == 0) { + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || + op->dtype == DataType::BFloat(16)) + << "Matrix_a and matrix_b only support half or char or unsigned char " + << "or uint4 or int4 or int1 type for now"; + } else { + ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || + op->dtype == DataType::Int(32)) + << "Accumulator only support half, float and int type for now"; + } + PrintWmmaScope(scope, op->dtype, buffer, stream); + } else { + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + } + + if (scope == "shared.dyn") { + stream << ' ' << vid << "[];\n"; + } else { + size_t constant_size = op->ConstantAllocationSize(); + ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + + if (scope.find("wmma.") == 0) { + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + } + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); + } + stream << ' ' << vid << '[' << constant_size << "];\n"; + } + + RegisterHandleType(op->buffer_var.get(), op->dtype); + this->PrintStmt(op->body); +} + +void CodeGenHIP::PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) { // NOLINT(*) + DataType ret_dtype = GetRuntimeDataType(ret_type); + if (ret_dtype.is_fixed_length_vector()) { + // + // Emit an unsupported vector call + // + // v = intrin_f((float4*)A[0], (float4*)B[0]) + // + // as + // + // float4 __ret; + // { + // float4 __arg0 = ((float4*)A)[0]; + // float4 __arg1 = ((float4*)B)[0]; + // __ret.x = intrin_f(__arg0.x, __arg1.x); + // __ret.y = intrin_f(__arg0.y, __arg1.y); + // __ret.z = intrin_f(__arg0.z, __arg1.z); + // __ret.w = intrin_f(__arg0.w, __arg1.w); + // } + // v = __ret; + // + // Declare the result vector. + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(ret_dtype, stream); + stream << ' ' << sret << ";\n"; + { + // Load arguments. + std::vector sargs; + size_t arg_begin = static_cast(skip_first_arg); + for (size_t i = arg_begin; i < args.size(); ++i) { + std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype()); + sargs.push_back(std::move(val)); + } + + // Emit a scalar call for each lane. + for (int i = 0; i < ret_dtype.lanes(); ++i) { + std::ostringstream scall; + scall << global_symbol << "("; + for (size_t j = 0; j < sargs.size(); ++j) { + if (j > 0) scall << ", "; + PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall); + } + scall << ")"; + PrintVecElemStore(sret, ret_dtype, i, scall.str()); + } + } + os << sret; + } else { + CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os); + } +} + +void CodeGenHIP::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os) { + std::stringstream type; + PrintType(t, type); + ICHECK(fragment_shapes.count(variable)) + << "Cannot find shape of the wmma fragment " << variable->name_hint; + std::string shape_str = fragment_shapes.at(variable); + if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) { + type.str(std::string()); + if (t.is_int()) { + if (t.bits() == 4) { + type << "rocwmma::experimental::precision::s4"; + } else if (t.bits() == 1) { + type << "rocwmma::experimental::precision::b1"; + } else { + LOG(FATAL) << "Unhandled interger type for wmma fragment!"; + } + } else if (t.is_uint()) { + if (t.bits() == 4) { + type << "rocwmma::experimental::precision::u4"; + } else { + LOG(FATAL) << "Unhandled interger type for wmma fragment!"; + } + } + } + if (scope == "wmma.matrix_a") { + need_wmma_h_ = true; + std::string layout_str = fragment_layouts[variable]; + ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; + os << "rocwmma::fragment"; + } else if (scope == "wmma.matrix_b") { + need_wmma_h_ = true; + std::string layout_str = fragment_layouts[variable]; + ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; + os << "rocwmma::fragment"; + } else if (scope == "wmma.accumulator") { + need_wmma_h_ = true; + os << "rocwmma::fragment"; + } +} + +int32_t CodeGenHIP::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, + int32_t size) { + ICHECK(fragment_shapes.count(variable)) + << "Cannot find shape of the wmma fragment " << variable->name_hint; + + auto stoi = [](const std::string& str) { + try { + return std::stoi(str); + } catch (std::invalid_argument& e) { + LOG(FATAL) << "Cannot convert \"" << str << "\" to int"; + throw; + } + }; + + std::string shape_str = fragment_shapes.at(variable); + size_t m, n, k; + size_t last_pos = 0, pos = 0; + pos = shape_str.find(", ", last_pos); + m = stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + pos = shape_str.find(", ", last_pos); + n = stoi(shape_str.substr(last_pos, pos - last_pos)); + last_pos = pos + 2; + k = stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); + if (scope == "wmma.matrix_a") { + return size / m / k; + } else if (scope == "wmma.matrix_b") { + return size / n / k; + } else if (scope == "wmma.accumulator") { + return size / m / n; + } + return 0; +} + +void CodeGenHIP::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) { + ICHECK_GT(t.lanes(), 1); + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { + if (!(t.lanes() == 2 || t.lanes() == 3)) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + return; + } + } + + if (t.is_float16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_half2(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + + if (t.is_bfloat16()) { + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << '('; + } + if (i % 2 == 0) { + os << "__pack_bfloat162(" << value; + } else { + os << "," << value << ")"; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + } + return; + } + + if (i == 0) { + os << "make_"; + PrintType(t, os); + os << "("; + } + os << value; + if (i != t.lanes() - 1) { + os << ","; + } else { + os << ")"; + } + return; +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/codegen_hip.h b/src/target/source/codegen_hip.h new file mode 100644 index 000000000000..87a99647d3d9 --- /dev/null +++ b/src/target/source/codegen_hip.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file codegen_hip.h + * \brief Utility to generate hip source code + */ +#ifndef TVM_TARGET_SOURCE_CODEGEN_HIP_H_ +#define TVM_TARGET_SOURCE_CODEGEN_HIP_H_ + +#include +#include +#include + +#include +#include + +#include "codegen_c.h" +#include "codegen_cuda.h" + +namespace tvm { +namespace codegen { + +class CodeGenHIP final : public CodeGenC { + public: + CodeGenHIP(); + void Init(bool output_ssa); + std::string Finish(); + bool need_include_path() { return (need_math_constants_h_ || need_wmma_h_); } + // override behavior + void PrintFuncPrefix(std::ostream& os) final; + void VisitStmt_(const ForNode* op) final; + void PrintStorageSync(const CallNode* op) final; + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os){ + cuda_codegen_.PrintVecBinaryOp(op, t, lhs, rhs, os); + }; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os){ + return cuda_codegen_.PrintVecElemLoad(vec, t, i, os); + }; // NOLINT(*) + void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) {return cuda_codegen_.PrintVecElemStore(vec, t, i, value);}; + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; + std::string CastFromTo(std::string value, DataType from, DataType target) { + return cuda_codegen_.CastFromTo(value, from, target); + }; + // overload visitor + void VisitExpr_(const RampNode* op, std::ostream& os) { + return cuda_codegen_.VisitExpr_(op, os); + }; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; + void VisitExpr_(const CallNode* op, std::ostream& os) final; + void VisitExpr_(const CastNode* op, std::ostream& os) final; + void VisitStmt_(const EvaluateNode* op) final; + void VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const AttrStmtNode* op) final; + //void VisitStmt_(const RasterNode* op) final; + + protected: + void PrintCallExtern(Type ret_type, String global_symbol, const Array& args, + bool skip_first_arg, std::ostream& os) final; // NOLINT(*) + + private: + CodeGenCUDA cuda_codegen_; + // Whether global barrier is needed. + bool need_global_barrier_{false}; + // Global barrier state + std::string vid_global_barrier_state_; + // Global barrier expected node. + std::string vid_global_barrier_expect_; + + // whether need math_constants.h + bool need_math_constants_h_{false}; + // whether need mfma.h + bool need_wmma_h_{false}; + // whether enable fp16 + bool enable_fp16_{false}; + // whether enable bf16 + bool enable_bf16_{false}; + // whether enable int8 + bool enable_int8_{false}; + std::unordered_map fragment_shapes; + std::unordered_map fragment_layouts; + friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenHIP* p); + void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os); + int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); +}; + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_CODEGEN_CUDA_H_ diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 708d3ccd7621..2a06c369277e 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -193,6 +193,30 @@ TargetJSON UpdateNVPTXAttrs(TargetJSON target) { return target; } +/*! + * \brief Update the attributes in the HIP target. + * \param target The Target to update + * \return The updated attributes + */ +TargetJSON UpdateHIPAttrs(TargetJSON target) { + using tvm::runtime::Registry; + // Update -mcpu=gfx + std::string arch = "gfx900"; + if (target.count("mcpu")) { + String mcpu = Downcast(target.at("mcpu")); + arch = ExtractStringWithPrefix(mcpu, "gfx"); + ICHECK(!arch.empty()) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu; + } else { + TVMRetValue val; + if (const auto* f_get_rocm_arch = Registry::Get("tvm_callback_rocm_get_arch")) { + arch = (*f_get_rocm_arch)().operator std::string(); + } + } + target.Set("mcpu", String(arch)); + LOG(INFO) << "HIP target uses -mcpu=" << arch; + return target; +} + /*! * \brief Update the attributes in the LLVM ROCm target. * \param target The Target to update @@ -326,6 +350,17 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); +TVM_REGISTER_TARGET_KIND("hip", kDLROCM) + .add_attr_option("mcpu") + // TODO(lei/masihi): Support querying from a target device + // On RDNA cards, thread_warp_size should be 32 + .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_threads_per_block", Integer(256)) + .add_attr_option("max_shared_memory_per_block", Integer(65536)) + .add_attr_option("thread_warp_size", Integer(64)) + .set_default_keys({"rocm", "gpu"}) + .set_target_parser(UpdateHIPAttrs); + TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option("mcpu") .add_attr_option("mtriple") diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index fbe31c890dad..a55f4532e8b1 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -327,6 +327,18 @@ TIR_DEFINE_BUILTIN_FUNC(mma_fill) .set_attr("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); +TIR_DEFINE_BUILTIN_FUNC(tvm_mfma) + .set_attr("TCallEffectKind",Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(tvm_mfma_store) + .set_attr("TCallEffectKind",Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(tvm_rdna_wmma) + .set_attr("TCallEffectKind",Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(tvm_rdna_wmma_store) + .set_attr("TCallEffectKind",Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) .set_attr("TScriptDtypePrintLocation",