Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make device contains take a template key parameter #174

Merged
merged 16 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions include/cuco/detail/static_map.inl
Original file line number Diff line number Diff line change
Expand Up @@ -716,9 +716,11 @@ static_map<Key, Value, Scope, Allocator>::device_view::find(CG g,
}

template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename Hash, typename KeyEqual>
__device__ bool static_map<Key, Value, Scope, Allocator>::device_view::contains(
Key const& k, Hash hash, KeyEqual key_equal) const noexcept
template <typename ProbeKey, typename Hash, typename KeyEqual>
__device__ std::enable_if_t<std::is_convertible_v<ProbeKey, Key>, bool>
static_map<Key, Value, Scope, Allocator>::device_view::contains(ProbeKey const& k,
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
Hash hash,
KeyEqual key_equal) const noexcept
{
auto current_slot = initial_slot(k, hash);

Expand All @@ -734,9 +736,12 @@ __device__ bool static_map<Key, Value, Scope, Allocator>::device_view::contains(
}

template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
template <typename CG, typename Hash, typename KeyEqual>
__device__ bool static_map<Key, Value, Scope, Allocator>::device_view::contains(
CG g, Key const& k, Hash hash, KeyEqual key_equal) const noexcept
template <typename CG, typename ProbeKey, typename Hash, typename KeyEqual>
__device__ std::enable_if_t<std::is_convertible_v<ProbeKey, Key>, bool>
static_map<Key, Value, Scope, Allocator>::device_view::contains(CG const& g,
ProbeKey const& k,
Hash hash,
KeyEqual key_equal) const noexcept
{
auto current_slot = initial_slot(g, k, hash);

Expand Down
12 changes: 8 additions & 4 deletions include/cuco/detail/static_multimap/device_view_impl.inl
Original file line number Diff line number Diff line change
Expand Up @@ -569,17 +569,19 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
*
* @tparam uses_vector_load Boolean flag indicating whether vector loads are used
* @tparam CG Cooperative Group type
* @tparam ProbeKey Probe key type that is convertible to the map's `key_type`
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
* @tparam KeyEqual Binary callable type
*
* @param g The Cooperative Group used to perform the contains operation
* @param k The key to search for
* @param key_equal The binary callable used to compare two keys
* for equality
* @return A boolean indicating whether the key/value pair
* containing `k` was inserted
*/
template <bool uses_vector_load, typename CG, typename KeyEqual>
template <bool uses_vector_load, typename CG, typename ProbeKey, typename KeyEqual>
__device__ __forceinline__ std::enable_if_t<uses_vector_load, bool> contains(
CG g, Key const& k, KeyEqual key_equal) noexcept
CG const& g, ProbeKey const& k, KeyEqual key_equal) noexcept
{
auto current_slot = initial_slot(g, k);

Expand Down Expand Up @@ -617,17 +619,19 @@ class static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view_
*
* @tparam uses_vector_load Boolean flag indicating whether vector loads are used
* @tparam CG Cooperative Group type
* @tparam ProbeKey Probe key type that is convertible to the map's `key_type`
* @tparam KeyEqual Binary callable type
*
* @param g The Cooperative Group used to perform the contains operation
* @param k The key to search for
* @param key_equal The binary callable used to compare two keys
* for equality
* @return A boolean indicating whether the key/value pair
* containing `k` was inserted
*/
template <bool uses_vector_load, typename CG, typename KeyEqual>
template <bool uses_vector_load, typename CG, typename ProbeKey, typename KeyEqual>
__device__ __forceinline__ std::enable_if_t<not uses_vector_load, bool> contains(
CG g, Key const& k, KeyEqual key_equal) noexcept
CG const& g, ProbeKey const& k, KeyEqual key_equal) noexcept
{
auto current_slot = initial_slot(g, k);

Expand Down
4 changes: 2 additions & 2 deletions include/cuco/detail/static_multimap/static_multimap.inl
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,11 @@ template <typename Key,
cuda::thread_scope Scope,
typename Allocator,
class ProbeSequence>
template <typename KeyEqual>
template <typename ProbeKey, typename KeyEqual>
__device__ __forceinline__ bool
static_multimap<Key, Value, Scope, Allocator, ProbeSequence>::device_view::contains(
cooperative_groups::thread_block_tile<ProbeSequence::cg_size> const& g,
Key const& k,
ProbeKey const& k,
KeyEqual key_equal) noexcept
{
return impl_.contains<uses_vector_load()>(g, k, key_equal);
Expand Down
31 changes: 18 additions & 13 deletions include/cuco/static_map.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1211,33 +1211,36 @@ class static_map {
* If the key `k` was inserted into the map, find returns
* true. Otherwise, it returns false.
*
* @tparam ProbeKey Probe key type that is convertible to the map's `key_type`
* @tparam Hash Unary callable type
* @tparam KeyEqual Binary callable type
*
* @param k The key to search for
* @param hash The unary callable used to hash the key
* @param key_equal The binary callable used to compare two keys
* for equality
* @return A boolean indicating whether the key/value pair
* containing `k` was inserted
*/
template <typename Hash = cuco::detail::MurmurHash3_32<key_type>,
template <typename ProbeKey,
typename Hash = cuco::detail::MurmurHash3_32<key_type>,
typename KeyEqual = thrust::equal_to<key_type>>
__device__ bool contains(Key const& k,
Hash hash = Hash{},
KeyEqual key_equal = KeyEqual{}) const noexcept;
__device__ std::enable_if_t<std::is_convertible_v<ProbeKey, Key>, bool> contains(
ProbeKey const& k, Hash hash = Hash{}, KeyEqual key_equal = KeyEqual{}) const noexcept;

/**
* @brief Indicates whether the key `k` was inserted into the map.
*
* If the key `k` was inserted into the map, find returns
* true. Otherwise, it returns false. Uses the CUDA Cooperative Groups API to
* to leverage multiple threads to perform a single contains operation. This provides a
* significant boost in throughput compared to the non Cooperative Group
* `contains` at moderate to high load factors.
* If the key `k` was inserted into the map, find returns true. Otherwise, it returns false.
* Uses the CUDA Cooperative Groups API to to leverage multiple threads to perform a single
* contains operation. This provides a significant boost in throughput compared to the non
* Cooperative Group `contains` at moderate to high load factors.
*
* @tparam CG Cooperative Group type
* @tparam ProbeKey Probe key type that is convertible to the map's `key_type`
* @tparam Hash Unary callable type
* @tparam KeyEqual Binary callable type
*
* @param g The Cooperative Group used to perform the contains operation
* @param k The key to search for
* @param hash The unary callable used to hash the key
Expand All @@ -1247,12 +1250,14 @@ class static_map {
* containing `k` was inserted
*/
template <typename CG,
typename ProbeKey,
typename Hash = cuco::detail::MurmurHash3_32<key_type>,
typename KeyEqual = thrust::equal_to<key_type>>
__device__ bool contains(CG g,
Key const& k,
Hash hash = Hash{},
KeyEqual key_equal = KeyEqual{}) const noexcept;
__device__ std::enable_if_t<std::is_convertible_v<ProbeKey, Key>, bool> contains(
CG const& g,
ProbeKey const& k,
Hash hash = Hash{},
KeyEqual key_equal = KeyEqual{}) const noexcept;
}; // class device_view

/**
Expand Down
6 changes: 4 additions & 2 deletions include/cuco/static_multimap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -820,18 +820,20 @@ class static_multimap {
* significant boost in throughput compared to the non Cooperative Group
* `contains` at moderate to high load factors.
*
* @tparam ProbeKey Probe key type that is convertible to the map's `key_type`
* @tparam KeyEqual Binary callable type
*
* @param g The Cooperative Group used to perform the contains operation
* @param k The key to search for
* @param key_equal The binary callable used to compare two keys
* for equality
* @return A boolean indicating whether the key/value pair
* containing `k` was inserted
*/
template <typename KeyEqual = thrust::equal_to<key_type>>
template <typename ProbeKey, typename KeyEqual = thrust::equal_to<key_type>>
__device__ __forceinline__ bool contains(
cooperative_groups::thread_block_tile<ProbeSequence::cg_size> const& g,
Key const& k,
ProbeKey const& k,
KeyEqual key_equal = KeyEqual{}) noexcept;

/**
Expand Down
6 changes: 3 additions & 3 deletions tests/static_map/custom_type_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ struct hash_custom_key {

// User-defined device key equality
struct custom_key_equals {
template <typename custom_type>
__device__ bool operator()(custom_type lhs, custom_type rhs)
template <typename lhs_type, typename rhs_type>
PointKernel marked this conversation as resolved.
Show resolved Hide resolved
__device__ bool operator()(lhs_type lhs, rhs_type rhs)
{
return std::tie(lhs.a, lhs.b) == std::tie(rhs.a, rhs.b);
return lhs == static_cast<lhs_type>(rhs);
}
};

Expand Down