From c0d11fd11ccf44dbcad919781aae8dc7b8f6f575 Mon Sep 17 00:00:00 2001 From: cloud-mxd Date: Thu, 13 Aug 2020 23:36:59 +0800 Subject: [PATCH] fix cuda half math function is undefined: hpow, htanh (#6253) --- src/target/source/literal/cuda_half_t.h | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index baf4ba733dce..f8e92d508d88 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -293,6 +293,22 @@ __pack_half2(const half x, const half y) { unsigned v1 = *((unsigned short *)&y); return (v1 << 16) | v0; } + +// fix undefined fp16 match function +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) +static inline __device__ __host__ half hpow(half x, half y) { + float tmp_x = __half2float(x); + float tmp_y = __half2float(y); + float result = powf(tmp_x, tmp_y); + return __float2half(result); +} + +static inline __device__ __host__ half htanh(half x) { + float tmp_x = __half2float(x); + float result = tanhf(tmp_x); + return __float2half(result); +} +#endif )"; static constexpr const char* _cuda_warp_intrinsic_util = R"(