Skip to content

Commit

Permalink
[Optimization] Warp level reduction support for CUDA
Browse files Browse the repository at this point in the history
- Added the warp level reduction support

- Upgraded shfl intrinsics to the sync version

- This is the building block for scheduling softmax like operations.

Signed-off-by: Wei Pan <[email protected]>
  • Loading branch information
wpan11nv committed May 3, 2020
1 parent 6347406 commit e491050
Show file tree
Hide file tree
Showing 10 changed files with 357 additions and 75 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
Submodule dmlc-core updated 54 files
+38 −0 .github/workflows/githubci.yml
+0 −1 .gitignore
+0 −82 .travis.yml
+51 −31 CMakeLists.txt
+201 −13 LICENSE
+1 −1 README.md
+19 −6 appveyor.yml
+13 −0 cmake/Modules/FindASan.cmake
+13 −0 cmake/Modules/FindLSan.cmake
+13 −0 cmake/Modules/FindTSan.cmake
+13 −0 cmake/Modules/FindUBSan.cmake
+63 −0 cmake/Sanitizer.cmake
+4 −1 cmake/build_config.h.in
+1 −1 cmake/gtest_cmake.in
+1 −16 doc/Doxyfile
+16 −1 include/dmlc/base.h
+4 −1 include/dmlc/build_config_default.h
+4 −0 include/dmlc/concurrency.h
+18 −18 include/dmlc/concurrentqueue.h
+3 −2 include/dmlc/json.h
+20 −3 include/dmlc/logging.h
+1 −1 include/dmlc/omp.h
+10 −0 include/dmlc/optional.h
+106 −23 include/dmlc/parameter.h
+1 −3 include/dmlc/thread_group.h
+4 −2 include/dmlc/thread_local.h
+74 −46 include/dmlc/threadediter.h
+0 −2 make/dmlc.mk
+2 −2 scripts/lint.py
+12 −19 scripts/packages.mk
+0 −0 scripts/s390x/Dockerfile
+0 −0 scripts/s390x/build_via_cmake.sh
+1 −1 scripts/s390x/ci_build.sh
+0 −0 scripts/s390x/entrypoint.sh
+0 −32 scripts/setup_nvcc.sh
+65 −0 scripts/test_script.sh
+0 −3 scripts/travis/travis_before_cache.sh
+0 −9 scripts/travis/travis_osx_install.sh
+0 −57 scripts/travis/travis_script.sh
+0 −40 scripts/travis/travis_setup_env.sh
+0 −16 src/build_config.cc
+7 −3 src/data/csv_parser.h
+1 −1 test/logging_test.cc
+4 −0 test/unittest/CMakeLists.txt
+2 −1 test/unittest/unittest_env.cc
+30 −0 test/unittest/unittest_param.cc
+80 −56 test/unittest/unittest_parser.cc
+0 −1 test/unittest/unittest_thread_group.cc
+2 −2 test/unittest/unittest_threaditer.cc
+19 −15 test/unittest/unittest_threaditer_exc_handling.cc
+4 −0 tracker/dmlc_tracker/launcher.py
+7 −0 tracker/dmlc_tracker/ssh.py
+13 −0 tracker/dmlc_tracker/util.py
+4 −2 tracker/dmlc_tracker/yarn.py
22 changes: 19 additions & 3 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1234,22 +1234,38 @@ constexpr const char *tvm_call_trace_packed_lowered =
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";

/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by warp_id);
* Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id);
* }
*
* Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id - offset);
* }
*
* Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id + offset);
* }
*
* Parameter warp_id indicates the source thread ID in a warp.
*
* Parameter offset indicates the relative distance to this_warp_id.
*
* Parameter width indicates the number of threads involved in one
* shuffle. See CUDA document for __shfl.
* shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, and
* __shfl_down_sync.
*
* Parameter warp_size is the size of a warp, which helps a backend
* to determine wheter the width paramter is legal.
*
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up";
constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down";

/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
36 changes: 34 additions & 2 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_half_util;
}

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
}

if (enable_int8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
decl_stream << "#include <sm_61_intrinsics.h>\n";
Expand Down Expand Up @@ -395,8 +399,36 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
os << sret;
}

void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
// Check if this is a warp shuffle intrinsic call and return its
// corresponding CUDA instrinsic name.
static const char* check_warp_shuffle(const CallNode* op) {
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) {
return "__shfl_sync";
}
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) {
return "__shfl_up_sync";
}
if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) {
return "__shfl_down_sync";
}
return nullptr;
}

void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
if (const char *shfl_name = check_warp_shuffle(op)) {
enable_warp_shuffle_ = true;
// mask, value, warp_id, width, warp_size
CHECK_EQ(op->args.size(), 5U);
os << shfl_name << '(';
this->PrintExpr(op->args[0], os);
os << ", ";
this->PrintExpr(op->args[1], os);
os << ", ";
this->PrintExpr(op->args[2], os);
os << ", ";
this->PrintExpr(op->args[3], os);
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
os << "nvcuda::wmma::fill_fragment(";
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_fp16_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
// whether need mma.h
Expand Down
13 changes: 0 additions & 13 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,6 @@ struct CUDAPopcount {
}
};

static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
*rv = CallNode::make(
call->dtype, "__shfl", cuda_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
.set_body(DispatchExtern<CUDAMath>);

Expand Down Expand Up @@ -157,9 +147,6 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchCUDAShuffle);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);

Expand Down
10 changes: 5 additions & 5 deletions src/target/source/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[2] == call->args[3]))
CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
<< "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> cuda_args{{call->args[0], call->args[1]}};
*rv = CallNode::make(
call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern);
Array<PrimExpr> opencl_args{{call->args[1], call->args[2]}};
*rv = CallNode::make(call->dtype, "intel_sub_group_shuffle",
opencl_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
Expand Down
14 changes: 14 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,18 @@ __pack_half2(const half x, const half y) {
}
)";

static constexpr const char* _cuda_warp_intrinsic_util = R"(
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
__shfl((var), (lane), (width))
#define __shfl_down_sync(mask, var, offset, width) \
__shfl_down((var), (offset), (width))
#define __shfl_up_sync(mask, var, offset, width) \
__shfl_up((var), (offset), (width))
#endif
)";

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
Loading

0 comments on commit e491050

Please sign in to comment.