Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor string-based scope to runtime::StorageScope #8

Merged
merged 2 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 47 additions & 35 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,21 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
}

void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope == "shared") {
os << "__shared__ ";
} else if (scope == "shared.dyn") {
os << "extern __shared__ ";
PrintStorageScope(runtime::StorageScope::Create(scope), os);
}

void CodeGenCUDA::PrintStorageScope(const runtime::StorageScope& scope, std::ostream& os) {
ICHECK(scope.rank != runtime::StorageRank::kGlobal)
<< "Cannot allocate global memory when targeting CUDA. You must pass "
"all global arrays as input instead";
if (scope.rank == runtime::StorageRank::kShared) {
if (scope.tag == "") {
os << "__shared__ ";
} else if (scope.tag == ".dyn") {
os << "extern __shared__ ";
} else {
LOG(FATAL) << "Unknown shared memory scope tag: " << scope.tag;
}
}
}

Expand Down Expand Up @@ -957,39 +966,42 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
std::string vid = AllocVarID(op->buffer_var.get());

this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
runtime::StorageScope scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));

const VarNode* buffer = op->buffer_var.as<VarNode>();
if (scope.find("wmma.") == 0) {
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
} else {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
}
if (scope.rank == runtime::StorageRank::kWMMAMatrixA ||
scope.rank == runtime::StorageRank::kWMMAMatrixB) {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
op->dtype == DataType::BFloat(16))
<< "Matrix_a and matrix_b only support half or char or unsigned char "
<< "or uint4 or int4 or int1 type for now";
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else if (scope.rank == runtime::StorageRank::kWMMAAccumulator) {
ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
op->dtype == DataType::Int(32))
<< "Accumulator only support half, float and int type for now";
PrintWmmaScope(scope, op->dtype, buffer, stream);
} else {
PrintStorageScope(scope, stream);
PrintType(op->dtype, stream);
}

if (scope == "shared.dyn") {
if (scope.rank == runtime::StorageRank::kShared && scope.tag == ".dyn") {
stream << ' ' << vid << "[];\n";
} else {
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";

if (scope.find("wmma.") == 0) {
if (scope.rank == runtime::StorageRank::kWMMAMatrixA ||
scope.rank == runtime::StorageRank::kWMMAMatrixB ||
scope.rank == runtime::StorageRank::kWMMAAccumulator) {
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
}
if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
op->dtype == DataType::Int(1)) &&
scope == "shared") {
scope.rank == runtime::StorageRank::kShared && scope.tag == "") {
constant_size = constant_size / (32 / op->dtype.bits());
}
stream << ' ' << vid << '[' << constant_size << "];\n";
Expand Down Expand Up @@ -1225,8 +1237,8 @@ void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOL
PrintConst(op, os, this);
}

void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
std::ostream& os) {
void CodeGenCUDA::PrintWmmaScope(const runtime::StorageScope& scope, DataType t,
const VarNode* variable, std::ostream& os) {
std::stringstream type;
PrintType(t, type);
std::string shape_str = fragment_shapes.at(variable);
Expand All @@ -1238,29 +1250,29 @@ void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const Var
} else if (t.bits() == 1) {
type << "nvcuda::wmma::experimental::precision::b1";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
LOG(FATAL) << "Unhandled integer type for wmma fragment!";
}
} else if (t.is_uint()) {
if (t.bits() == 4) {
type << "nvcuda::wmma::experimental::precision::u4";
} else {
LOG(FATAL) << "Unhandled interger type for wmma fragment!";
LOG(FATAL) << "Unhandled integer type for wmma fragment!";
}
}
}
if (scope == "wmma.matrix_a") {
if (scope.rank == runtime::StorageRank::kWMMAMatrixA) {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.matrix_b") {
} else if (scope.rank == runtime::StorageRank::kWMMAMatrixB) {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.accumulator") {
} else if (scope.rank == runtime::StorageRank::kWMMAAccumulator) {
need_mma_h_ = true;
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str()
<< ">";
Expand All @@ -1276,8 +1288,8 @@ int stoi(const std::string& str) {
}
}

int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable,
int32_t size) {
int32_t CodeGenCUDA::GetWmmaFragmentSize(const runtime::StorageScope& scope,
const VarNode* variable, int32_t size) {
std::string shape_str = fragment_shapes.at(variable);
size_t m, n, k;
size_t last_pos = 0, pos = 0;
Expand All @@ -1288,11 +1300,11 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode
n = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos));
last_pos = pos + 2;
k = tvm::codegen::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
if (scope == "wmma.matrix_a") {
if (scope.rank == runtime::StorageRank::kWMMAMatrixA) {
return size / m / k;
} else if (scope == "wmma.matrix_b") {
} else if (scope.rank == runtime::StorageRank::kWMMAMatrixB) {
return size / n / k;
} else if (scope == "wmma.accumulator") {
} else if (scope.rank == runtime::StorageRank::kWMMAAccumulator) {
return size / m / n;
}
return 0;
Expand Down
9 changes: 6 additions & 3 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <string>
#include <unordered_map>

#include "../../runtime/thread_storage_scope.h"
#include "codegen_c.h"

namespace tvm {
Expand All @@ -49,7 +50,8 @@ class CodeGenCUDA final : public CodeGenC {
void PrintExtraAttrs(const PrimFunc& f) final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageScope(const runtime::StorageScope& scope, std::ostream& os); // NOLINT(*)
void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
std::ostream& os) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
Expand Down Expand Up @@ -106,9 +108,10 @@ class CodeGenCUDA final : public CodeGenC {
std::unordered_map<const VarNode*, std::string> fragment_shapes;
std::unordered_map<const VarNode*, std::string> fragment_layouts;
friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
void PrintWmmaScope(const runtime::StorageScope& scope, DataType t, const VarNode* variable,
std::ostream& os);
int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
int32_t GetWmmaFragmentSize(const runtime::StorageScope& scope, const VarNode* variable,
int32_t size);
};

} // namespace codegen
Expand Down