From 0721c2c3bab6f87d98099ab48730369d57627d1d Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Mon, 1 Jul 2024 14:57:31 +0200 Subject: [PATCH] Changed TwiddleIn/Out implementation to make use of rocprim::radix_key_codec --- .../hipcub/backend/rocprim/util_type.hpp | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/hipcub/include/hipcub/backend/rocprim/util_type.hpp b/hipcub/include/hipcub/backend/rocprim/util_type.hpp index c02e0ad8..b3f89a06 100644 --- a/hipcub/include/hipcub/backend/rocprim/util_type.hpp +++ b/hipcub/include/hipcub/backend/rocprim/util_type.hpp @@ -33,6 +33,7 @@ #include "../../config.hpp" #include +#include #include #include @@ -454,15 +455,16 @@ struct BaseTraits NULL_TYPE = false, }; + using key_codec = rocprim::radix_key_codec; static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { - return key; + return key_codec::encode(rocprim::detail::bit_cast(key)); } static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { - return key; + return key_codec::decode(rocprim::detail::bit_cast(key)); } static HIPCUB_HOST_DEVICE __forceinline__ T Max() @@ -502,14 +504,16 @@ struct BaseTraits NULL_TYPE = false, }; + using key_codec = rocprim::radix_key_codec; + static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { - return key ^ HIGH_BIT; + return key_codec::encode(rocprim::detail::bit_cast(key)); }; static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { - return key ^ HIGH_BIT; + return key_codec::decode(rocprim::detail::bit_cast(key)); }; static HIPCUB_HOST_DEVICE __forceinline__ T Max() @@ -593,6 +597,8 @@ struct BaseTraits static const UnsignedBits LOWEST_KEY = UnsignedBits(-1); static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT; + using key_codec = rocprim::radix_key_codec; + enum { PRIMITIVE = true, @@ -601,14 +607,12 @@ struct BaseTraits static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { - UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT; - return key ^ mask; + return key_codec::encode(rocprim::detail::bit_cast(key)); }; static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { - UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1); - return key ^ mask; + return key_codec::decode(rocprim::detail::bit_cast(key)); }; static HIPCUB_HOST_DEVICE __forceinline__ T Max() { @@ -655,14 +659,16 @@ struct NumericTraits<__uint128_t> static constexpr bool PRIMITIVE = false; static constexpr bool NULL_TYPE = false; + using key_codec = rocprim::radix_key_codec; + static __host__ __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { - return key; + return key_codec::encode(rocprim::detail::bit_cast(key)); } static __host__ __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { - return key; + return key_codec::decode(rocprim::detail::bit_cast(key)); } static __host__ __device__ __forceinline__ T Max() @@ -690,14 +696,16 @@ struct NumericTraits<__int128_t> static constexpr bool PRIMITIVE = false; static constexpr bool NULL_TYPE = false; + using key_codec = rocprim::radix_key_codec; + static __host__ __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { - return key ^ HIGH_BIT; + return key_codec::encode(rocprim::detail::bit_cast(key)); }; static __host__ __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { - return key ^ HIGH_BIT; + return key_codec::decode(rocprim::detail::bit_cast(key)); }; static __host__ __device__ __forceinline__ T Max()