Skip to content

Commit

Permalink
fix cuda half math function is undefined: hpow, htanh (apache#6225)
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-mxd authored and trevor-m committed Sep 3, 2020
1 parent bfc1d97 commit 9ef4b07
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,19 @@ __pack_half2(const half x, const half y) {
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
static inline __device__ __host__ half hpow(half x, half y) {
float tmp_x = __half2float(x);
float tmp_y = __half2float(y);
float result = powf(tmp_x, tmp_y);
return __float2half(result);
}
static inline __device__ __host__ half htanh(half x) {
float tmp_x = __half2float(x);
float result = tanhf(tmp_x);
return __float2half(result);
}
)";

static constexpr const char* _cuda_warp_intrinsic_util = R"(
Expand Down

0 comments on commit 9ef4b07

Please sign in to comment.