diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 22e8d842e424..2a412823d6ef 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() { << "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n"; decl_stream << "__device__ half min(half a, half b)\n" << "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n"; - decl_stream << "__device__ half operator<=" - << "(__half a, __half b)\n" - << "{\n return __hlt(a, b);\n}\n"; - decl_stream << "__device__ half operator+" - << "(__half a, __half &b)\n" - <<"{\n return __hadd(a, b);\n}\n"; - decl_stream << "__device__ half operator*" - << "(__half a, __half b)\n" - << "{\n return __hmul(a, b);\n}\n"; + // FIXME(tvm-team): "volatile" is used to enable cross thread reduction, + // which is needed by operations such as softmax. + // However, volatile overloading is not supported in NVRTC and CUDA < 9.2. + // We need to figure out a solution which can satisfy both scenario. + // decl_stream << "__device__ half operator<=" + // << "(const volatile __half &a, const volatile __half &b)\n" + // << "{\n return __hlt(a, b);\n}\n"; + // decl_stream << "__device__ half operator+" + // << "(const volatile __half &a, const volatile __half &b)\n" + // <<"{\n return __hadd(a, b);\n}\n"; + // decl_stream << "__device__ half operator*" + // << "(const volatile __half &a, const volatile __half &b)\n" + // << "{\n return __hmul(a, b);\n}\n"; // otherwise simulate computation via float32 decl_stream << "#else\n"; decl_stream << _cuda_half_t_def; diff --git a/src/codegen/literal/cuda_half_t.h b/src/codegen/literal/cuda_half_t.h index 23075b0b6e76..0889032aadd4 100644 --- a/src/codegen/literal/cuda_half_t.h +++ b/src/codegen/literal/cuda_half_t.h @@ -28,6 +28,7 @@ static constexpr const char* _cuda_half_t_def = R"( typedef unsigned short uint16_t; typedef unsigned char uint8_t; +typedef signed char int8_t; typedef int int32_t; typedef unsigned long long uint64_t; typedef unsigned int uint32_t; @@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half { TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } - TVM_XINLINE explicit half(const int64_t& value) { constructor(value); } + TVM_XINLINE explicit half(const long long& value) { constructor(value); } TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } TVM_XINLINE operator float() const { \