Skip to content

Commit

Permalink
[PTX-MMA] Add full PTX MMA code generation support (#9909)
Browse files Browse the repository at this point in the history
  • Loading branch information
KnowingNothing authored Jan 24, 2022
1 parent 74a2fa8 commit d066441
Show file tree
Hide file tree
Showing 6 changed files with 2,822 additions and 0 deletions.
11 changes: 11 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,17 @@ TVM_DLL const Op& tvm_fill_fragment();
*/
TVM_DLL const Op& tvm_store_matrix_sync();

/*!
* \brief tvm intrinsic for ptx tensor core mma instructions.
*
* void ptx_mma(StringImm shape, StringImm A_layout, StringImm B_layout,
* StringImm A_dtype, StringImm B_dtype, StringImm C_dtype,
* Var multiplicand_a, Expr a_index,
* Var multiplicand_b, Expr b_index,
* Var accumulator, Expr c_index, bool saturate);
*/
TVM_DLL const Op& ptx_mma();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
33 changes: 33 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <vector>

#include "literal/cuda_half_t.h"
#include "ptx_mma.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -723,6 +724,38 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(builtin::ptx_mma())) {
// arg 0: shape: mXnXkX
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: fp16, fp64, ...
// arg 4: B precision: fp16, fp64, ...
// arg 5: C precision: fp32, fp64, ...
// arg 6: A multiplicand
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 11: C accumulator index
// arg 12: saturate
ICHECK_EQ(op->args.size(), 13U);
std::string shape = Downcast<StringImm>(op->args[0])->value;
std::string A_layout = Downcast<StringImm>(op->args[1])->value;
std::string B_layout = Downcast<StringImm>(op->args[2])->value;
std::string A_dtype = Downcast<StringImm>(op->args[3])->value;
std::string B_dtype = Downcast<StringImm>(op->args[4])->value;
std::string C_dtype = Downcast<StringImm>(op->args[5])->value;
std::string a_ref = this->PrintExpr(op->args[6]);
std::string a_bias = this->PrintExpr(op->args[7]);
std::string b_ref = this->PrintExpr(op->args[8]);
std::string b_bias = this->PrintExpr(op->args[9]);
std::string c_ref = this->PrintExpr(op->args[10]);
std::string c_bias = this->PrintExpr(op->args[11]);
bool saturate = (Downcast<IntImm>(op->args[12])->value != 0);
std::string asm_code = PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype,
a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate);

this->stream << asm_code;
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
Loading

0 comments on commit d066441

Please sign in to comment.