From db36789289ef838427a5678e2db21a36b55f7f0d Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Tue, 2 Apr 2024 12:45:47 -0700 Subject: [PATCH] Add some missing _Float16 support (Changes extracted from https://github.com/halide/Halide/pull/8169, which may or may not land in its current form) Some missing support for _Float16 that will likely be handy: - Allow _Float16 to be detected for Clang 15 (since my local XCode Clang 15 definitely supports it) - Expr(_Float16) - HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(_Float16); - Add _Float16 to the convert matrix in halide_image_io.h --- src/Expr.h | 5 ++ src/Type.h | 3 + src/runtime/HalideRuntime.h | 2 +- tools/halide_image_io.h | 118 ++++++++++++++++++++++++++++++++++++ 4 files changed, 127 insertions(+), 1 deletion(-) diff --git a/src/Expr.h b/src/Expr.h index 31850fc56001..b9832c104de8 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -298,6 +298,11 @@ struct Expr : public Internal::IRHandle { Expr(bfloat16_t x) : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) { } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 + explicit Expr(_Float16 x) + : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) { + } +#endif Expr(float x) : IRHandle(Internal::FloatImm::make(Float(32), x)) { } diff --git a/src/Type.h b/src/Type.h index af5447350810..c8a397b3f0a7 100644 --- a/src/Type.h +++ b/src/Type.h @@ -166,6 +166,9 @@ HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::float16_t); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(Halide::bfloat16_t); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_task_t); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(halide_loop_task_t); +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(_Float16); +#endif HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(float); HALIDE_DECLARE_EXTERN_SIMPLE_TYPE(double); HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_buffer_t); diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 1d0843be0329..0379c1f9ab47 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -91,7 +91,7 @@ extern "C" { // Ideally there would be a better way to detect if the type // is supported, even in a compiler independent fashion, but // coming up with one has proven elusive. -#if defined(__clang__) && (__clang_major__ >= 16) && !defined(__EMSCRIPTEN__) && !defined(__i386__) +#if defined(__clang__) && (__clang_major__ >= 15) && !defined(__EMSCRIPTEN__) && !defined(__i386__) #if defined(__is_identifier) #if !__is_identifier(_Float16) #define HALIDE_CPP_COMPILER_HAS_FLOAT16 diff --git a/tools/halide_image_io.h b/tools/halide_image_io.h index e039f7c2e798..1e0cbff01897 100644 --- a/tools/halide_image_io.h +++ b/tools/halide_image_io.h @@ -116,6 +116,12 @@ template<> inline bool convert(const int64_t &in) { return in != 0; } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline bool convert(const _Float16 &in) { + return (float)in != 0; +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline bool convert(const float &in) { return in != 0; @@ -165,6 +171,12 @@ template<> inline uint8_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline uint8_t convert(const _Float16 &in) { + return (uint8_t)std::lround((float)in * 255.0f); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline uint8_t convert(const float &in) { return (uint8_t)std::lround(in * 255.0f); @@ -211,6 +223,12 @@ template<> inline uint16_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline uint16_t convert(const _Float16 &in) { + return (uint16_t)std::lround((float)in * 65535.0f); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline uint16_t convert(const float &in) { return (uint16_t)std::lround(in * 65535.0f); @@ -257,6 +275,12 @@ template<> inline uint32_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline uint32_t convert(const _Float16 &in) { + return (uint32_t)std::llround((float)in * 4294967295.0); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline uint32_t convert(const float &in) { return (uint32_t)std::llround(in * 4294967295.0); @@ -303,6 +327,12 @@ template<> inline uint64_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline uint64_t convert(const _Float16 &in) { + return convert((uint32_t)std::llround((float)in * 4294967295.0)); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline uint64_t convert(const float &in) { return convert((uint32_t)std::llround(in * 4294967295.0)); @@ -349,6 +379,12 @@ template<> inline int8_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline int8_t convert(const _Float16 &in) { + return convert((float)in); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline int8_t convert(const float &in) { return convert(in); @@ -395,6 +431,12 @@ template<> inline int16_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline int16_t convert(const _Float16 &in) { + return convert((float)in); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline int16_t convert(const float &in) { return convert(in); @@ -441,6 +483,12 @@ template<> inline int32_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline int32_t convert(const _Float16 &in) { + return convert((float)in); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline int32_t convert(const float &in) { return convert(in); @@ -487,6 +535,12 @@ template<> inline int64_t convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline int64_t convert(const _Float16 &in) { + return convert((float)in); +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline int64_t convert(const float &in) { return convert(in); @@ -496,6 +550,58 @@ inline int64_t convert(const double &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +// Convert to f16 +template<> +inline _Float16 convert(const bool &in) { + return in; +} +template<> +inline _Float16 convert(const uint8_t &in) { + return (_Float16)(in / 255.0f); +} +template<> +inline _Float16 convert(const uint16_t &in) { + return (_Float16)(in / 65535.0f); +} +template<> +inline _Float16 convert(const uint32_t &in) { + return (_Float16)(in / 4294967295.0); +} +template<> +inline _Float16 convert(const uint64_t &in) { + return convert<_Float16, uint32_t>(uint32_t(in >> 32)); +} +template<> +inline _Float16 convert(const int8_t &in) { + return convert<_Float16, uint8_t>(in); +} +template<> +inline _Float16 convert(const int16_t &in) { + return convert<_Float16, uint16_t>(in); +} +template<> +inline _Float16 convert(const int32_t &in) { + return convert<_Float16, uint64_t>(in); +} +template<> +inline _Float16 convert(const int64_t &in) { + return convert<_Float16, uint64_t>(in); +} +template<> +inline _Float16 convert(const _Float16 &in) { + return in; +} +template<> +inline _Float16 convert(const float &in) { + return (_Float16)in; +} +template<> +inline _Float16 convert(const double &in) { + return (_Float16)in; +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 + // Convert to f32 template<> inline float convert(const bool &in) { @@ -533,6 +639,12 @@ template<> inline float convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline float convert(const _Float16 &in) { + return (float)in; +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline float convert(const float &in) { return in; @@ -579,6 +691,12 @@ template<> inline double convert(const int64_t &in) { return convert(in); } +#ifdef HALIDE_CPP_COMPILER_HAS_FLOAT16 +template<> +inline double convert(const _Float16 &in) { + return (double)in; +} +#endif // HALIDE_CPP_COMPILER_HAS_FLOAT16 template<> inline double convert(const float &in) { return (double)in;