diff --git a/include/kernel_float.h b/include/kernel_float.h index ee098fc..be7e21c 100644 --- a/include/kernel_float.h +++ b/include/kernel_float.h @@ -1,6 +1,7 @@ #ifndef KERNEL_FLOAT_H #define KERNEL_FLOAT_H +#include "kernel_float/approx.h" #include "kernel_float/base.h" #include "kernel_float/bf16.h" #include "kernel_float/binops.h" diff --git a/include/kernel_float/apply.h b/include/kernel_float/apply.h index 3a7e02c..1421132 100644 --- a/include/kernel_float/apply.h +++ b/include/kernel_float/apply.h @@ -130,6 +130,9 @@ struct apply_impl { template struct apply_fastmath_impl: apply_impl {}; + +template +struct apply_approx_impl: apply_fastmath_impl {}; } // namespace detail struct accurate_policy { @@ -142,6 +145,14 @@ struct fast_policy { using type = detail::apply_fastmath_impl; }; +template +struct approximate_policy { + template + using type = detail::apply_approx_impl; +}; + +using default_approximate_policy = approximate_policy<>; + #ifdef KERNEL_FLOAT_POLICY using default_policy = KERNEL_FLOAT_POLICY; #else diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h new file mode 100644 index 0000000..6725faf --- /dev/null +++ b/include/kernel_float/approx.h @@ -0,0 +1,391 @@ +#pragma once + +#include "apply.h" +#include "bf16.h" +#include "fp16.h" +#include "macros.h" + +namespace kernel_float { + +namespace approx { + +static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +using uint32_t = unsigned int; + +template +KERNEL_FLOAT_DEVICE T transmute(const U& input) { + static_assert(sizeof(T) == sizeof(U), "types must have equal size"); + T result {}; + ::memcpy(&result, &input, sizeof(T)); + return result; +} + +KERNEL_FLOAT_DEVICE uint32_t +bitwise_if_else(uint32_t condition, uint32_t if_true, uint32_t if_false) { + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + // equivalent to (condition & if_true) | ((~condition) & if_false) + asm("lop3.b32 %0, %1, %2, %3, 0xCA;" + : "=r"(result) + : "r"(condition), "r"(if_true), "r"(if_false)); +#else + result = (condition & if_true) | ((~condition) & if_false); +#endif + return result; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x) { + return y; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x, T coef, TRest... coefs) { + y = __hfma2(x, y, T2 {coef, coef}); + return eval_poly_recur(y, x, coefs...); +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly(T2 x, T coef, TRest... coefs) { + return eval_poly_recur(T2 {coef, coef}, x, coefs...); +} + +#define KERNEL_FLOAT_DEFINE_POLY(NAME, N, ...) \ + template \ + struct NAME { \ + template \ + static KERNEL_FLOAT_DEVICE T2 call(T2 x) { \ + return eval_poly(x, __VA_ARGS__); \ + } \ + }; + +template +struct sin_poly: sin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 1, 1.365) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 2, -21.56, 5.18) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 3, 53.53, -38.06, 6.184) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 4, -56.1, 77.94, -41.1, 6.277) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 5, 32.78, -74.5, 81.4, -41.34, 6.28) + +template +struct cos_poly: cos_poly {}; +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 1, 0.0) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 2, -8.0, 0.6943) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 3, 38.94, -17.5, 0.9707) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 4, -59.66, 61.12, -19.56, 0.9985) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 5, 45.66, -82.4, 64.7, -19.73, 1.0) + +template +struct asin_poly: asin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 1, 1.531) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 2, -0.169, 1.567) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { + // Flip signbit of input when sign<0 + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + asm("lop3.b32 %0, %1, %2, %3, 0x6A;" + : "=r"(result) + : "r"(0x80008000), "r"(transmute(sign)), "r"(transmute(input))); +#else + result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); +#endif + + return transmute<__half2>(result); +} + +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { + uint32_t val; +#if KERNEL_FLOAT_IS_CUDA + uint32_t ai = *(reinterpret_cast(&a)); + uint32_t bi = *(reinterpret_cast(&b)); + asm("{ set.gt.u32.f16x2 %0,%1,%2;\n}" : "=r"(val) : "r"(ai), "r"(bi)); +#else + val = transmute(make_short2(a.x > b.x ? ~0 : 0, a.y > b.y ? ~0 : 0)); +#endif + return val; +} + +KERNEL_FLOAT_INLINE __half2 make_half2(half x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __half2 normalize_trig_input(__half2 x) { + /* Using rint is too slow. Round using floating-point magic instead. */ + // __half2 x = arg * make_half2(-0.15915494309); + // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); + + // 1/(2pi) = 0.15915494309189535 + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __half2 ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __half2 cos(__half2 x) { + __half2 xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __half2 sin(__half2 x) { + __half2 xf = normalize_trig_input(x); + return sin_poly::call(__hmul2(xf, xf)) * xf; +} + +template +KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { + // Flip bits + uint32_t m = ~transmute(x); + + // Multiply by bias (add contant) + __half2 y = transmute<__half2>(uint32_t(0x776d776d) + m); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + // y += y * (1 - x * y) + y = __hfma2(y, __hfma2(-x, y, make_half2(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; + + // Add bias (0x199c) + __half2 y = transmute<__half2>(uint32_t(r) + uint32_t(0x199c199c)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __half2 half_x = make_half2(-0.5) * x; + __half2 correction = __hfma2(half_x, y * y, make_half2(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { + if (Iter == 1) { + __half2 y = rsqrt<0>(x); + + // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` + __half2 xy = x * y; + return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); + } + + return x * rsqrt(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + auto abs_x = __habs2(x); + auto v = asin_poly::call(abs_x); + auto abs_y = __hfma2(-v, sqrt(make_half2(1) - abs_x), make_half2(HALF_PI)); + return flipsign(abs_y, x); +} + +template +KERNEL_FLOAT_DEVICE __half2 acos(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + return make_half2(HALF_PI) - asin(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { + __half2 y; + + if (Deg == 0) { + // Bring the value to range [32, 64] + // 1.442 = 1/log(2) + // 46.969 = 32.5/log(2) + __half2 m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + + // Transmute to int, shift higher mantissa bits into exponent field. + y = transmute<__half2>((transmute(m) & 0x03ff03ff) << 5); + } else { + // Add a large number to round to an integer + __half2 v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + + // The exponent is now in the lower 5 bits. Shift that into the exponent field. + __half2 exp = transmute<__half2>((transmute(v) & 0x001f001f) << 10); + + // The fractional part can be obtained from "1231-v". + // 0.6934 = log(2) + __half2 frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + + // This is the Taylor expansion of "exp(x)-1" around 0 + __half2 adjust; + if (Deg == 1) { + adjust = frac; + } else if (Deg == 2) { + // adjust = frac + 0.5 * frac^2 + adjust = __hfma2(frac, __hmul2(frac, make_half2(0.5)), frac); + } else /* if (Deg == 2) */ { + // adjust = frac + 0.5 * frac^2 + 1/6 * frac^3 + adjust = __hfma2( + frac, + __hmul2(__hfma2(frac, make_half2(0.1666), make_half2(0.5)), frac), + frac); + } + + // result = exp * (adjust + 1) + y = __hfma2(exp, adjust, exp); + } + + // Values below -10.39 (= -15*log(2)) become zero + uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); + return transmute<__half2>(zero_mask & transmute(y)); +} + +template +KERNEL_FLOAT_DEVICE __half2 log(__half2 arg) { + // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) + uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); + + // 0.6934 = log(2) + // 32.53 = 46.969*log(2) + return __hfma2(transmute<__half2>(bits), make_half2(0.6934), make_half2(-32.53125)); +} + +template +KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { + if (Deg == 0) { + return x * rcp<0>(make_half2(0.2869) + __habs2(x)); + } else { + auto c0 = make_half2(0.4531); + auto c1 = make_half2(0.5156); + auto x2b = __hfma2(x, x, c1); + return (x * x2b) * rcp(__hfma2(x2b, __habs2(x), c0)); + } +} + +#endif // KERNEL_FLOAT_FP16_AVAILABLE + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(__bfloat16 x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(double x) { + return {__double2bfloat16(x), __double2bfloat16(x)}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input(__nv_bfloat162 x) { + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __bfloat162 ws = __hadd2( + __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), + make_bfloat162(OFFSET)); + return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 cos(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sin(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { + __bfloat162 y = transmute<__bfloat162>(uint32_t(0x7ef07ef0) + ~transmute(x)); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + y = __hfma2(y, __hfma2(__hneg2(x), y, make_bfloat162(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias (0x1f36) + __bfloat162 y = transmute<__bfloat162>(uint32_t(r) + uint32_t(0x1f361f36)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __bfloat162 half_x = __hmul2(make_bfloat162(-0.5), x); + __bfloat162 correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sqrt(__bfloat162 x) { + return __hmul2(x, rsqrt(x)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { + static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float OFFSET = 382.4958400542335; + + auto a = fmaf(__bfloat162float(arg.x), SCALE, OFFSET); + auto b = fmaf(__bfloat162float(arg.y), SCALE, OFFSET); + + return { + transmute<__bfloat16>(uint16_t(transmute(a))), + transmute<__bfloat16>(uint16_t(transmute(b)))}; +} +#endif +} // namespace approx + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ + namespace detail { \ + template \ + struct apply_approx_impl, 2, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::FUN<__half> fun, __half* output, const __half* input) { \ + __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + template<> \ + struct apply_approx_impl<-1, ops::FUN<__half>, 2, __half, __half>: \ + apply_approx_impl, 2, __half, __half> {}; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) + +} // namespace kernel_float diff --git a/include/kernel_float/bf16.h b/include/kernel_float/bf16.h index bad5f75..6784e89 100644 --- a/include/kernel_float/bf16.h +++ b/include/kernel_float/bf16.h @@ -85,8 +85,10 @@ KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp2, ::hexp2, ::h2exp2) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log2, ::hlog2, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) diff --git a/include/kernel_float/binops.h b/include/kernel_float/binops.h index 5ac6ed9..a19ca64 100644 --- a/include/kernel_float/binops.h +++ b/include/kernel_float/binops.h @@ -359,8 +359,7 @@ template< typename L, typename R, typename T = promoted_vector_value_type, - typename = - enable_if_t> && is_vector_broadcastable>>> + typename = enable_if_t<(vector_size == 3 && vector_size == 3)>> KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { return detail::cross_impl::call(convert_storage(left), convert_storage(right)); } diff --git a/include/kernel_float/fp16.h b/include/kernel_float/fp16.h index 0e90d8d..8d94c51 100644 --- a/include/kernel_float/fp16.h +++ b/include/kernel_float/fp16.h @@ -73,8 +73,10 @@ KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp2, hexp2, h2exp2) KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log2, hlog2, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) diff --git a/include/kernel_float/macros.h b/include/kernel_float/macros.h index 01b0254..68be6e5 100644 --- a/include/kernel_float/macros.h +++ b/include/kernel_float/macros.h @@ -8,6 +8,7 @@ // clang-format off #ifdef __CUDACC__ #define KERNEL_FLOAT_IS_CUDA (1) + #define KERNEL_FLOAT_DEVICE __forceinline__ __device__ #ifdef __CUDA_ARCH__ #define KERNEL_FLOAT_INLINE __forceinline__ __device__ @@ -18,6 +19,7 @@ #endif // __CUDA_ARCH__ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) + #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ #ifdef __HIP_DEVICE_COMPILE__ #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ diff --git a/include/kernel_float/unops.h b/include/kernel_float/unops.h index 9e5fe42..739f795 100644 --- a/include/kernel_float/unops.h +++ b/include/kernel_float/unops.h @@ -178,7 +178,6 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(abs) -KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs) KERNEL_FLOAT_DEFINE_UNARY_MATH(floor) KERNEL_FLOAT_DEFINE_UNARY_MATH(round) KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil) @@ -208,31 +207,43 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp) return ::kernel_float::map(ops::NAME> {}, input); \ } -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) + +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) + // This PTX is only supported on CUDA #if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ +#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ struct apply_fastmath_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ - *result = FAST_FUN(*inputs); \ + T input = inputs[0]; \ + *result = EXPR_F32; \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log2, __log2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log10, __log10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ @@ -250,12 +261,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") - -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") + +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") + #endif } // namespace kernel_float diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index ea9787f..7dc3fed 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 11:10:30.225884 -// git hash: 5490ea756b41c688b66dd69e776a58c9ce5b1ef2 +// date: 2024-11-18 12:11:06.609851 +// git hash: de62ad0ced81f2d5129b31bfb621fcbc0ce161e9 //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -30,6 +30,7 @@ // clang-format off #ifdef __CUDACC__ #define KERNEL_FLOAT_IS_CUDA (1) + #define KERNEL_FLOAT_DEVICE __forceinline__ __device__ #ifdef __CUDA_ARCH__ #define KERNEL_FLOAT_INLINE __forceinline__ __device__ @@ -40,6 +41,7 @@ #endif // __CUDA_ARCH__ #elif defined(__HIPCC__) #define KERNEL_FLOAT_IS_HIP (1) + #define KERNEL_FLOAT_DEVICE __attribute__((always_inline)) __device__ #ifdef __HIP_DEVICE_COMPILE__ #define KERNEL_FLOAT_INLINE __attribute__((always_inline)) __host__ __device__ @@ -795,6 +797,9 @@ struct apply_impl { template struct apply_fastmath_impl: apply_impl {}; + +template +struct apply_approx_impl: apply_fastmath_impl {}; } // namespace detail struct accurate_policy { @@ -807,6 +812,14 @@ struct fast_policy { using type = detail::apply_fastmath_impl; }; +template +struct approximate_policy { + template + using type = detail::apply_approx_impl; +}; + +using default_approximate_policy = approximate_policy<>; + #ifdef KERNEL_FLOAT_POLICY using default_policy = KERNEL_FLOAT_POLICY; #else @@ -1307,7 +1320,6 @@ KERNEL_FLOAT_DEFINE_UNARY_MATH(cbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(rcbrt) KERNEL_FLOAT_DEFINE_UNARY_MATH(abs) -KERNEL_FLOAT_DEFINE_UNARY_MATH(fabs) KERNEL_FLOAT_DEFINE_UNARY_MATH(floor) KERNEL_FLOAT_DEFINE_UNARY_MATH(round) KERNEL_FLOAT_DEFINE_UNARY_MATH(ceil) @@ -1337,31 +1349,43 @@ KERNEL_FLOAT_DEFINE_UNARY_FUN(rcp) return ::kernel_float::map(ops::NAME> {}, input); \ } -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) -KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sin) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(cos) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(tan) + +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(exp2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log) KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(log2) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(sqrt) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rcp) +KERNEL_FLOAT_DEFINE_UNARY_FUN_FAST(rsqrt) + // This PTX is only supported on CUDA #if KERNEL_FLOAT_IS_CUDA && KERNEL_FLOAT_IS_DEVICE -#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, FAST_FUN) \ +#define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(T, F, EXPR_F32) \ namespace detail { \ template<> \ struct apply_fastmath_impl, 1, T, T> { \ KERNEL_FLOAT_INLINE static void call(ops::F, T* result, const T* inputs) { \ - *result = FAST_FUN(*inputs); \ + T input = inputs[0]; \ + *result = EXPR_F32; \ } \ }; \ } -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf) -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp, __expf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp2, __exp2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, exp10, __exp10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log, __logf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log2, __log2f(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, log10, __log10f(input)) + +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, sin, __sinf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, cos, __cosf(input)) +KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_FUN(float, tan, __tanf(input)) #define KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(T, F, INSTR, REG) \ namespace detail { \ @@ -1379,12 +1403,13 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(double, rsqrt, "rsqrt.approx.f64", "d") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sqrt, "sqrt.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") - -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") -KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f") + +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f") +//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, log2, "lg2.approx.f32", "f") + #endif } // namespace kernel_float @@ -1968,8 +1993,7 @@ template< typename L, typename R, typename T = promoted_vector_value_type, - typename = - enable_if_t> && is_vector_broadcastable>>> + typename = enable_if_t<(vector_size == 3 && vector_size == 3)>> KERNEL_FLOAT_INLINE vector> cross(const L& left, const R& right) { return detail::cross_impl::call(convert_storage(left), convert_storage(right)); } @@ -3972,8 +3996,10 @@ KERNEL_FLOAT_FP16_UNARY_FUN(sin, hsin, h2sin) KERNEL_FLOAT_FP16_UNARY_FUN(cos, hcos, h2cos) KERNEL_FLOAT_FP16_UNARY_FUN(exp, hexp, h2exp) +KERNEL_FLOAT_FP16_UNARY_FUN(exp2, hexp2, h2exp2) KERNEL_FLOAT_FP16_UNARY_FUN(exp10, hexp10, h2exp10) KERNEL_FLOAT_FP16_UNARY_FUN(log, hlog, h2log) +KERNEL_FLOAT_FP16_UNARY_FUN(log2, hlog2, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(log10, hlog10, h2log2) KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, hsqrt, h2sqrt) @@ -4205,8 +4231,10 @@ KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin) KERNEL_FLOAT_BF16_UNARY_FUN(cos, ::hcos, ::h2cos) KERNEL_FLOAT_BF16_UNARY_FUN(exp, ::hexp, ::h2exp) +KERNEL_FLOAT_BF16_UNARY_FUN(exp2, ::hexp2, ::h2exp2) KERNEL_FLOAT_BF16_UNARY_FUN(exp10, ::hexp10, ::h2exp10) KERNEL_FLOAT_BF16_UNARY_FUN(log, ::hlog, ::h2log) +KERNEL_FLOAT_BF16_UNARY_FUN(log2, ::hlog2, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(log10, ::hlog10, ::h2log2) KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt) @@ -4364,6 +4392,397 @@ struct promote_type<__half, __bfloat16> { #endif #endif //KERNEL_FLOAT_BF16_H +#pragma once + + + + + + +namespace kernel_float { + +namespace approx { + +static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +using uint32_t = unsigned int; + +template +KERNEL_FLOAT_DEVICE T transmute(const U& input) { + static_assert(sizeof(T) == sizeof(U), "types must have equal size"); + T result {}; + ::memcpy(&result, &input, sizeof(T)); + return result; +} + +KERNEL_FLOAT_DEVICE uint32_t +bitwise_if_else(uint32_t condition, uint32_t if_true, uint32_t if_false) { + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + // equivalent to (condition & if_true) | ((~condition) & if_false) + asm("lop3.b32 %0, %1, %2, %3, 0xCA;" + : "=r"(result) + : "r"(condition), "r"(if_true), "r"(if_false)); +#else + result = (condition & if_true) | ((~condition) & if_false); +#endif + return result; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x) { + return y; +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly_recur(T2 y, T2 x, T coef, TRest... coefs) { + y = __hfma2(x, y, T2 {coef, coef}); + return eval_poly_recur(y, x, coefs...); +} + +template +KERNEL_FLOAT_DEVICE T2 eval_poly(T2 x, T coef, TRest... coefs) { + return eval_poly_recur(T2 {coef, coef}, x, coefs...); +} + +#define KERNEL_FLOAT_DEFINE_POLY(NAME, N, ...) \ + template \ + struct NAME { \ + template \ + static KERNEL_FLOAT_DEVICE T2 call(T2 x) { \ + return eval_poly(x, __VA_ARGS__); \ + } \ + }; + +template +struct sin_poly: sin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 1, 1.365) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 2, -21.56, 5.18) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 3, 53.53, -38.06, 6.184) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 4, -56.1, 77.94, -41.1, 6.277) +KERNEL_FLOAT_DEFINE_POLY(sin_poly, 5, 32.78, -74.5, 81.4, -41.34, 6.28) + +template +struct cos_poly: cos_poly {}; +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 1, 0.0) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 2, -8.0, 0.6943) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 3, 38.94, -17.5, 0.9707) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 4, -59.66, 61.12, -19.56, 0.9985) +KERNEL_FLOAT_DEFINE_POLY(cos_poly, 5, 45.66, -82.4, 64.7, -19.73, 1.0) + +template +struct asin_poly: asin_poly {}; +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 1, 1.531) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 2, -0.169, 1.567) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 3, 0.05167, -0.2057, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 4, -0.02103, 0.077, -0.2129, 1.57) +KERNEL_FLOAT_DEFINE_POLY(asin_poly, 5, 0.009796, -0.03772, 0.0857, -0.2142, 1.57) + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEVICE __half2 flipsign(__half2 input, __half2 sign) { + // Flip signbit of input when sign<0 + uint32_t result; + +#if KERNEL_FLOAT_IS_CUDA + asm("lop3.b32 %0, %1, %2, %3, 0x6A;" + : "=r"(result) + : "r"(0x80008000), "r"(transmute(sign)), "r"(transmute(input))); +#else + result = uint32_t(transmute(sign) & 0x80008000) ^ transmute(input); +#endif + + return transmute<__half2>(result); +} + +KERNEL_FLOAT_DEVICE uint32_t half2_gt_mask(__half2 a, __half2 b) { + uint32_t val; +#if KERNEL_FLOAT_IS_CUDA + uint32_t ai = *(reinterpret_cast(&a)); + uint32_t bi = *(reinterpret_cast(&b)); + asm("{ set.gt.u32.f16x2 %0,%1,%2;\n}" : "=r"(val) : "r"(ai), "r"(bi)); +#else + val = transmute(make_short2(a.x > b.x ? ~0 : 0, a.y > b.y ? ~0 : 0)); +#endif + return val; +} + +KERNEL_FLOAT_INLINE __half2 make_half2(half x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __half2 normalize_trig_input(__half2 x) { + /* Using rint is too slow. Round using floating-point magic instead. */ + // __half2 x = arg * make_half2(-0.15915494309); + // return __hfma2(arg, make_half2(0.15915494309), h2rint(x)); + + // 1/(2pi) = 0.15915494309189535 + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __half2 ws = __hfma2(x, make_half2(-ONE_OVER_TWOPI), make_half2(-OFFSET)) + make_half2(OFFSET); + return __hfma2(x, make_half2(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __half2 cos(__half2 x) { + __half2 xf = normalize_trig_input(x); + return cos_poly::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __half2 sin(__half2 x) { + __half2 xf = normalize_trig_input(x); + return sin_poly::call(__hmul2(xf, xf)) * xf; +} + +template +KERNEL_FLOAT_DEVICE __half2 rcp(__half2 x) { + // Flip bits + uint32_t m = ~transmute(x); + + // Multiply by bias (add contant) + __half2 y = transmute<__half2>(uint32_t(0x776d776d) + m); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + // y += y * (1 - x * y) + y = __hfma2(y, __hfma2(-x, y, make_half2(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 rsqrt(__half2 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + //uint32_t r = uint32_t(~(transmute(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1; + + // Add bias (0x199c) + __half2 y = transmute<__half2>(uint32_t(r) + uint32_t(0x199c199c)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __half2 half_x = make_half2(-0.5) * x; + __half2 correction = __hfma2(half_x, y * y, make_half2(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __half2 sqrt(__half2 x) { + if (Iter == 1) { + __half2 y = rsqrt<0>(x); + + // This method uses only 4 muls, instead of 5 muls when using `arg * approx_rsqrt<1>(arg)` + __half2 xy = x * y; + return xy * __hfma2(make_half2(-0.5) * y, xy, make_half2(1.5)); + } + + return x * rsqrt(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 asin(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + auto abs_x = __habs2(x); + auto v = asin_poly::call(abs_x); + auto abs_y = __hfma2(-v, sqrt(make_half2(1) - abs_x), make_half2(HALF_PI)); + return flipsign(abs_y, x); +} + +template +KERNEL_FLOAT_DEVICE __half2 acos(__half2 x) { + static constexpr double HALF_PI = 1.57079632679; + return make_half2(HALF_PI) - asin(x); +} + +template +KERNEL_FLOAT_DEVICE __half2 exp(__half2 x) { + __half2 y; + + if (Deg == 0) { + // Bring the value to range [32, 64] + // 1.442 = 1/log(2) + // 46.969 = 32.5/log(2) + __half2 m = __hfma2(x, make_half2(1.442), make_half2(46.9375)); + + // Transmute to int, shift higher mantissa bits into exponent field. + y = transmute<__half2>((transmute(m) & 0x03ff03ff) << 5); + } else { + // Add a large number to round to an integer + __half2 v = __hfma2(x, make_half2(1.442), make_half2(1231.0)); + + // The exponent is now in the lower 5 bits. Shift that into the exponent field. + __half2 exp = transmute<__half2>((transmute(v) & 0x001f001f) << 10); + + // The fractional part can be obtained from "1231-v". + // 0.6934 = log(2) + __half2 frac = __hfma2(make_half2(1231.0) - v, make_half2(0.6934), x); + + // This is the Taylor expansion of "exp(x)-1" around 0 + __half2 adjust; + if (Deg == 1) { + adjust = frac; + } else if (Deg == 2) { + // adjust = frac + 0.5 * frac^2 + adjust = __hfma2(frac, __hmul2(frac, make_half2(0.5)), frac); + } else /* if (Deg == 2) */ { + // adjust = frac + 0.5 * frac^2 + 1/6 * frac^3 + adjust = __hfma2( + frac, + __hmul2(__hfma2(frac, make_half2(0.1666), make_half2(0.5)), frac), + frac); + } + + // result = exp * (adjust + 1) + y = __hfma2(exp, adjust, exp); + } + + // Values below -10.39 (= -15*log(2)) become zero + uint32_t zero_mask = half2_gt_mask(x, make_half2(-10.390625)); + return transmute<__half2>(zero_mask & transmute(y)); +} + +template +KERNEL_FLOAT_DEVICE __half2 log(__half2 arg) { + // Shift exponent field into mantissa bits. Fill exponent bits with 0x5000 (= 32.0) + uint32_t bits = bitwise_if_else(0x03ff03ff, transmute(arg) >> 5, 0x50005000); + + // 0.6934 = log(2) + // 32.53 = 46.969*log(2) + return __hfma2(transmute<__half2>(bits), make_half2(0.6934), make_half2(-32.53125)); +} + +template +KERNEL_FLOAT_DEVICE __half2 tanh(__half2 x) { + if (Deg == 0) { + return x * rcp<0>(make_half2(0.2869) + __habs2(x)); + } else { + auto c0 = make_half2(0.4531); + auto c1 = make_half2(0.5156); + auto x2b = __hfma2(x, x, c1); + return (x * x2b) * rcp(__hfma2(x2b, __habs2(x), c0)); + } +} + +#endif // KERNEL_FLOAT_FP16_AVAILABLE + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(__bfloat16 x) { + return {x, x}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 make_bfloat162(double x) { + return {__double2bfloat16(x), __double2bfloat16(x)}; +} + +KERNEL_FLOAT_DEVICE __bfloat162 normalize_trig_input(__nv_bfloat162 x) { + static constexpr double ONE_OVER_TWOPI = 0.15915494309189535; + static constexpr double OFFSET = -2042.0; + + __bfloat162 ws = __hadd2( + __hfma2(x, make_bfloat162(-ONE_OVER_TWOPI), make_bfloat162(-OFFSET)), + make_bfloat162(OFFSET)); + return __hfma2(x, make_bfloat162(ONE_OVER_TWOPI), ws); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 cos(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return cos_poly<__bfloat16, Iter + 1>::call(__hmul2(xf, xf)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sin(__bfloat162 x) { + __bfloat162 xf = normalize_trig_input(x); + return __hmul2(sin_poly<__bfloat16, Iter>::call(__hmul2(xf, xf)), xf); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rcp(__bfloat162 x) { + __bfloat162 y = transmute<__bfloat162>(uint32_t(0x7ef07ef0) + ~transmute(x)); + +#pragma unroll + for (int i = 0; i < Iter; i++) { + y = __hfma2(y, __hfma2(__hneg2(x), y, make_bfloat162(1.0)), y); + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 rsqrt(__bfloat162 x) { + // Set top and bottom bits for both halfs, then shift by 1, then invert + uint32_t r = ~((uint32_t(transmute(x) >> 1)) | ~uint32_t(0x3fff3fff)); + + // Add bias (0x1f36) + __bfloat162 y = transmute<__bfloat162>(uint32_t(r) + uint32_t(0x1f361f36)); + + // Newton-Raphson iterations +#pragma unroll + for (int i = 0; i < Iter; i++) { + __bfloat162 half_x = __hmul2(make_bfloat162(-0.5), x); + __bfloat162 correction = __hfma2(half_x, __hmul2(y, y), make_bfloat162(0.5)); + y = __hfma2(correction, y, y); // y += y * correction + } + + return y; +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 sqrt(__bfloat162 x) { + return __hmul2(x, rsqrt(x)); +} + +template +KERNEL_FLOAT_DEVICE __bfloat162 exp(__bfloat162 arg) { + static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float OFFSET = 382.4958400542335; + + auto a = fmaf(__bfloat162float(arg.x), SCALE, OFFSET); + auto b = fmaf(__bfloat162float(arg.y), SCALE, OFFSET); + + return { + transmute<__bfloat16>(uint16_t(transmute(a))), + transmute<__bfloat16>(uint16_t(transmute(b)))}; +} +#endif +} // namespace approx + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ + namespace detail { \ + template \ + struct apply_approx_impl, 2, __half, __half> { \ + KERNEL_FLOAT_INLINE static void \ + call(ops::FUN<__half> fun, __half* output, const __half* input) { \ + __half2 res = approx::FUN(__half2 {input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + template<> \ + struct apply_approx_impl<-1, ops::FUN<__half>, 2, __half, __half>: \ + apply_approx_impl, 2, __half, __half> {}; \ + } \ + \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) + +} // namespace kernel_float #ifndef KERNEL_FLOAT_FP8_H #define KERNEL_FLOAT_FP8_H