Skip to content

Commit

Permalink
Add approx_* functions
Browse files Browse the repository at this point in the history
  • Loading branch information
stijnh committed Nov 18, 2024
1 parent 003ce36 commit 76501fd
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 90 deletions.
94 changes: 64 additions & 30 deletions include/kernel_float/approx.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace kernel_float {

namespace approx {

static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int");
static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int");
using uint32_t = unsigned int;

template<typename T, typename U>
Expand Down Expand Up @@ -346,11 +346,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) {

template<int = 0>
KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
static constexpr float SCALE = 1.44272065994f / 256.0f;
static constexpr float SCALE = 1.44272065994 / 256.0;
static constexpr float OFFSET = 382.4958400542335;
static constexpr float MINIMUM = 382;

auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET);
auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET);
float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);

return {
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(a))),
Expand All @@ -359,33 +360,66 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
#endif
} // namespace approx

#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \
namespace detail { \
template<int Degree> \
struct apply_impl<approx_level_policy<Degree>, ops::FUN<half_t>, 2, half_t, half_t> { \
KERNEL_FLOAT_INLINE static void \
call(ops::FUN<half_t> fun, half_t* output, const half_t* input) { \
half2_t res = approx::FUN<Degree>(half2_t {input[0], input[1]}); \
output[0] = res.x; \
output[1] = res.y; \
} \
}; \
template<> \
struct apply_impl<approx_policy, ops::FUN<half_t>, 2, half_t, half_t>: \
apply_impl<approx_level_policy<DEG>, ops::FUN<half_t>, 2, half_t, half_t> {}; \
} \
\
template<int Level = -1, typename V> \
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
namespace detail {
template<int Level, typename F, typename T>
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
T in2[2], out2[2];
out2[0] = input[0];
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
output[0] = out2[0];
}
};
} // namespace detail

#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \
namespace detail { \
template<int Degree> \
struct apply_impl<approx_level_policy<Degree>, ops::FUN<T>, 2, T, T> { \
KERNEL_FLOAT_INLINE static void call(ops::FUN<T>, T* output, const T* input) { \
auto res = approx::FUN<Degree>({input[0], input[1]}); \
output[0] = res.x; \
output[1] = res.y; \
} \
}; \
\
template<> \
struct apply_impl<approx_policy, ops::FUN<T>, 2, T, T>: \
apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2, T, T> {}; \
}

#if KERNEL_FLOAT_FP16_AVAILABLE
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
#endif

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0)
//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
#endif

#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \
template<int Level = -1, typename V> \
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
}

KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0)
KERNEL_FLOAT_DEFINE_APPROX_FUN(sin)
KERNEL_FLOAT_DEFINE_APPROX_FUN(cos)
KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt)
KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt)
KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp)
KERNEL_FLOAT_DEFINE_APPROX_FUN(exp)
KERNEL_FLOAT_DEFINE_APPROX_FUN(log)

} // namespace kernel_float
165 changes: 105 additions & 60 deletions single_include/kernel_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

//================================================================================
// this file has been auto-generated, do not modify its contents!
// date: 2024-11-18 13:50:24.614671
// git hash: f89cf98f79e78ab6013063dea4b4b516ce163855
// date: 2024-11-18 16:57:58.817191
// git hash: 003ce3677ecb97dc1602e38a3e774c103d05aa1a
//================================================================================

#ifndef KERNEL_FLOAT_MACROS_H
Expand Down Expand Up @@ -824,31 +824,53 @@ using default_policy = KERNEL_FLOAT_POLICY;

namespace detail {

//
template<typename Policy, typename F, size_t N, typename Output, typename... Args>
struct apply_base_impl {
struct apply_fallback_impl {
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
#pragma unroll
for (size_t i = 0; i < N; i++) {
output[i] = fun(args[i]...);
}
static_assert(N > 0, "operation not implemented");
}
};

template<typename Policy, typename F, size_t N, typename Output, typename... Args>
struct apply_base_impl: apply_fallback_impl<Policy, F, N, Output, Args...> {};

template<typename Policy, typename F, size_t N, typename Output, typename... Args>
struct apply_impl: apply_base_impl<Policy, F, N, Output, Args...> {};

// `fast_policy` falls back to `accurate_policy`
template<typename F, size_t N, typename Output, typename... Args>
struct apply_base_impl<fast_policy, F, N, Output, Args...>:
struct apply_fallback_impl<fast_policy, F, N, Output, Args...>:
apply_impl<accurate_policy, F, N, Output, Args...> {};

// `approx_policy` falls back to `fast_policy`
template<typename F, size_t N, typename Output, typename... Args>
struct apply_base_impl<approx_policy, F, N, Output, Args...>:
struct apply_fallback_impl<approx_policy, F, N, Output, Args...>:
apply_impl<fast_policy, F, N, Output, Args...> {};

// `approx_level_policy` falls back to `approx_policy`
template<int Level, typename F, size_t N, typename Output, typename... Args>
struct apply_base_impl<approx_level_policy<Level>, F, N, Output, Args...>:
struct apply_fallback_impl<approx_level_policy<Level>, F, N, Output, Args...>:
apply_impl<approx_policy, F, N, Output, Args...> {};

template<typename F, typename Output, typename... Args>
struct invoke_impl {
KERNEL_FLOAT_INLINE static Output call(F fun, Args... args) {
return fun(args...);
}
};

// Only for `accurate_policy` do we implement `apply_impl`, the others will fall back to `apply_base_impl`.
template<typename F, size_t N, typename Output, typename... Args>
struct apply_impl<accurate_policy, F, N, Output, Args...> {
KERNEL_FLOAT_INLINE static void call(F fun, Output* output, const Args*... args) {
#pragma unroll
for (size_t i = 0; i < N; i++) {
output[i] = invoke_impl<F, Output, Args...>::call(fun, args[i]...);
}
}
};

template<typename Policy, typename F, size_t N, typename Output, typename... Args>
struct map_impl {
static constexpr size_t packet_size = preferred_vector_size<Output>::value;
Expand Down Expand Up @@ -1949,7 +1971,7 @@ struct multiply<bool> {

namespace detail {
template<typename Policy, typename T, size_t N>
struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
struct apply_base_impl<Policy, ops::divide<T>, N, T, T, T> {
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
T rhs_rcp[N];

Expand All @@ -1959,10 +1981,6 @@ struct apply_impl<Policy, ops::divide<T>, N, T, T, T> {
}
};

template<typename T, size_t N>
struct apply_impl<accurate_policy, ops::divide<T>, N, T, T, T>:
apply_base_impl<accurate_policy, ops::divide<T>, N, T, T, T> {};

#if KERNEL_FLOAT_IS_DEVICE
template<>
struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
Expand All @@ -1977,7 +1995,7 @@ struct apply_impl<fast_policy, ops::divide<float>, 1, float, float, float> {
namespace detail {
// Override `pow` using `log2` and `exp2`
template<typename Policy, typename T, size_t N>
struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
struct apply_base_impl<Policy, ops::pow<T>, N, T, T, T> {
KERNEL_FLOAT_INLINE static void call(ops::divide<T>, T* result, const T* lhs, const T* rhs) {
T lhs_log[N];
T result_log[N];
Expand All @@ -1988,10 +2006,6 @@ struct apply_impl<Policy, ops::pow<T>, N, T, T, T> {
apply_impl<Policy, ops::exp2<T>, N, T, T, T>::call({}, result, result_log);
}
};

template<typename T, size_t N>
struct apply_impl<accurate_policy, ops::pow<T>, N, T, T, T>:
apply_base_impl<accurate_policy, ops::pow<T>, N, T, T, T> {};
} // namespace detail

template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
Expand Down Expand Up @@ -3218,13 +3232,13 @@ struct fma {
} // namespace ops

namespace detail {
template<typename Policy, typename T, size_t N>
struct apply_impl<Policy, ops::fma<T>, N, T, T, T, T> {
template<typename T, size_t N>
struct apply_impl<accurate_policy, ops::fma<T>, N, T, T, T, T> {
KERNEL_FLOAT_INLINE
static void call(ops::fma<T>, T* output, const T* a, const T* b, const T* c) {
T temp[N];
apply_impl<Policy, ops::multiply<T>, N, T, T, T>::call({}, temp, a, b);
apply_impl<Policy, ops::add<T>, N, T, T, T>::call({}, output, temp, c);
apply_impl<accurate_policy, ops::multiply<T>, N, T, T, T>::call({}, temp, a, b);
apply_impl<accurate_policy, ops::add<T>, N, T, T, T>::call({}, output, temp, c);
}
};
} // namespace detail
Expand Down Expand Up @@ -3992,9 +4006,6 @@ namespace kernel_float {
using half_t = ::__half;
using half2_t = ::__half2;

using __half = void;
using __half2 = void;

template<>
struct preferred_vector_size<half_t> {
static constexpr size_t value = 2;
Expand All @@ -4020,7 +4031,7 @@ template<>
struct allow_float_fallback<half_t> {
static constexpr bool value = true;
};
}; // namespace detail
} // namespace detail

#if KERNEL_FLOAT_IS_DEVICE
#define KERNEL_FLOAT_FP16_UNARY_FUN(NAME, FUN1, FUN2) \
Expand Down Expand Up @@ -4469,7 +4480,7 @@ namespace kernel_float {

namespace approx {

static_assert(sizeof(unsigned int) * 8 == 32, "invalid side of unsigned int");
static_assert(sizeof(unsigned int) * 8 == 32, "invalid size of unsigned int");
using uint32_t = unsigned int;

template<typename T, typename U>
Expand Down Expand Up @@ -4806,11 +4817,12 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t sqrt(bfloat16x2_t x) {

template<int = 0>
KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
static constexpr float SCALE = 1.44272065994f / 256.0f;
static constexpr float SCALE = 1.44272065994 / 256.0;
static constexpr float OFFSET = 382.4958400542335;
static constexpr float MINIMUM = 382;

auto a = fmaf(bfloat16x2_tfloat(arg.x), SCALE, OFFSET);
auto b = fmaf(bfloat16x2_tfloat(arg.y), SCALE, OFFSET);
float a = fmaxf(fmaf(bfloat162float(arg.x), SCALE, OFFSET), MINIMUM);
float b = fmaxf(fmaf(bfloat162float(arg.y), SCALE, OFFSET), MINIMUM);

return {
transmute<__bfloat16>(uint16_t(transmute<uint32_t>(a))),
Expand All @@ -4819,34 +4831,67 @@ KERNEL_FLOAT_DEVICE bfloat16x2_t exp(bfloat16x2_t arg) {
#endif
} // namespace approx

#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FULL_NAME, FUN, DEG) \
namespace detail { \
template<int Degree> \
struct apply_impl<approx_level_policy<Degree>, ops::FUN<half_t>, 2, half_t, half_t> { \
KERNEL_FLOAT_INLINE static void \
call(ops::FUN<half_t> fun, half_t* output, const half_t* input) { \
half2_t res = approx::FUN<Degree>(half2_t {input[0], input[1]}); \
output[0] = res.x; \
output[1] = res.y; \
} \
}; \
template<> \
struct apply_impl<approx_policy, ops::FUN<half_t>, 2, half_t, half_t>: \
apply_impl<approx_level_policy<DEG>, ops::FUN<half_t>, 2, half_t, half_t> {}; \
} \
\
template<int Level = -1, typename V> \
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
}

KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sin, sin, 4)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_cos, cos, 4)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rsqrt, rsqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_sqrt, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_rcp, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_exp, exp, 0)
KERNEL_FLOAT_DEFINE_APPROX_FUN(approx_log, log, 0)
namespace detail {
template<int Level, typename F, typename T>
struct apply_impl<approx_level_policy<Level>, F, 1, T, T> {
KERNEL_FLOAT_INLINE static void call(F fun, T* output, const T* input) {
T in2[2], out2[2];
out2[0] = input[0];
apply_impl<approx_level_policy<Level>, F, 2, T, T>::call(fun, out2, in2);
output[0] = out2[0];
}
};
} // namespace detail

#define KERNEL_FLOAT_DEFINE_APPROX_IMPL(T, FUN, DEFAULT_LEVEL) \
namespace detail { \
template<int Degree> \
struct apply_impl<approx_level_policy<Degree>, ops::FUN<T>, 2, T, T> { \
KERNEL_FLOAT_INLINE static void call(ops::FUN<T>, T* output, const T* input) { \
auto res = approx::FUN<Degree>({input[0], input[1]}); \
output[0] = res.x; \
output[1] = res.y; \
} \
}; \
\
template<> \
struct apply_impl<approx_policy, ops::FUN<T>, 2, T, T>: \
apply_impl<approx_level_policy<DEFAULT_LEVEL>, ops::FUN<T>, 2, T, T> {}; \
}

#if KERNEL_FLOAT_FP16_AVAILABLE
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sin, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, cos, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rsqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, exp, 0)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
#endif

#if KERNEL_FLOAT_BF16_OPS_SUPPORTED
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, cos, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sin, 4)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rcp, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, rsqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, sqrt, 1)
KERNEL_FLOAT_DEFINE_APPROX_IMPL(bfloat16_t, exp, 0)
//KERNEL_FLOAT_DEFINE_APPROX_IMPL(half_t, log, 0)
#endif

#define KERNEL_FLOAT_DEFINE_APPROX_FUN(FUN) \
template<int Level = -1, typename V> \
KERNEL_FLOAT_INLINE into_vector_type<V> approx_##FUN(const V& args) { \
return map<approx_level_policy<Level>>(ops::FUN<vector_value_type<V>> {}, args); \
}

KERNEL_FLOAT_DEFINE_APPROX_FUN(sin)
KERNEL_FLOAT_DEFINE_APPROX_FUN(cos)
KERNEL_FLOAT_DEFINE_APPROX_FUN(rsqrt)
KERNEL_FLOAT_DEFINE_APPROX_FUN(sqrt)
KERNEL_FLOAT_DEFINE_APPROX_FUN(rcp)
KERNEL_FLOAT_DEFINE_APPROX_FUN(exp)
KERNEL_FLOAT_DEFINE_APPROX_FUN(log)

} // namespace kernel_float
#ifndef KERNEL_FLOAT_FP8_H
Expand Down

0 comments on commit 76501fd

Please sign in to comment.