Skip to content

Commit

Permalink
kernel_float::approx::sqrt(0) now returns 0
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Nov 20, 2024
1 parent e6c8a7c commit 5c859b9
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 37 deletions.
15 changes: 10 additions & 5 deletions include/kernel_float/approx.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) {

template<int Iter>
KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) {
// A small number added such that rsqrt(0) does not return NaN
static constexpr double EPS = 0.00000768899917602539;

// Set top and bottom bits for both halfs, then shift by 1, then invert
uint32_t r = ~((uint32_t(transmute<uint32_t>(x) >> 1)) | ~uint32_t(0x3fff3fff));
//uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;

// Add bias (0x199c)
half2_t y = transmute<half2_t>(uint32_t(r) + uint32_t(0x199c199c));
// Add bias
static constexpr uint32_t BIAS = 0x199c199c;
half2_t y = transmute<half2_t>(uint32_t(r) + BIAS);

// Newton-Raphson iterations
#pragma unroll
for (int i = 0; i < Iter; i++) {
half2_t half_x = make_half2(-0.5) * x;
half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS));
half2_t correction = __hfma2(half_x, y * y, make_half2(0.5));
y = __hfma2(correction, y, y); // y += y * correction
}
Expand Down Expand Up @@ -365,7 +368,7 @@ template<int Level, typename F, typename T>
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
T in2[2], out2[2];
out2[0] = input[0];
in2[0] = input[0];
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
output[0] = out2[0];
}
Expand Down Expand Up @@ -396,6 +399,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2)
#endif

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
Expand Down
120 changes: 88 additions & 32 deletions single_include/kernel_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

//================================================================================
// this file has been auto-generated, do not modify its contents!
// date: 2024-11-18 16:57:58.817191
// git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
// date: 2024-11-20 10:36:45.284577
// git hash: 76501fda40df9e396998d11840bc8f10b11ea47b
//================================================================================

#ifndef KERNEL_FLOAT_MACROS_H
Expand Down Expand Up @@ -813,7 +813,7 @@ struct approx_level_policy {};
using approx_policy = approx_level_policy<>;

#ifndef KERNEL_FLOAT_POLICY
#define KERNEL_FLOAT_POLICY accurate_policy;
#define KERNEL_FLOAT_POLICY accurate_policy
#endif

/**
Expand Down Expand Up @@ -1448,6 +1448,9 @@ KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rcp, "rcp.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, rsqrt, "rsqrt.approx.f32", "f")
KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, tanh, "tanh.approx.f32;", "f")

#define KERNEL_FLOAT_FAST_F32_MAP(F) \
F(exp) F(exp2) F(exp10) F(log) F(log2) F(log10) F(sin) F(cos) F(tan) F(rcp) F(rsqrt) F(sqrt)

//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, sin, "sin.approx.f32", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, cos, "cos.approx.f32", "f")
//KERNEL_FLOAT_DEFINE_UNARY_FAST_IMPL_PTX(float, exp2, "ex2.approx.f32", "f")
Expand Down Expand Up @@ -1724,15 +1727,15 @@ using zip_common_type = vector<
* vec<float, 3> c = zip_common([](float x, float y){ return x + y; }, a, b); // returns [5.0f, 7.0f, 9.0f]
* ```
*/
template<typename F, typename L, typename R>
template<typename Accuracy = default_policy, typename F, typename L, typename R>
KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, const R& right) {
using T = promoted_vector_value_type<L, R>;
using O = result_t<F, T, T>;
using E = broadcast_vector_extent_type<L, R>;

vector_storage<O, extent_size<E>> result;

detail::default_map_impl<F, extent_size<E>, O, T, T>::call(
detail::map_impl<Accuracy, F, extent_size<E>, O, T, T>::call(
fun,
result.data(),
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
Expand All @@ -1745,10 +1748,17 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
return result;
}

#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
template<typename L, typename R, typename C = promoted_vector_value_type<L, R>> \
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
return zip_common(ops::NAME<C> {}, static_cast<L&&>(left), static_cast<R&&>(right)); \
#define KERNEL_FLOAT_DEFINE_BINARY_FUN(NAME) \
template< \
typename Accuracy = default_policy, \
typename L, \
typename R, \
typename C = promoted_vector_value_type<L, R>> \
KERNEL_FLOAT_INLINE zip_common_type<ops::NAME<C>, L, R> NAME(L&& left, R&& right) { \
return zip_common<Accuracy>( \
ops::NAME<C> {}, \
static_cast<L&&>(left), \
static_cast<R&&>(right)); \
}

#define KERNEL_FLOAT_DEFINE_BINARY(NAME, EXPR, EXPR_F64, EXPR_F32) \
Expand Down Expand Up @@ -3887,11 +3897,20 @@ struct vector: public S {
}

/**
* Returns the result of `*this + lhs * rhs`.
* Returns the result of `this + lhs * rhs`.
*
* The operation is performed using a single `kernel_float::fma` call, which may be faster then perform
* the addition and multiplication separately.
*/
template<
typename L,
typename R,
typename T2 = promote_t<T, vector_value_type<L>, vector_value_type<R>>,
typename E2 = broadcast_extent<E, vector_extent_type<L>, vector_extent_type<R>>>
KERNEL_FLOAT_INLINE vector<T2, E2> add_mul(const L& lhs, const R& rhs) const {
return ::kernel_float::fma(lhs, rhs, *this);
}

template<
typename L,
typename R,
Expand Down Expand Up @@ -4138,6 +4157,22 @@ struct apply_impl<accurate_policy, ops::fma<half_t>, 2, half_t, half_t, half_t,
result[0] = r.x, result[1] = r.y;
}
};

// clang-format off
#define KERNEL_FLOAT_FAST_FP16_DISPATCH(OP) \
template<size_t N> \
struct apply_impl<fast_policy, ops::OP<half_t>, N, half_t, half_t> { \
KERNEL_FLOAT_INLINE static void \
call(ops::OP<half_t>, half_t* output, const half_t* input) { \
float v[N]; \
map_impl<fast_policy, ops::cast<half_t, float>, N, float, half_t>::call({}, v, input); \
map_impl<fast_policy, ops::OP<float>, N, float, float>::call({}, v, v); \
map_impl<fast_policy, ops::cast<float, half_t>, N, half_t, float>::call({}, output, v); \
} \
};
// clang-format on

KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_FP16_DISPATCH)
} // namespace detail
#endif

Expand Down Expand Up @@ -4390,6 +4425,22 @@ struct apply_impl<
result[0] = r.x, result[1] = r.y;
}
};

// clang-format off
#define KERNEL_FLOAT_FAST_BF16_DISPATCH(OP) \
template<size_t N> \
struct apply_impl<fast_policy, ops::OP<bfloat16_t>, N, bfloat16_t, bfloat16_t> { \
KERNEL_FLOAT_INLINE static void \
call(ops::OP<bfloat16_t>, bfloat16_t* output, const bfloat16_t* input) { \
float v[N]; \
map_impl<fast_policy, ops::cast<bfloat16_t, float>, N, float, bfloat16_t>::call({}, v, input); \
map_impl<fast_policy, ops::OP<float>, N, float, float>::call({}, v, v); \
map_impl<fast_policy, ops::cast<float, bfloat16_t>, N, bfloat16_t, float>::call({}, output, v); \
} \
};
// clang-format on

KERNEL_FLOAT_FAST_F32_MAP(KERNEL_FLOAT_FAST_BF16_DISPATCH)
} // namespace detail
#endif

Expand Down Expand Up @@ -4631,17 +4682,20 @@ KERNEL_FLOAT_DEVICE half2_t rcp(half2_t x) {

template<int Iter>
KERNEL_FLOAT_DEVICE half2_t rsqrt(half2_t x) {
// A small number added such that rsqrt(0) does not return NaN
static constexpr double EPS = 0.00000768899917602539;

// Set top and bottom bits for both halfs, then shift by 1, then invert
uint32_t r = ~((uint32_t(transmute<uint32_t>(x) >> 1)) | ~uint32_t(0x3fff3fff));
//uint32_t r = uint32_t(~(transmute<uint32_t>(arg) | (~uint32_t(0x3ffe3ffe)))) >> 1;

// Add bias (0x199c)
half2_t y = transmute<half2_t>(uint32_t(r) + uint32_t(0x199c199c));
// Add bias
static constexpr uint32_t BIAS = 0x199c199c;
half2_t y = transmute<half2_t>(uint32_t(r) + BIAS);

// Newton-Raphson iterations
#pragma unroll
for (int i = 0; i < Iter; i++) {
half2_t half_x = make_half2(-0.5) * x;
half2_t half_x = __hfma2(make_half2(-0.5), x, make_half2(-EPS));
half2_t correction = __hfma2(half_x, y * y, make_half2(0.5));
y = __hfma2(correction, y, y); // y += y * correction
}
Expand Down Expand Up @@ -4836,7 +4890,7 @@ template<int Level, typename F, typename T>
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
T in2[2], out2[2];
out2[0] = input[0];
in2[0] = input[0];
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
output[0] = out2[0];
}
Expand Down Expand Up @@ -4867,6 +4921,8 @@ KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, asin, 2)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, acos, 2)
#endif

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
Expand Down Expand Up @@ -4960,7 +5016,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
#define KERNEL_FLOAT_FP8_CAST2(T, FP8_TY, FP8_INTERP) \
namespace detail { \
template<> \
struct apply_impl<ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
struct apply_impl<accurate_policy, ops::cast<T, FP8_TY>, 2, FP8_TY, T> { \
KERNEL_FLOAT_INLINE static void call(ops::cast<T, FP8_TY>, FP8_TY* result, const T* v) { \
__half2_raw x; \
memcpy(&x, v, 2 * sizeof(T)); \
Expand All @@ -4969,7 +5025,7 @@ struct allow_float_fallback<__nv_fp8_e5m2> {
} \
}; \
template<> \
struct apply_impl<ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
struct apply_impl<accurate_policy, ops::cast<FP8_TY, T>, 2, T, FP8_TY> { \
KERNEL_FLOAT_INLINE static void call(ops::cast<FP8_TY, T>, T* result, const FP8_TY* v) { \
__nv_fp8x2_storage_t x; \
memcpy(&x, v, 2 * sizeof(FP8_TY)); \
Expand All @@ -4987,12 +5043,12 @@ KERNEL_FLOAT_FP8_CAST(double)


namespace kernel_float {
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e4m3)
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__half, __nv_fp8_e5m2)
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e4m3)
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(half_t, __nv_fp8_e5m2)

KERNEL_FLOAT_FP8_CAST(__half)
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e4m3, __NV_E4M3)
KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)
KERNEL_FLOAT_FP8_CAST(half_t)
KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e4m3, __NV_E4M3)
KERNEL_FLOAT_FP8_CAST2(half_t, __nv_fp8_e5m2, __NV_E5M2)

} // namespace kernel_float
#endif // KERNEL_FLOAT_FP16_AVAILABLE
Expand All @@ -5001,12 +5057,12 @@ KERNEL_FLOAT_FP8_CAST2(__half, __nv_fp8_e5m2, __NV_E5M2)


namespace kernel_float {
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e4m3)
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(__nv_bfloat16, __nv_fp8_e5m2)
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e4m3)
KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(bfloat16_t, __nv_fp8_e5m2)

KERNEL_FLOAT_FP8_CAST(__nv_bfloat16)
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e4m3, __NV_E4M3)
KERNEL_FLOAT_FP8_CAST2(__nv_bfloat16, __nv_fp8_e5m2, __NV_E5M2)
KERNEL_FLOAT_FP8_CAST(bfloat16_t)
KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e4m3, __NV_E4M3)
KERNEL_FLOAT_FP8_CAST2(bfloat16_t, __nv_fp8_e5m2, __NV_E5M2)
} // namespace kernel_float
#endif // KERNEL_FLOAT_BF16_AVAILABLE

Expand Down Expand Up @@ -5075,14 +5131,14 @@ KERNEL_FLOAT_TYPE_ALIAS(f64x, double)
KERNEL_FLOAT_TYPE_ALIAS(float64x, double)

#if KERNEL_FLOAT_FP16_AVAILABLE
KERNEL_FLOAT_TYPE_ALIAS(half, __half)
KERNEL_FLOAT_TYPE_ALIAS(f16x, __half)
KERNEL_FLOAT_TYPE_ALIAS(float16x, __half)
KERNEL_FLOAT_TYPE_ALIAS(half, half_t)
KERNEL_FLOAT_TYPE_ALIAS(f16x, half_t)
KERNEL_FLOAT_TYPE_ALIAS(float16x, half_t)
#endif

#if KERNEL_FLOAT_BF16_AVAILABLE
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, __bfloat16)
KERNEL_FLOAT_TYPE_ALIAS(bf16x, __bfloat16)
KERNEL_FLOAT_TYPE_ALIAS(bfloat16x, bfloat16_t)
KERNEL_FLOAT_TYPE_ALIAS(bf16x, bfloat16_t)
#endif

#if KERNEL_FLOAT_BF8_AVAILABLE
Expand Down

0 comments on commit 5c859b9

Please sign in to comment.