diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 52fda1a23eef..5c9aa266a91f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -570,12 +570,21 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { } void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) - ICHECK_NE(scope, "global") << "Cannot allocate global memory when targeting CUDA. You must pass " - "all global arrays as input instead"; - if (scope == "shared") { - os << "__shared__ "; - } else if (scope == "shared.dyn") { - os << "extern __shared__ "; + PrintStorageScope(runtime::StorageScope::Create(scope), os); +} + +void CodeGenCUDA::PrintStorageScope(const runtime::StorageScope& scope, std::ostream& os) { + ICHECK(scope.rank != runtime::StorageRank::kGlobal) + << "Cannot allocate global memory when targeting CUDA. You must pass " + "all global arrays as input instead"; + if (scope.rank == runtime::StorageRank::kShared) { + if (scope.tag == "") { + os << "__shared__ "; + } else if (scope.tag == ".dyn") { + os << "extern __shared__ "; + } else { + LOG(FATAL) << "Unknown shared memory scope tag: " << scope.tag; + } } } @@ -957,39 +966,42 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { std::string vid = AllocVarID(op->buffer_var.get()); this->PrintIndent(); - std::string scope = GetPtrStorageScope(op->buffer_var); + runtime::StorageScope scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + const VarNode* buffer = op->buffer_var.as(); - if (scope.find("wmma.") == 0) { - if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || - op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || - op->dtype == DataType::BFloat(16)) - << "Matrix_a and matrix_b only support half or char or unsigned char " - << "or uint4 or int4 or int1 type for now"; - } else { - ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || - op->dtype == DataType::Int(32)) - << "Accumulator only support half, float and int type for now"; - } + if (scope.rank == runtime::StorageRank::kWMMAMatrixA || + scope.rank == runtime::StorageRank::kWMMAMatrixB) { + ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) || + op->dtype == DataType::BFloat(16)) + << "Matrix_a and matrix_b only support half or char or unsigned char " + << "or uint4 or int4 or int1 type for now"; + PrintWmmaScope(scope, op->dtype, buffer, stream); + } else if (scope.rank == runtime::StorageRank::kWMMAAccumulator) { + ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || + op->dtype == DataType::Int(32)) + << "Accumulator only support half, float and int type for now"; PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); PrintType(op->dtype, stream); } - if (scope == "shared.dyn") { + if (scope.rank == runtime::StorageRank::kShared && scope.tag == ".dyn") { stream << ' ' << vid << "[];\n"; } else { size_t constant_size = op->ConstantAllocationSize(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - if (scope.find("wmma.") == 0) { + if (scope.rank == runtime::StorageRank::kWMMAMatrixA || + scope.rank == runtime::StorageRank::kWMMAMatrixB || + scope.rank == runtime::StorageRank::kWMMAAccumulator) { constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); } if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) && - scope == "shared") { + scope.rank == runtime::StorageRank::kShared && scope.tag == "") { constant_size = constant_size / (32 / op->dtype.bits()); } stream << ' ' << vid << '[' << constant_size << "];\n"; @@ -1225,8 +1237,8 @@ void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOL PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, - std::ostream& os) { +void CodeGenCUDA::PrintWmmaScope(const runtime::StorageScope& scope, DataType t, + const VarNode* variable, std::ostream& os) { std::stringstream type; PrintType(t, type); std::string shape_str = fragment_shapes.at(variable); @@ -1238,29 +1250,29 @@ void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const Var } else if (t.bits() == 1) { type << "nvcuda::wmma::experimental::precision::b1"; } else { - LOG(FATAL) << "Unhandled interger type for wmma fragment!"; + LOG(FATAL) << "Unhandled integer type for wmma fragment!"; } } else if (t.is_uint()) { if (t.bits() == 4) { type << "nvcuda::wmma::experimental::precision::u4"; } else { - LOG(FATAL) << "Unhandled interger type for wmma fragment!"; + LOG(FATAL) << "Unhandled integer type for wmma fragment!"; } } } - if (scope == "wmma.matrix_a") { + if (scope.rank == runtime::StorageRank::kWMMAMatrixA) { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a"; os << "nvcuda::wmma::fragment"; - } else if (scope == "wmma.matrix_b") { + } else if (scope.rank == runtime::StorageRank::kWMMAMatrixB) { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b"; os << "nvcuda::wmma::fragment"; - } else if (scope == "wmma.accumulator") { + } else if (scope.rank == runtime::StorageRank::kWMMAAccumulator) { need_mma_h_ = true; os << "nvcuda::wmma::fragment"; @@ -1276,8 +1288,8 @@ int stoi(const std::string& str) { } } -int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, - int32_t size) { +int32_t CodeGenCUDA::GetWmmaFragmentSize(const runtime::StorageScope& scope, + const VarNode* variable, int32_t size) { std::string shape_str = fragment_shapes.at(variable); size_t m, n, k; size_t last_pos = 0, pos = 0; @@ -1288,11 +1300,11 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode n = tvm::codegen::stoi(shape_str.substr(last_pos, pos - last_pos)); last_pos = pos + 2; k = tvm::codegen::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos)); - if (scope == "wmma.matrix_a") { + if (scope.rank == runtime::StorageRank::kWMMAMatrixA) { return size / m / k; - } else if (scope == "wmma.matrix_b") { + } else if (scope.rank == runtime::StorageRank::kWMMAMatrixB) { return size / n / k; - } else if (scope == "wmma.accumulator") { + } else if (scope.rank == runtime::StorageRank::kWMMAAccumulator) { return size / m / n; } return 0; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 673753c470ae..a938745ccbea 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -31,6 +31,7 @@ #include #include +#include "../../runtime/thread_storage_scope.h" #include "codegen_c.h" namespace tvm { @@ -49,7 +50,8 @@ class CodeGenCUDA final : public CodeGenC { void PrintExtraAttrs(const PrimFunc& f) final; void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; - void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageScope(const runtime::StorageScope& scope, std::ostream& os); // NOLINT(*) void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) @@ -106,9 +108,10 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + void PrintWmmaScope(const runtime::StorageScope& scope, DataType t, const VarNode* variable, std::ostream& os); - int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); + int32_t GetWmmaFragmentSize(const runtime::StorageScope& scope, const VarNode* variable, + int32_t size); }; } // namespace codegen