From e04eecd96b54021f78a461139881ae76b0eed8e2 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 <972845868@qq.com> Date: Tue, 29 Oct 2024 02:45:53 +0000 Subject: [PATCH] [Dev][AMD]add tensor_ --- python/tvm/tir/tensor_intrin/__init__.py | 2 +- python/tvm/tir/tensor_intrin/hip.py | 129 +++++++++++++++++++++++ 2 files changed, 130 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py index 7e5a26bdeb43..c4a3368258e7 100644 --- a/python/tvm/tir/tensor_intrin/__init__.py +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -16,4 +16,4 @@ # under the License. # pylint: disable=unused-import """Intrinsics for tensorization.""" -from . import arm_cpu, cuda, rocm, x86, hexagon +from . import arm_cpu, cuda, rocm, x86, hexagon, hip diff --git a/python/tvm/tir/tensor_intrin/hip.py b/python/tvm/tir/tensor_intrin/hip.py index 18f36677118e..a82f53091bcf 100644 --- a/python/tvm/tir/tensor_intrin/hip.py +++ b/python/tvm/tir/tensor_intrin/hip.py @@ -1,8 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import tvm.tir from tvm.runtime import convert from tvm.tir.expr import Cast, IntImm from tvm.tir.function import TensorIntrin from tvm.script import tir as T +from typing import Dict, Optional, Tuple, Literal, List lift = convert @@ -439,3 +442,129 @@ def mfma_store_impl(a: T.handle, c: T.handle) -> None: TensorIntrin.register( HIP_MFMA_STORE_16x16_f32_INTRIN, *get_mfma_store_intrin(4, "float32", "global") ) + +def get_mfma_intrin_group( + load_scope: Literal["shared", "shared.dyn"] = "shared", + store_scope: Literal["global", "shared", "shared.dyn"] = "global", + a_dtype: Literal["float16", "int8", "bfloat16", "e4m3_float8", "e5m2_float8"] = "float16", + b_dtype: Literal["float16", "int8", "bfloat16", "e4m3_float8", "e5m2_float8"] = "float16", + out_dtype: Literal["float16", "float32", "int32"] = "float16", + trans_a: bool = False, + trans_b: bool = False, + not_use_mfma_store_intrinic: bool = True, + store_to_smem_dtype: Optional[Literal["float16", "float32", "int32"]] = None, +) -> Dict[str, str]: + """Get a group of intrinsics for mma tensor core with the given configurations + + Parameters + ---------- + load_scope : Literal["shared", "shared.dyn"] + The memory scope of the input buffer. + + store_scope : Literal["global", "shared", "shared.dyn"] + The memory scope of the result buffer. + + a_dtype : str + The dtype of the input matrix A. + + b_dtype : str + The dtype of the input matrix B. + + out_dtype : str + The output data dtype. + + trans_b : bool + Whether the input matrix B is transposed. + + not_use_mma_store_intrinic : bool + Whether to not use the mma_store intrinsic. If True, use BufferStore stmts to store the + result of mma. Otherwise, use mfma_store intrinsic. + + This is because if we use mfma_store intrinsic, during swizzling shared memory visits, our + rearrangement scheme will involve areas accessed by different mma_store calls. This makes + swizzling quite complex. But BufferStore will not face this problem. + + store_to_smem_dtype : Optional[Literal["float16", "float32", "int32"]] + The dtype that we use to store from register to shared memory. By default it is out_dtype. + + Returns + ------- + ret : Dict[str, str] + A group of tensor intrinsics. + """ + assert load_scope in ["shared", "shared.dyn"] + assert store_scope in ["global", "shared", "shared.dyn"] + assert a_dtype in ["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"] + assert b_dtype in ["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"] + assert out_dtype in ["float16", "float32", "int32"] + + shape = "16x16" + + dtype_mapping = { + "float16": "f16", + "bfloat16": "bf16", + "float32": "f32", + "int8": "i8", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + "int32": "i32", + } + a_dtype = dtype_mapping[a_dtype] + b_dtype = dtype_mapping[b_dtype] + out_dtype = dtype_mapping[out_dtype] + + # e.g. HIP_mfma_fill_16x16_f32 + init_intrin = f"HIP_mfma_fill_{shape}_{out_dtype}" + + # TODO should change these + # e.g. hip_mfma_load_16x4_a_shared_f32, hip_mfma_load_16x16_a_shared_s8 + trans_a = "_trans" if trans_a else "" + trans_b = "_trans" if trans_b else "" + if a_dtype == "f32": + load_a_intrin = f"hip_mfma_load_16x4_a_shared_{out_dtype}" + else: + load_a_intrin = f"hip_mfma_load_16x16_a_shared_{out_dtype}" + + if b_dtype == "f32": + load_b_intrin = f"hip_mfma_load_b_16x4_shared_{out_dtype}" + else: + load_b_intrin = f"hip_mfma_load_b_16x16_shared_{out_dtype}" + + # e.g. hip_mfma_f32f32f32 + compute_intrin = ( + f"hip_mfma_{a_dtype}{b_dtype}{out_dtype}" + ) + + # e.g. hip_mfma_store_16x16_s32 + # store_scope = store_scope.replace(".", "_") + # store_to_smem_dtype = dtype_mapping[store_to_smem_dtype] if store_to_smem_dtype else out_dtype + store_intrin = f"hip_mfma_store_{shape}_{a_dtype}" + + index_map_c = shared_16x16_to_local_64x4_layout_C + if a_dtype in ["f16", "bf16"]: + index_map_a = shared_16x16_to_local_64x4_layout_A + index_map_b = shared_16x16_to_local_64x4_layout_B + elif a_dtype in ["i8", "e4m3", "e5m2"]: + index_map_a = shared_16x4_to_local_64x1_layout_A + index_map_b = shared_4x16_to_local_64x1_layout_B + else: + raise ValueError(f"Unsupported in_dtype: {a_dtype}") + + # micro kernel size, the order is [m, n, k] + micro_kernel: List[int] + if a_dtype in ["f16", "bf16"]: + micro_kernel = [16, 16, 16] + elif a_dtype in ["i8", "e4m3", "e5m2"]: + micro_kernel = [16, 16, 32] + else: + raise ValueError(f"Unsupported in_dtype: {a_dtype}") + + return { + "init": init_intrin, + "load_a": load_a_intrin, + "load_b": load_b_intrin, + "compute": compute_intrin, + "store": store_intrin, + "index_map": [index_map_a, index_map_b, index_map_c], + "micro_kernel": micro_kernel, + } \ No newline at end of file