Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cuda half math function is undefined: hpow, htanh #6225

Merged
merged 1 commit into from
Aug 10, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,19 @@ __pack_half2(const half x, const half y) {
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}

static inline __device__ __host__ half hpow(half x, half y) {
float tmp_x = __half2float(x);
float tmp_y = __half2float(y);
float result = powf(tmp_x, tmp_y);
return __float2half(result);
}

static inline __device__ __host__ half htanh(half x) {
float tmp_x = __half2float(x);
float result = tanhf(tmp_x);
return __float2half(result);
}
)";

static constexpr const char* _cuda_warp_intrinsic_util = R"(
Expand Down