Skip to content

Commit

Permalink
[Dev][AMD]add tensor_
Browse files Browse the repository at this point in the history
  • Loading branch information
Cunxiao2002 committed Oct 29, 2024
1 parent 072a5a1 commit e04eecd
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# under the License.
# pylint: disable=unused-import
"""Intrinsics for tensorization."""
from . import arm_cpu, cuda, rocm, x86, hexagon
from . import arm_cpu, cuda, rocm, x86, hexagon, hip
129 changes: 129 additions & 0 deletions python/tvm/tir/tensor_intrin/hip.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tvm.tir
from tvm.runtime import convert
from tvm.tir.expr import Cast, IntImm
from tvm.tir.function import TensorIntrin
from tvm.script import tir as T
from typing import Dict, Optional, Tuple, Literal, List

lift = convert

Expand Down Expand Up @@ -439,3 +442,129 @@ def mfma_store_impl(a: T.handle, c: T.handle) -> None:
TensorIntrin.register(
HIP_MFMA_STORE_16x16_f32_INTRIN, *get_mfma_store_intrin(4, "float32", "global")
)

def get_mfma_intrin_group(
load_scope: Literal["shared", "shared.dyn"] = "shared",
store_scope: Literal["global", "shared", "shared.dyn"] = "global",
a_dtype: Literal["float16", "int8", "bfloat16", "e4m3_float8", "e5m2_float8"] = "float16",
b_dtype: Literal["float16", "int8", "bfloat16", "e4m3_float8", "e5m2_float8"] = "float16",
out_dtype: Literal["float16", "float32", "int32"] = "float16",
trans_a: bool = False,
trans_b: bool = False,
not_use_mfma_store_intrinic: bool = True,
store_to_smem_dtype: Optional[Literal["float16", "float32", "int32"]] = None,
) -> Dict[str, str]:
"""Get a group of intrinsics for mma tensor core with the given configurations
Parameters
----------
load_scope : Literal["shared", "shared.dyn"]
The memory scope of the input buffer.
store_scope : Literal["global", "shared", "shared.dyn"]
The memory scope of the result buffer.
a_dtype : str
The dtype of the input matrix A.
b_dtype : str
The dtype of the input matrix B.
out_dtype : str
The output data dtype.
trans_b : bool
Whether the input matrix B is transposed.
not_use_mma_store_intrinic : bool
Whether to not use the mma_store intrinsic. If True, use BufferStore stmts to store the
result of mma. Otherwise, use mfma_store intrinsic.
This is because if we use mfma_store intrinsic, during swizzling shared memory visits, our
rearrangement scheme will involve areas accessed by different mma_store calls. This makes
swizzling quite complex. But BufferStore will not face this problem.
store_to_smem_dtype : Optional[Literal["float16", "float32", "int32"]]
The dtype that we use to store from register to shared memory. By default it is out_dtype.
Returns
-------
ret : Dict[str, str]
A group of tensor intrinsics.
"""
assert load_scope in ["shared", "shared.dyn"]
assert store_scope in ["global", "shared", "shared.dyn"]
assert a_dtype in ["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"]
assert b_dtype in ["float16", "bfloat16", "int8", "e4m3_float8", "e5m2_float8"]
assert out_dtype in ["float16", "float32", "int32"]

shape = "16x16"

dtype_mapping = {
"float16": "f16",
"bfloat16": "bf16",
"float32": "f32",
"int8": "i8",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
"int32": "i32",
}
a_dtype = dtype_mapping[a_dtype]
b_dtype = dtype_mapping[b_dtype]
out_dtype = dtype_mapping[out_dtype]

# e.g. HIP_mfma_fill_16x16_f32
init_intrin = f"HIP_mfma_fill_{shape}_{out_dtype}"

# TODO should change these
# e.g. hip_mfma_load_16x4_a_shared_f32, hip_mfma_load_16x16_a_shared_s8
trans_a = "_trans" if trans_a else ""
trans_b = "_trans" if trans_b else ""
if a_dtype == "f32":
load_a_intrin = f"hip_mfma_load_16x4_a_shared_{out_dtype}"
else:
load_a_intrin = f"hip_mfma_load_16x16_a_shared_{out_dtype}"

if b_dtype == "f32":
load_b_intrin = f"hip_mfma_load_b_16x4_shared_{out_dtype}"
else:
load_b_intrin = f"hip_mfma_load_b_16x16_shared_{out_dtype}"

# e.g. hip_mfma_f32f32f32
compute_intrin = (
f"hip_mfma_{a_dtype}{b_dtype}{out_dtype}"
)

# e.g. hip_mfma_store_16x16_s32
# store_scope = store_scope.replace(".", "_")
# store_to_smem_dtype = dtype_mapping[store_to_smem_dtype] if store_to_smem_dtype else out_dtype
store_intrin = f"hip_mfma_store_{shape}_{a_dtype}"

index_map_c = shared_16x16_to_local_64x4_layout_C
if a_dtype in ["f16", "bf16"]:
index_map_a = shared_16x16_to_local_64x4_layout_A
index_map_b = shared_16x16_to_local_64x4_layout_B
elif a_dtype in ["i8", "e4m3", "e5m2"]:
index_map_a = shared_16x4_to_local_64x1_layout_A
index_map_b = shared_4x16_to_local_64x1_layout_B
else:
raise ValueError(f"Unsupported in_dtype: {a_dtype}")

# micro kernel size, the order is [m, n, k]
micro_kernel: List[int]
if a_dtype in ["f16", "bf16"]:
micro_kernel = [16, 16, 16]
elif a_dtype in ["i8", "e4m3", "e5m2"]:
micro_kernel = [16, 16, 32]
else:
raise ValueError(f"Unsupported in_dtype: {a_dtype}")

return {
"init": init_intrin,
"load_a": load_a_intrin,
"load_b": load_b_intrin,
"compute": compute_intrin,
"store": store_intrin,
"index_map": [index_map_a, index_map_b, index_map_c],
"micro_kernel": micro_kernel,
}

0 comments on commit e04eecd

Please sign in to comment.