From 8eaba848d22d96577b34f3b0d07ee9328633546b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 14 Feb 2023 14:34:25 -0800 Subject: [PATCH] use pw specializations in rbc --- cpp/CMakeLists.txt | 2 +- .../ball_cover_all_knn_query.cu | 4 +- .../specializations/ball_cover_build_index.cu | 4 +- .../specializations/ball_cover_knn_query.cu | 4 +- .../nn/specializations/knn_long_float_int.cu | 45 ------------------ .../nn/specializations/knn_long_float_uint.cu | 44 ------------------ .../nn/specializations/knn_uint_float_int.cu | 44 ------------------ .../nn/specializations/knn_uint_float_uint.cu | 46 ------------------- 8 files changed, 4 insertions(+), 189 deletions(-) delete mode 100644 cpp/src/nn/specializations/knn_long_float_int.cu delete mode 100644 cpp/src/nn/specializations/knn_long_float_uint.cu delete mode 100644 cpp/src/nn/specializations/knn_uint_float_int.cu delete mode 100644 cpp/src/nn/specializations/knn_uint_float_uint.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 41f196dbad..f7ae76b5e4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -455,7 +455,7 @@ if(RAFT_COMPILE_NN_LIBRARY) target_link_libraries( raft_nn_lib - PUBLIC raft::raft + PUBLIC raft::raft raft::raft_distance_lib PRIVATE nvidia::cutlass::cutlass ) target_compile_options( diff --git a/cpp/src/nn/specializations/ball_cover_all_knn_query.cu b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu index da5cd8de4f..184e18e2ba 100644 --- a/cpp/src/nn/specializations/ball_cover_all_knn_query.cu +++ b/cpp/src/nn/specializations/ball_cover_all_knn_query.cu @@ -18,10 +18,8 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -#ifdef RAFT_DISTANCE_COMPILED +static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); #include -#endif - #include #include #include diff --git a/cpp/src/nn/specializations/ball_cover_build_index.cu b/cpp/src/nn/specializations/ball_cover_build_index.cu index 70fcbec356..05b3beec73 100644 --- a/cpp/src/nn/specializations/ball_cover_build_index.cu +++ b/cpp/src/nn/specializations/ball_cover_build_index.cu @@ -18,10 +18,8 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -#ifdef RAFT_DISTANCE_COMPILED +static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); #include -#endif - #include #include #include diff --git a/cpp/src/nn/specializations/ball_cover_knn_query.cu b/cpp/src/nn/specializations/ball_cover_knn_query.cu index d5ca1cbc1c..a11f6ba2d2 100644 --- a/cpp/src/nn/specializations/ball_cover_knn_query.cu +++ b/cpp/src/nn/specializations/ball_cover_knn_query.cu @@ -18,10 +18,8 @@ #include // Ignore upstream specializations to avoid unnecessary recompiling -#ifdef RAFT_DISTANCE_COMPILED +static_assert(RAFT_DISTANCE_COMPILED, "Requires distance specializations"); #include -#endif - #include #include #include diff --git a/cpp/src/nn/specializations/knn_long_float_int.cu b/cpp/src/nn/specializations/knn_long_float_int.cu deleted file mode 100644 index 1360430132..0000000000 --- a/cpp/src/nn/specializations/knn_long_float_int.cu +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_long_float_uint.cu b/cpp/src/nn/specializations/knn_long_float_uint.cu deleted file mode 100644 index a84a9e9456..0000000000 --- a/cpp/src/nn/specializations/knn_long_float_uint.cu +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_uint_float_int.cu b/cpp/src/nn/specializations/knn_uint_float_int.cu deleted file mode 100644 index da8bf0eeec..0000000000 --- a/cpp/src/nn/specializations/knn_uint_float_int.cu +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/knn_uint_float_uint.cu b/cpp/src/nn/specializations/knn_uint_float_uint.cu deleted file mode 100644 index b2a482a868..0000000000 --- a/cpp/src/nn/specializations/knn_uint_float_uint.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -// include specializations for pairwise_distance and fusedl2knn -// to avoid recompiling again here -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft