Skip to content

Commit

Permalink
[CODEGEN] Support cuda tensorcore subbyte int data type in auto tenso…
Browse files Browse the repository at this point in the history
…rcore (apache#4546)

* support cuda tensorcore subbyte int data type in auto tensorcore

* add lisence

* pass cpplint

* fix code review comments

* merge the int4/int1 codegen tutorial into the existing auto tensorcore tutorial

* using master's new API

* disable tuning when cuda is not enabled

* address cr comment

* do not run the tuning

* fix test failure

* fix cpplint error

* fix bool type reduction bug

* 1. fix a index bug 2. fix returned bytes value of int1/int4/uint4

* fix typo
  • Loading branch information
Orion34-lanbo authored and alexwong committed Feb 26, 2020
1 parent 744e6eb commit 0125add
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 24 deletions.
7 changes: 6 additions & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,12 @@ class DataType {
inline int GetVectorBytes(DataType dtype) {
int data_bits = dtype.bits() * dtype.lanes();
// allow bool to exist
if (dtype == DataType::Bool()) return 1;
if (dtype == DataType::Bool() ||
dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
return 1;
}
CHECK_EQ(data_bits % 8, 0U)
<< "Need to load/store by multiple of bytes";
return data_bits / 8;
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,18 @@ constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
* }
*/
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
/*!
* \brief tvm intrinsic for tensor core bmma_sync operators.
*
* void tvm_bmma_sync(Var fragment_d, Expr index_d,
* Var fragment_a, Expr index_a,
* Var fragment_b, Expr index_b,
* Var fragment_c, Expr index_c) {
* nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a],
* fragment_b[index_b], fragment_c[index_c]);
* }
*/
constexpr const char* tvm_bmma_sync = "tvm_bmma_sync";
/*!
* \brief tvm intrinsic for tensor core fill_fragment operators.
*
Expand Down
7 changes: 6 additions & 1 deletion src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ inline void VerifyDataType(DLDataType dtype) {
} else {
// allow uint1 as a special flag for bool.
if (dtype.bits == 1 && dtype.code == kDLUInt) return;
CHECK_EQ(dtype.bits % 8, 0);
// allow int1/uint4/int4
else if (dtype.bits == 1 && dtype.code == kDLInt) return;
else if (dtype.bits == 4 && dtype.code == kDLUInt) return;
else if (dtype.bits == 4 && dtype.code == kDLInt) return;
else
CHECK_EQ(dtype.bits % 8, 0);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}
Expand Down
14 changes: 12 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,13 @@ std::string CodeGenC::GetBufferRef(
} else {
os << vid;
}
os << '[';
os << "[(";
PrintExpr(index, os);
os << ")";
if (t.bits() == 4 ||
(t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << ']';
} else {
// Buffer declared as vector type.
Expand Down Expand Up @@ -205,8 +210,13 @@ std::string CodeGenC::GetBufferRef(
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + ";
os << vid << " + (";
PrintExpr(index, os);
os << ")";
if (t.bits() == 4 ||
(t.bits() == 1 && t.is_int())) {
os << " / " << (32 / t.bits());
}
os << "))[0]";
}
return os.str();
Expand Down
73 changes: 70 additions & 3 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,37 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
}
switch (t.bits()) {
case 1: {
if (t.lanes() == 1) {
os << "int"; return;
} else if (t.lanes() == 8) {
os << "int8_t"; return;
} else if (t.lanes() == 16) {
os << "int16_t"; return;
} else if (t.lanes() == 32) {
os << "int"; return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 4: {
if (t.lanes() == 1) {
os << "int"; return;
} else if (t.lanes() == 4) {
os << "int16_t"; return;
} else if (t.lanes() == 8) {
// directly 8 4-bit int in integer.
os << "int"; return;
} else if (t.lanes() == 16) {
os << "int2"; return;
} else if (t.lanes() == 32) {
os << "int4"; return;
} else if (t.lanes() == 64) {
os << "int8"; return;
} else {
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
}
}
case 8: {
if (t.lanes() == 4) {
// directly 4 8 bit int in integer.
Expand Down Expand Up @@ -182,7 +213,6 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
os << "long"; break;
}
}
case 1: os << "int"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) {
Expand Down Expand Up @@ -371,6 +401,16 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 8U);
os << "nvcuda::wmma::bmma_sync(";
for (int i = 0; i < 4; ++i) {
this->PrintExpr(op->args[i * 2], os);
os << "[";
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", ": ")");
}
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down Expand Up @@ -410,8 +450,12 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8))
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
op->dtype == DataType::UInt(8) ||
op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
CHECK(op->dtype == DataType::Float(16) ||
op->dtype == DataType::Float(32) ||
Expand All @@ -425,6 +469,11 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
stream << ' ';
PrintType(op->dtype, stream);
}
if ((op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) && scope == "shared") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' '<< vid << '['
<< constant_size << "];\n";
}
Expand Down Expand Up @@ -552,6 +601,24 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t,
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes[variable];
if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
type.str(std::string());
if (t.is_int()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::s4";
} else if (t.bits() == 1) {
type << "nvcuda::wmma::experimental::precision::b1";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
}
} else if (t.is_uint()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::u4";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
}
}
}
if (scope == "wmma.matrix_a") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
Expand Down
11 changes: 10 additions & 1 deletion src/tir/pass/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
IntImm(DataType::UInt(8), dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) ==
IntImm(DataType::UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
if (!(dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1))) {
asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop));
}
// data field
if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
Expand All @@ -201,6 +205,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
init_nest_.emplace_back(LetStmtNode::make(
v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (dtype == DataType::Int(4) ||
dtype == DataType::UInt(4) ||
dtype == DataType::Int(1)) {
break;
}
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k],
Expand Down
3 changes: 2 additions & 1 deletion src/tir/pass/infer_fragment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class FragmentChecker : public StmtExprVisitor {
void VisitExpr_(const CallNode* op) final {
StmtExprVisitor::VisitExpr_(op);
// Check shape when calling tvm_mma_sync
if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
if (op->is_intrinsic(intrinsic::tvm_mma_sync) ||
op->is_intrinsic(intrinsic::tvm_bmma_sync)) {
CHECK_EQ(op->args.size(), 8U);
const VarNode* buffer_var_d = op->args[0].as<VarNode>();
const VarNode* buffer_var_a = op->args[2].as<VarNode>();
Expand Down
52 changes: 41 additions & 11 deletions src/tir/pass/tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ class MMAMatcher: public StmtVisitor {
BufferInfo buffer_a;
if (!check_local_buffer_(load_a, &buffer_a)
|| !(buffer_a.dtype == DataType::Float(16) ||
buffer_a.dtype == DataType::Int(8))) {
buffer_a.dtype == DataType::Int(8) ||
buffer_a.dtype == DataType::UInt(8) ||
buffer_a.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) ||
buffer_a.dtype == DataType::Int(1))) {
return false;
}

Expand All @@ -208,7 +212,11 @@ class MMAMatcher: public StmtVisitor {
BufferInfo buffer_b;
if (!check_local_buffer_(load_b, &buffer_b)
|| !(buffer_b.dtype == DataType::Float(16) ||
buffer_b.dtype == DataType::Int(8))) {
buffer_b.dtype == DataType::Int(8) ||
buffer_b.dtype == DataType::UInt(8) ||
buffer_b.dtype == DataType::Int(4) ||
buffer_a.dtype == DataType::UInt(4) ||
buffer_a.dtype == DataType::Int(1))) {
return false;
}

Expand Down Expand Up @@ -736,6 +744,17 @@ class BufferAnalyser : public StmtExprVisitor {
warp_tile_.k == 16) {
return true;
}
if (warp_tile_.m == 8 &&
warp_tile_.n == 8 &&
warp_tile_.k == 32) {
return true;
}
if (warp_tile_.m == 8 &&
warp_tile_.n == 8 &&
warp_tile_.k == 128) {
return true;
}

return false;
}

Expand Down Expand Up @@ -869,18 +888,29 @@ class TensorCoreIRMutator : public StmtExprMutator {
ObjectPtr<BufferNode> buffer_node_c = make_object<BufferNode>();

auto mma_sync_call =
[&buffer_node_a, &buffer_node_b]
[&buffer_node_a, &buffer_node_b, &ca, &cb]
(const Buffer &buffer) {
Buffer buffer_a(buffer_node_a);
Buffer buffer_b(buffer_node_b);
return EvaluateNode::make(
CallNode::make(DataType::Handle(),
intrinsic::tvm_mma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) {
return EvaluateNode::make(
CallNode::make(DataType::Handle(),
intrinsic::tvm_bmma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
} else {
return EvaluateNode::make(
CallNode::make(DataType::Handle(),
intrinsic::tvm_mma_sync,
{buffer->data, buffer->elem_offset,
buffer_a->data, buffer_a->elem_offset,
buffer_b->data, buffer_b->elem_offset,
buffer->data, buffer->elem_offset},
CallNode::Intrinsic));
}
};

auto call_add_c =
Expand Down
Loading

0 comments on commit 0125add

Please sign in to comment.