Skip to content

Commit

Permalink
[TIR] Extend DP4A tensor intrin (#16293)
Browse files Browse the repository at this point in the history
* update dp4a tensor intrin

* update dp4a tensor intrin

* lint

---------

Co-authored-by: Lufang CHEN 陈橹方 <[email protected]>
  • Loading branch information
vincentccc and Lufang CHEN 陈橹方 authored Jan 8, 2024
1 parent 8e54a9e commit 4c77f0f
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 56 deletions.
9 changes: 7 additions & 2 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
# pylint: disable=invalid-name,missing-function-docstring,unused-import
"""Intrinsics for ARM tensorization."""
from tvm.script import tir as T
from .. import TensorIntrin
from .dot_product_common import DP4A_INTRIN # pylint: disable=unused-import
from .dot_product_common import (
DP4A_S8S8S32_INTRIN,
DP4A_S8U8S32_INTRIN,
DP4A_U8S8S32_INTRIN,
DP4A_U8U8U32_INTRIN,
)


# TODO(masahi): Parametrize the TVMScript description of dot product by
Expand Down
82 changes: 49 additions & 33 deletions python/tvm/tir/tensor_intrin/dot_product_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,52 @@
from .. import TensorIntrin


@T.prim_func
def dp4a_desc(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])
for i in range(0, 4):
with T.block("update"):
vi = T.axis.remap("R", [i])
C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")


@T.prim_func
def dp4a_impl(
A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])

C[0] += T.call_pure_extern(
"__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32"
)


DP4A_INTRIN = "dp4a"

TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
def get_dp4a_intrin(dtype_a, dtype_b, dtype_c):
if dtype_c == "uint32":
assert dtype_a == dtype_b == "uint8"
vec_type_a = "int8x4" if dtype_a == "int8" else "uint8x4"
vec_type_b = "int8x4" if dtype_b == "int8" else "uint8x4"

@T.prim_func
def dp4a_desc(
A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])
for i in range(0, 4):
with T.block("update"):
vi = T.axis.remap("R", [i])
C[0] = C[0] + T.cast(A[vi], dtype_c) * T.cast(B[vi], dtype_c)

@T.prim_func
def dp4a_impl(
A: T.Buffer((4,), dtype_a, offset_factor=1, align=4, scope="shared"),
B: T.Buffer((4,), dtype_b, offset_factor=1, align=4, scope="shared"),
C: T.Buffer((1,), dtype_c, offset_factor=1, align=4, scope="local"),
) -> None:
with T.block("root"):
T.reads(C[0], A[0:4], B[0:4])
T.writes(C[0])

C[0] += T.call_pure_extern(
"__dp4a",
A.vload([0], vec_type_a),
B.vload([0], vec_type_b),
T.uint32(0) if dtype_c == "uint32" else T.int32(0),
dtype=dtype_c,
)

return dp4a_desc, dp4a_impl


DP4A_S8S8S32_INTRIN = "dp4a_s8s8s32"
TensorIntrin.register(DP4A_S8S8S32_INTRIN, *get_dp4a_intrin("int8", "int8", "int32"))
DP4A_U8S8S32_INTRIN = "dp4a_u8s8s32"
TensorIntrin.register(DP4A_U8S8S32_INTRIN, *get_dp4a_intrin("uint8", "int8", "int32"))
DP4A_S8U8S32_INTRIN = "dp4a_s8u8s32"
TensorIntrin.register(DP4A_S8U8S32_INTRIN, *get_dp4a_intrin("int8", "uint8", "int32"))
DP4A_U8U8U32_INTRIN = "dp4a_u8u8u32"
TensorIntrin.register(DP4A_U8U8U32_INTRIN, *get_dp4a_intrin("uint8", "uint8", "uint32"))
3 changes: 2 additions & 1 deletion python/tvm/tir/tensor_intrin/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from tvm.runtime import convert
from tvm.tir.expr import Cast, IntImm
from .dot_product_common import dp4a_desc
from .dot_product_common import get_dp4a_intrin
from .. import TensorIntrin


Expand Down Expand Up @@ -50,6 +50,7 @@ def sdot4(

AMDGPU_SDOT4_INTRIN = "sdot4"

dp4a_desc, _ = get_dp4a_intrin("int8", "int8", "int32")
TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)

WARP_SIZE = 64
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include "../../tir/transforms/ir_utils.h"
#include "literal/cuda_half_t.h"
#include "literal/cuda_int8_t.h"
#include "ptx.h"

namespace tvm {
Expand Down Expand Up @@ -130,6 +131,7 @@ std::string CodeGenCUDA::Finish() {
if (enable_int8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
decl_stream << "#include <sm_61_intrinsics.h>\n";
decl_stream << _cuda_int8_t_def;
decl_stream << "#endif\n";
}

Expand Down
64 changes: 64 additions & 0 deletions src/target/source/literal/cuda_int8_t.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file cuda_int8_t.h
* \brief Extra int8 intrisic for cuda codegen.
*/
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
#define TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_

static constexpr const char* _cuda_int8_t_def = R"(

#if defined(__CUDACC_RTC__)
#define __SM_61_INTRINSICS_DECL__ __device__
#else /* !__CUDACC_RTC__ */
#define __SM_61_INTRINSICS_DECL__ static __device__ __inline__
#endif /* __CUDACC_RTC__ */

#ifndef __CUDA_ARCH__
#define __DEF_IF_HOST { }
#else /* !__CUDA_ARCH__ */
#define __DEF_IF_HOST ;
#endif /* __CUDA_ARCH__ */

__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) __DEF_IF_HOST
__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) __DEF_IF_HOST

#undef __DEF_IF_HOST

#if !defined(__CUDACC_RTC__) && defined(__CUDA_ARCH__)
__SM_61_INTRINSICS_DECL__ int __dp4a(unsigned int srcA, int srcB, int c) {
int ret;
asm volatile ("dp4a.u32.s32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA), "r"(srcB), "r"(c));
return ret;
}

__SM_61_INTRINSICS_DECL__ int __dp4a(int srcA, unsigned int srcB, int c) {
int ret;
asm volatile ("dp4a.s32.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(srcA), "r"(srcB), "r"(c));
return ret;
}
#endif /* !__CUDACC_RTC__ && defined(__CUDA_ARCH__) */

#undef __SM_61_INTRINSICS_DECL__

)";

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_INT8_T_H_
50 changes: 30 additions & 20 deletions tests/python/tir-schedule/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
verify_trace_roundtrip,
)
from tvm.tir.tensor_intrin.arm_cpu import (
DP4A_INTRIN,
DP4A_S8S8S32_INTRIN,
DP4A_U8U8U32_INTRIN,
DP4A_U8S8S32_INTRIN,
DP4A_S8U8S32_INTRIN,
ARM_DOT_4x4_i8_NEON_INTRIN,
ARM_DOT_4x4_i8_SDOT_INTRIN,
)
Expand Down Expand Up @@ -687,26 +690,25 @@ def test_tensorize_vdmpy():
verify_trace_roundtrip(sch=sch, mod=func)


def test_tensorize_dpa4():
m, n, k = 128, 128, 128

X = te.placeholder((m, k), name="X", dtype="int8")
W = te.placeholder((n, k), name="W", dtype="int8")
ak = te.reduce_axis((0, k), name="k")

matmul = te.compute(
(m, n),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* W[j, ak].astype("int32"),
axis=ak,
),
name="compute",
)
def test_tensorize_dp4a():
# pylint: disable=too-many-locals
def _test_intrin(dtype_a, dtype_b, dtype_c, intrin):
m, n, k = 128, 128, 128
X = te.placeholder((m, k), name="X", dtype=dtype_a)
W = te.placeholder((n, k), name="W", dtype=dtype_b)
ak = te.reduce_axis((0, k), name="k")

matmul = te.compute(
(m, n),
lambda i, j: te.sum(
X[i, ak].astype(dtype_c) * W[j, ak].astype(dtype_c),
axis=ak,
),
name="compute",
)

func = te.create_prim_func([X, W, matmul])
func = te.create_prim_func([X, W, matmul])

for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]:
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
i, j, k = sch.get_loops(block)
Expand All @@ -717,7 +719,6 @@ def test_tensorize_dpa4():
ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))

sch.reorder(by, bx, ty, tx, yi, xi)

CC = sch.cache_write(block, 0, "local")
sch.reverse_compute_at(CC, tx)

Expand All @@ -734,6 +735,15 @@ def fetch_to_shared(block, idx):

verify_trace_roundtrip(sch=sch, mod=func)

for args in [
("int8", "int8", "int32", AMDGPU_SDOT4_INTRIN),
("int8", "int8", "int32", DP4A_S8S8S32_INTRIN),
("int8", "uint8", "int32", DP4A_S8U8S32_INTRIN),
("uint8", "int8", "int32", DP4A_U8S8S32_INTRIN),
("uint8", "uint8", "uint32", DP4A_U8U8U32_INTRIN),
]:
_test_intrin(*args)


def test_tensor_intrin_look_up():
intrin_name = 'non_existent_intrin'
Expand Down

0 comments on commit 4c77f0f

Please sign in to comment.