From 25befef03f84e1562c9144276abb324bb123de6c Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 3 Jun 2022 20:14:00 -0400 Subject: [PATCH 01/15] Add template contains for static_map --- include/cuco/detail/static_map.inl | 17 ++++++++++------ include/cuco/static_map.cuh | 29 +++++++++++++++------------- tests/static_map/custom_type_test.cu | 6 +++--- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl index 23d797cae..8f1beab8c 100644 --- a/include/cuco/detail/static_map.inl +++ b/include/cuco/detail/static_map.inl @@ -716,9 +716,11 @@ static_map::device_view::find(CG g, } template -template -__device__ bool static_map::device_view::contains( - Key const& k, Hash hash, KeyEqual key_equal) const noexcept +template +__device__ std::enable_if_t, bool> +static_map::device_view::contains(ProbeKey const& k, + Hash hash, + KeyEqual key_equal) const noexcept { auto current_slot = initial_slot(k, hash); @@ -734,9 +736,12 @@ __device__ bool static_map::device_view::contains( } template -template -__device__ bool static_map::device_view::contains( - CG g, Key const& k, Hash hash, KeyEqual key_equal) const noexcept +template +__device__ std::enable_if_t, bool> +static_map::device_view::contains(CG const& g, + ProbeKey const& k, + Hash hash, + KeyEqual key_equal) const noexcept { auto current_slot = initial_slot(g, k, hash); diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 2bfb6a5a8..d52b2f412 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -1211,6 +1211,7 @@ class static_map { * If the key `k` was inserted into the map, find returns * true. Otherwise, it returns false. * + * @tparam ProbeKey Probe key type that can be convertible to map's `key_type` * @tparam Hash Unary callable type * @tparam KeyEqual Binary callable type * @param k The key to search for @@ -1220,22 +1221,22 @@ class static_map { * @return A boolean indicating whether the key/value pair * containing `k` was inserted */ - template , + template , typename KeyEqual = thrust::equal_to> - __device__ bool contains(Key const& k, - Hash hash = Hash{}, - KeyEqual key_equal = KeyEqual{}) const noexcept; + __device__ std::enable_if_t, bool> contains( + ProbeKey const& k, Hash hash = Hash{}, KeyEqual key_equal = KeyEqual{}) const noexcept; /** * @brief Indicates whether the key `k` was inserted into the map. * - * If the key `k` was inserted into the map, find returns - * true. Otherwise, it returns false. Uses the CUDA Cooperative Groups API to - * to leverage multiple threads to perform a single contains operation. This provides a - * significant boost in throughput compared to the non Cooperative Group - * `contains` at moderate to high load factors. + * If the key `k` was inserted into the map, find returns true. Otherwise, it returns false. + * Uses the CUDA Cooperative Groups API to to leverage multiple threads to perform a single + * contains operation. This provides a significant boost in throughput compared to the non + * Cooperative Group `contains` at moderate to high load factors. * * @tparam CG Cooperative Group type + * @tparam ProbeKey Probe key type that can be convertible to map's `key_type` * @tparam Hash Unary callable type * @tparam KeyEqual Binary callable type * @param g The Cooperative Group used to perform the contains operation @@ -1247,12 +1248,14 @@ class static_map { * containing `k` was inserted */ template , typename KeyEqual = thrust::equal_to> - __device__ bool contains(CG g, - Key const& k, - Hash hash = Hash{}, - KeyEqual key_equal = KeyEqual{}) const noexcept; + __device__ std::enable_if_t, bool> contains( + CG const& g, + ProbeKey const& k, + Hash hash = Hash{}, + KeyEqual key_equal = KeyEqual{}) const noexcept; }; // class device_view /** diff --git a/tests/static_map/custom_type_test.cu b/tests/static_map/custom_type_test.cu index c5722b03d..c68deb14d 100644 --- a/tests/static_map/custom_type_test.cu +++ b/tests/static_map/custom_type_test.cu @@ -89,10 +89,10 @@ struct hash_custom_key { // User-defined device key equality struct custom_key_equals { - template - __device__ bool operator()(custom_type lhs, custom_type rhs) + template + __device__ bool operator()(lhs_type lhs, rhs_type rhs) { - return std::tie(lhs.a, lhs.b) == std::tie(rhs.a, rhs.b); + return lhs == static_cast(rhs); } }; From 8103c797b4a91b628677373d3c04e032d66c156b Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Fri, 3 Jun 2022 20:31:11 -0400 Subject: [PATCH 02/15] Add template contains for static_multimap --- .../cuco/detail/static_multimap/device_view_impl.inl | 12 ++++++++---- .../cuco/detail/static_multimap/static_multimap.inl | 4 ++-- include/cuco/static_map.cuh | 6 ++++-- include/cuco/static_multimap.cuh | 6 ++++-- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index 8d0add2da..963f80041 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -569,7 +569,9 @@ class static_multimap::device_view_ * * @tparam uses_vector_load Boolean flag indicating whether vector loads are used * @tparam CG Cooperative Group type + * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` * @tparam KeyEqual Binary callable type + * * @param g The Cooperative Group used to perform the contains operation * @param k The key to search for * @param key_equal The binary callable used to compare two keys @@ -577,9 +579,9 @@ class static_multimap::device_view_ * @return A boolean indicating whether the key/value pair * containing `k` was inserted */ - template + template __device__ __forceinline__ std::enable_if_t contains( - CG g, Key const& k, KeyEqual key_equal) noexcept + CG const& g, ProbeKey const& k, KeyEqual key_equal) noexcept { auto current_slot = initial_slot(g, k); @@ -617,7 +619,9 @@ class static_multimap::device_view_ * * @tparam uses_vector_load Boolean flag indicating whether vector loads are used * @tparam CG Cooperative Group type + * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` * @tparam KeyEqual Binary callable type + * * @param g The Cooperative Group used to perform the contains operation * @param k The key to search for * @param key_equal The binary callable used to compare two keys @@ -625,9 +629,9 @@ class static_multimap::device_view_ * @return A boolean indicating whether the key/value pair * containing `k` was inserted */ - template + template __device__ __forceinline__ std::enable_if_t contains( - CG g, Key const& k, KeyEqual key_equal) noexcept + CG const& g, ProbeKey const& k, KeyEqual key_equal) noexcept { auto current_slot = initial_slot(g, k); diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index a3920b181..f5c86fc15 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -536,11 +536,11 @@ template -template +template __device__ __forceinline__ bool static_multimap::device_view::contains( cooperative_groups::thread_block_tile const& g, - Key const& k, + ProbeKey const& k, KeyEqual key_equal) noexcept { return impl_.contains(g, k, key_equal); diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index d52b2f412..1c304ba32 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -1211,9 +1211,10 @@ class static_map { * If the key `k` was inserted into the map, find returns * true. Otherwise, it returns false. * - * @tparam ProbeKey Probe key type that can be convertible to map's `key_type` + * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` * @tparam Hash Unary callable type * @tparam KeyEqual Binary callable type + * * @param k The key to search for * @param hash The unary callable used to hash the key * @param key_equal The binary callable used to compare two keys @@ -1236,9 +1237,10 @@ class static_map { * Cooperative Group `contains` at moderate to high load factors. * * @tparam CG Cooperative Group type - * @tparam ProbeKey Probe key type that can be convertible to map's `key_type` + * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` * @tparam Hash Unary callable type * @tparam KeyEqual Binary callable type + * * @param g The Cooperative Group used to perform the contains operation * @param k The key to search for * @param hash The unary callable used to hash the key diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 179126d7c..4c0119fbb 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -820,7 +820,9 @@ class static_multimap { * significant boost in throughput compared to the non Cooperative Group * `contains` at moderate to high load factors. * + * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` * @tparam KeyEqual Binary callable type + * * @param g The Cooperative Group used to perform the contains operation * @param k The key to search for * @param key_equal The binary callable used to compare two keys @@ -828,10 +830,10 @@ class static_multimap { * @return A boolean indicating whether the key/value pair * containing `k` was inserted */ - template > + template > __device__ __forceinline__ bool contains( cooperative_groups::thread_block_tile const& g, - Key const& k, + ProbeKey const& k, KeyEqual key_equal = KeyEqual{}) noexcept; /** From 3b0adf597ed828f3813030452fb39f9b1735c90b Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 7 Jun 2022 11:26:59 -0400 Subject: [PATCH 03/15] Make initial_slot take template key parameter --- include/cuco/detail/probe_sequence_impl.cuh | 18 ++++++++---- .../static_multimap/device_view_impl.inl | 20 +++++++++---- include/cuco/static_map.cuh | 28 ++++++++++++------- tests/static_map/custom_type_test.cu | 2 +- 4 files changed, 45 insertions(+), 23 deletions(-) diff --git a/include/cuco/detail/probe_sequence_impl.cuh b/include/cuco/detail/probe_sequence_impl.cuh index 747df89c7..62bbcfa12 100644 --- a/include/cuco/detail/probe_sequence_impl.cuh +++ b/include/cuco/detail/probe_sequence_impl.cuh @@ -21,6 +21,8 @@ #include +#include + namespace cuco { namespace detail { @@ -186,13 +188,15 @@ class linear_probing_impl * * If vector-load is enabled, the return slot is always even to avoid illegal memory access. * - * @tparam CG CUDA Cooperative Groups type + * @tparam ProbeKey Probe key type + * * @param g the Cooperative Group for which the initial slot is needed * @param k The key to get the slot for * @return Pointer to the initial slot for `k` */ - template - __device__ __forceinline__ iterator initial_slot(CG const& g, Key const k) noexcept + template + __device__ __forceinline__ iterator + initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const k) noexcept { auto const hash_value = [&]() { auto const tmp = hash_(k); @@ -307,13 +311,15 @@ class double_hashing_impl * If vector-load is enabled, the return slot is always a multiple of (`cg_size` * `vector_width`) * to avoid illegal memory access. * - * @tparam CG CUDA Cooperative Groups type + * @tparam ProbeKey Probe key type + * * @param g the Cooperative Group for which the initial slot is needed * @param k The key to get the slot for * @return Pointer to the initial slot for `k` */ - template - __device__ __forceinline__ iterator initial_slot(CG const& g, Key const k) noexcept + template + __device__ __forceinline__ iterator + initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const k) noexcept { std::size_t index; auto const hash_value = hash1_(k); diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index 963f80041..be85f5f1d 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -20,6 +20,8 @@ #include #include +#include + namespace cuco { template ::device_view_ * * To be used for Cooperative Group based probing. * - * @tparam CG Cooperative Group type + * @tparam ProbeKey Probe key type + * * @param g the Cooperative Group for which the initial slot is needed * @param k The key to get the slot for * @return Pointer to the initial slot for `k` */ - template - __device__ __forceinline__ iterator initial_slot(CG const& g, Key const& k) noexcept + template + __device__ __forceinline__ iterator + initial_slot(cooperative_groups::thread_block_tile const& g, + ProbeKey const& k) noexcept { return probe_sequence_.initial_slot(g, k); } @@ -85,13 +90,16 @@ class static_multimap::device_view_ * * To be used for Cooperative Group based probing. * - * @tparam CG Cooperative Group type + * @tparam ProbeKey Probe key type + * * @param g the Cooperative Group for which the initial slot is needed * @param k The key to get the slot for * @return Pointer to the initial slot for `k` */ - template - __device__ __forceinline__ const_iterator initial_slot(CG g, Key const& k) const noexcept + template + __device__ __forceinline__ const_iterator + initial_slot(cooperative_groups::thread_block_tile const& g, + ProbeKey const& k) const noexcept { return probe_sequence_.initial_slot(g, k); } diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 1c304ba32..449bcfa23 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -454,13 +454,15 @@ class static_map { /** * @brief Returns the initial slot for a given key `k` * + * @tparam ProbeKey Probe key type * @tparam Hash Unary callable type + * * @param k The key to get the slot for * @param hash The unary callable used to hash the key * @return Pointer to the initial slot for `k` */ - template - __device__ iterator initial_slot(Key const& k, Hash hash) noexcept + template + __device__ iterator initial_slot(ProbeKey const& k, Hash hash) noexcept { return &slots_[hash(k) % capacity_]; } @@ -468,13 +470,15 @@ class static_map { /** * @brief Returns the initial slot for a given key `k` * + * @tparam ProbeKey Probe key type * @tparam Hash Unary callable type + * * @param k The key to get the slot for * @param hash The unary callable used to hash the key * @return Pointer to the initial slot for `k` */ - template - __device__ const_iterator initial_slot(Key const& k, Hash hash) const noexcept + template + __device__ const_iterator initial_slot(ProbeKey const& k, Hash hash) const noexcept { return &slots_[hash(k) % capacity_]; } @@ -485,14 +489,16 @@ class static_map { * To be used for Cooperative Group based probing. * * @tparam CG Cooperative Group type + * @tparam ProbeKey Probe key type * @tparam Hash Unary callable type + * * @param g the Cooperative Group for which the initial slot is needed * @param k The key to get the slot for * @param hash The unary callable used to hash the key * @return Pointer to the initial slot for `k` */ - template - __device__ iterator initial_slot(CG g, Key const& k, Hash hash) noexcept + template + __device__ iterator initial_slot(CG const& g, ProbeKey const& k, Hash hash) noexcept { return &slots_[(hash(k) + g.thread_rank()) % capacity_]; } @@ -503,14 +509,16 @@ class static_map { * To be used for Cooperative Group based probing. * * @tparam CG Cooperative Group type + * @tparam ProbeKey Probe key type * @tparam Hash Unary callable type + * * @param g the Cooperative Group for which the initial slot is needed * @param k The key to get the slot for * @param hash The unary callable used to hash the key * @return Pointer to the initial slot for `k` */ - template - __device__ const_iterator initial_slot(CG g, Key const& k, Hash hash) const noexcept + template + __device__ const_iterator initial_slot(CG const& g, ProbeKey const& k, Hash hash) const noexcept { return &slots_[(hash(k) + g.thread_rank()) % capacity_]; } @@ -550,7 +558,7 @@ class static_map { * @return The next slot after `s` */ template - __device__ iterator next_slot(CG g, iterator s) noexcept + __device__ iterator next_slot(CG const& g, iterator s) noexcept { uint32_t index = s - slots_; return &slots_[(index + g.size()) % capacity_]; @@ -568,7 +576,7 @@ class static_map { * @return The next slot after `s` */ template - __device__ const_iterator next_slot(CG g, const_iterator s) const noexcept + __device__ const_iterator next_slot(CG const& g, const_iterator s) const noexcept { uint32_t index = s - slots_; return &slots_[(index + g.size()) % capacity_]; diff --git a/tests/static_map/custom_type_test.cu b/tests/static_map/custom_type_test.cu index c68deb14d..801618d6e 100644 --- a/tests/static_map/custom_type_test.cu +++ b/tests/static_map/custom_type_test.cu @@ -83,7 +83,7 @@ struct hash_custom_key { template __device__ uint32_t operator()(custom_type k) { - return k.a; + return thrust::raw_reference_cast(k).a; }; }; From b0a95ca12081e23ad3fdf3b57fd9fe07c0f5223f Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 7 Jun 2022 13:56:03 -0400 Subject: [PATCH 04/15] Add static_assert + update docs --- include/cuco/detail/hash_functions.cuh | 4 +++- include/cuco/detail/static_map.inl | 22 +++++++++++++---- .../static_multimap/static_multimap.inl | 24 +++++++++++++++++++ include/cuco/probe_sequences.cuh | 7 ++++++ include/cuco/static_map.cuh | 17 +++++++++---- include/cuco/static_multimap.cuh | 7 ++++-- 6 files changed, 68 insertions(+), 13 deletions(-) diff --git a/include/cuco/detail/hash_functions.cuh b/include/cuco/detail/hash_functions.cuh index d5dcd0f64..151cf99e6 100644 --- a/include/cuco/detail/hash_functions.cuh +++ b/include/cuco/detail/hash_functions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021, NVIDIA CORPORATION. + * Copyright (c) 2017-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ namespace cuco { +using hash_value_type = uint32_t; + namespace detail { // MurmurHash3_32 implementation from diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl index dc3f60b6c..12087f875 100644 --- a/include/cuco/detail/static_map.inl +++ b/include/cuco/detail/static_map.inl @@ -716,11 +716,16 @@ static_map::device_view::find(CG g, template template -__device__ std::enable_if_t, bool> -static_map::device_view::contains(ProbeKey const& k, - Hash hash, - KeyEqual key_equal) const noexcept +__device__ bool static_map::device_view::contains( + ProbeKey const& k, Hash hash, KeyEqual key_equal) const noexcept { + static_assert(std::is_invocable_r_v, + "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); + static_assert(std::is_invocable_r_v, + "Hash(Key{}) must be a valid callable."); + static_assert(std::is_invocable_r_v, + "Hash(ProbeKey{}) must be a valid callable."); + auto current_slot = initial_slot(k, hash); while (true) { @@ -736,12 +741,19 @@ static_map::device_view::contains(ProbeKey const& template template -__device__ std::enable_if_t, bool> +__device__ std::enable_if_t, bool> static_map::device_view::contains(CG const& g, ProbeKey const& k, Hash hash, KeyEqual key_equal) const noexcept { + static_assert(std::is_invocable_r_v, + "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); + static_assert(std::is_invocable_r_v, + "Hash(Key{}) must be a valid callable."); + static_assert(std::is_invocable_r_v, + "Hash(ProbeKey{}) must be a valid callable."); + auto current_slot = initial_slot(g, k, hash); while (true) { diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index f5c86fc15..433547a71 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -543,6 +543,30 @@ static_multimap::device_view::conta ProbeKey const& k, KeyEqual key_equal) noexcept { + static_assert(std::is_invocable_r_v, + "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); + + if constexpr (ProbeSequence::is_linear_probing) { + static_assert(std::is_invocable_r_v, + "ProbeSequence::hasher(Key{}) must be a valid callable."); + static_assert( + std::is_invocable_r_v, + "ProbeSequence::hasher(ProbeKey{}) must be a valid callable."); + } else { + static_assert( + std::is_invocable_r_v, + "ProbeSequence::hasher1(Key{}) must be a valid callable."); + static_assert( + std::is_invocable_r_v, + "ProbeSequence::hasher2(Key{}) must be a valid callable."); + static_assert( + std::is_invocable_r_v, + "ProbeSequence::hasher1(ProbeKey{}) must be a valid callable."); + static_assert( + std::is_invocable_r_v, + "ProbeSequence::hasher2(ProbeKey{}) must be a valid callable."); + } + return impl_.contains(g, k, key_equal); } diff --git a/include/cuco/probe_sequences.cuh b/include/cuco/probe_sequences.cuh index f923f9df6..2b79a9cf0 100644 --- a/include/cuco/probe_sequences.cuh +++ b/include/cuco/probe_sequences.cuh @@ -35,9 +35,12 @@ namespace cuco { template class linear_probing : public detail::probe_sequence_base { public: + static constexpr bool is_linear_probing = true; + using probe_sequence_base_type = detail::probe_sequence_base; using probe_sequence_base_type::cg_size; using probe_sequence_base_type::vector_width; + using hasher = Hash; template using impl = detail::linear_probing_impl; @@ -61,9 +64,13 @@ class linear_probing : public detail::probe_sequence_base { template class double_hashing : public detail::probe_sequence_base { public: + static constexpr bool is_linear_probing = false; + using probe_sequence_base_type = detail::probe_sequence_base; using probe_sequence_base_type::cg_size; using probe_sequence_base_type::vector_width; + using hasher1 = Hash1; + using hasher2 = Hash2; template using impl = detail::double_hashing_impl; diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 449bcfa23..37d9568af 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -1219,7 +1219,10 @@ class static_map { * If the key `k` was inserted into the map, find returns * true. Otherwise, it returns false. * - * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` + * If `key_equal(probe_key, slot_key)` returns true, `hash(probe_key) == hash(slot_key)` must + * also be true. + * + * @tparam ProbeKey Probe key type * @tparam Hash Unary callable type * @tparam KeyEqual Binary callable type * @@ -1233,8 +1236,9 @@ class static_map { template , typename KeyEqual = thrust::equal_to> - __device__ std::enable_if_t, bool> contains( - ProbeKey const& k, Hash hash = Hash{}, KeyEqual key_equal = KeyEqual{}) const noexcept; + __device__ bool contains(ProbeKey const& k, + Hash hash = Hash{}, + KeyEqual key_equal = KeyEqual{}) const noexcept; /** * @brief Indicates whether the key `k` was inserted into the map. @@ -1244,8 +1248,11 @@ class static_map { * contains operation. This provides a significant boost in throughput compared to the non * Cooperative Group `contains` at moderate to high load factors. * + * If `key_equal(probe_key, slot_key)` returns true, `hash(probe_key) == hash(slot_key)` must + * also be true. + * * @tparam CG Cooperative Group type - * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` + * @tparam ProbeKey Probe key type * @tparam Hash Unary callable type * @tparam KeyEqual Binary callable type * @@ -1261,7 +1268,7 @@ class static_map { typename ProbeKey, typename Hash = cuco::detail::MurmurHash3_32, typename KeyEqual = thrust::equal_to> - __device__ std::enable_if_t, bool> contains( + __device__ std::enable_if_t, bool> contains( CG const& g, ProbeKey const& k, Hash hash = Hash{}, diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 4c0119fbb..684adde36 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -155,7 +155,7 @@ class static_multimap { static_assert( std::is_base_of_v, ProbeSequence>, "ProbeSequence must be a specialization of either cuco::double_hashing or " - "cuco::linear_probing"); + "cuco::linear_probing."); public: using value_type = cuco::pair_type; @@ -820,7 +820,10 @@ class static_multimap { * significant boost in throughput compared to the non Cooperative Group * `contains` at moderate to high load factors. * - * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` + * If `key_equal(probe_key, slot_key)` returns true, `hash(probe_key) == hash(slot_key)` must + * also be true. + * + * @tparam ProbeKey Probe key type * @tparam KeyEqual Binary callable type * * @param g The Cooperative Group used to perform the contains operation From ad321f4d664317fcba1cc0d7505090cbc7aa6e28 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 7 Jun 2022 13:58:36 -0400 Subject: [PATCH 05/15] Minor fix: initial_slot take const ref --- include/cuco/detail/probe_sequence_impl.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/cuco/detail/probe_sequence_impl.cuh b/include/cuco/detail/probe_sequence_impl.cuh index 62bbcfa12..a7349dcdb 100644 --- a/include/cuco/detail/probe_sequence_impl.cuh +++ b/include/cuco/detail/probe_sequence_impl.cuh @@ -196,7 +196,7 @@ class linear_probing_impl */ template __device__ __forceinline__ iterator - initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const k) noexcept + initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const& k) noexcept { auto const hash_value = [&]() { auto const tmp = hash_(k); @@ -319,7 +319,7 @@ class double_hashing_impl */ template __device__ __forceinline__ iterator - initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const k) noexcept + initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const& k) noexcept { std::size_t index; auto const hash_value = hash1_(k); From aadf4a2282320f7dbf514c3d651fbaad9c289c83 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 7 Jun 2022 14:02:09 -0400 Subject: [PATCH 06/15] Minor doc cleanup --- include/cuco/detail/static_multimap/device_view_impl.inl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index be85f5f1d..57797f6d5 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -577,7 +577,7 @@ class static_multimap::device_view_ * * @tparam uses_vector_load Boolean flag indicating whether vector loads are used * @tparam CG Cooperative Group type - * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` + * @tparam ProbeKey Probe key type * @tparam KeyEqual Binary callable type * * @param g The Cooperative Group used to perform the contains operation From fcd9d4588cab0207dfbd2da839cd3605a71efba2 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 7 Jun 2022 15:18:09 -0400 Subject: [PATCH 07/15] Add unit tests --- .../static_multimap/static_multimap.inl | 2 +- include/cuco/static_multimap.cuh | 7 +- tests/CMakeLists.txt | 4 +- tests/static_map/heterogeneous_lookup_test.cu | 128 +++++++++++++++++ tests/static_multimap/custom_type_test.cu | 4 +- .../heterogeneous_lookup_test.cu | 131 ++++++++++++++++++ 6 files changed, 269 insertions(+), 7 deletions(-) create mode 100644 tests/static_map/heterogeneous_lookup_test.cu create mode 100644 tests/static_multimap/heterogeneous_lookup_test.cu diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index 433547a71..6119907c3 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -108,7 +108,7 @@ template template void static_multimap::contains( - InputIt first, InputIt last, OutputIt output_begin, cudaStream_t stream, KeyEqual key_equal) const + InputIt first, InputIt last, OutputIt output_begin, KeyEqual key_equal, cudaStream_t stream) const { auto const num_keys = std::distance(first, last); if (num_keys == 0) { return; } diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 684adde36..9273f7e7d 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -277,18 +277,19 @@ class static_multimap { * @tparam OutputIt Device accessible output iterator whose `value_type` is convertible from * `bool` * @tparam KeyEqual Binary callable type used to compare two keys for equality + * * @param first Beginning of the sequence of keys * @param last End of the sequence of keys * @param output_begin Beginning of the output sequence indicating whether each key is present - * @param stream CUDA stream used for contains * @param key_equal The binary function to compare two keys for equality + * @param stream CUDA stream used for contains */ template > void contains(InputIt first, InputIt last, OutputIt output_begin, - cudaStream_t stream = 0, - KeyEqual key_equal = KeyEqual{}) const; + KeyEqual key_equal = KeyEqual{}, + cudaStream_t stream = 0) const; /** * @brief Counts the occurrences of keys in `[first, last)` contained in the multimap. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b7b955362..dec5170c4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -55,9 +55,10 @@ endfunction(ConfigureTest) ################################################################################################### # - static_map tests ------------------------------------------------------------------------------ ConfigureTest(STATIC_MAP_TEST - static_map/erase_test.cu static_map/custom_type_test.cu static_map/duplicate_keys_test.cu + static_map/erase_test.cu + static_map/heterogeneous_lookup_test.cu static_map/key_sentinel_test.cu static_map/shared_memory_test.cu static_map/stream_test.cu @@ -80,6 +81,7 @@ ConfigureTest(DYNAMIC_MAP_TEST ConfigureTest(STATIC_MULTIMAP_TEST static_multimap/custom_pair_retrieve_test.cu static_multimap/custom_type_test.cu + static_multimap/heterogeneous_lookup_test.cu static_multimap/insert_if_test.cu static_multimap/multiplicity_test.cu static_multimap/non_match_test.cu diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu new file mode 100644 index 000000000..ba31eaf8d --- /dev/null +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include + +// insert key type +template +struct key_pair { + T a; + T b; + + __host__ __device__ key_pair() {} + __host__ __device__ key_pair(T x) : a{x}, b{x} {} + + // Device equality operator is mandatory due to libcudacxx bug: + // https://github.com/NVIDIA/libcudacxx/issues/223 + __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } +}; + +// probe key type +template +struct key_triplet { + T a; + T b; + T c; + + __host__ __device__ key_triplet() {} + __host__ __device__ key_triplet(T x) : a{x}, b{x}, c{x} {} + + // Device equality operator is mandatory due to libcudacxx bug: + // https://github.com/NVIDIA/libcudacxx/issues/223 + __device__ bool operator==(key_triplet const& other) const + { + return a == other.a and b == other.b and c == other.c; + } +}; + +// User-defined device hasher +struct custom_hasher { + template + __device__ uint32_t operator()(CustomKey const& k) + { + return thrust::raw_reference_cast(k).a; + }; +}; + +// User-defined device key equality +struct custom_key_equal { + template + __device__ bool operator()(LHS const& lhs, RHS const& rhs) + { + return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a; + } +}; + +TEMPLATE_TEST_CASE_SIG("User defined key and value type", + "", + ((typename Key, typename Value), Key, Value), +#ifndef CUCO_NO_INDEPENDENT_THREADS // Key type larger than 8B only supported for sm_70 and up + (key_pair, int64_t), +#endif + (key_pair, int32_t)) +{ + auto const sentinel_key = Key{-1}; + auto const sentinel_value = Value{-1}; + + constexpr std::size_t num = 100; + constexpr std::size_t capacity = num * 2; + cuco::static_map map{capacity, + cuco::sentinel::empty_key{sentinel_key}, + cuco::sentinel::empty_value{sentinel_value}}; + + thrust::device_vector> probe_keys(num); + + thrust::transform(thrust::device, + thrust::counting_iterator(0), + thrust::counting_iterator(num), + probe_keys.begin(), + [] __device__(int i) { return key_triplet{i}; }); + + auto insert_keys = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + [] __device__(auto i) { return cuco::pair_type(i, i); }); + + SECTION("All inserted keys-value pairs should be contained") + { + thrust::device_vector contained(num); + map.insert(insert_keys, insert_keys + num, custom_hasher{}, custom_key_equal{}); + map.contains( + probe_keys.begin(), probe_keys.end(), contained.begin(), custom_hasher{}, custom_key_equal{}); + REQUIRE(cuco::test::all_of( + contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + } + + SECTION("Non-inserted keys-value pairs should not be contained") + { + thrust::device_vector contained(num); + map.contains( + probe_keys.begin(), probe_keys.end(), contained.begin(), custom_hasher{}, custom_key_equal{}); + REQUIRE(cuco::test::none_of( + contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + } +} diff --git a/tests/static_multimap/custom_type_test.cu b/tests/static_multimap/custom_type_test.cu index cb3136ce2..ff8eed7ac 100644 --- a/tests/static_multimap/custom_type_test.cu +++ b/tests/static_multimap/custom_type_test.cu @@ -195,7 +195,7 @@ __inline__ void test_custom_key_value_type(Map& map, std::size_t num_pairs) REQUIRE(size == num_pairs); thrust::device_vector contained(num_pairs); - map.contains(key_begin, key_begin + num_pairs, contained.begin(), stream, key_pair_equals{}); + map.contains(key_begin, key_begin + num_pairs, contained.begin(), key_pair_equals{}, stream); REQUIRE(cuco::test::all_of( contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); } @@ -206,7 +206,7 @@ __inline__ void test_custom_key_value_type(Map& map, std::size_t num_pairs) REQUIRE(size == 0); thrust::device_vector contained(num_pairs); - map.contains(key_begin, key_begin + num_pairs, contained.begin(), stream, key_pair_equals{}); + map.contains(key_begin, key_begin + num_pairs, contained.begin(), key_pair_equals{}, stream); REQUIRE(cuco::test::none_of( contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); diff --git a/tests/static_multimap/heterogeneous_lookup_test.cu b/tests/static_multimap/heterogeneous_lookup_test.cu new file mode 100644 index 000000000..0e3c4afc4 --- /dev/null +++ b/tests/static_multimap/heterogeneous_lookup_test.cu @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include + +// insert key type +template +struct key_pair { + T a; + T b; + + __host__ __device__ key_pair() {} + __host__ __device__ key_pair(T x) : a{x}, b{x} {} + + // Device equality operator is mandatory due to libcudacxx bug: + // https://github.com/NVIDIA/libcudacxx/issues/223 + __device__ bool operator==(key_pair const& other) const { return a == other.a and b == other.b; } +}; + +// probe key type +template +struct key_triplet { + T a; + T b; + T c; + + __host__ __device__ key_triplet() {} + __host__ __device__ key_triplet(T x) : a{x}, b{x}, c{x} {} + + // Device equality operator is mandatory due to libcudacxx bug: + // https://github.com/NVIDIA/libcudacxx/issues/223 + __device__ bool operator==(key_triplet const& other) const + { + return a == other.a and b == other.b and c == other.c; + } +}; + +// User-defined device hasher +struct custom_hasher { + template + __device__ uint32_t operator()(CustomKey const& k) + { + return thrust::raw_reference_cast(k).a; + }; +}; + +// User-defined device key equality +struct custom_key_equal { + template + __device__ bool operator()(LHS const& lhs, RHS const& rhs) + { + return thrust::raw_reference_cast(lhs).a == thrust::raw_reference_cast(rhs).a; + } +}; + +TEMPLATE_TEST_CASE_SIG("User defined key and value type", + "", + ((typename Key, typename Value), Key, Value), +#ifndef CUCO_NO_INDEPENDENT_THREADS // Key type larger than 8B only supported for sm_70 and up + (key_pair, int64_t), +#endif + (key_pair, int32_t)) +{ + auto const sentinel_key = Key{-1}; + auto const sentinel_value = Value{-1}; + + constexpr std::size_t num = 100; + constexpr std::size_t capacity = num * 2; + cuco::static_multimap, + cuco::linear_probing<1, custom_hasher>> + map{capacity, + cuco::sentinel::empty_key{sentinel_key}, + cuco::sentinel::empty_value{sentinel_value}}; + + thrust::device_vector> probe_keys(num); + + thrust::transform(thrust::device, + thrust::counting_iterator(0), + thrust::counting_iterator(num), + probe_keys.begin(), + [] __device__(int i) { return key_triplet{i}; }); + + auto insert_keys = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + [] __device__(auto i) { return cuco::pair_type(i, i); }); + + SECTION("All inserted keys-value pairs should be contained") + { + thrust::device_vector contained(num); + map.insert(insert_keys, insert_keys + num); + map.contains(probe_keys.begin(), probe_keys.end(), contained.begin(), custom_key_equal{}); + REQUIRE(cuco::test::all_of( + contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + } + + SECTION("Non-inserted keys-value pairs should not be contained") + { + thrust::device_vector contained(num); + map.contains(probe_keys.begin(), probe_keys.end(), contained.begin(), custom_key_equal{}); + REQUIRE(cuco::test::none_of( + contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + } +} From c1ba451beebea0c82530cfbdfa3192771ab4f24c Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 7 Jun 2022 15:26:10 -0400 Subject: [PATCH 08/15] Cleanups: use make_transform_iterator --- tests/static_map/heterogeneous_lookup_test.cu | 16 +++++----------- .../static_multimap/heterogeneous_lookup_test.cu | 16 +++++----------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu index ba31eaf8d..4d1eef960 100644 --- a/tests/static_map/heterogeneous_lookup_test.cu +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -95,24 +95,18 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", cuco::sentinel::empty_key{sentinel_key}, cuco::sentinel::empty_value{sentinel_value}}; - thrust::device_vector> probe_keys(num); - - thrust::transform(thrust::device, - thrust::counting_iterator(0), - thrust::counting_iterator(num), - probe_keys.begin(), - [] __device__(int i) { return key_triplet{i}; }); - auto insert_keys = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), + thrust::counting_iterator(0), [] __device__(auto i) { return cuco::pair_type(i, i); }); + auto probe_keys = thrust::make_transform_iterator( + thrust::counting_iterator(0), [] __device__(auto i) { return key_triplet(i); }); SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector contained(num); map.insert(insert_keys, insert_keys + num, custom_hasher{}, custom_key_equal{}); map.contains( - probe_keys.begin(), probe_keys.end(), contained.begin(), custom_hasher{}, custom_key_equal{}); + probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); REQUIRE(cuco::test::all_of( contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); } @@ -121,7 +115,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", { thrust::device_vector contained(num); map.contains( - probe_keys.begin(), probe_keys.end(), contained.begin(), custom_hasher{}, custom_key_equal{}); + probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); REQUIRE(cuco::test::none_of( contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); } diff --git a/tests/static_multimap/heterogeneous_lookup_test.cu b/tests/static_multimap/heterogeneous_lookup_test.cu index 0e3c4afc4..a70199616 100644 --- a/tests/static_multimap/heterogeneous_lookup_test.cu +++ b/tests/static_multimap/heterogeneous_lookup_test.cu @@ -100,23 +100,17 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", cuco::sentinel::empty_key{sentinel_key}, cuco::sentinel::empty_value{sentinel_value}}; - thrust::device_vector> probe_keys(num); - - thrust::transform(thrust::device, - thrust::counting_iterator(0), - thrust::counting_iterator(num), - probe_keys.begin(), - [] __device__(int i) { return key_triplet{i}; }); - auto insert_keys = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), + thrust::counting_iterator(0), [] __device__(auto i) { return cuco::pair_type(i, i); }); + auto probe_keys = thrust::make_transform_iterator( + thrust::counting_iterator(0), [] __device__(auto i) { return key_triplet(i); }); SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector contained(num); map.insert(insert_keys, insert_keys + num); - map.contains(probe_keys.begin(), probe_keys.end(), contained.begin(), custom_key_equal{}); + map.contains(probe_keys, probe_keys + num, contained.begin(), custom_key_equal{}); REQUIRE(cuco::test::all_of( contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); } @@ -124,7 +118,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", SECTION("Non-inserted keys-value pairs should not be contained") { thrust::device_vector contained(num); - map.contains(probe_keys.begin(), probe_keys.end(), contained.begin(), custom_key_equal{}); + map.contains(probe_keys, probe_keys + num, contained.begin(), custom_key_equal{}); REQUIRE(cuco::test::none_of( contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); } From 55bba1fc6028afb230a2df43cb3fa428940efc3a Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Tue, 7 Jun 2022 15:42:29 -0400 Subject: [PATCH 09/15] Update CI build script --- ci/gpu/build.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 7ab4e5c5b..8ae26bcf4 100644 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -37,8 +37,6 @@ conda activate cuda gpuci_logger "Check versions" python --version -$CC --version -$CXX --version gpuci_logger "Check conda environment" conda info From faedd758516d1e3bfae9f0ba213511d9a8a86258 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Wed, 8 Jun 2022 17:43:31 -0400 Subject: [PATCH 10/15] Add more static_assert + variable renaming --- include/cuco/detail/static_map.inl | 4 ++++ include/cuco/detail/static_multimap/static_multimap.inl | 2 ++ tests/static_map/heterogeneous_lookup_test.cu | 4 ++-- tests/static_multimap/heterogeneous_lookup_test.cu | 4 ++-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl index 12087f875..ef9768175 100644 --- a/include/cuco/detail/static_map.inl +++ b/include/cuco/detail/static_map.inl @@ -721,6 +721,8 @@ __device__ bool static_map::device_view::contains( { static_assert(std::is_invocable_r_v, "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); + static_assert(std::is_invocable_r_v, + "KeyEqual(Key{}, ProbeKey{}) must be a valid callable."); static_assert(std::is_invocable_r_v, "Hash(Key{}) must be a valid callable."); static_assert(std::is_invocable_r_v, @@ -749,6 +751,8 @@ static_map::device_view::contains(CG const& g, { static_assert(std::is_invocable_r_v, "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); + static_assert(std::is_invocable_r_v, + "KeyEqual(Key{}, ProbeKey{}) must be a valid callable."); static_assert(std::is_invocable_r_v, "Hash(Key{}) must be a valid callable."); static_assert(std::is_invocable_r_v, diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index 6119907c3..5d6d30623 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -545,6 +545,8 @@ static_multimap::device_view::conta { static_assert(std::is_invocable_r_v, "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); + static_assert(std::is_invocable_r_v, + "KeyEqual(Key{}, ProbeKey{}) must be a valid callable."); if constexpr (ProbeSequence::is_linear_probing) { static_assert(std::is_invocable_r_v, diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu index 4d1eef960..d0c0e9834 100644 --- a/tests/static_map/heterogeneous_lookup_test.cu +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -95,7 +95,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", cuco::sentinel::empty_key{sentinel_key}, cuco::sentinel::empty_value{sentinel_value}}; - auto insert_keys = thrust::make_transform_iterator( + auto insert_pairs = thrust::make_transform_iterator( thrust::counting_iterator(0), [] __device__(auto i) { return cuco::pair_type(i, i); }); auto probe_keys = thrust::make_transform_iterator( @@ -104,7 +104,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector contained(num); - map.insert(insert_keys, insert_keys + num, custom_hasher{}, custom_key_equal{}); + map.insert(insert_pairs, insert_pairs + num, custom_hasher{}, custom_key_equal{}); map.contains( probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); REQUIRE(cuco::test::all_of( diff --git a/tests/static_multimap/heterogeneous_lookup_test.cu b/tests/static_multimap/heterogeneous_lookup_test.cu index a70199616..d0fdadcb0 100644 --- a/tests/static_multimap/heterogeneous_lookup_test.cu +++ b/tests/static_multimap/heterogeneous_lookup_test.cu @@ -100,7 +100,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", cuco::sentinel::empty_key{sentinel_key}, cuco::sentinel::empty_value{sentinel_value}}; - auto insert_keys = thrust::make_transform_iterator( + auto insert_pairs = thrust::make_transform_iterator( thrust::counting_iterator(0), [] __device__(auto i) { return cuco::pair_type(i, i); }); auto probe_keys = thrust::make_transform_iterator( @@ -109,7 +109,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector contained(num); - map.insert(insert_keys, insert_keys + num); + map.insert(insert_pairs, insert_pairs + num); map.contains(probe_keys, probe_keys + num, contained.begin(), custom_key_equal{}); REQUIRE(cuco::test::all_of( contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); From e4341bf390438e05351031a83cf8448abd1f0bd4 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Sun, 12 Jun 2022 12:02:35 -0400 Subject: [PATCH 11/15] Update docs --- .../cuco/detail/static_multimap/device_view_impl.inl | 2 +- include/cuco/static_map.cuh | 11 +++++------ include/cuco/static_multimap.cuh | 3 +-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index 57797f6d5..8089d6149 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -627,7 +627,7 @@ class static_multimap::device_view_ * * @tparam uses_vector_load Boolean flag indicating whether vector loads are used * @tparam CG Cooperative Group type - * @tparam ProbeKey Probe key type that is convertible to the map's `key_type` + * @tparam ProbeKey Probe key type * @tparam KeyEqual Binary callable type * * @param g The Cooperative Group used to perform the contains operation diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 37d9568af..321622477 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -380,17 +380,16 @@ class static_map { cudaStream_t stream = 0); /** - * @brief Indicates whether the keys in the range - * `[first, last)` are contained in the map. + * @brief Indicates whether the keys in the range `[first, last)` are contained in the map. * * Writes a `bool` to `(output + i)` indicating if the key `*(first + i)` exists in the map. * - * @tparam InputIt Device accessible input iterator whose `value_type` is - * convertible to the map's `key_type` - * @tparam OutputIt Device accessible output iterator whose `value_type` is - * convertible to the map's `mapped_type` + * @tparam InputIt Device accessible input iterator + * @tparam OutputIt Device accessible output iterator whose `value_type` is convertible from + * `bool` * @tparam Hash Unary callable type * @tparam KeyEqual Binary callable type + * * @param first Beginning of the sequence of keys * @param last End of the sequence of keys * @param output_begin Beginning of the sequence of booleans for the presence of each key diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 9273f7e7d..40a1f2fec 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -272,8 +272,7 @@ class static_multimap { * Stores `true` or `false` to `(output + i)` indicating if the key `*(first + i)` exists in the * map. * - * @tparam InputIt Device accessible input iterator whose `value_type` is - * convertible to the map's `key_type` + * @tparam InputIt Device accessible input iterator * @tparam OutputIt Device accessible output iterator whose `value_type` is convertible from * `bool` * @tparam KeyEqual Binary callable type used to compare two keys for equality From c3fc8dedebb47f8fa1212076d7689d6417f538c5 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 13 Jun 2022 11:28:45 -0400 Subject: [PATCH 12/15] Cleanups --- .../static_multimap/device_view_impl.inl | 31 +++++++++---------- include/cuco/static_multimap.cuh | 1 - 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index 8089d6149..d1dc017a3 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,15 +15,13 @@ */ #include +#include #include #include #include -#include - namespace cuco { - template ::device_view_ * @return Pointer to the initial slot for `k` */ template - __device__ __forceinline__ iterator - initial_slot(cooperative_groups::thread_block_tile const& g, - ProbeKey const& k) noexcept + __device__ __forceinline__ iterator initial_slot( + detail::cg::thread_block_tile const& g, ProbeKey const& k) noexcept { return probe_sequence_.initial_slot(g, k); } @@ -98,7 +95,7 @@ class static_multimap::device_view_ */ template __device__ __forceinline__ const_iterator - initial_slot(cooperative_groups::thread_block_tile const& g, + initial_slot(detail::cg::thread_block_tile const& g, ProbeKey const& k) const noexcept { return probe_sequence_.initial_slot(g, k); @@ -500,13 +497,13 @@ class static_multimap::device_view_ if constexpr (thrust::is_contiguous_iterator_v) { #if defined(CUCO_HAS_CG_MEMCPY_ASYNC) #if defined(CUCO_HAS_CUDA_BARRIER) - cooperative_groups::memcpy_async( + detail::cg::memcpy_async( g, output_begin + offset, output_buffer, cuda::aligned_size_t(sizeof(value_type) * num_outputs)); #else - cooperative_groups::memcpy_async( + detail::cg::memcpy_async( g, output_begin + offset, output_buffer, sizeof(value_type) * num_outputs); #endif // end CUCO_HAS_CUDA_BARRIER return; @@ -576,7 +573,6 @@ class static_multimap::device_view_ * `contains` at moderate to high load factors. * * @tparam uses_vector_load Boolean flag indicating whether vector loads are used - * @tparam CG Cooperative Group type * @tparam ProbeKey Probe key type * @tparam KeyEqual Binary callable type * @@ -587,9 +583,11 @@ class static_multimap::device_view_ * @return A boolean indicating whether the key/value pair * containing `k` was inserted */ - template + template __device__ __forceinline__ std::enable_if_t contains( - CG const& g, ProbeKey const& k, KeyEqual key_equal) noexcept + detail::cg::thread_block_tile const& g, + ProbeKey const& k, + KeyEqual key_equal) noexcept { auto current_slot = initial_slot(g, k); @@ -626,7 +624,6 @@ class static_multimap::device_view_ * `contains` at moderate to high load factors. * * @tparam uses_vector_load Boolean flag indicating whether vector loads are used - * @tparam CG Cooperative Group type * @tparam ProbeKey Probe key type * @tparam KeyEqual Binary callable type * @@ -637,9 +634,11 @@ class static_multimap::device_view_ * @return A boolean indicating whether the key/value pair * containing `k` was inserted */ - template + template __device__ __forceinline__ std::enable_if_t contains( - CG const& g, ProbeKey const& k, KeyEqual key_equal) noexcept + detail::cg::thread_block_tile const& g, + ProbeKey const& k, + KeyEqual key_equal) noexcept { auto current_slot = initial_slot(g, k); diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 40a1f2fec..1f4feb51e 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include From 111734aa762169de662049f11f97322f9302b908 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 13 Jun 2022 11:52:52 -0400 Subject: [PATCH 13/15] Clean up tests by using thrust::identity --- tests/dynamic_map/unique_sequence_test.cu | 7 +++--- tests/static_map/custom_type_test.cu | 7 +++--- tests/static_map/duplicate_keys_test.cu | 10 ++++---- tests/static_map/heterogeneous_lookup_test.cu | 24 ++++++++++--------- tests/static_map/stream_test.cu | 3 +-- tests/static_map/unique_sequence_test.cu | 7 +++--- tests/static_multimap/custom_type_test.cu | 7 +++--- .../heterogeneous_lookup_test.cu | 24 ++++++++++--------- tests/static_multimap/multiplicity_test.cu | 7 +++--- 9 files changed, 46 insertions(+), 50 deletions(-) diff --git a/tests/dynamic_map/unique_sequence_test.cu b/tests/dynamic_map/unique_sequence_test.cu index 1306daac3..de26bb3dc 100644 --- a/tests/dynamic_map/unique_sequence_test.cu +++ b/tests/dynamic_map/unique_sequence_test.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -78,15 +79,13 @@ TEMPLATE_TEST_CASE_SIG("Unique sequence of keys", map.insert(pairs_begin, pairs_begin + num_keys); map.contains(d_keys.begin(), d_keys.end(), d_contained.begin()); - REQUIRE(cuco::test::all_of( - d_contained.begin(), d_contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } SECTION("Non-inserted keys-value pairs should not be contained") { map.contains(d_keys.begin(), d_keys.end(), d_contained.begin()); - REQUIRE(cuco::test::none_of( - d_contained.begin(), d_contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } } diff --git a/tests/static_map/custom_type_test.cu b/tests/static_map/custom_type_test.cu index 801618d6e..c32a1bdc4 100644 --- a/tests/static_map/custom_type_test.cu +++ b/tests/static_map/custom_type_test.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -164,8 +165,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", contained.begin(), hash_custom_key{}, custom_key_equals{}); - REQUIRE(cuco::test::all_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } SECTION("All conditionally inserted keys-value pairs should be contained") @@ -203,8 +203,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", contained.begin(), hash_custom_key{}, custom_key_equals{}); - REQUIRE(cuco::test::none_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{})); } SECTION("All inserted keys-value pairs should be contained") diff --git a/tests/static_map/duplicate_keys_test.cu b/tests/static_map/duplicate_keys_test.cu index 014cf0e19..34a315a1c 100644 --- a/tests/static_map/duplicate_keys_test.cu +++ b/tests/static_map/duplicate_keys_test.cu @@ -88,12 +88,10 @@ TEMPLATE_TEST_CASE_SIG("Duplicate keys", map.insert(pairs_begin, pairs_begin + num_keys); map.contains(d_keys.begin(), d_keys.end(), d_contained.begin()); - REQUIRE(cuco::test::all_of(d_contained.begin(), - d_contained.begin() + num_keys / 2, - [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of( + d_contained.begin(), d_contained.begin() + num_keys / 2, thrust::identity{})); - REQUIRE(cuco::test::none_of(d_contained.begin() + num_keys / 2, - d_contained.end(), - [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of( + d_contained.begin() + num_keys / 2, d_contained.end(), thrust::identity{})); } } diff --git a/tests/static_map/heterogeneous_lookup_test.cu b/tests/static_map/heterogeneous_lookup_test.cu index d0c0e9834..905d6a898 100644 --- a/tests/static_map/heterogeneous_lookup_test.cu +++ b/tests/static_map/heterogeneous_lookup_test.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -78,14 +79,17 @@ struct custom_key_equal { } }; -TEMPLATE_TEST_CASE_SIG("User defined key and value type", - "", - ((typename Key, typename Value), Key, Value), +TEMPLATE_TEST_CASE("Heterogeneous lookup", + "", #ifndef CUCO_NO_INDEPENDENT_THREADS // Key type larger than 8B only supported for sm_70 and up - (key_pair, int64_t), + int64_t, #endif - (key_pair, int32_t)) + int32_t) { + using Key = key_pair; + using Value = TestType; + using ProbeKey = key_triplet; + auto const sentinel_key = Key{-1}; auto const sentinel_value = Value{-1}; @@ -98,8 +102,8 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", auto insert_pairs = thrust::make_transform_iterator( thrust::counting_iterator(0), [] __device__(auto i) { return cuco::pair_type(i, i); }); - auto probe_keys = thrust::make_transform_iterator( - thrust::counting_iterator(0), [] __device__(auto i) { return key_triplet(i); }); + auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), + [] __device__(auto i) { return ProbeKey(i); }); SECTION("All inserted keys-value pairs should be contained") { @@ -107,8 +111,7 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", map.insert(insert_pairs, insert_pairs + num, custom_hasher{}, custom_key_equal{}); map.contains( probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); - REQUIRE(cuco::test::all_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } SECTION("Non-inserted keys-value pairs should not be contained") @@ -116,7 +119,6 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", thrust::device_vector contained(num); map.contains( probe_keys, probe_keys + num, contained.begin(), custom_hasher{}, custom_key_equal{}); - REQUIRE(cuco::test::none_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{})); } } diff --git a/tests/static_map/stream_test.cu b/tests/static_map/stream_test.cu index 639701764..5f816410e 100644 --- a/tests/static_map/stream_test.cu +++ b/tests/static_map/stream_test.cu @@ -84,8 +84,7 @@ TEMPLATE_TEST_CASE_SIG("Unique sequence of keys on given stream", map.insert(pairs_begin, pairs_begin + num_keys, hash_fn, equal_fn, stream); map.contains(d_keys.begin(), d_keys.end(), d_contained.begin(), hash_fn, equal_fn, stream); - REQUIRE(cuco::test::all_of( - d_contained.begin(), d_contained.end(), [] __device__(bool const& b) { return b; }, stream)); + REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{}, stream)); } cudaStreamDestroy(stream); diff --git a/tests/static_map/unique_sequence_test.cu b/tests/static_map/unique_sequence_test.cu index 17ccc024a..75bb67d61 100644 --- a/tests/static_map/unique_sequence_test.cu +++ b/tests/static_map/unique_sequence_test.cu @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -74,16 +75,14 @@ TEMPLATE_TEST_CASE_SIG("Unique sequence of keys", map.insert(pairs_begin, pairs_begin + num_keys); map.contains(d_keys.begin(), d_keys.end(), d_contained.begin()); - REQUIRE(cuco::test::all_of( - d_contained.begin(), d_contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } SECTION("Non-inserted keys-value pairs should not be contained") { map.contains(d_keys.begin(), d_keys.end(), d_contained.begin()); - REQUIRE(cuco::test::none_of( - d_contained.begin(), d_contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } SECTION("Inserting unique keys should return insert success.") diff --git a/tests/static_multimap/custom_type_test.cu b/tests/static_multimap/custom_type_test.cu index ff8eed7ac..a12e4ec9a 100644 --- a/tests/static_multimap/custom_type_test.cu +++ b/tests/static_multimap/custom_type_test.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -196,8 +197,7 @@ __inline__ void test_custom_key_value_type(Map& map, std::size_t num_pairs) thrust::device_vector contained(num_pairs); map.contains(key_begin, key_begin + num_pairs, contained.begin(), key_pair_equals{}, stream); - REQUIRE(cuco::test::all_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } SECTION("Non-inserted keys-value pairs should not be contained") @@ -208,8 +208,7 @@ __inline__ void test_custom_key_value_type(Map& map, std::size_t num_pairs) thrust::device_vector contained(num_pairs); map.contains(key_begin, key_begin + num_pairs, contained.begin(), key_pair_equals{}, stream); - REQUIRE(cuco::test::none_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{})); } } diff --git a/tests/static_multimap/heterogeneous_lookup_test.cu b/tests/static_multimap/heterogeneous_lookup_test.cu index d0fdadcb0..6283af7ce 100644 --- a/tests/static_multimap/heterogeneous_lookup_test.cu +++ b/tests/static_multimap/heterogeneous_lookup_test.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -78,14 +79,17 @@ struct custom_key_equal { } }; -TEMPLATE_TEST_CASE_SIG("User defined key and value type", - "", - ((typename Key, typename Value), Key, Value), +TEMPLATE_TEST_CASE("Heterogeneous lookup", + "", #ifndef CUCO_NO_INDEPENDENT_THREADS // Key type larger than 8B only supported for sm_70 and up - (key_pair, int64_t), + int64_t, #endif - (key_pair, int32_t)) + int32_t) { + using Key = key_pair; + using Value = TestType; + using ProbeKey = key_triplet; + auto const sentinel_key = Key{-1}; auto const sentinel_value = Value{-1}; @@ -103,23 +107,21 @@ TEMPLATE_TEST_CASE_SIG("User defined key and value type", auto insert_pairs = thrust::make_transform_iterator( thrust::counting_iterator(0), [] __device__(auto i) { return cuco::pair_type(i, i); }); - auto probe_keys = thrust::make_transform_iterator( - thrust::counting_iterator(0), [] __device__(auto i) { return key_triplet(i); }); + auto probe_keys = thrust::make_transform_iterator(thrust::counting_iterator(0), + [] __device__(auto i) { return ProbeKey(i); }); SECTION("All inserted keys-value pairs should be contained") { thrust::device_vector contained(num); map.insert(insert_pairs, insert_pairs + num); map.contains(probe_keys, probe_keys + num, contained.begin(), custom_key_equal{}); - REQUIRE(cuco::test::all_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of(contained.begin(), contained.end(), thrust::identity{})); } SECTION("Non-inserted keys-value pairs should not be contained") { thrust::device_vector contained(num); map.contains(probe_keys, probe_keys + num, contained.begin(), custom_key_equal{}); - REQUIRE(cuco::test::none_of( - contained.begin(), contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of(contained.begin(), contained.end(), thrust::identity{})); } } diff --git a/tests/static_multimap/multiplicity_test.cu b/tests/static_multimap/multiplicity_test.cu index 8039c3a0e..3f5581b03 100644 --- a/tests/static_multimap/multiplicity_test.cu +++ b/tests/static_multimap/multiplicity_test.cu @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -61,8 +62,7 @@ __inline__ void test_multiplicity_two(Map& map, std::size_t num_items) REQUIRE(size == 0); map.contains(key_begin, key_begin + num_keys, d_contained.begin()); - REQUIRE(cuco::test::none_of( - d_contained.begin(), d_contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::none_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } map.insert(pair_begin, pair_begin + num_items); @@ -74,8 +74,7 @@ __inline__ void test_multiplicity_two(Map& map, std::size_t num_items) map.contains(key_begin, key_begin + num_keys, d_contained.begin()); - REQUIRE(cuco::test::all_of( - d_contained.begin(), d_contained.end(), [] __device__(bool const& b) { return b; })); + REQUIRE(cuco::test::all_of(d_contained.begin(), d_contained.end(), thrust::identity{})); } SECTION("Total count should be equal to the number of inserted pairs.") From b35f7b6f132c5375cfe25668d08c79f97670feb8 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 13 Jun 2022 12:25:38 -0400 Subject: [PATCH 14/15] Remove static_assert and add pre-conditions in docs instead --- include/cuco/detail/static_map.inl | 18 ------------- .../static_multimap/static_multimap.inl | 26 ------------------- include/cuco/probe_sequences.cuh | 7 ----- include/cuco/static_map.cuh | 10 +++++++ include/cuco/static_multimap.cuh | 7 +++++ 5 files changed, 17 insertions(+), 51 deletions(-) diff --git a/include/cuco/detail/static_map.inl b/include/cuco/detail/static_map.inl index ef9768175..9eb31db14 100644 --- a/include/cuco/detail/static_map.inl +++ b/include/cuco/detail/static_map.inl @@ -719,15 +719,6 @@ template __device__ bool static_map::device_view::contains( ProbeKey const& k, Hash hash, KeyEqual key_equal) const noexcept { - static_assert(std::is_invocable_r_v, - "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); - static_assert(std::is_invocable_r_v, - "KeyEqual(Key{}, ProbeKey{}) must be a valid callable."); - static_assert(std::is_invocable_r_v, - "Hash(Key{}) must be a valid callable."); - static_assert(std::is_invocable_r_v, - "Hash(ProbeKey{}) must be a valid callable."); - auto current_slot = initial_slot(k, hash); while (true) { @@ -749,15 +740,6 @@ static_map::device_view::contains(CG const& g, Hash hash, KeyEqual key_equal) const noexcept { - static_assert(std::is_invocable_r_v, - "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); - static_assert(std::is_invocable_r_v, - "KeyEqual(Key{}, ProbeKey{}) must be a valid callable."); - static_assert(std::is_invocable_r_v, - "Hash(Key{}) must be a valid callable."); - static_assert(std::is_invocable_r_v, - "Hash(ProbeKey{}) must be a valid callable."); - auto current_slot = initial_slot(g, k, hash); while (true) { diff --git a/include/cuco/detail/static_multimap/static_multimap.inl b/include/cuco/detail/static_multimap/static_multimap.inl index 5d6d30623..89edb434e 100644 --- a/include/cuco/detail/static_multimap/static_multimap.inl +++ b/include/cuco/detail/static_multimap/static_multimap.inl @@ -543,32 +543,6 @@ static_multimap::device_view::conta ProbeKey const& k, KeyEqual key_equal) noexcept { - static_assert(std::is_invocable_r_v, - "KeyEqual(ProbeKey{}, Key{}) must be a valid callable."); - static_assert(std::is_invocable_r_v, - "KeyEqual(Key{}, ProbeKey{}) must be a valid callable."); - - if constexpr (ProbeSequence::is_linear_probing) { - static_assert(std::is_invocable_r_v, - "ProbeSequence::hasher(Key{}) must be a valid callable."); - static_assert( - std::is_invocable_r_v, - "ProbeSequence::hasher(ProbeKey{}) must be a valid callable."); - } else { - static_assert( - std::is_invocable_r_v, - "ProbeSequence::hasher1(Key{}) must be a valid callable."); - static_assert( - std::is_invocable_r_v, - "ProbeSequence::hasher2(Key{}) must be a valid callable."); - static_assert( - std::is_invocable_r_v, - "ProbeSequence::hasher1(ProbeKey{}) must be a valid callable."); - static_assert( - std::is_invocable_r_v, - "ProbeSequence::hasher2(ProbeKey{}) must be a valid callable."); - } - return impl_.contains(g, k, key_equal); } diff --git a/include/cuco/probe_sequences.cuh b/include/cuco/probe_sequences.cuh index 2b79a9cf0..f923f9df6 100644 --- a/include/cuco/probe_sequences.cuh +++ b/include/cuco/probe_sequences.cuh @@ -35,12 +35,9 @@ namespace cuco { template class linear_probing : public detail::probe_sequence_base { public: - static constexpr bool is_linear_probing = true; - using probe_sequence_base_type = detail::probe_sequence_base; using probe_sequence_base_type::cg_size; using probe_sequence_base_type::vector_width; - using hasher = Hash; template using impl = detail::linear_probing_impl; @@ -64,13 +61,9 @@ class linear_probing : public detail::probe_sequence_base { template class double_hashing : public detail::probe_sequence_base { public: - static constexpr bool is_linear_probing = false; - using probe_sequence_base_type = detail::probe_sequence_base; using probe_sequence_base_type::cg_size; using probe_sequence_base_type::vector_width; - using hasher1 = Hash1; - using hasher2 = Hash2; template using impl = detail::double_hashing_impl; diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh index 321622477..dfb75cc49 100644 --- a/include/cuco/static_map.cuh +++ b/include/cuco/static_map.cuh @@ -384,6 +384,10 @@ class static_map { * * Writes a `bool` to `(output + i)` indicating if the key `*(first + i)` exists in the map. * + * Hash should be callable with both `std::iterator_traits::value_type` and Key type. + * `std::invoke_result::value_type, Key>` must be + * well-formed. + * * @tparam InputIt Device accessible input iterator * @tparam OutputIt Device accessible output iterator whose `value_type` is convertible from * `bool` @@ -1218,6 +1222,9 @@ class static_map { * If the key `k` was inserted into the map, find returns * true. Otherwise, it returns false. * + * Hash should be callable with both ProbeKey and Key type. `std::invoke_result` must be well-formed. + * * If `key_equal(probe_key, slot_key)` returns true, `hash(probe_key) == hash(slot_key)` must * also be true. * @@ -1247,6 +1254,9 @@ class static_map { * contains operation. This provides a significant boost in throughput compared to the non * Cooperative Group `contains` at moderate to high load factors. * + * Hash should be callable with both ProbeKey and Key type. `std::invoke_result` must be well-formed. + * * If `key_equal(probe_key, slot_key)` returns true, `hash(probe_key) == hash(slot_key)` must * also be true. * diff --git a/include/cuco/static_multimap.cuh b/include/cuco/static_multimap.cuh index 1f4feb51e..e69f2afb2 100644 --- a/include/cuco/static_multimap.cuh +++ b/include/cuco/static_multimap.cuh @@ -271,6 +271,10 @@ class static_multimap { * Stores `true` or `false` to `(output + i)` indicating if the key `*(first + i)` exists in the * map. * + * ProbeSequence hashers should be callable with both `std::iterator_traits::value_type` + * and Key type. `std::invoke_result::value_type, Key>` + * must be well-formed. + * * @tparam InputIt Device accessible input iterator * @tparam OutputIt Device accessible output iterator whose `value_type` is convertible from * `bool` @@ -819,6 +823,9 @@ class static_multimap { * significant boost in throughput compared to the non Cooperative Group * `contains` at moderate to high load factors. * + * ProbeSequence hashers should be callable with both ProbeKey and Key type. + * `std::invoke_result` must be well-formed. + * * If `key_equal(probe_key, slot_key)` returns true, `hash(probe_key) == hash(slot_key)` must * also be true. * From d19d09d5c273446b3cf3dc2cff16647014a865a0 Mon Sep 17 00:00:00 2001 From: Yunsong Wang Date: Mon, 13 Jun 2022 16:32:48 -0400 Subject: [PATCH 15/15] Revert cg aliases --- .../detail/static_multimap/device_view_impl.inl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/include/cuco/detail/static_multimap/device_view_impl.inl b/include/cuco/detail/static_multimap/device_view_impl.inl index d1dc017a3..d0c9bc7b9 100644 --- a/include/cuco/detail/static_multimap/device_view_impl.inl +++ b/include/cuco/detail/static_multimap/device_view_impl.inl @@ -21,6 +21,8 @@ #include #include +#include + namespace cuco { template ::device_view_ * @return Pointer to the initial slot for `k` */ template - __device__ __forceinline__ iterator initial_slot( - detail::cg::thread_block_tile const& g, ProbeKey const& k) noexcept + __device__ __forceinline__ iterator + initial_slot(cooperative_groups::thread_block_tile const& g, + ProbeKey const& k) noexcept { return probe_sequence_.initial_slot(g, k); } @@ -95,7 +98,7 @@ class static_multimap::device_view_ */ template __device__ __forceinline__ const_iterator - initial_slot(detail::cg::thread_block_tile const& g, + initial_slot(cooperative_groups::thread_block_tile const& g, ProbeKey const& k) const noexcept { return probe_sequence_.initial_slot(g, k); @@ -497,13 +500,13 @@ class static_multimap::device_view_ if constexpr (thrust::is_contiguous_iterator_v) { #if defined(CUCO_HAS_CG_MEMCPY_ASYNC) #if defined(CUCO_HAS_CUDA_BARRIER) - detail::cg::memcpy_async( + cooperative_groups::memcpy_async( g, output_begin + offset, output_buffer, cuda::aligned_size_t(sizeof(value_type) * num_outputs)); #else - detail::cg::memcpy_async( + cooperative_groups::memcpy_async( g, output_begin + offset, output_buffer, sizeof(value_type) * num_outputs); #endif // end CUCO_HAS_CUDA_BARRIER return; @@ -585,7 +588,7 @@ class static_multimap::device_view_ */ template __device__ __forceinline__ std::enable_if_t contains( - detail::cg::thread_block_tile const& g, + cooperative_groups::thread_block_tile const& g, ProbeKey const& k, KeyEqual key_equal) noexcept { @@ -636,7 +639,7 @@ class static_multimap::device_view_ */ template __device__ __forceinline__ std::enable_if_t contains( - detail::cg::thread_block_tile const& g, + cooperative_groups::thread_block_tile const& g, ProbeKey const& k, KeyEqual key_equal) noexcept {