Skip to content

Commit

Permalink
Fix compilation error on HIP due to KERNEL_FLOAT_FAST_F32_MAP
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Nov 26, 2024
1 parent 5c859b9 commit 76c695a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
13 changes: 7 additions & 6 deletions include/kernel_float/unops.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan)

KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2)

KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt)

// This PTX is only supported on CUDA
#if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE
#if KERNEL_FLOAT_IS_DEVICE
#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \
namespace detail { \
template<> \
Expand All @@ -245,6 +242,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input))
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input))
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input))

// This PTX is only supported on CUDA
#if KERNEL_FLOAT_IS_CUDA
#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \
namespace detail { \
template<> \
Expand All @@ -261,7 +260,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f")
#endif

#define KERNEL_FLOAT_FAST_F32_MAP(F) \
F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt)
Expand All @@ -270,7 +270,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f")

#else
#define KERNEL_FLOAT_FAST_F32_MAP(F)
#endif

} // namespace kernel_float
Expand Down
17 changes: 9 additions & 8 deletions single_include/kernel_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

//================================================================================
// this file has been auto-generated, do not modify its contents!
// date: 2024-11-20 10:36:45.284577
// git hash: 76501fda40df9e396998d11840bc8f10b11ea47b
// date: 2024-11-26 13:52:06.286983
// git hash: c4c6ac09808d14b5407afb06ecdecd235cd50ed3
//================================================================================

#ifndef KERNEL_FLOAT_MACROS_H
Expand Down Expand Up @@ -1397,16 +1397,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan)

KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2)

KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp)
KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt)

// This PTX is only supported on CUDA
#if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE
#if KERNEL_FLOAT_IS_DEVICE
#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \
namespace detail { \
template<> \
Expand All @@ -1430,6 +1427,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input))
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input))
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input))

// This PTX is only supported on CUDA
#if KERNEL_FLOAT_IS_CUDA
#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \
namespace detail { \
template<> \
Expand All @@ -1446,7 +1445,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32", "f")
#endif

#define KERNEL_FLOAT_FAST_F32_MAP(F) \
F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt)
Expand All @@ -1455,7 +1455,8 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f")

#else
#define KERNEL_FLOAT_FAST_F32_MAP(F)
#endif

} // namespace kernel_float
Expand Down

0 comments on commit 76c695a

Please sign in to comment.