diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 86857a33cdf4..e6665f965b5b 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -580,6 +580,17 @@ TVM_DLL const Op& tvm_fill_fragment(); */ TVM_DLL const Op& tvm_store_matrix_sync(); +/*! + * \brief tvm intrinsic for ptx tensor core mma instructions. + * + * void ptx_mma(StringImm shape, StringImm A_layout, StringImm B_layout, + * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, + * Var multiplicand_a, Expr a_index, + * Var multiplicand_b, Expr b_index, + * Var accumulator, Expr c_index, bool saturate); + */ +TVM_DLL const Op& ptx_mma(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index a52564c34a68..a1f257391db4 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -33,6 +33,7 @@ #include #include "literal/cuda_half_t.h" +#include "ptx_mma.h" namespace tvm { namespace codegen { @@ -723,6 +724,38 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } + } else if (op->op.same_as(builtin::ptx_mma())) { + // arg 0: shape: mXnXkX + // arg 1: A layout: row/col + // arg 2: B layout: row/col + // arg 3: A precision: fp16, fp64, ... + // arg 4: B precision: fp16, fp64, ... + // arg 5: C precision: fp32, fp64, ... + // arg 6: A multiplicand + // arg 7: A multiplicand index + // arg 8: B multiplicand + // arg 9: B multiplicand index + // arg 10: C accumulator + // arg 11: C accumulator index + // arg 12: saturate + ICHECK_EQ(op->args.size(), 13U); + std::string shape = Downcast(op->args[0])->value; + std::string A_layout = Downcast(op->args[1])->value; + std::string B_layout = Downcast(op->args[2])->value; + std::string A_dtype = Downcast(op->args[3])->value; + std::string B_dtype = Downcast(op->args[4])->value; + std::string C_dtype = Downcast(op->args[5])->value; + std::string a_ref = this->PrintExpr(op->args[6]); + std::string a_bias = this->PrintExpr(op->args[7]); + std::string b_ref = this->PrintExpr(op->args[8]); + std::string b_bias = this->PrintExpr(op->args[9]); + std::string c_ref = this->PrintExpr(op->args[10]); + std::string c_bias = this->PrintExpr(op->args[11]); + bool saturate = (Downcast(op->args[12])->value != 0); + std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, + a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate); + + this->stream << asm_code; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx_mma.cc new file mode 100644 index 000000000000..b6182720416c --- /dev/null +++ b/src/target/source/ptx_mma.cc @@ -0,0 +1,1374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx_mma.cc + */ + +#include "ptx_mma.h" + +namespace tvm { +namespace codegen { + +std::string ReplaceMMAArgument(std::string asm_code, const std::string& original, + const std::string& new_arg) { + size_t len = original.size(); + size_t new_len = new_arg.size(); + size_t pos = asm_code.find(original); + while (pos != std::string::npos) { + asm_code = asm_code.replace(pos, len, new_arg); + pos = asm_code.find(original, pos + new_len); + } + return asm_code; +} + +std::string PrintMMAm8n8k4Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || + ((A_dtype == "fp64") && (B_dtype == "fp64"))); + ICHECK(saturate == false) << "Saturate is not allowed for m8n8k4 mma."; + if ((A_dtype == "fp16") && (B_dtype == "fp16")) { + // A/B multiplicand is fp16, SM 70 Tensor Core instructions + ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); + if (C_dtype == "fp16") { + // C accumulator is fp16 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k4.left_layout.right_layout.f16.f16.f16.f16 " + "{%0,%1,%2,%3}, {%4,%5}, {%6,%7}, " + "{%8,%9,%10,%11};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // C accumulator is fp32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k4.left_layout.right_layout.f32.f16.f16.f32 " + "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " + "{%12,%13,%14,%15,%16,%17,%18,%19};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]), + "=f"(D[4]), "=f"(D[5]), "=f"(D[6]), "=f"(D[7]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } + )"; + } + } else { + // A/B multiplicand is fp64, SM 80 Tensor Core instructions + ICHECK(C_dtype == "fp64"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Fp64 Tensor Core instructions " + << "with shape m8n8k4 expect A layout is row major and B layout is col major."; + // C accumulator is fp64 + new_a_ref = "((double *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((double *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((double *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=d"(D[0]), "=d"(D[1]) + : "d"(A[0]), "d"(B[0]), + "d"(C[0]), "d"(C[1])); + } + )"; + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm16n8k8Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || + ((A_dtype == "bf16") && (B_dtype == "bf16"))); + ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 mma."; + if ((A_dtype == "fp16") && (B_dtype == "fp16")) { + // A/B multiplicand is fp16, SM 75 Tensor Core instructions + ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m16n8k8 expect A layout is row major and B layout is col major."; + if (C_dtype == "fp16") { + // C accumulator is fp16 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%5}, " + "{%5,%6};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // C accumulator is fp32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + } + )"; + } + } else { + // A/B multiplicand is bf16, SM 80 Tensor Core instructions + ICHECK(C_dtype == "fp32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k8 expect A layout is row major and B layout is col major."; + // C accumulator is fp32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + } + )"; + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm8n8k16Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) || + ((A_dtype == "uint8") && (B_dtype == "int8")) || + ((A_dtype == "int8") && (B_dtype == "uint8")) || + ((A_dtype == "uint8") && (B_dtype == "uint8"))); + if ((A_dtype == "int8") && (B_dtype == "int8")) { + // A/B multiplicand is int8, SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { + // A multiplicand is uint8, B multiplicand is int8 + // SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { + // A multiplicand is int8, B multiplicand is uint8 + // SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } else { + // A/B multiplicand is uint8, SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm8n8k32Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) || + ((A_dtype == "uint4") && (B_dtype == "int4")) || + ((A_dtype == "int4") && (B_dtype == "uint4")) || + ((A_dtype == "uint4") && (B_dtype == "uint4"))); + if ((A_dtype == "int4") && (B_dtype == "int4")) { + // A/B multiplicand is int4, SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } else if ((A_dtype == "uint4") && (B_dtype == "int4")) { + // A multiplicand is uint4, B multiplicand is int4 + // SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } else if ((A_dtype == "int4") && (B_dtype == "uint4")) { + // A multiplicand is int4, B multiplicand is uint4 + // SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } else { + // A/B multiplicand is uint4, SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM75 Tensor Core instructions " + << "with shape m8n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 " + "{%0,%1}, {%2}, {%3}, " + "{%4,%5};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(B[0]), + "r"(C[0]), "r"(C[1])); + } + )"; + } + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm16n8k4Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK((A_dtype == "tf32") && (B_dtype == "tf32")); + ICHECK(saturate == false) << "Saturate is not allowed for m16n8k4 mma."; + // A/B multiplicand is tf32, SM 80 Tensor Core instructions + ICHECK(C_dtype == "fp32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k4 expect A layout is row major and B layout is col major."; + // C accumulator is fp32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "f"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + } + )"; + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm16n8k16Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "fp16") && (B_dtype == "fp16")) || + ((A_dtype == "bf16") && (B_dtype == "bf16")) || + ((A_dtype == "int8") && (B_dtype == "int8")) || + ((A_dtype == "uint8") && (B_dtype == "int8")) || + ((A_dtype == "int8") && (B_dtype == "uint8")) || + ((A_dtype == "uint8") && (B_dtype == "uint8"))); + if ((A_dtype == "fp16") && (B_dtype == "fp16")) { + ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 fp16 mma."; + // A/B multiplicand is fp16, SM 80 Tensor Core instructions + ICHECK((C_dtype == "fp16") || (C_dtype == "fp32")); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k16 expect A layout is row major and B layout is col major."; + if (C_dtype == "fp16") { + // C accumulator is fp16 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((unsigned *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, " + "{%8,%9};\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1])); + } + )"; + } else { + // C accumulator is fp32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + } + )"; + } + } else if ((A_dtype == "bf16") && (B_dtype == "bf16")) { + // A/B multiplicand is bf16, SM 80 Tensor Core instructions + ICHECK(saturate == false) << "Saturate is not allowed for m16n8k8 bf16 mma."; + ICHECK(C_dtype == "fp32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is fp32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((float *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + } + )"; + } else if ((A_dtype == "int8") && (B_dtype == "int8")) { + // A/B multiplicand is int8, SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { + // A multiplicand is uint8, B multiplicand is int8 + // SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { + // A multiplicand is int8, B multiplicand is uint8 + // SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else { + // A/B multiplicand is uint8, SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k16 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, " + "{%7,%8,%9,%10};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm16n8k32Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "int8") && (B_dtype == "int8")) || + ((A_dtype == "uint8") && (B_dtype == "int8")) || + ((A_dtype == "int8") && (B_dtype == "uint8")) || + ((A_dtype == "uint8") && (B_dtype == "uint8"))); + if ((A_dtype == "int8") && (B_dtype == "int8")) { + // A/B multiplicand is int8, SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else if ((A_dtype == "uint8") && (B_dtype == "int8")) { + // A multiplicand is uint8, B multiplicand is int8 + // SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else if ((A_dtype == "int8") && (B_dtype == "uint8")) { + // A multiplicand is int8, B multiplicand is uint8 + // SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else { + // A/B multiplicand is uint8, SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k32 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm16n8k64Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "int4") && (B_dtype == "int4")) || + ((A_dtype == "uint4") && (B_dtype == "int4")) || + ((A_dtype == "int4") && (B_dtype == "uint4")) || + ((A_dtype == "uint4") && (B_dtype == "uint4"))); + if ((A_dtype == "int4") && (B_dtype == "int4")) { + // A/B multiplicand is int4, SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k64 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else if ((A_dtype == "uint4") && (B_dtype == "int4")) { + // A multiplicand is uint4, B multiplicand is int4 + // SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k64 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else if ((A_dtype == "int4") && (B_dtype == "uint4")) { + // A multiplicand is int4, B multiplicand is uint4 + // SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k64 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } else { + // A/B multiplicand is uint4, SM 75 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k64 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + if (!saturate) { + // no saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // saturate + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAm16n8k256Assembly(const std::string& A_layout, const std::string& B_layout, + const std::string& A_dtype, const std::string& B_dtype, + const std::string& C_dtype, const std::string& a_ref, + const std::string& a_bias, const std::string& b_ref, + const std::string& b_bias, const std::string& c_ref, + const std::string& c_bias, bool saturate) { + std::string asm_code = ""; + std::string new_a_ref = ""; + std::string new_b_ref = ""; + std::string new_c_ref = ""; + ICHECK(((A_dtype == "uint1") && (B_dtype == "uint1")) || + ((A_dtype == "int1") && (B_dtype == "int1"))); + if ((A_dtype == "uint1") && (B_dtype == "uint1")) { + // A/B multiplicand is uint1, SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k256 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } else { + // A/B multiplicand is int1, SM 80 Tensor Core instructions + ICHECK(C_dtype == "int32"); + ICHECK((A_layout == "row") && (B_layout == "col")) + << "SM80 Tensor Core instructions " + << "with shape m16n8k256 expect A layout is row major and B layout is col major."; + // C accumulator is int32 + new_a_ref = "((unsigned *)(" + a_ref + " + " + a_bias + "))"; + new_b_ref = "((unsigned *)(" + b_ref + " + " + b_bias + "))"; + new_c_ref = "((int *)(" + c_ref + " + " + c_bias + "))"; + asm_code = R"( + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, " + "{%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + } + )"; + } + asm_code = ReplaceMMAArgument(asm_code, "left_layout", A_layout); + asm_code = ReplaceMMAArgument(asm_code, "right_layout", B_layout); + asm_code = ReplaceMMAArgument(asm_code, "A", new_a_ref); + asm_code = ReplaceMMAArgument(asm_code, "B", new_b_ref); + asm_code = ReplaceMMAArgument(asm_code, "C", new_c_ref); + asm_code = ReplaceMMAArgument(asm_code, "D", new_c_ref); + return asm_code; +} + +std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, + const std::string& B_layout, const std::string& A_dtype, + const std::string& B_dtype, const std::string& C_dtype, + const std::string& a_ref, const std::string& a_bias, + const std::string& b_ref, const std::string& b_bias, + const std::string& c_ref, const std::string& c_bias, bool saturate) { + ICHECK((shape == "m8n8k4") || (shape == "m16n8k8") || (shape == "m8n8k16") || + (shape == "m8n8k32") || (shape == "m16n8k4") || (shape == "m16n8k16") || + (shape == "m16n8k32") || (shape == "m16n8k64") || (shape == "m16n8k256")); + ICHECK((A_layout == "row") || (A_layout == "col")) << "Unknown A layout: " << A_layout; + ICHECK((B_layout == "row") || (B_layout == "col")) << "Unknown B layout: " << B_layout; + + if (shape == "m8n8k4") { + return PrintMMAm8n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m16n8k8") { + return PrintMMAm16n8k8Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m8n8k16") { + return PrintMMAm8n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m8n8k32") { + return PrintMMAm8n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m16n8k4") { + return PrintMMAm16n8k4Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m16n8k16") { + return PrintMMAm16n8k16Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m16n8k32") { + return PrintMMAm16n8k32Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m16n8k64") { + return PrintMMAm16n8k64Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } else if (shape == "m16n8k256") { + return PrintMMAm16n8k256Assembly(A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, + b_ref, b_bias, c_ref, c_bias, saturate); + } + /* + * TODO: add mma.m16n8k128 + */ + throw Error("Unknown PTX mma instructions."); +} + +} // namespace codegen +} // namespace tvm diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx_mma.h new file mode 100644 index 000000000000..d2a7a6705d6d --- /dev/null +++ b/src/target/source/ptx_mma.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ptx_mma.h + * \brief MMA code generation with inlined PTX code. + */ +#ifndef TVM_TARGET_SOURCE_PTX_MMA_H_ +#define TVM_TARGET_SOURCE_PTX_MMA_H_ + +#include + +#include +#include + +namespace tvm { +namespace codegen { + +std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, + const std::string& B_layout, const std::string& A_dtype, + const std::string& B_dtype, const std::string& C_dtype, + const std::string& a_ref, const std::string& a_bias, + const std::string& b_ref, const std::string& b_bias, + const std::string& c_ref, const std::string& c_bias, bool saturate); + +} // namespace codegen +} // namespace tvm + +#endif // TVM_TARGET_SOURCE_PTX_MMA_H_ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index c593cbf7290c..5fccae040dc9 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -234,6 +234,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment) TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/tests/python/unittest/test_tir_ptx_mma.py b/tests/python/unittest/test_tir_ptx_mma.py new file mode 100644 index 000000000000..4b8e3fcaffef --- /dev/null +++ b/tests/python/unittest/test_tir_ptx_mma.py @@ -0,0 +1,1356 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +import pytest + +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +@T.prim_func +def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [8, 4], dtype="float64") + B = T.match_buffer(b, [8, 4], dtype="float64") + C = T.match_buffer(c, [8, 8], dtype="float64") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([1], "float64", scope="local") + MultiB = T.allocate([1], "float64", scope="local") + Accum = T.allocate([2], "float64", scope="local") + for i in range(2): + Accum[i] = T.float64(0) + + MultiA[0] = A[(tx % 32) // 4, (tx % 32) % 4] + MultiB[0] = B[(tx % 32) // 4, (tx % 32) % 4] + T.evaluate( + T.ptx_mma( + "m8n8k4", + "row", + "col", + "fp64", + "fp64", + "fp64", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="float64", + ) + ) + for mma_accum_c_id in range(2): + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( + "float64", Accum, mma_accum_c_id + ) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k4_row_col_fp64pf64fp64(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_col_fp64pf64fp64) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [8, 4]).astype("float64") + B_np = np.random.uniform(-1, 1, [8, 4]).astype("float64") + C_np = np.zeros([8, 8]).astype("float64") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("float64"), B_np.astype("float64").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 4], dtype="float16") + B = T.match_buffer(b, [4, 16], dtype="float16") + C = T.match_buffer(c, [16, 16], dtype="float16") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([4], "float16", scope="local") + MultiB = T.allocate([4], "float16", scope="local") + Accum = T.allocate([8], "float16", scope="local") + for i in range(8): + Accum[i] = T.float32(0) + + for mma_multi_a_col in T.vectorized(4): + MultiA[mma_multi_a_col] = A[ + ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), + mma_multi_a_col, + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) % 4, + mma_multi_b_col + (4 * ((tx % 32) // 8)), + ] + T.evaluate( + T.ptx_mma( + "m8n8k4", + "row", + "row", + "fp16", + "fp16", + "fp16", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="float16", + ) + ) + for mma_accum_c_id in range(8): + C[ + ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), + mma_accum_c_id % 4 + (4 * ((tx % 32) % 16 // 8)) + mma_accum_c_id // 4 * 8, + ] = T.load("float16", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k4_row_row_fp16fp16fp16(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp16) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 7: + # Require at least SM70 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 4]).astype("float16") + B_np = np.random.uniform(-1, 1, [4, 16]).astype("float16") + C_np = np.zeros([16, 16]).astype("float16") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("float16"), B_np.astype("float16")) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 4], dtype="float16") + B = T.match_buffer(b, [4, 16], dtype="float16") + C = T.match_buffer(c, [16, 16], dtype="float32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([4], "float16", scope="local") + MultiB = T.allocate([4], "float16", scope="local") + Accum = T.allocate([8], "float32", scope="local") + for i in range(8): + Accum[i] = T.float32(0) + + for mma_multi_a_col in T.vectorized(4): + MultiA[mma_multi_a_col] = A[ + ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), + mma_multi_a_col, + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) % 4, + mma_multi_b_col + (4 * ((tx % 32) // 8)), + ] + T.evaluate( + T.ptx_mma( + "m8n8k4", + "row", + "row", + "fp16", + "fp16", + "fp32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="float32", + ) + ) + for mma_accum_c_id in range(8): + C[ + ((tx % 32) % 2) + + ((mma_accum_c_id // 2 % 2) * 2) + + 4 * ((tx % 32) // 16) + + ((tx % 32) % 16 // 4) % 2 * 8, + (tx % 32) % 4 // 2 * 2 + + (tx % 32) % 16 // 8 * 4 + + mma_accum_c_id % 2 + + mma_accum_c_id // 4 * 8, + ] = T.load("float32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k4_row_row_fp16fp16fp32(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 7: + # Require at least SM70 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 4]).astype("float16") + B_np = np.random.uniform(-1, 1, [4, 16]).astype("float16") + C_np = np.zeros([16, 16]).astype("float32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("float32"), B_np.astype("float32")) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [8, 16], dtype="int8") + B = T.match_buffer(b, [8, 16], dtype="int8") + C = T.match_buffer(c, [8, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([4], "int8", scope="local") + MultiB = T.allocate([4], "int8", scope="local") + Accum = T.allocate([2], "int32", scope="local") + for i in range(2): + Accum[i] = T.int32(0) + + for mma_multi_a_col in T.vectorized(4): + MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 4] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 4] + T.evaluate( + T.ptx_mma( + "m8n8k16", + "row", + "col", + "int8", + "int8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(2): + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( + "int32", Accum, mma_accum_c_id + ) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k16_row_col_s8s8s32(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8s8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 75: + # Require at least SM75 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [8, 16]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 16]).astype("int8") + C_np = np.zeros([8, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [8, 16], dtype="int8") + B = T.match_buffer(b, [8, 16], dtype="uint8") + C = T.match_buffer(c, [8, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([4], "int8", scope="local") + MultiB = T.allocate([4], "uint8", scope="local") + Accum = T.allocate([2], "int32", scope="local") + for i in range(2): + Accum[i] = T.int32(0) + + for mma_multi_a_col in T.vectorized(4): + MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 4] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 4] + T.evaluate( + T.ptx_mma( + "m8n8k16", + "row", + "col", + "int8", + "uint8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(2): + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( + "int32", Accum, mma_accum_c_id + ) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k16_row_col_s8u8s32(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 75: + # Require at least SM75 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [8, 16]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 16]).astype("uint8") + C_np = np.zeros([8, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [8, 32], dtype="int4") + B = T.match_buffer(b, [8, 32], dtype="int4") + C = T.match_buffer(c, [8, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "int4", scope="local") + MultiB = T.allocate([8], "int4", scope="local") + Accum = T.allocate([2], "int32", scope="local") + for i in range(2): + Accum[i] = T.int32(0) + + for mma_multi_a_col in T.vectorized(8): + MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] + for mma_multi_b_col in T.vectorized(8): + MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] + T.evaluate( + T.ptx_mma( + "m8n8k32", + "row", + "col", + "int4", + "int4", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(2): + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( + "int32", Accum, mma_accum_c_id + ) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k32_row_col_s4s4s32(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4s4s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 75: + # Require at least SM75 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + ctx = tvm.cuda() + A_tvm = tvm.nd.empty([8, 32], "int4", ctx) + B_tvm = tvm.nd.empty([8, 32], "int4", ctx) + C_tvm = tvm.nd.empty([8, 8], "int32", ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + # Currently the correctness is not checked. + # TODO: add correctness checking here. + + +@T.prim_func +def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [8, 32], dtype="int4") + B = T.match_buffer(b, [8, 32], dtype="uint4") + C = T.match_buffer(c, [8, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "int4", scope="local") + MultiB = T.allocate([8], "uint4", scope="local") + Accum = T.allocate([2], "int32", scope="local") + for i in range(2): + Accum[i] = T.int32(0) + + for mma_multi_a_col in T.vectorized(8): + MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] + for mma_multi_b_col in T.vectorized(8): + MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] + T.evaluate( + T.ptx_mma( + "m8n8k32", + "row", + "col", + "int4", + "uint4", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(2): + C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( + "int32", Accum, mma_accum_c_id + ) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m8n8k32_row_col_s4u4s32(): + sch = tvm.tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 75: + # Require at least SM75 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + ctx = tvm.cuda() + A_tvm = tvm.nd.empty([8, 32], "int4", ctx) + B_tvm = tvm.nd.empty([8, 32], "uint4", ctx) + C_tvm = tvm.nd.empty([8, 8], "int32", ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + # Currently the correctness is not checked. + # TODO: add correctness checking here. + + +@T.prim_func +def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 8], dtype="float16") + B = T.match_buffer(b, [8, 8], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([4], "float16", scope="local") + MultiB = T.allocate([2], "float16", scope="local") + Accum = T.allocate([4], "float32", scope="local") + for i in range(4): + Accum[i] = T.float32(0) + + for mma_multi_a_col in T.vectorized(4): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2 + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4 + mma_multi_b_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_b_col % 2 + ] + T.evaluate( + T.ptx_mma( + "m16n8k8", + "row", + "col", + "fp16", + "fp16", + "fp32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="float32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2 + ] = T.load("float32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k8_row_col_fp16fp16fp32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k8_row_col_fp16fp16fp32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 8]).astype("float16") + B_np = np.random.uniform(-1, 1, [8, 8]).astype("float16") + C_np = np.zeros([16, 8]).astype("float32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("float32"), B_np.astype("float32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [8, 16], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float16") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "float16", scope="local") + MultiB = T.allocate([4], "float16", scope="local") + Accum = T.allocate([4], "float16", scope="local") + for i in range(4): + Accum[i] = T.float32(0) + + for mma_multi_a_col in range(8): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 4 // 2 * 8, + (tx % 32) % 4 * 2 + mma_multi_a_col % 2 + mma_multi_a_col // 4 * 8, + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 2 + mma_multi_b_col % 2 + mma_multi_b_col // 2 * 8, + ] + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp16", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="float16", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("float16", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k16_row_col_fp16fp16fp16(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp16) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16") + B_np = np.random.uniform(-1, 1, [8, 16]).astype("float16") + C_np = np.zeros([16, 8]).astype("float16") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("float16"), B_np.astype("float16").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="float16") + B = T.match_buffer(b, [8, 16], dtype="float16") + C = T.match_buffer(c, [16, 8], dtype="float32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "float16", scope="local") + MultiB = T.allocate([4], "float16", scope="local") + Accum = T.allocate([4], "float32", scope="local") + for i in range(4): + Accum[i] = T.float32(0) + + for mma_multi_a_col in range(8): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 4 // 2 * 8, + (tx % 32) % 4 * 2 + mma_multi_a_col % 2 + mma_multi_a_col // 4 * 8, + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 2 + mma_multi_b_col % 2 + mma_multi_b_col // 2 * 8, + ] + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "fp16", + "fp16", + "fp32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="float32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("float32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k16_row_col_fp16fp16fp32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-1, 1, [16, 16]).astype("float16") + B_np = np.random.uniform(-1, 1, [8, 16]).astype("float16") + C_np = np.zeros([16, 8]).astype("float32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("float32"), B_np.astype("float32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="int8") + B = T.match_buffer(b, [8, 16], dtype="int8") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "int8", scope="local") + MultiB = T.allocate([4], "int8", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(8): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col // 4 * 8, + (tx % 32) % 4 * 4 + mma_multi_a_col % 4, + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 4 + mma_multi_b_col, + ] + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "int8", + "int8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k16_row_col_s8s8s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_s8s8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [16, 16]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 16]).astype("int8") + C_np = np.zeros([16, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 16], dtype="int8") + B = T.match_buffer(b, [8, 16], dtype="uint8") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([8], "int8", scope="local") + MultiB = T.allocate([4], "uint8", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(8): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col // 4 * 8, + (tx % 32) % 4 * 4 + mma_multi_a_col % 4, + ] + for mma_multi_b_col in T.vectorized(4): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 4 + mma_multi_b_col, + ] + T.evaluate( + T.ptx_mma( + "m16n8k16", + "row", + "col", + "int8", + "uint8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k16_row_col_s8u8s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k16_row_col_s8u8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [16, 16]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 16]).astype("uint8") + C_np = np.zeros([16, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 32], dtype="int8") + B = T.match_buffer(b, [8, 32], dtype="int8") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([16], "int8", scope="local") + MultiB = T.allocate([8], "int8", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(16): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8, + (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16, + ] + for mma_multi_b_col in range(8): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16, + ] + T.evaluate( + T.ptx_mma( + "m16n8k32", + "row", + "col", + "int8", + "int8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k32_row_col_s8s8s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k32_row_col_s8s8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [16, 32]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 32]).astype("int8") + C_np = np.zeros([16, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 32], dtype="int8") + B = T.match_buffer(b, [8, 32], dtype="uint8") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([16], "int8", scope="local") + MultiB = T.allocate([8], "uint8", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(16): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8, + (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16, + ] + for mma_multi_b_col in range(8): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16, + ] + T.evaluate( + T.ptx_mma( + "m16n8k32", + "row", + "col", + "int8", + "uint8", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k32_row_col_s8u8s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k32_row_col_s8u8s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + A_np = np.random.uniform(-10, 10, [16, 32]).astype("int8") + B_np = np.random.uniform(-10, 10, [8, 32]).astype("uint8") + C_np = np.zeros([16, 8]).astype("int32") + + ctx = tvm.cuda() + A_tvm = tvm.nd.array(A_np, ctx) + B_tvm = tvm.nd.array(B_np, ctx) + C_tvm = tvm.nd.array(C_np, ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + + golden = np.matmul(A_np.astype("int32"), B_np.astype("int32").T) + + C_numpy = C_tvm.numpy() + + tvm.testing.assert_allclose(golden, C_numpy, atol=1e-3, rtol=1e-3) + + +@T.prim_func +def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 64], dtype="int4") + B = T.match_buffer(b, [8, 64], dtype="int4") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([32], "int4", scope="local") + MultiB = T.allocate([16], "int4", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(32): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 16 // 8 * 8, + (tx % 32) % 4 * 8 + mma_multi_a_col % 8 + mma_multi_a_col // 16 * 32, + ] + for mma_multi_b_col in range(16): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 8 + mma_multi_b_col % 8 + mma_multi_b_col // 8 * 32, + ] + T.evaluate( + T.ptx_mma( + "m8n8k32", + "row", + "col", + "int4", + "int4", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k64_row_col_s4s4s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k64_row_col_s4s4s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + ctx = tvm.cuda() + A_tvm = tvm.nd.empty([16, 64], "int4", ctx) + B_tvm = tvm.nd.empty([8, 64], "int4", ctx) + C_tvm = tvm.nd.empty([16, 8], "int32", ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + # Currently the correctness is not checked. + # TODO: add correctness checking here. + + +@T.prim_func +def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 64], dtype="int4") + B = T.match_buffer(b, [8, 64], dtype="uint4") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([32], "int4", scope="local") + MultiB = T.allocate([16], "uint4", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(32): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 16 // 8 * 8, + (tx % 32) % 4 * 8 + mma_multi_a_col % 8 + mma_multi_a_col // 16 * 32, + ] + for mma_multi_b_col in range(16): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 8 + mma_multi_b_col % 8 + mma_multi_b_col // 8 * 32, + ] + T.evaluate( + T.ptx_mma( + "m8n8k32", + "row", + "col", + "int4", + "uint4", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k64_row_col_s4u4s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k64_row_col_s4u4s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + ctx = tvm.cuda() + A_tvm = tvm.nd.empty([16, 64], "int4", ctx) + B_tvm = tvm.nd.empty([8, 64], "uint4", ctx) + C_tvm = tvm.nd.empty([16, 8], "int32", ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + # Currently the correctness is not checked. + # TODO: add correctness checking here. + + +@T.prim_func +def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + A = T.match_buffer(a, [16, 256], dtype="int1") + B = T.match_buffer(b, [8, 256], dtype="int1") + C = T.match_buffer(c, [16, 8], dtype="int32") + brow = T.env_thread("blockIdx.y") + bcol = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(brow, 1) + T.launch_thread(bcol, 1) + T.launch_thread(tx, 32) + MultiA = T.allocate([128], "int1", scope="local") + MultiB = T.allocate([64], "int1", scope="local") + Accum = T.allocate([4], "int32", scope="local") + for i in range(4): + Accum[i] = T.int32(0) + + for mma_multi_a_col in range(128): + MultiA[mma_multi_a_col] = A[ + (tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8, + (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128, + ] + for mma_multi_b_col in range(16): + MultiB[mma_multi_b_col] = B[ + (tx % 32) // 4, + (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128, + ] + T.evaluate( + T.ptx_mma( + "m16n8k256", + "row", + "col", + "int1", + "int1", + "int32", + MultiA, + 0, + MultiB, + 0, + Accum, + 0, + False, + dtype="int32", + ) + ) + for mma_accum_c_id in range(4): + C[ + (tx % 32) // 4 + mma_accum_c_id // 2 * 8, + (tx % 32) % 4 * 2 + mma_accum_c_id % 2, + ] = T.load("int32", Accum, mma_accum_c_id) + + +@tvm.testing.requires_cuda +def test_gemm_mma_m16n8k256_row_col_b1b1s32(): + sch = tvm.tir.Schedule(gemm_mma_m16n8k256_row_col_b1b1s32) + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # Require at least SM80 + return + cuda_mod = tvm.build(sch.mod, target="cuda") + cuda_mod = tvm.build(sch.mod, target="cuda") + + ctx = tvm.cuda() + A_tvm = tvm.nd.empty([16, 256], "int1", ctx) + B_tvm = tvm.nd.empty([8, 256], "int1", ctx) + C_tvm = tvm.nd.empty([16, 8], "int32", ctx) + + cuda_mod(A_tvm, B_tvm, C_tvm) + # Currently the correctness is not checked. + # TODO: add correctness checking here. + + +if __name__ == "__main__": + test_gemm_mma_m8n8k4_row_col_fp64pf64fp64() + test_gemm_mma_m8n8k4_row_row_fp16fp16fp16() + test_gemm_mma_m8n8k4_row_row_fp16fp16fp32() + test_gemm_mma_m8n8k16_row_col_s8s8s32() + test_gemm_mma_m8n8k16_row_col_s8u8s32() + test_gemm_mma_m8n8k32_row_col_s4s4s32() + test_gemm_mma_m8n8k32_row_col_s4u4s32() + test_gemm_mma_m16n8k8_row_col_fp16fp16fp32() + test_gemm_mma_m16n8k16_row_col_fp16fp16fp16() + test_gemm_mma_m16n8k16_row_col_fp16fp16fp32() + test_gemm_mma_m16n8k16_row_col_s8s8s32() + test_gemm_mma_m16n8k16_row_col_s8u8s32() + test_gemm_mma_m16n8k32_row_col_s8s8s32() + test_gemm_mma_m16n8k32_row_col_s8u8s32() + test_gemm_mma_m16n8k64_row_col_s4s4s32() + test_gemm_mma_m16n8k64_row_col_s4u4s32() + test_gemm_mma_m16n8k256_row_col_b1b1s32()