Skip to content

Commit

Permalink
Update CUB to latest master from https://github.com/NVlabs/cub
Browse files Browse the repository at this point in the history
  • Loading branch information
Spudz76 committed Oct 6, 2021
1 parent b045739 commit 138005d
Show file tree
Hide file tree
Showing 84 changed files with 27,236 additions and 411 deletions.
778 changes: 778 additions & 0 deletions src/3rdparty/cub/agent/agent_histogram.cuh

Large diffs are not rendered by default.

750 changes: 750 additions & 0 deletions src/3rdparty/cub/agent/agent_merge_sort.cuh

Large diffs are not rendered by default.

102 changes: 47 additions & 55 deletions src/3rdparty/cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include "../block/block_store.cuh"
#include "../block/block_radix_rank.cuh"
#include "../block/block_exchange.cuh"
#include "../block/radix_rank_sort_operations.cuh"
#include "../config.cuh"
#include "../util_type.cuh"
#include "../iterator/cache_modified_input_iterator.cuh"
Expand All @@ -56,16 +57,6 @@ namespace cub {
* Tuning policy types
******************************************************************************/

/**
* Radix ranking algorithm
*/
enum RadixRankAlgorithm
{
RADIX_RANK_BASIC,
RADIX_RANK_MEMOIZE,
RADIX_RANK_MATCH
};

/**
* Parameterizable tuning policy type for AgentRadixSortDownsweep
*/
Expand Down Expand Up @@ -137,6 +128,9 @@ struct AgentRadixSortDownsweep

RADIX_DIGITS = 1 << RADIX_BITS,
KEYS_ONLY = Equals<ValueT, NullType>::VALUE,
LOAD_WARP_STRIPED = RANK_ALGORITHM == RADIX_RANK_MATCH ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
};

// Input iterator wrapper type (for applying cache modifier)s
Expand All @@ -148,10 +142,22 @@ struct AgentRadixSortDownsweep
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, false, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MEMOIZE),
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, true, SCAN_ALGORITHM>,
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH),
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY),
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ANY>,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ATOMIC_OR>
>::Type
>::Type
>::Type
>::Type BlockRadixRankT;

// Digit extractor type
typedef BFEDigitExtractor<KeyT> DigitExtractorT;


enum
{
/// Number of bin-starting offsets tracked per thread
Expand Down Expand Up @@ -184,11 +190,11 @@ struct AgentRadixSortDownsweep
typename BlockLoadValuesT::TempStorage load_values;
typename BlockRadixRankT::TempStorage radix_rank;

struct
struct KeysAndOffsets
{
UnsignedBits exchange_keys[TILE_ITEMS];
OffsetT relative_bin_offsets[RADIX_DIGITS];
};
} keys_and_offsets;

Uninitialized<ValueExchangeT> exchange_values;

Expand Down Expand Up @@ -216,11 +222,8 @@ struct AgentRadixSortDownsweep
// The global scatter base offset for each digit (valid in the first RADIX_DIGITS threads)
OffsetT bin_offset[BINS_TRACKED_PER_THREAD];

// The least-significant bit position of the current digit to extract
int current_bit;

// Number of bits in current digit
int num_bits;
// Digit extractor
DigitExtractorT digit_extractor;

// Whether to short-cirucit
int short_circuit;
Expand All @@ -243,17 +246,17 @@ struct AgentRadixSortDownsweep
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
temp_storage.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM];
temp_storage.keys_and_offsets.exchange_keys[ranks[ITEM]] = twiddled_keys[ITEM];
}

CTA_SYNC();

#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
UnsignedBits key = temp_storage.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)];
UnsignedBits digit = BFE(key, current_bit, num_bits);
relative_bin_offsets[ITEM] = temp_storage.relative_bin_offsets[digit];
UnsignedBits key = temp_storage.keys_and_offsets.exchange_keys[threadIdx.x + (ITEM * BLOCK_THREADS)];
UnsignedBits digit = digit_extractor.Digit(key);
relative_bin_offsets[ITEM] = temp_storage.keys_and_offsets.relative_bin_offsets[digit];

// Un-twiddle
key = Traits<KeyT>::TwiddleOut(key);
Expand Down Expand Up @@ -303,16 +306,15 @@ struct AgentRadixSortDownsweep
}

/**
* Load a tile of keys (specialized for full tile, any ranking algorithm)
* Load a tile of keys (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadKeysT(temp_storage.load_keys).Load(
d_keys_in + block_offset, keys);
Expand All @@ -322,16 +324,15 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for partial tile, any ranking algorithm)
* Load a tile of keys (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -345,30 +346,29 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for full tile, match ranking algorithm)
* Load a tile of keys (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys);
}


/**
* Load a tile of keys (specialized for partial tile, match ranking algorithm)
* Load a tile of keys (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -377,17 +377,15 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys, valid_items, oob_item);
}


/**
* Load a tile of values (specialized for full tile, any ranking algorithm)
* Load a tile of values (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadValuesT(temp_storage.load_values).Load(
d_values_in + block_offset, values);
Expand All @@ -397,15 +395,14 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of values (specialized for partial tile, any ranking algorithm)
* Load a tile of values (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -419,28 +416,27 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of items (specialized for full tile, match ranking algorithm)
* Load a tile of items (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values);
}


/**
* Load a tile of items (specialized for partial tile, match ranking algorithm)
* Load a tile of items (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -449,7 +445,6 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values, valid_items);
}


/**
* Truck along associated values
*/
Expand All @@ -470,7 +465,7 @@ struct AgentRadixSortDownsweep
block_offset,
valid_items,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

ScatterValues<FULL_TILE>(
values,
Expand Down Expand Up @@ -515,7 +510,7 @@ struct AgentRadixSortDownsweep
valid_items,
default_key,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

// Twiddle key bits if necessary
#pragma unroll
Expand All @@ -529,8 +524,7 @@ struct AgentRadixSortDownsweep
BlockRadixRankT(temp_storage.radix_rank).RankKeys(
keys,
ranks,
current_bit,
num_bits,
digit_extractor,
exclusive_digit_prefix);

CTA_SYNC();
Expand Down Expand Up @@ -586,7 +580,7 @@ struct AgentRadixSortDownsweep
if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
{
bin_offset[track] -= exclusive_digit_prefix[track];
temp_storage.relative_bin_offsets[bin_idx] = bin_offset[track];
temp_storage.keys_and_offsets.relative_bin_offsets[bin_idx] = bin_offset[track];
bin_offset[track] += inclusive_digit_prefix[track];
}
}
Expand Down Expand Up @@ -677,8 +671,7 @@ struct AgentRadixSortDownsweep
d_values_in(d_values_in),
d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
d_values_out(d_values_out),
current_bit(current_bit),
num_bits(num_bits),
digit_extractor(current_bit, num_bits),
short_circuit(1)
{
#pragma unroll
Expand Down Expand Up @@ -717,8 +710,7 @@ struct AgentRadixSortDownsweep
d_values_in(d_values_in),
d_keys_out(reinterpret_cast<UnsignedBits*>(d_keys_out)),
d_values_out(d_values_out),
current_bit(current_bit),
num_bits(num_bits),
digit_extractor(current_bit, num_bits),
short_circuit(1)
{
#pragma unroll
Expand Down
Loading

0 comments on commit 138005d

Please sign in to comment.