From 35e93ac0afb6ae31887038d8a73cd1d9348dc6f4 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 6 Nov 2019 16:36:43 -0800 Subject: [PATCH 1/5] check __CUDA_ARCH__ in cuda codegen for fp16 --- src/codegen/codegen_cuda.cc | 249 ++++++++++++++++++++++++++++++++++++ 1 file changed, 249 insertions(+) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 39a3ab7df0cc..63ab072f98ad 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -50,6 +50,7 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { std::string CodeGenCUDA::Finish() { if (enable_fp16_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; decl_stream << "#include \n"; decl_stream << "__device__ half max" \ "(const half a, const half b)\n" @@ -65,10 +66,258 @@ std::string CodeGenCUDA::Finish() { decl_stream << "__device__ half operator*" \ "(const volatile __half &a, const volatile __half &b)\n" "{\n return __hmul(a, b);\n}\n"; + decl_stream << "#else\n"; + decl_stream << "typedef unsigned short uint16_t;\n"; + decl_stream << "typedef unsigned char uint8_t;\n"; + decl_stream << "typedef int int32_t;\n"; + decl_stream << "typedef unsigned long long uint64_t;\n"; + decl_stream << "typedef unsigned int uint32_t;\n"; + decl_stream << "#define TVM_FORCE_INLINE inline __attribute__((always_inline))\n"; + decl_stream << "#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__\n"; + decl_stream << "#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))\n"; + decl_stream << "#define TVM_HALF_OPERATOR(RTYPE, OP) \\\n" + " TVM_XINLINE RTYPE operator OP (half a, half b) { \\\n" + " return RTYPE(float(a) OP float(b)); \\\n" + " } \\\n" + " template \\\n" + " TVM_XINLINE RTYPE operator OP (half a, T b) { \\\n" + " return RTYPE(float(a) OP float(b)); \\\n" + " } \\\n" + " template \\\n" + " TVM_XINLINE RTYPE operator OP (T a, half b) { \\\n" + " return RTYPE(float(a) OP float(b)); \\\n" + " }\n" + "\n"; + decl_stream << "#define TVM_HALF_ASSIGNOP(AOP, OP) \\\n" + " template \\\n" + " TVM_XINLINE half operator AOP (const T& a) { \\\n" + " return *this = half(float(*this) OP float(a)); \\\n" + " } \\\n" + " template \\\n" + " TVM_XINLINE half operator AOP (const volatile T& a) volatile { \\\n" + " return *this = half(float(*this) OP float(a)); \\\n" + " }\n\n"; + decl_stream << "class TVM_ALIGNED(2) half {\n" + " public:\n" + " uint16_t half_;\n" + "\n" + " static TVM_XINLINE half Binary(uint16_t value) {\n" + " half res;\n" + " res.half_ = value;\n" + " return res;\n" + " }\n" + "\n" + " TVM_XINLINE half() {}\n" + "\n" + " TVM_XINLINE half(const float& value) { constructor(value); }\n" + " TVM_XINLINE explicit half(const double& value) { constructor(value); }\n" + " TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }\n" + " TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }\n" + " TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }\n" + " TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }\n" + " TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }\n" + " TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }\n" + "\n" + " TVM_XINLINE operator float() const { \\\n" + " return float(half2float(half_)); \\\n" + " } \\\n" + " TVM_XINLINE operator float() const volatile { \\\n" + " return float(half2float(half_)); \\\n" + " }\n\n" + "\n" + " TVM_HALF_ASSIGNOP(+=, +)\n" + " TVM_HALF_ASSIGNOP(-=, -)\n" + " TVM_HALF_ASSIGNOP(*=, *)\n" + " TVM_HALF_ASSIGNOP(/=, /)\n" + "\n" + " TVM_XINLINE half operator+() {\n" + " return *this;\n" + " }\n" + "\n" + " TVM_XINLINE half operator-() {\n" + " return half(-float(*this)); \n" + " }\n" + "\n" + " TVM_XINLINE half operator=(const half& a) {\n" + " half_ = a.half_;\n" + " return a;\n" + " }\n" + "\n" + " template\n" + " TVM_XINLINE half operator=(const T& a) {\n" + " return *this = half(a); \n" + " }\n" + "\n" + " TVM_XINLINE half operator=(const half& a) volatile {\n" + " half_ = a.half_;\n" + " return a;\n" + " }\n" + "\n" + " template\n" + " TVM_XINLINE half operator=(const T& a) volatile {\n" + " return *this = half(a); \n" + " }\n" + "\n" + " private:\n" + " union Bits {\n" + " float f;\n" + " int32_t si;\n" + " uint32_t ui;\n" + " };\n" + "\n" + " static int const fp16FractionBits = 10;\n" + " static int const fp32FractionBits = 23;\n" + " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff\n" + " static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000\n" + " static int const shift = fp32FractionBits - fp16FractionBits; // == 13\n" + " static int const shiftSign = 16;\n" + " // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n" + " static int32_t const expAdjust = 127 - 15;\n" + "\n" + " static int32_t const infN = 0x7F800000; // flt32 infinity\n" + " static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift\n" + " static int32_t const minN = 0x38800000; // min flt16 normal as a flt32\n" + " static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16\n" + " static int32_t const signN = 0x80000000; // flt32 sign bit\n" + "\n" + " static int32_t const infC = infN >> shift;\n" + " static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32\n" + " static int32_t const maxC = maxN >> shift;\n" + " static int32_t const minC = minN >> shift;\n" + " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n" + "\n" + " static int32_t const mulN = 0x52000000; // (1 << 23) / minN\n" + " static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))\n" + "\n" + " static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted\n" + " static int32_t const norC = 0x00400; // min flt32 normal down shifted\n" + "\n" + " static int32_t const maxD = infC - maxC - 1;\n" + " static int32_t const minD = minC - subC - 1;\n" + "\n" + " TVM_XINLINE uint16_t float2half(const float& value) const {\n" + " Bits v;\n" + " v.f = value;\n" + " uint32_t sign = v.si & signN; // grab sign bit\n" + " v.si ^= sign; // clear sign bit from v\n" + " sign >>= shiftSign; // logical shift sign to fp16 position\n" + "\n" + " if (v.si <= maxZ) {\n" + " // Handle eventual zeros here to ensure vshift will not exceed 32 below.\n" + " v.ui = 0;\n" + " } else if (v.si < minN) {\n" + " // Handle denorms\n" + " uint32_t exp32 = v.ui >> fp32FractionBits;\n" + " int32_t exp16 = exp32 - expAdjust;\n" + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n" + " // Smaller (so negative) exp16 values should result in greater right shifts.\n" + " uint32_t vshift = 1 - exp16;\n" + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n" + " v.ui = significand >> vshift;\n" + " } else if (v.si <= maxN) {\n" + " // Handle norms\n" + " v.ui -= expAdjust << fp32FractionBits;\n" + " } else if (v.si <= infN) {\n" + " v.si = infN;\n" + " } else if (v.si < nanN) {\n" + " v.si = nanN;\n" + " }\n" + "\n" + " v.ui >>= shift;\n" + " return sign | (v.ui & 0x7fff);\n" + " }\n" + "\n" + " // Same as above routine, except for addition of volatile keyword\n" + " TVM_XINLINE uint16_t float2half(const volatile float& value) const volatile { \n" + " Bits v;\n" + " v.f = value;\n" + " uint32_t sign = v.si & signN; // grab sign bit\n" + " v.si ^= sign; // clear sign bit from v\n" + " sign >>= shiftSign; // logical shift sign to fp16 position\n" + "\n" + " if (v.si <= maxZ) {\n" + " // Handle eventual zeros here to ensure vshift will not exceed 32 below.\n" + " v.ui = 0;\n" + " } else if (v.si < minN) {\n" + " // Handle denorms\n" + " uint32_t exp32 = v.ui >> fp32FractionBits;\n" + " int32_t exp16 = exp32 - expAdjust;\n" + " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n" + " // Smaller (so negative) exp16 values should result in greater right shifts.\n" + " uint32_t vshift = 1 - exp16;\n" + " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n" + " v.ui = significand >> vshift;\n" + " } else if (v.si <= maxN) {\n" + " // Handle norms\n" + " v.ui -= expAdjust << fp32FractionBits;\n" + " } else if (v.si <= infN) {\n" + " v.si = infN;\n" + " } else if (v.si < nanN) {\n" + " v.si = nanN;\n" + " }\n" + "\n" + " v.ui >>= shift;\n" + " return sign | (v.ui & 0x7fff);\n" + " }\n" + "\n" + " TVM_XINLINE float half2float(const uint16_t& value) const {\n" + " Bits v;\n" + " v.ui = value;\n" + " int32_t sign = v.si & signC;\n" + " v.si ^= sign;\n" + " sign <<= shiftSign;\n" + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n" + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n" + " Bits s;\n" + " s.si = mulC;\n" + " s.f *= v.si;\n" + " int32_t mask = -(norC > v.si);\n" + " v.si <<= shift;\n" + " v.si ^= (s.si ^ v.si) & mask;\n" + " v.si |= sign;\n" + " return v.f;\n" + " }\n" + "\n" + " TVM_XINLINE float half2float(const volatile uint16_t& value) const volatile { \n" + " Bits v;\n" + " v.ui = value;\n" + " int32_t sign = v.si & signC;\n" + " v.si ^= sign;\n" + " sign <<= shiftSign;\n" + " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n" + " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n" + " Bits s;\n" + " s.si = mulC;\n" + " s.f *= v.si;\n" + " int32_t mask = -(norC > v.si);\n" + " v.si <<= shift;\n" + " v.si ^= (s.si ^ v.si) & mask;\n" + " v.si |= sign;\n" + " return v.f;\n" + " }\n" + "\n" + " template\n" + " TVM_XINLINE void constructor(const T& value) {\n" + " half_ = float2half(float(value)); \n" + " }\n" + "};\n" + "\n" + + "TVM_HALF_OPERATOR(half, +)\n" + "TVM_HALF_OPERATOR(half, -)\n" + "TVM_HALF_OPERATOR(half, *)\n" + "TVM_HALF_OPERATOR(half, /)\n" + "TVM_HALF_OPERATOR(bool, >)\n" + "TVM_HALF_OPERATOR(bool, <)\n" + "TVM_HALF_OPERATOR(bool, >=)\n" + "TVM_HALF_OPERATOR(bool, <=)\n"; + decl_stream << "#endif\n\n"; } if (enable_int8_) { + decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n"; decl_stream << "#include \n"; + decl_stream << "#endif\n"; } if (need_math_constants_h_) { From 50b86844fd61990680b2df48ca77f2b950e73415 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 6 Nov 2019 17:00:14 -0800 Subject: [PATCH 2/5] fix lint --- src/codegen/codegen_cuda.cc | 50 ++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 39701defe0d8..c3eebabf2e2e 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -168,21 +168,29 @@ std::string CodeGenCUDA::Finish() { "\n" " static int const fp16FractionBits = 10;\n" " static int const fp32FractionBits = 23;\n" - " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff\n" - " static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000\n" - " static int const shift = fp32FractionBits - fp16FractionBits; // == 13\n" + " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);" + " // == 0x7fffff\n" + " static int32_t const fp32HiddenBit = 1 << fp32FractionBits;" + " // == 0x800000\n" + " static int const shift = fp32FractionBits - fp16FractionBits;" + " // == 13\n" " static int const shiftSign = 16;\n" - " // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n" - " static int32_t const expAdjust = 127 - 15;\n" + " static int32_t const expAdjust = 127 - 15;" + " // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n" "\n" - " static int32_t const infN = 0x7F800000; // flt32 infinity\n" - " static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift\n" - " static int32_t const minN = 0x38800000; // min flt16 normal as a flt32\n" - " static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16\n" + " static int32_t const infN = 0x7F800000;" + " // flt32 infinity\n" + " static int32_t const maxN = 0x477FFFFF;" + " // max flt32 that's a flt16 normal after >> by shift\n" + " static int32_t const minN = 0x38800000;" + " // min flt16 normal as a flt32\n" + " static int32_t const maxZ = 0x33000000;" + " // max fp32 number that's still rounded to zero in fp16\n" " static int32_t const signN = 0x80000000; // flt32 sign bit\n" "\n" " static int32_t const infC = infN >> shift;\n" - " static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32\n" + " static int32_t const nanN = (infC + 1) << shift;" + " // minimum flt16 nan as a flt32\n" " static int32_t const maxC = maxN >> shift;\n" " static int32_t const minC = minN >> shift;\n" " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n" @@ -204,14 +212,16 @@ std::string CodeGenCUDA::Finish() { " sign >>= shiftSign; // logical shift sign to fp16 position\n" "\n" " if (v.si <= maxZ) {\n" - " // Handle eventual zeros here to ensure vshift will not exceed 32 below.\n" + " // Handle eventual zeros here to ensure\n" + " // vshift will not exceed 32 below.\n" " v.ui = 0;\n" " } else if (v.si < minN) {\n" " // Handle denorms\n" " uint32_t exp32 = v.ui >> fp32FractionBits;\n" " int32_t exp16 = exp32 - expAdjust;\n" - " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n" - " // Smaller (so negative) exp16 values should result in greater right shifts.\n" + // If exp16 == 0 (just into the denorm range), + // then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. " uint32_t vshift = 1 - exp16;\n" " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n" " v.ui = significand >> vshift;\n" @@ -229,7 +239,8 @@ std::string CodeGenCUDA::Finish() { " }\n" "\n" " // Same as above routine, except for addition of volatile keyword\n" - " TVM_XINLINE uint16_t float2half(const volatile float& value) const volatile { \n" + " TVM_XINLINE uint16_t float2half(\n" + " const volatile float& value) const volatile {\n" " Bits v;\n" " v.f = value;\n" " uint32_t sign = v.si & signN; // grab sign bit\n" @@ -237,14 +248,16 @@ std::string CodeGenCUDA::Finish() { " sign >>= shiftSign; // logical shift sign to fp16 position\n" "\n" " if (v.si <= maxZ) {\n" - " // Handle eventual zeros here to ensure vshift will not exceed 32 below.\n" + " // Handle eventual zeros here to ensure\n" + " // vshift will not exceed 32 below.\n" " v.ui = 0;\n" " } else if (v.si < minN) {\n" " // Handle denorms\n" " uint32_t exp32 = v.ui >> fp32FractionBits;\n" " int32_t exp16 = exp32 - expAdjust;\n" - " // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.\n" - " // Smaller (so negative) exp16 values should result in greater right shifts.\n" + // If exp16 == 0 (just into the denorm range), + // then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. " uint32_t vshift = 1 - exp16;\n" " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n" " v.ui = significand >> vshift;\n" @@ -279,7 +292,8 @@ std::string CodeGenCUDA::Finish() { " return v.f;\n" " }\n" "\n" - " TVM_XINLINE float half2float(const volatile uint16_t& value) const volatile { \n" + " TVM_XINLINE float half2float(\n" + " const volatile uint16_t& value) const volatile {\n" " Bits v;\n" " v.ui = value;\n" " int32_t sign = v.si & signC;\n" From 406c4f5000d87c625372095a0f43d09e156750ad Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 7 Nov 2019 19:15:17 -0800 Subject: [PATCH 3/5] use c++ literal --- src/codegen/codegen_cuda.cc | 262 +--------------------------- src/codegen/literal/cuda_half_t.txt | 251 ++++++++++++++++++++++++++ 2 files changed, 255 insertions(+), 258 deletions(-) create mode 100644 src/codegen/literal/cuda_half_t.txt diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index c3eebabf2e2e..ce15922a1400 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -50,6 +50,9 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { std::string CodeGenCUDA::Finish() { if (enable_fp16_) { + static constexpr const char* _cuda_half_t_def = + #include "literal/cuda_half_t.txt" + ; decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; decl_stream << "#include \n"; decl_stream << "__device__ half max" @@ -68,264 +71,7 @@ std::string CodeGenCUDA::Finish() { << "{\n return __hmul(a, b);\n}\n"; // otherwise simulate computation via float32 decl_stream << "#else\n"; - decl_stream << "typedef unsigned short uint16_t;\n"; - decl_stream << "typedef unsigned char uint8_t;\n"; - decl_stream << "typedef int int32_t;\n"; - decl_stream << "typedef unsigned long long uint64_t;\n"; - decl_stream << "typedef unsigned int uint32_t;\n"; - decl_stream << "#define TVM_FORCE_INLINE inline __attribute__((always_inline))\n"; - decl_stream << "#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__\n"; - decl_stream << "#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))\n"; - decl_stream << "#define TVM_HALF_OPERATOR(RTYPE, OP) \\\n" - " TVM_XINLINE RTYPE operator OP (half a, half b) { \\\n" - " return RTYPE(float(a) OP float(b)); \\\n" - " } \\\n" - " template \\\n" - " TVM_XINLINE RTYPE operator OP (half a, T b) { \\\n" - " return RTYPE(float(a) OP float(b)); \\\n" - " } \\\n" - " template \\\n" - " TVM_XINLINE RTYPE operator OP (T a, half b) { \\\n" - " return RTYPE(float(a) OP float(b)); \\\n" - " }\n" - "\n"; - decl_stream << "#define TVM_HALF_ASSIGNOP(AOP, OP) \\\n" - " template \\\n" - " TVM_XINLINE half operator AOP (const T& a) { \\\n" - " return *this = half(float(*this) OP float(a)); \\\n" - " } \\\n" - " template \\\n" - " TVM_XINLINE half operator AOP (const volatile T& a) volatile { \\\n" - " return *this = half(float(*this) OP float(a)); \\\n" - " }\n\n"; - decl_stream << "class TVM_ALIGNED(2) half {\n" - " public:\n" - " uint16_t half_;\n" - "\n" - " static TVM_XINLINE half Binary(uint16_t value) {\n" - " half res;\n" - " res.half_ = value;\n" - " return res;\n" - " }\n" - "\n" - " TVM_XINLINE half() {}\n" - "\n" - " TVM_XINLINE half(const float& value) { constructor(value); }\n" - " TVM_XINLINE explicit half(const double& value) { constructor(value); }\n" - " TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }\n" - " TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }\n" - " TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }\n" - " TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }\n" - " TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }\n" - " TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }\n" - "\n" - " TVM_XINLINE operator float() const { \\\n" - " return float(half2float(half_)); \\\n" - " } \\\n" - " TVM_XINLINE operator float() const volatile { \\\n" - " return float(half2float(half_)); \\\n" - " }\n\n" - "\n" - " TVM_HALF_ASSIGNOP(+=, +)\n" - " TVM_HALF_ASSIGNOP(-=, -)\n" - " TVM_HALF_ASSIGNOP(*=, *)\n" - " TVM_HALF_ASSIGNOP(/=, /)\n" - "\n" - " TVM_XINLINE half operator+() {\n" - " return *this;\n" - " }\n" - "\n" - " TVM_XINLINE half operator-() {\n" - " return half(-float(*this)); \n" - " }\n" - "\n" - " TVM_XINLINE half operator=(const half& a) {\n" - " half_ = a.half_;\n" - " return a;\n" - " }\n" - "\n" - " template\n" - " TVM_XINLINE half operator=(const T& a) {\n" - " return *this = half(a); \n" - " }\n" - "\n" - " TVM_XINLINE half operator=(const half& a) volatile {\n" - " half_ = a.half_;\n" - " return a;\n" - " }\n" - "\n" - " template\n" - " TVM_XINLINE half operator=(const T& a) volatile {\n" - " return *this = half(a); \n" - " }\n" - "\n" - " private:\n" - " union Bits {\n" - " float f;\n" - " int32_t si;\n" - " uint32_t ui;\n" - " };\n" - "\n" - " static int const fp16FractionBits = 10;\n" - " static int const fp32FractionBits = 23;\n" - " static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);" - " // == 0x7fffff\n" - " static int32_t const fp32HiddenBit = 1 << fp32FractionBits;" - " // == 0x800000\n" - " static int const shift = fp32FractionBits - fp16FractionBits;" - " // == 13\n" - " static int const shiftSign = 16;\n" - " static int32_t const expAdjust = 127 - 15;" - " // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)\n" - "\n" - " static int32_t const infN = 0x7F800000;" - " // flt32 infinity\n" - " static int32_t const maxN = 0x477FFFFF;" - " // max flt32 that's a flt16 normal after >> by shift\n" - " static int32_t const minN = 0x38800000;" - " // min flt16 normal as a flt32\n" - " static int32_t const maxZ = 0x33000000;" - " // max fp32 number that's still rounded to zero in fp16\n" - " static int32_t const signN = 0x80000000; // flt32 sign bit\n" - "\n" - " static int32_t const infC = infN >> shift;\n" - " static int32_t const nanN = (infC + 1) << shift;" - " // minimum flt16 nan as a flt32\n" - " static int32_t const maxC = maxN >> shift;\n" - " static int32_t const minC = minN >> shift;\n" - " static int32_t const signC = signN >> shiftSign; // flt16 sign bit\n" - "\n" - " static int32_t const mulN = 0x52000000; // (1 << 23) / minN\n" - " static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift))\n" - "\n" - " static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted\n" - " static int32_t const norC = 0x00400; // min flt32 normal down shifted\n" - "\n" - " static int32_t const maxD = infC - maxC - 1;\n" - " static int32_t const minD = minC - subC - 1;\n" - "\n" - " TVM_XINLINE uint16_t float2half(const float& value) const {\n" - " Bits v;\n" - " v.f = value;\n" - " uint32_t sign = v.si & signN; // grab sign bit\n" - " v.si ^= sign; // clear sign bit from v\n" - " sign >>= shiftSign; // logical shift sign to fp16 position\n" - "\n" - " if (v.si <= maxZ) {\n" - " // Handle eventual zeros here to ensure\n" - " // vshift will not exceed 32 below.\n" - " v.ui = 0;\n" - " } else if (v.si < minN) {\n" - " // Handle denorms\n" - " uint32_t exp32 = v.ui >> fp32FractionBits;\n" - " int32_t exp16 = exp32 - expAdjust;\n" - // If exp16 == 0 (just into the denorm range), - // then significant should be shifted right 1. - // Smaller (so negative) exp16 values should result in greater right shifts. - " uint32_t vshift = 1 - exp16;\n" - " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n" - " v.ui = significand >> vshift;\n" - " } else if (v.si <= maxN) {\n" - " // Handle norms\n" - " v.ui -= expAdjust << fp32FractionBits;\n" - " } else if (v.si <= infN) {\n" - " v.si = infN;\n" - " } else if (v.si < nanN) {\n" - " v.si = nanN;\n" - " }\n" - "\n" - " v.ui >>= shift;\n" - " return sign | (v.ui & 0x7fff);\n" - " }\n" - "\n" - " // Same as above routine, except for addition of volatile keyword\n" - " TVM_XINLINE uint16_t float2half(\n" - " const volatile float& value) const volatile {\n" - " Bits v;\n" - " v.f = value;\n" - " uint32_t sign = v.si & signN; // grab sign bit\n" - " v.si ^= sign; // clear sign bit from v\n" - " sign >>= shiftSign; // logical shift sign to fp16 position\n" - "\n" - " if (v.si <= maxZ) {\n" - " // Handle eventual zeros here to ensure\n" - " // vshift will not exceed 32 below.\n" - " v.ui = 0;\n" - " } else if (v.si < minN) {\n" - " // Handle denorms\n" - " uint32_t exp32 = v.ui >> fp32FractionBits;\n" - " int32_t exp16 = exp32 - expAdjust;\n" - // If exp16 == 0 (just into the denorm range), - // then significant should be shifted right 1. - // Smaller (so negative) exp16 values should result in greater right shifts. - " uint32_t vshift = 1 - exp16;\n" - " uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);\n" - " v.ui = significand >> vshift;\n" - " } else if (v.si <= maxN) {\n" - " // Handle norms\n" - " v.ui -= expAdjust << fp32FractionBits;\n" - " } else if (v.si <= infN) {\n" - " v.si = infN;\n" - " } else if (v.si < nanN) {\n" - " v.si = nanN;\n" - " }\n" - "\n" - " v.ui >>= shift;\n" - " return sign | (v.ui & 0x7fff);\n" - " }\n" - "\n" - " TVM_XINLINE float half2float(const uint16_t& value) const {\n" - " Bits v;\n" - " v.ui = value;\n" - " int32_t sign = v.si & signC;\n" - " v.si ^= sign;\n" - " sign <<= shiftSign;\n" - " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n" - " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n" - " Bits s;\n" - " s.si = mulC;\n" - " s.f *= v.si;\n" - " int32_t mask = -(norC > v.si);\n" - " v.si <<= shift;\n" - " v.si ^= (s.si ^ v.si) & mask;\n" - " v.si |= sign;\n" - " return v.f;\n" - " }\n" - "\n" - " TVM_XINLINE float half2float(\n" - " const volatile uint16_t& value) const volatile {\n" - " Bits v;\n" - " v.ui = value;\n" - " int32_t sign = v.si & signC;\n" - " v.si ^= sign;\n" - " sign <<= shiftSign;\n" - " v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);\n" - " v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);\n" - " Bits s;\n" - " s.si = mulC;\n" - " s.f *= v.si;\n" - " int32_t mask = -(norC > v.si);\n" - " v.si <<= shift;\n" - " v.si ^= (s.si ^ v.si) & mask;\n" - " v.si |= sign;\n" - " return v.f;\n" - " }\n" - "\n" - " template\n" - " TVM_XINLINE void constructor(const T& value) {\n" - " half_ = float2half(float(value)); \n" - " }\n" - "};\n" - "\n" - - "TVM_HALF_OPERATOR(half, +)\n" - "TVM_HALF_OPERATOR(half, -)\n" - "TVM_HALF_OPERATOR(half, *)\n" - "TVM_HALF_OPERATOR(half, /)\n" - "TVM_HALF_OPERATOR(bool, >)\n" - "TVM_HALF_OPERATOR(bool, <)\n" - "TVM_HALF_OPERATOR(bool, >=)\n" - "TVM_HALF_OPERATOR(bool, <=)\n"; + decl_stream << _cuda_half_t_def; decl_stream << "#endif\n\n"; } diff --git a/src/codegen/literal/cuda_half_t.txt b/src/codegen/literal/cuda_half_t.txt new file mode 100644 index 000000000000..72b8c1fe9149 --- /dev/null +++ b/src/codegen/literal/cuda_half_t.txt @@ -0,0 +1,251 @@ +R"( +typedef unsigned short uint16_t; +typedef unsigned char uint8_t; +typedef int int32_t; +typedef unsigned long long uint64_t; +typedef unsigned int uint32_t; + +#define TVM_FORCE_INLINE inline __attribute__((always_inline)) +#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__ +#define TVM_ALIGNED(x) __attribute__ ((aligned(x))) +#define TVM_HALF_OPERATOR(RTYPE, OP) \ + TVM_XINLINE RTYPE operator OP (half a, half b) { \ + return RTYPE(float(a) OP float(b)); \ + } \ + template \ + TVM_XINLINE RTYPE operator OP (half a, T b) { \ + return RTYPE(float(a) OP float(b)); \ + } \ + template \ + TVM_XINLINE RTYPE operator OP (T a, half b) { \ + return RTYPE(float(a) OP float(b)); \ + } + +#define TVM_HALF_ASSIGNOP(AOP, OP) \ + template \ + TVM_XINLINE half operator AOP (const T& a) { \ + return *this = half(float(*this) OP float(a)); \ + } \ + template \ + TVM_XINLINE half operator AOP (const volatile T& a) volatile { \ + return *this = half(float(*this) OP float(a)); \ + } + +class TVM_ALIGNED(2) half { + public: + uint16_t half_; + + static TVM_XINLINE half Binary(uint16_t value) { + half res; + res.half_ = value; + return res; + } + + TVM_XINLINE half() {} + + TVM_XINLINE half(const float& value) { constructor(value); } + TVM_XINLINE explicit half(const double& value) { constructor(value); } + TVM_XINLINE explicit half(const int8_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); } + TVM_XINLINE explicit half(const int32_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); } + TVM_XINLINE explicit half(const int64_t& value) { constructor(value); } + TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); } + + TVM_XINLINE operator float() const { \ + return float(half2float(half_)); \ + } \ + TVM_XINLINE operator float() const volatile { \ + return float(half2float(half_)); \ + } + + + TVM_HALF_ASSIGNOP(+=, +) + TVM_HALF_ASSIGNOP(-=, -) + TVM_HALF_ASSIGNOP(*=, *) + TVM_HALF_ASSIGNOP(/=, /) + + TVM_XINLINE half operator+() { + return *this; + } + + TVM_XINLINE half operator-() { + return half(-float(*this)); + } + + TVM_XINLINE half operator=(const half& a) { + half_ = a.half_; + return a; + } + + template + TVM_XINLINE half operator=(const T& a) { + return *this = half(a); + } + + TVM_XINLINE half operator=(const half& a) volatile { + half_ = a.half_; + return a; + } + + template + TVM_XINLINE half operator=(const T& a) volatile { + return *this = half(a); + } + + private: + union Bits { + float f; + int32_t si; + uint32_t ui; + }; + + static int const fp16FractionBits = 10; + static int const fp32FractionBits = 23; + static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff + static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 + static int const shift = fp32FractionBits - fp16FractionBits; // == 13 + static int const shiftSign = 16; + static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) + + static int32_t const infN = 0x7F800000; // flt32 infinity + static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift + static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 + static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 + static int32_t const signN = 0x80000000; // flt32 sign bit + + static int32_t const infC = infN >> shift; + static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 + static int32_t const maxC = maxN >> shift; + static int32_t const minC = minN >> shift; + static int32_t const signC = signN >> shiftSign; // flt16 sign bit + + static int32_t const mulN = 0x52000000; // (1 << 23) / minN + static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) + + static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted + static int32_t const norC = 0x00400; // min flt32 normal down shifted + + static int32_t const maxD = infC - maxC - 1; + static int32_t const minD = minC - subC - 1; + + TVM_XINLINE uint16_t float2half(const float& value) const { + Bits v; + v.f = value; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure + // vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; + } else if (v.si <= maxN) { + // Handle norms + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); + } + + // Same as above routine, except for addition of volatile keyword + TVM_XINLINE uint16_t float2half( + const volatile float& value) const volatile { + Bits v; + v.f = value; + uint32_t sign = v.si & signN; // grab sign bit + v.si ^= sign; // clear sign bit from v + sign >>= shiftSign; // logical shift sign to fp16 position + + if (v.si <= maxZ) { + // Handle eventual zeros here to ensure + // vshift will not exceed 32 below. + v.ui = 0; + } else if (v.si < minN) { + // Handle denorms + uint32_t exp32 = v.ui >> fp32FractionBits; + int32_t exp16 = exp32 - expAdjust; + // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. + // Smaller (so negative) exp16 values should result in greater right shifts. + uint32_t vshift = 1 - exp16; + uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); + v.ui = significand >> vshift; + } else if (v.si <= maxN) { + // Handle norms + v.ui -= expAdjust << fp32FractionBits; + } else if (v.si <= infN) { + v.si = infN; + } else if (v.si < nanN) { + v.si = nanN; + } + + v.ui >>= shift; + return sign | (v.ui & 0x7fff); + } + + TVM_XINLINE float half2float(const uint16_t& value) const { + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + TVM_XINLINE float half2float( + const volatile uint16_t& value) const volatile { + Bits v; + v.ui = value; + int32_t sign = v.si & signC; + v.si ^= sign; + sign <<= shiftSign; + v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); + v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); + Bits s; + s.si = mulC; + s.f *= v.si; + int32_t mask = -(norC > v.si); + v.si <<= shift; + v.si ^= (s.si ^ v.si) & mask; + v.si |= sign; + return v.f; + } + + template + TVM_XINLINE void constructor(const T& value) { + half_ = float2half(float(value)); + } +}; + +TVM_HALF_OPERATOR(half, +) +TVM_HALF_OPERATOR(half, -) +TVM_HALF_OPERATOR(half, *) +TVM_HALF_OPERATOR(half, /) +TVM_HALF_OPERATOR(bool, >) +TVM_HALF_OPERATOR(bool, <) +TVM_HALF_OPERATOR(bool, >=) +TVM_HALF_OPERATOR(bool, <=) +)" \ No newline at end of file From b105241697c4e2ab4c1ce1de66f9c0a293579361 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 7 Nov 2019 19:21:04 -0800 Subject: [PATCH 4/5] fix lint --- src/codegen/codegen_cuda.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index ce15922a1400..c42c452e3aeb 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -52,7 +52,7 @@ std::string CodeGenCUDA::Finish() { if (enable_fp16_) { static constexpr const char* _cuda_half_t_def = #include "literal/cuda_half_t.txt" - ; + ; // NOLINT(*) decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; decl_stream << "#include \n"; decl_stream << "__device__ half max" From 8625032b35d42c85416790b5095a0183657f7cda Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 8 Nov 2019 09:20:26 -0800 Subject: [PATCH 5/5] str literal move to .h file --- src/codegen/codegen_cuda.cc | 4 +-- .../{cuda_half_t.txt => cuda_half_t.h} | 33 +++++++++++++++++-- 2 files changed, 32 insertions(+), 5 deletions(-) rename src/codegen/literal/{cuda_half_t.txt => cuda_half_t.h} (88%) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index c42c452e3aeb..22e8d842e424 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -27,6 +27,7 @@ #include #include #include +#include "literal/cuda_half_t.h" #include "codegen_cuda.h" namespace tvm { @@ -50,9 +51,6 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) { std::string CodeGenCUDA::Finish() { if (enable_fp16_) { - static constexpr const char* _cuda_half_t_def = - #include "literal/cuda_half_t.txt" - ; // NOLINT(*) decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n"; decl_stream << "#include \n"; decl_stream << "__device__ half max" diff --git a/src/codegen/literal/cuda_half_t.txt b/src/codegen/literal/cuda_half_t.h similarity index 88% rename from src/codegen/literal/cuda_half_t.txt rename to src/codegen/literal/cuda_half_t.h index 72b8c1fe9149..23075b0b6e76 100644 --- a/src/codegen/literal/cuda_half_t.txt +++ b/src/codegen/literal/cuda_half_t.h @@ -1,4 +1,31 @@ -R"( +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file cuda_half_t.h + * \brief half_t (fp16) definition for cuda codegen. + */ +#ifndef TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_ +#define TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_ + +static constexpr const char* _cuda_half_t_def = R"( typedef unsigned short uint16_t; typedef unsigned char uint8_t; typedef int int32_t; @@ -248,4 +275,6 @@ TVM_HALF_OPERATOR(bool, >) TVM_HALF_OPERATOR(bool, <) TVM_HALF_OPERATOR(bool, >=) TVM_HALF_OPERATOR(bool, <=) -)" \ No newline at end of file +)"; + +#endif // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_