Skip to content

Commit

Permalink
Changed TwiddleIn/Out implementation to make use of rocprim::radix_ke…
Browse files Browse the repository at this point in the history
…y_codec
  • Loading branch information
NB4444 authored and Beanavil committed Jul 31, 2024
1 parent 4e408ac commit 0721c2c
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions hipcub/include/hipcub/backend/rocprim/util_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "../../config.hpp"

#include <rocprim/detail/various.hpp>
#include <rocprim/thread/radix_key_codec.hpp>
#include <rocprim/types/future_value.hpp>

#include <hip/hip_fp16.h>
Expand Down Expand Up @@ -454,15 +455,16 @@ struct BaseTraits<UNSIGNED_INTEGER, true, false, _UnsignedBits, T>
NULL_TYPE = false,
};

using key_codec = rocprim::radix_key_codec<T>;

static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
{
return key;
return key_codec::encode(rocprim::detail::bit_cast<T>(key));
}

static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
{
return key;
return key_codec::decode(rocprim::detail::bit_cast<T>(key));
}

static HIPCUB_HOST_DEVICE __forceinline__ T Max()
Expand Down Expand Up @@ -502,14 +504,16 @@ struct BaseTraits<SIGNED_INTEGER, true, false, _UnsignedBits, T>
NULL_TYPE = false,
};

using key_codec = rocprim::radix_key_codec<T>;

static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
{
return key ^ HIGH_BIT;
return key_codec::encode(rocprim::detail::bit_cast<T>(key));
};

static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
{
return key ^ HIGH_BIT;
return key_codec::decode(rocprim::detail::bit_cast<T>(key));
};

static HIPCUB_HOST_DEVICE __forceinline__ T Max()
Expand Down Expand Up @@ -593,6 +597,8 @@ struct BaseTraits<FLOATING_POINT, true, false, _UnsignedBits, T>
static const UnsignedBits LOWEST_KEY = UnsignedBits(-1);
static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT;

using key_codec = rocprim::radix_key_codec<T>;

enum
{
PRIMITIVE = true,
Expand All @@ -601,14 +607,12 @@ struct BaseTraits<FLOATING_POINT, true, false, _UnsignedBits, T>

static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
{
UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT;
return key ^ mask;
return key_codec::encode(rocprim::detail::bit_cast<T>(key));
};

static HIPCUB_HOST_DEVICE __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
{
UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1);
return key ^ mask;
return key_codec::decode(rocprim::detail::bit_cast<T>(key));
};

static HIPCUB_HOST_DEVICE __forceinline__ T Max() {
Expand Down Expand Up @@ -655,14 +659,16 @@ struct NumericTraits<__uint128_t>
static constexpr bool PRIMITIVE = false;
static constexpr bool NULL_TYPE = false;

using key_codec = rocprim::radix_key_codec<T>;

static __host__ __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
{
return key;
return key_codec::encode(rocprim::detail::bit_cast<T>(key));
}

static __host__ __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
{
return key;
return key_codec::decode(rocprim::detail::bit_cast<T>(key));
}

static __host__ __device__ __forceinline__ T Max()
Expand Down Expand Up @@ -690,14 +696,16 @@ struct NumericTraits<__int128_t>
static constexpr bool PRIMITIVE = false;
static constexpr bool NULL_TYPE = false;

using key_codec = rocprim::radix_key_codec<T>;

static __host__ __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
{
return key ^ HIGH_BIT;
return key_codec::encode(rocprim::detail::bit_cast<T>(key));
};

static __host__ __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
{
return key ^ HIGH_BIT;
return key_codec::decode(rocprim::detail::bit_cast<T>(key));
};

static __host__ __device__ __forceinline__ T Max()
Expand Down

0 comments on commit 0721c2c

Please sign in to comment.