diff --git a/include/kernel_float/approx.h b/include/kernel_float/approx.h index df81d30..c1e7836 100644 --- a/include/kernel_float/approx.h +++ b/include/kernel_float/approx.h @@ -9,7 +9,7 @@ namespace kernel_float { namespace approx { -static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); using uint32_t = unsigned int; template @@ -346,11 +346,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { template KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { - static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float SCALE = 1.44272065994 / 256.0; static constexpr float OFFSET = 382.4958400542335; + static constexpr float MINIMUM = 382; - auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET); - auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET); + float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); return { transmute<__bfloat16>(uint16_t(transmute(a))), @@ -359,33 +360,66 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { #endif } // namespace approx -#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ - namespace detail { \ - template \ - struct apply_impl, ops::FUN, 2, half_t, half_t> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::FUN fun, half_t* output, const half_t* input) { \ - half2_t res = approx::FUN(half2_t {input[0], input[1]}); \ - output[0] = res.x; \ - output[1] = res.y; \ - } \ - }; \ - template<> \ - struct apply_impl, 2, half_t, half_t>: \ - apply_impl, ops::FUN, 2, half_t, half_t> {}; \ - } \ - \ - template \ - KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ - return map>(ops::FUN> {}, args); \ +namespace detail { +template +struct apply_impl, F, 1, T, T> { + KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { + T in2[2], out2[2]; + out2[0] = input[0]; + apply_impl, F, 2, T, T>::call(fun, out2, in2); + output[0] = out2[0]; + } +}; +} // namespace detail + +#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN, 2, T, T> { \ + KERNEL_FLOAT_INLINE static void call(ops::FUN, T* output, const T* input) { \ + auto res = approx::FUN({input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + \ + template<> \ + struct apply_impl, 2, T, T>: \ + apply_impl, ops::FUN, 2, T, T> {}; \ + } + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0) +//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ } -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sin) +KERNEL_FLOAT_DEFINE_APPROX_FUN(cos) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(exp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(log) } // namespace kernel_float diff --git a/single_include/kernel_float.h b/single_include/kernel_float.h index 31fcb0f..0e66057 100644 --- a/single_include/kernel_float.h +++ b/single_include/kernel_float.h @@ -16,8 +16,8 @@ //================================================================================ // this file has been auto-generated, do not modify its contents! -// date: 2024-11-18 13:50:24.614671 -// git hash: f89cf98f79e78ab6013063dea4b4b516ce163855 +// date: 2024-11-18 16:57:58.817191 +// git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a //================================================================================ #ifndef KERNEL_FLOAT_MACROS_H @@ -824,31 +824,53 @@ using default_policy = KERNEL_FLOAT_POLICY; namespace detail { +// template -struct apply_base_impl { +struct apply_fallback_impl { KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { -#pragma unroll - for (size_t i = 0; i < N; i++) { - output[i] = fun(args[i]...); - } + static_assert(N > 0, "operation not implemented"); } }; +template +struct apply_base_impl: apply_fallback_impl {}; + template struct apply_impl: apply_base_impl {}; +// `fast_policy` falls back to `accurate_policy` template -struct apply_base_impl: +struct apply_fallback_impl: apply_impl {}; +// `approx_policy` falls back to `fast_policy` template -struct apply_base_impl: +struct apply_fallback_impl: apply_impl {}; +// `approx_level_policy` falls back to `approx_policy` template -struct apply_base_impl, F, N, Output, Args...>: +struct apply_fallback_impl, F, N, Output, Args...>: apply_impl {}; +template +struct invoke_impl { + KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) { + return fun(args...); + } +}; + +// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`. +template +struct apply_impl { + KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) { +#pragma unroll + for (size_t i = 0; i < N; i++) { + output[i] = invoke_impl::call(fun, args[i]...); + } + } +}; + template struct map_impl { static constexpr size_t packet_size = preferred_vector_size::value; @@ -1949,7 +1971,7 @@ struct multiply { namespace detail { template -struct apply_impl, N, T, T, T> { +struct apply_base_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T rhs_rcp[N]; @@ -1959,10 +1981,6 @@ struct apply_impl, N, T, T, T> { } }; -template -struct apply_impl, N, T, T, T>: - apply_base_impl, N, T, T, T> {}; - #if KERNEL_FLOAT_IS_DEVICE template<> struct apply_impl, 1, float, float, float> { @@ -1977,7 +1995,7 @@ struct apply_impl, 1, float, float, float> { namespace detail { // Override `pow` using `log2` and `exp2` template -struct apply_impl, N, T, T, T> { +struct apply_base_impl, N, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::divide, T* result, const T* lhs, const T* rhs) { T lhs_log[N]; T result_log[N]; @@ -1988,10 +2006,6 @@ struct apply_impl, N, T, T, T> { apply_impl, N, T, T, T>::call({}, result, result_log); } }; - -template -struct apply_impl, N, T, T, T>: - apply_base_impl, N, T, T, T> {}; } // namespace detail template> @@ -3218,13 +3232,13 @@ struct fma { } // namespace ops namespace detail { -template -struct apply_impl, N, T, T, T, T> { +template +struct apply_impl, N, T, T, T, T> { KERNEL_FLOAT_INLINE static void call(ops::fma, T* output, const T* a, const T* b, const T* c) { T temp[N]; - apply_impl, N, T, T, T>::call({}, temp, a, b); - apply_impl, N, T, T, T>::call({}, output, temp, c); + apply_impl, N, T, T, T>::call({}, temp, a, b); + apply_impl, N, T, T, T>::call({}, output, temp, c); } }; } // namespace detail @@ -3992,9 +4006,6 @@ namespace kernel_float { using half_t = ::__half; using half2_t = ::__half2; -using __half = void; -using __half2 = void; - template<> struct preferred_vector_size { static constexpr size_t value = 2; @@ -4020,7 +4031,7 @@ template<> struct allow_float_fallback { static constexpr bool value = true; }; -}; // namespace detail +} // namespace detail #if KERNEL_FLOAT_IS_DEVICE #define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \ @@ -4469,7 +4480,7 @@ namespace kernel_float { namespace approx { -static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int"); +static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int"); using uint32_t = unsigned int; template @@ -4806,11 +4817,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) { template KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { - static constexpr float SCALE = 1.44272065994f / 256.0f; + static constexpr float SCALE = 1.44272065994 / 256.0; static constexpr float OFFSET = 382.4958400542335; + static constexpr float MINIMUM = 382; - auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET); - auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET); + float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM); + float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM); return { transmute<__bfloat16>(uint16_t(transmute(a))), @@ -4819,34 +4831,67 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) { #endif } // namespace approx -#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \ - namespace detail { \ - template \ - struct apply_impl, ops::FUN, 2, half_t, half_t> { \ - KERNEL_FLOAT_INLINE static void \ - call(ops::FUN fun, half_t* output, const half_t* input) { \ - half2_t res = approx::FUN(half2_t {input[0], input[1]}); \ - output[0] = res.x; \ - output[1] = res.y; \ - } \ - }; \ - template<> \ - struct apply_impl, 2, half_t, half_t>: \ - apply_impl, ops::FUN, 2, half_t, half_t> {}; \ - } \ - \ - template \ - KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ - return map>(ops::FUN> {}, args); \ - } - -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0) -KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0) +namespace detail { +template +struct apply_impl, F, 1, T, T> { + KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) { + T in2[2], out2[2]; + out2[0] = input[0]; + apply_impl, F, 2, T, T>::call(fun, out2, in2); + output[0] = out2[0]; + } +}; +} // namespace detail + +#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \ + namespace detail { \ + template \ + struct apply_impl, ops::FUN, 2, T, T> { \ + KERNEL_FLOAT_INLINE static void call(ops::FUN, T* output, const T* input) { \ + auto res = approx::FUN({input[0], input[1]}); \ + output[0] = res.x; \ + output[1] = res.y; \ + } \ + }; \ + \ + template<> \ + struct apply_impl, 2, T, T>: \ + apply_impl, ops::FUN, 2, T, T> {}; \ + } + +#if KERNEL_FLOAT_FP16_AVAILABLE +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#if KERNEL_FLOAT_BF16_OPS_SUPPORTED +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1) +KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0) +//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0) +#endif + +#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \ + template \ + KERNEL_FLOAT_INLINE into_vector_type approx_##FUN(const V& args) { \ + return map>(ops::FUN> {}, args); \ + } + +KERNEL_FLOAT_DEFINE_APPROX_FUN(sin) +KERNEL_FLOAT_DEFINE_APPROX_FUN(cos) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt) +KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(exp) +KERNEL_FLOAT_DEFINE_APPROX_FUN(log) } // namespace kernel_float #ifndef KERNEL_FLOAT_FP8_H