From f69127758dc68bcb34c1c82ea8082fd3723f3050 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A9=AC=E8=B4=A4=E8=BF=AA?= Date: Thu, 6 Aug 2020 22:11:04 +0800 Subject: [PATCH] fix cuda half math function is undefined: hpow, htanh --- src/target/source/literal/cuda_half_t.h | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index baf4ba733dce..422d2c0d0f72 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -293,6 +293,19 @@ __pack_half2(const half x, const half y) { unsigned v1 = *((unsigned short *)&y); return (v1 << 16) | v0; } + +static inline __device__ __host__ half hpow(half x, half y) { + float tmp_x = __half2float(x); + float tmp_y = __half2float(y); + float result = powf(tmp_x, tmp_y); + return __float2half(result); +} + +static inline __device__ __host__ half htanh(half x) { + float tmp_x = __half2float(x); + float result = tanhf(tmp_x); + return __float2half(result); +} )"; static constexpr const char* _cuda_warp_intrinsic_util = R"(