Skip to content

Commit

Permalink
add volatile override back
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Nov 14, 2019
1 parent 3486e2c commit 41d657c
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,18 @@ std::string CodeGenCUDA::Finish() {
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
decl_stream << "__device__ half min(half a, half b)\n"
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
// FIXME(tvm-team): add "volatile" to enable cross thread reduction,
// which is needed by operations such as softmax.
// However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
// We need to figure out a solution which can satisfy the whole world.
decl_stream << "__device__ half operator<="
<< "(__half a, __half b)\n"
<< "(const volatile __half &a, const volatile __half &b)\n"
<< "{\n return __hlt(a, b);\n}\n";
decl_stream << "__device__ half operator+"
<< "(__half a, __half &b)\n"
<< "(const volatile __half &a, const volatile __half &b)\n"
<<"{\n return __hadd(a, b);\n}\n";
decl_stream << "__device__ half operator*"
<< "(__half a, __half b)\n"
<< "(const volatile __half &a, const volatile __half &b)\n"
<< "{\n return __hmul(a, b);\n}\n";
// otherwise simulate computation via float32
decl_stream << "#else\n";
Expand Down

0 comments on commit 41d657c

Please sign in to comment.