From dfc28a46264e8478d77782c5ca76c3b86db9fd0e Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Sat, 7 Aug 2021 00:04:44 +0530 Subject: [PATCH 01/16] Add interfaces in cuml for hamming, correlation, jensen-shannon, kl-divergence and russell-rao distance prims --- cpp/CMakeLists.txt | 5 ++ cpp/src/metrics/pairwise_distance.cu | 1 + cpp/src/metrics/pairwise_distance_canberra.cu | 10 +-- .../metrics/pairwise_distance_correlation.cu | 74 +++++++++++++++++++ .../metrics/pairwise_distance_correlation.cuh | 49 ++++++++++++ cpp/src/metrics/pairwise_distance_hamming.cu | 74 +++++++++++++++++++ cpp/src/metrics/pairwise_distance_hamming.cuh | 49 ++++++++++++ .../pairwise_distance_jensen_shannon.cu | 74 +++++++++++++++++++ .../pairwise_distance_jensen_shannon.cuh | 49 ++++++++++++ .../pairwise_distance_kl_divergence.cu | 74 +++++++++++++++++++ .../pairwise_distance_kl_divergence.cuh | 49 ++++++++++++ .../metrics/pairwise_distance_russell_rao.cu | 74 +++++++++++++++++++ .../metrics/pairwise_distance_russell_rao.cuh | 49 ++++++++++++ 13 files changed, 622 insertions(+), 9 deletions(-) create mode 100644 cpp/src/metrics/pairwise_distance_correlation.cu create mode 100644 cpp/src/metrics/pairwise_distance_correlation.cuh create mode 100644 cpp/src/metrics/pairwise_distance_hamming.cu create mode 100644 cpp/src/metrics/pairwise_distance_hamming.cuh create mode 100644 cpp/src/metrics/pairwise_distance_jensen_shannon.cu create mode 100644 cpp/src/metrics/pairwise_distance_jensen_shannon.cuh create mode 100644 cpp/src/metrics/pairwise_distance_kl_divergence.cu create mode 100644 cpp/src/metrics/pairwise_distance_kl_divergence.cuh create mode 100644 cpp/src/metrics/pairwise_distance_russell_rao.cu create mode 100644 cpp/src/metrics/pairwise_distance_russell_rao.cuh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 852320f555..2b39cc37cc 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -250,11 +250,16 @@ if(BUILD_CUML_CPP_LIBRARY) src/metrics/pairwise_distance.cu src/metrics/pairwise_distance_canberra.cu src/metrics/pairwise_distance_chebyshev.cu + src/metrics/pairwise_distance_correlation.cu src/metrics/pairwise_distance_cosine.cu src/metrics/pairwise_distance_euclidean.cu + src/metrics/pairwise_distance_hamming.cu src/metrics/pairwise_distance_hellinger.cu + src/metrics/pairwise_distance_jensen_shannon.cu + src/metrics/pairwise_distance_kl_divergence.cu src/metrics/pairwise_distance_l1.cu src/metrics/pairwise_distance_minkowski.cu + src/metrics/pairwise_distance_russell_rao.cu src/metrics/r2_score.cu src/metrics/rand_index.cu src/metrics/silhouette_score.cu diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 47af2985c4..071bd0b472 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -24,6 +24,7 @@ #include "pairwise_distance_chebyshev.cuh" #include "pairwise_distance_cosine.cuh" #include "pairwise_distance_euclidean.cuh" +#include "pairwise_distance_hamming.cuh" #include "pairwise_distance_hellinger.cuh" #include "pairwise_distance_l1.cuh" #include "pairwise_distance_minkowski.cuh" diff --git a/cpp/src/metrics/pairwise_distance_canberra.cu b/cpp/src/metrics/pairwise_distance_canberra.cu index 504d6da510..cd2a2f3b19 100644 --- a/cpp/src/metrics/pairwise_distance_canberra.cu +++ b/cpp/src/metrics/pairwise_distance_canberra.cu @@ -15,9 +15,9 @@ * limitations under the License. */ -//#include #include #include +#include "pairwise_distance_canberra.cuh" namespace ML { @@ -37,10 +37,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - /* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, - handle.get_stream(), isRowMajor, - metric_arg);*/ - switch (metric) { case raft::distance::DistanceType::Canberra: raft::distance::pairwise_distance_impl( @@ -65,10 +61,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - /* raft::distance::pairwise_distance(x, y, dist, m, n, k, workspace, metric, - handle.get_stream(), isRowMajor, - metric_arg);*/ - switch (metric) { case raft::distance::DistanceType::Canberra: raft::distance::pairwise_distance_impl( diff --git a/cpp/src/metrics/pairwise_distance_correlation.cu b/cpp/src/metrics/pairwise_distance_correlation.cu new file mode 100644 index 0000000000..76149df8dd --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_correlation.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_correlation.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_correlation(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::CorrelationExpanded: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_correlation(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::CorrelationExpanded: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_correlation.cuh b/cpp/src/metrics/pairwise_distance_correlation.cuh new file mode 100644 index 0000000000..18c7525dd5 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_correlation.cuh @@ -0,0 +1,49 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_hamming(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_hamming(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hamming.cu b/cpp/src/metrics/pairwise_distance_hamming.cu new file mode 100644 index 0000000000..bf4f65c329 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_hamming.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_hamming.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_hamming(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::HammingUnexpanded: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_hamming(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::HammingUnexpanded: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hamming.cuh b/cpp/src/metrics/pairwise_distance_hamming.cuh new file mode 100644 index 0000000000..18c7525dd5 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_hamming.cuh @@ -0,0 +1,49 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_hamming(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_hamming(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu new file mode 100644 index 0000000000..4285e0afdc --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_jensen_shannon.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::JensenShannon: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::JensenShannon: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh new file mode 100644 index 0000000000..fae5973b35 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh @@ -0,0 +1,49 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_jensen_shannon(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cu b/cpp/src/metrics/pairwise_distance_kl_divergence.cu new file mode 100644 index 0000000000..604d580b2e --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_kl_divergence.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::KLDivergence: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::KLDivergence: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh new file mode 100644 index 0000000000..55c20704d1 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh @@ -0,0 +1,49 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_kl_divergence(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cu b/cpp/src/metrics/pairwise_distance_russell_rao.cu new file mode 100644 index 0000000000..47c8211462 --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "pairwise_distance_russell_rao.cuh" + +namespace ML { + +namespace Metrics { +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::RusselRaoExpanded: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) +{ + // Allocate workspace + raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + + // Call the distance function + switch (metric) { + case raft::distance::DistanceType::RusselRaoExpanded: + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + } +} + +} // namespace Metrics +} // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cuh b/cpp/src/metrics/pairwise_distance_russell_rao.cuh new file mode 100644 index 0000000000..55207fc6bd --- /dev/null +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cuh @@ -0,0 +1,49 @@ + +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace ML { + +namespace Metrics { +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); + +void pairwise_distance_russell_rao(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + +} // namespace Metrics +} // namespace ML From ad67053d626e8607f5f08396e4b9bcda1a1e0535 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Sat, 7 Aug 2021 00:11:27 +0530 Subject: [PATCH 02/16] fix clang formatting issues --- .../metrics/pairwise_distance_correlation.cu | 46 ++++++++++--------- .../metrics/pairwise_distance_correlation.cuh | 36 +++++++-------- cpp/src/metrics/pairwise_distance_hamming.cu | 46 ++++++++++--------- cpp/src/metrics/pairwise_distance_hamming.cuh | 36 +++++++-------- .../pairwise_distance_jensen_shannon.cu | 46 ++++++++++--------- .../pairwise_distance_jensen_shannon.cuh | 36 +++++++-------- .../pairwise_distance_kl_divergence.cu | 46 ++++++++++--------- .../pairwise_distance_kl_divergence.cuh | 36 +++++++-------- .../metrics/pairwise_distance_russell_rao.cu | 46 ++++++++++--------- .../metrics/pairwise_distance_russell_rao.cuh | 36 +++++++-------- 10 files changed, 210 insertions(+), 200 deletions(-) diff --git a/cpp/src/metrics/pairwise_distance_correlation.cu b/cpp/src/metrics/pairwise_distance_correlation.cu index 76149df8dd..393ea1fd95 100644 --- a/cpp/src/metrics/pairwise_distance_correlation.cu +++ b/cpp/src/metrics/pairwise_distance_correlation.cu @@ -23,15 +23,15 @@ namespace ML { namespace Metrics { void pairwise_distance_correlation(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg) + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); @@ -39,23 +39,24 @@ void pairwise_distance_correlation(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::CorrelationExpanded: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } } void pairwise_distance_correlation(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); @@ -63,8 +64,9 @@ void pairwise_distance_correlation(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::CorrelationExpanded: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } diff --git a/cpp/src/metrics/pairwise_distance_correlation.cuh b/cpp/src/metrics/pairwise_distance_correlation.cuh index 18c7525dd5..182fc6e2c6 100644 --- a/cpp/src/metrics/pairwise_distance_correlation.cuh +++ b/cpp/src/metrics/pairwise_distance_correlation.cuh @@ -24,26 +24,26 @@ namespace ML { namespace Metrics { void pairwise_distance_hamming(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg); + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); void pairwise_distance_hamming(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_hamming.cu b/cpp/src/metrics/pairwise_distance_hamming.cu index bf4f65c329..cc09a9f676 100644 --- a/cpp/src/metrics/pairwise_distance_hamming.cu +++ b/cpp/src/metrics/pairwise_distance_hamming.cu @@ -23,15 +23,15 @@ namespace ML { namespace Metrics { void pairwise_distance_hamming(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg) + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); @@ -39,23 +39,24 @@ void pairwise_distance_hamming(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::HammingUnexpanded: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } } void pairwise_distance_hamming(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); @@ -63,8 +64,9 @@ void pairwise_distance_hamming(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::HammingUnexpanded: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } diff --git a/cpp/src/metrics/pairwise_distance_hamming.cuh b/cpp/src/metrics/pairwise_distance_hamming.cuh index 18c7525dd5..182fc6e2c6 100644 --- a/cpp/src/metrics/pairwise_distance_hamming.cuh +++ b/cpp/src/metrics/pairwise_distance_hamming.cuh @@ -24,26 +24,26 @@ namespace ML { namespace Metrics { void pairwise_distance_hamming(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg); + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); void pairwise_distance_hamming(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu index 4285e0afdc..51d718c678 100644 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu @@ -23,15 +23,15 @@ namespace ML { namespace Metrics { void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg) + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); @@ -39,23 +39,24 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::JensenShannon: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } } void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); @@ -63,8 +64,9 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::JensenShannon: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh index fae5973b35..63dd22a5bb 100644 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh @@ -24,26 +24,26 @@ namespace ML { namespace Metrics { void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg); + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); void pairwise_distance_jensen_shannon(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cu b/cpp/src/metrics/pairwise_distance_kl_divergence.cu index 604d580b2e..20c7678810 100644 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cu +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cu @@ -23,15 +23,15 @@ namespace ML { namespace Metrics { void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg) + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); @@ -39,23 +39,24 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::KLDivergence: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } } void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); @@ -63,8 +64,9 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::KLDivergence: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh index 55c20704d1..89679e64f8 100644 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh @@ -24,26 +24,26 @@ namespace ML { namespace Metrics { void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg); + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); void pairwise_distance_kl_divergence(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); } // namespace Metrics } // namespace ML diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cu b/cpp/src/metrics/pairwise_distance_russell_rao.cu index 47c8211462..2d8ef9d95b 100644 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cu +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cu @@ -23,15 +23,15 @@ namespace ML { namespace Metrics { void pairwise_distance_russell_rao(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg) + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); @@ -39,23 +39,24 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::RusselRaoExpanded: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } } void pairwise_distance_russell_rao(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); @@ -63,8 +64,9 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, // Call the distance function switch (metric) { case raft::distance::DistanceType::RusselRaoExpanded: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); } diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cuh b/cpp/src/metrics/pairwise_distance_russell_rao.cuh index 55207fc6bd..0caf3f89df 100644 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cuh +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cuh @@ -24,26 +24,26 @@ namespace ML { namespace Metrics { void pairwise_distance_russell_rao(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg); + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); void pairwise_distance_russell_rao(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); } // namespace Metrics } // namespace ML From 38729dc4ddeeb900269a08d74dfdc4a7dd5d33ab Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Sat, 7 Aug 2021 00:35:28 +0530 Subject: [PATCH 03/16] add all new distances to main API and fix function name in correlation headers --- cpp/src/metrics/pairwise_distance.cu | 34 ++++++++++++++++ .../metrics/pairwise_distance_correlation.cuh | 40 +++++++++---------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 071bd0b472..85ec644358 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -22,12 +22,16 @@ #include #include "pairwise_distance_canberra.cuh" #include "pairwise_distance_chebyshev.cuh" +#include "pairwise_distance_correlation.cuh" #include "pairwise_distance_cosine.cuh" #include "pairwise_distance_euclidean.cuh" #include "pairwise_distance_hamming.cuh" #include "pairwise_distance_hellinger.cuh" +#include "pairwise_distance_jensen_shannon.cuh" +#include "pairwise_distance_kl_divergence.cuh" #include "pairwise_distance_l1.cuh" #include "pairwise_distance_minkowski.cuh" +#include "pairwise_distance_russell_rao.cuh" namespace ML { @@ -68,6 +72,21 @@ void pairwise_distance(const raft::handle_t& handle, case raft::distance::DistanceType::Canberra: pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); break; + case raft::distance::DistanceType::CorrelationExpanded: + pairwise_distance_correlation(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::HammingUnexpanded: + pairwise_distance_hamming(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::JensenShannon: + pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::KLDivergence: + pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::RusselRaoExpanded: + pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } @@ -108,6 +127,21 @@ void pairwise_distance(const raft::handle_t& handle, case raft::distance::DistanceType::Canberra: pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); break; + case raft::distance::DistanceType::CorrelationExpanded: + pairwise_distance_correlation(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::HammingUnexpanded: + pairwise_distance_hamming(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::JensenShannon: + pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::KLDivergence: + pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::RusselRaoExpanded: + pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } diff --git a/cpp/src/metrics/pairwise_distance_correlation.cuh b/cpp/src/metrics/pairwise_distance_correlation.cuh index 182fc6e2c6..ffa39a98a5 100644 --- a/cpp/src/metrics/pairwise_distance_correlation.cuh +++ b/cpp/src/metrics/pairwise_distance_correlation.cuh @@ -23,27 +23,27 @@ namespace ML { namespace Metrics { -void pairwise_distance_hamming(const raft::handle_t& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - double metric_arg); +void pairwise_distance_correlation(const raft::handle_t& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + double metric_arg); -void pairwise_distance_hamming(const raft::handle_t& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); +void pairwise_distance_correlation(const raft::handle_t& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); } // namespace Metrics } // namespace ML From 0a734fdbb02c80e543db3afd213c7ede22c45dca Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 11 Aug 2021 01:44:00 +0530 Subject: [PATCH 04/16] add python interfaces for all new dist metrics, with tests for all working except some bug in kldivergence --- python/cuml/metrics/distance_type.pxd | 3 +++ python/cuml/metrics/pairwise_distances.pyx | 18 ++++++++++++++++-- python/cuml/test/test_metrics.py | 13 ++++++++++--- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/python/cuml/metrics/distance_type.pxd b/python/cuml/metrics/distance_type.pxd index 93cf1ad9e9..4286ea1c9d 100644 --- a/python/cuml/metrics/distance_type.pxd +++ b/python/cuml/metrics/distance_type.pxd @@ -33,5 +33,8 @@ cdef extern from "raft/linalg/distance_type.h" namespace "raft::distance": Haversine "raft::distance::DistanceType::Haversine" BrayCurtis "raft::distance::DistanceType::BrayCurtis" JensenShannon "raft::distance::DistanceType::JensenShannon" + HammingUnexpanded "raft::distance::DistanceType::HammingUnexpanded" + KLDivergence "raft::distance::DistanceType::KLDivergence" + RusselRaoExpanded "raft::distance::DistanceType::RusselRaoExpanded" DiceExpanded "raft::distance::DistanceType::DiceExpanded" Precomputed "raft::distance::DistanceType::Precomputed" diff --git a/python/cuml/metrics/pairwise_distances.pyx b/python/cuml/metrics/pairwise_distances.pyx index d43fcb4329..29f9d492ac 100644 --- a/python/cuml/metrics/pairwise_distances.pyx +++ b/python/cuml/metrics/pairwise_distances.pyx @@ -73,7 +73,12 @@ PAIRWISE_DISTANCE_METRICS = { "canberra": DistanceType.Canberra, "chebyshev": DistanceType.Linf, "minkowski": DistanceType.LpUnexpanded, - "hellinger": DistanceType.HellingerExpanded + "hellinger": DistanceType.HellingerExpanded, + "correlation": DistanceType.CorrelationExpanded, + "jensenshannon": DistanceType.JensenShannon, + "hamming": DistanceType.HammingUnexpanded, + "kldivergence": DistanceType.KLDivergence, + "russellrao": DistanceType.RusselRaoExpanded } PAIRWISE_DISTANCE_SPARSE_METRICS = { @@ -217,6 +222,11 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, handle = Handle() if handle is None else handle cdef handle_t *handle_ = handle.getHandle() + if metric in ['russellrao'] and not np.all(X.data == 1.): + warnings.warn("X was converted to boolean for metric {}" + .format(metric)) + X = np.where(X != 0., 1.0, 0.0) + # Get the input arrays, preserve order and type where possible X_m, n_samples_x, n_features_x, dtype_x = \ input_to_cuml_array(X, order="K", check_dtype=[np.float32, np.float64]) @@ -235,12 +245,16 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None, if (n_samples_x == 1 or n_features_x == 1): input_order = "K" + if metric in ['russellrao'] and not np.all(Y.data == 1.): + warnings.warn("Y was converted to boolean for metric {}" + .format(metric)) + Y = np.where(Y != 0., 1.0, 0.0) + Y_m, n_samples_y, n_features_y, dtype_y = \ input_to_cuml_array(Y, order=input_order, convert_to_dtype=(dtype_x if convert_dtype else None), check_dtype=[dtype_x]) - # Get the order from Y if necessary (It's possible to set order="F" in # input_to_cuml_array and have Y_m.order=="C") if (input_order == "K"): diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index 8d642b72fd..a24099392f 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -77,7 +77,8 @@ from cuml.metrics import pairwise_distances, sparse_pairwise_distances, \ PAIRWISE_DISTANCE_METRICS, PAIRWISE_DISTANCE_SPARSE_METRICS from sklearn.metrics import pairwise_distances as sklearn_pairwise_distances - +from scipy.spatial import distance as scipy_pairwise_distances +from scipy.special import rel_entr as scipy_kl_divergence @pytest.fixture(scope='module') def random_state(): @@ -869,12 +870,18 @@ def ref_dense_pairwise_dist(X, Y=None, metric=None): Y = X if metric == "hellinger": return naive_hellinger(X, Y) + elif metric == "jensenshannon": + return scipy_pairwise_distances.cdist(X, Y, 'jensenshannon') + elif metric == "kldivergence": + return 0.5 * np.array([[np.sum(np.where(yj != 0, xi * np.log(xi / yj), 0)) for yj in Y] for xi in X]) + #return np.array([[scipy_kl_divergence(xi, yj) for yj in Y] for xi in X]) + #[[f(i, j) for j in b] for i in a] else: return sklearn_pairwise_distances(X, Y, metric) def prep_dense_array(array, metric, col_major=0): - if metric == "hellinger": + if metric in ['hellinger', 'jensenshannon', 'kldivergence']: norm_array = preprocessing.normalize(array, norm="l1") return np.asfortranarray(norm_array) if col_major else norm_array else: @@ -908,7 +915,7 @@ def test_pairwise_distances(metric: str, matrix_size, is_col_major): # Compare single and double inputs to eachother S = pairwise_distances(X, metric=metric) S2 = pairwise_distances(X, Y, metric=metric) - cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) + #cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, with Y dim != X dim Y = prep_dense_array(rng.random_sample((2, matrix_size[1])), From aa072ac5681f64b1f5073a37f2c6bbecece3db99 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 11 Aug 2021 20:07:29 +0530 Subject: [PATCH 05/16] add test support for kl-divergence dist metric --- python/cuml/test/test_metrics.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index a24099392f..d3b55dfad3 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -863,7 +863,12 @@ def test_log_loss_at_limits(): log_loss(y_true, y_pred) -def ref_dense_pairwise_dist(X, Y=None, metric=None): +def naive_kl_divergence_dist(X, Y): + return 0.5 * np.array([[np.sum(np.where(yj != 0, + scipy_kl_divergence(xi, yj), 0.0)) for yj in Y] for xi in X]) + + +def ref_dense_pairwise_dist(X, Y=None, metric=None, convert_dtype=False): # Select sklearn except for Hellinger that # sklearn doesn't support if Y is None: @@ -873,9 +878,7 @@ def ref_dense_pairwise_dist(X, Y=None, metric=None): elif metric == "jensenshannon": return scipy_pairwise_distances.cdist(X, Y, 'jensenshannon') elif metric == "kldivergence": - return 0.5 * np.array([[np.sum(np.where(yj != 0, xi * np.log(xi / yj), 0)) for yj in Y] for xi in X]) - #return np.array([[scipy_kl_divergence(xi, yj) for yj in Y] for xi in X]) - #[[f(i, j) for j in b] for i in a] + return naive_kl_divergence_dist(X, Y) else: return sklearn_pairwise_distances(X, Y, metric) @@ -915,7 +918,7 @@ def test_pairwise_distances(metric: str, matrix_size, is_col_major): # Compare single and double inputs to eachother S = pairwise_distances(X, metric=metric) S2 = pairwise_distances(X, Y, metric=metric) - #cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) + cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Compare to sklearn, with Y dim != X dim Y = prep_dense_array(rng.random_sample((2, matrix_size[1])), @@ -942,11 +945,12 @@ def test_pairwise_distances(metric: str, matrix_size, is_col_major): cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Test sending an int type with convert_dtype=True - Y = prep_dense_array(rng.randint(10, size=Y.shape), - metric=metric, col_major=is_col_major) - S = pairwise_distances(X, Y, metric=metric, convert_dtype=True) - S2 = ref_dense_pairwise_dist(X, Y, metric=metric) - cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) + if metric != 'kldivergence': + Y = prep_dense_array(rng.randint(10, size=Y.shape), + metric=metric, col_major=is_col_major) + S = pairwise_distances(X, Y, metric=metric, convert_dtype=True) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric, convert_dtype=True) + cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Test that uppercase on the metric name throws an error. with pytest.raises(ValueError): From 6dd14bf010160cbd9f309c17815f9e1d6eb81a2a Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 11 Aug 2021 20:33:18 +0530 Subject: [PATCH 06/16] pin mdoijade raft fork for testing change --- cpp/cmake/thirdparty/get_raft.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index abef8830d4..926ab20699 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -25,8 +25,8 @@ function(find_and_configure_raft) BUILD_EXPORT_SET cuml-exports INSTALL_EXPORT_SET cuml-exports CPM_ARGS - GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git - GIT_TAG ${PKG_PINNED_TAG} + GIT_REPOSITORY https://github.com/mdoijade/raft.git + GIT_TAG additionalDistPrims SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" From 911eef88a17746f4cb9dab000e2c138e593d1c0c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 11 Aug 2021 21:11:30 +0530 Subject: [PATCH 07/16] fix flake formating issues in test_metrics --- python/cuml/test/test_metrics.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index d3b55dfad3..b350aa2d6f 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -80,6 +80,7 @@ from scipy.spatial import distance as scipy_pairwise_distances from scipy.special import rel_entr as scipy_kl_divergence + @pytest.fixture(scope='module') def random_state(): random_state = random.randint(0, 1e6) @@ -865,7 +866,8 @@ def test_log_loss_at_limits(): def naive_kl_divergence_dist(X, Y): return 0.5 * np.array([[np.sum(np.where(yj != 0, - scipy_kl_divergence(xi, yj), 0.0)) for yj in Y] for xi in X]) + scipy_kl_divergence(xi, yj), 0.0)) for yj in Y] + for xi in X]) def ref_dense_pairwise_dist(X, Y=None, metric=None, convert_dtype=False): @@ -946,11 +948,11 @@ def test_pairwise_distances(metric: str, matrix_size, is_col_major): # Test sending an int type with convert_dtype=True if metric != 'kldivergence': - Y = prep_dense_array(rng.randint(10, size=Y.shape), - metric=metric, col_major=is_col_major) - S = pairwise_distances(X, Y, metric=metric, convert_dtype=True) - S2 = ref_dense_pairwise_dist(X, Y, metric=metric, convert_dtype=True) - cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) + Y = prep_dense_array(rng.randint(10, size=Y.shape), + metric=metric, col_major=is_col_major) + S = pairwise_distances(X, Y, metric=metric, convert_dtype=True) + S2 = ref_dense_pairwise_dist(X, Y, metric=metric, convert_dtype=True) + cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision) # Test that uppercase on the metric name throws an error. with pytest.raises(ValueError): From ae9f7a29a17b6fcd0db9a7dfb5a37c5e434d4a6a Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 12 Aug 2021 13:30:23 +0530 Subject: [PATCH 08/16] temp commit to trigger ci --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 926ab20699..46c19add9e 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -26,7 +26,7 @@ function(find_and_configure_raft) INSTALL_EXPORT_SET cuml-exports CPM_ARGS GIT_REPOSITORY https://github.com/mdoijade/raft.git - GIT_TAG additionalDistPrims + GIT_TAG additionalDistPrims SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" From 5fc77f7729d1fe308261935e46194c8f5316aaf5 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 12 Aug 2021 13:33:27 +0530 Subject: [PATCH 09/16] temp commit to trigger ci --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 46c19add9e..926ab20699 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -26,7 +26,7 @@ function(find_and_configure_raft) INSTALL_EXPORT_SET cuml-exports CPM_ARGS GIT_REPOSITORY https://github.com/mdoijade/raft.git - GIT_TAG additionalDistPrims + GIT_TAG additionalDistPrims SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" From 90c8f16855aba23cccd65f5e278e7b7cf7745130 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 12 Aug 2021 19:07:09 +0530 Subject: [PATCH 10/16] temp commit to trigger ci to check updated raft changes --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 926ab20699..46c19add9e 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -26,7 +26,7 @@ function(find_and_configure_raft) INSTALL_EXPORT_SET cuml-exports CPM_ARGS GIT_REPOSITORY https://github.com/mdoijade/raft.git - GIT_TAG additionalDistPrims + GIT_TAG additionalDistPrims SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" From 1d5c966e12f28d020cc4b4555a9f3a136065d060 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 12 Aug 2021 20:58:57 +0530 Subject: [PATCH 11/16] temp commit to trigger ci to check updated raft changes --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 46c19add9e..926ab20699 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -26,7 +26,7 @@ function(find_and_configure_raft) INSTALL_EXPORT_SET cuml-exports CPM_ARGS GIT_REPOSITORY https://github.com/mdoijade/raft.git - GIT_TAG additionalDistPrims + GIT_TAG additionalDistPrims SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" From 3f6de71854287fb20f2cc5f2b7caf375e0079fef Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 23 Aug 2021 20:57:57 +0530 Subject: [PATCH 12/16] temp commit to test new raft commits --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 926ab20699..46c19add9e 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -26,7 +26,7 @@ function(find_and_configure_raft) INSTALL_EXPORT_SET cuml-exports CPM_ARGS GIT_REPOSITORY https://github.com/mdoijade/raft.git - GIT_TAG additionalDistPrims + GIT_TAG additionalDistPrims SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" From b71866a7627a60f781c737b6bad7d7a8e5fb2ed0 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 26 Aug 2021 13:38:11 +0530 Subject: [PATCH 13/16] revert raft mdoijade fork as raft PR is merged now --- cpp/cmake/thirdparty/get_raft.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 46c19add9e..abef8830d4 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -25,8 +25,8 @@ function(find_and_configure_raft) BUILD_EXPORT_SET cuml-exports INSTALL_EXPORT_SET cuml-exports CPM_ARGS - GIT_REPOSITORY https://github.com/mdoijade/raft.git - GIT_TAG additionalDistPrims + GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git + GIT_TAG ${PKG_PINNED_TAG} SOURCE_SUBDIR cpp OPTIONS "BUILD_TESTS OFF" From 1a1dd715672fbebb1f350ba8e5e5b0f0c862c9e6 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 30 Aug 2021 19:51:13 +0530 Subject: [PATCH 14/16] remove redundant metric arg and switch based on it on the APIs which contains single distance metric implementation --- cpp/src/metrics/pairwise_distance.cu | 44 +++++++++---------- cpp/src/metrics/pairwise_distance_canberra.cu | 20 ++------- .../metrics/pairwise_distance_canberra.cuh | 2 - .../metrics/pairwise_distance_chebyshev.cu | 20 ++------- .../metrics/pairwise_distance_chebyshev.cuh | 2 - .../metrics/pairwise_distance_correlation.cu | 24 +++------- .../metrics/pairwise_distance_correlation.cuh | 2 - cpp/src/metrics/pairwise_distance_cosine.cu | 23 +++------- cpp/src/metrics/pairwise_distance_cosine.cuh | 2 - cpp/src/metrics/pairwise_distance_hamming.cu | 24 +++------- cpp/src/metrics/pairwise_distance_hamming.cuh | 2 - .../metrics/pairwise_distance_hellinger.cu | 24 +++------- .../metrics/pairwise_distance_hellinger.cuh | 2 - .../pairwise_distance_jensen_shannon.cu | 22 ++-------- .../pairwise_distance_jensen_shannon.cuh | 2 - .../pairwise_distance_kl_divergence.cu | 22 ++-------- .../pairwise_distance_kl_divergence.cuh | 2 - cpp/src/metrics/pairwise_distance_l1.cu | 20 ++------- cpp/src/metrics/pairwise_distance_l1.cuh | 2 - .../metrics/pairwise_distance_minkowski.cu | 22 ++-------- .../metrics/pairwise_distance_minkowski.cuh | 2 - .../metrics/pairwise_distance_russell_rao.cu | 24 +++------- .../metrics/pairwise_distance_russell_rao.cuh | 2 - 23 files changed, 75 insertions(+), 236 deletions(-) diff --git a/cpp/src/metrics/pairwise_distance.cu b/cpp/src/metrics/pairwise_distance.cu index 85ec644358..3a7f89f263 100644 --- a/cpp/src/metrics/pairwise_distance.cu +++ b/cpp/src/metrics/pairwise_distance.cu @@ -55,37 +55,37 @@ void pairwise_distance(const raft::handle_t& handle, pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); break; case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::L1: - pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Linf: - pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Canberra: - pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::CorrelationExpanded: - pairwise_distance_correlation(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::HammingUnexpanded: - pairwise_distance_hamming(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::JensenShannon: - pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::KLDivergence: - pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::RusselRaoExpanded: - pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; @@ -110,37 +110,37 @@ void pairwise_distance(const raft::handle_t& handle, pairwise_distance_euclidean(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); break; case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_cosine(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_cosine(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::L1: - pairwise_distance_l1(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_l1(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Linf: - pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_chebyshev(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_hellinger(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_hellinger(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_minkowski(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_minkowski(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Canberra: - pairwise_distance_canberra(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_canberra(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::CorrelationExpanded: - pairwise_distance_correlation(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_correlation(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::HammingUnexpanded: - pairwise_distance_hamming(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_hamming(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::JensenShannon: - pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_jensen_shannon(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::KLDivergence: - pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_kl_divergence(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; case raft::distance::DistanceType::RusselRaoExpanded: - pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, metric, isRowMajor, metric_arg); + pairwise_distance_russell_rao(handle, x, y, dist, m, n, k, isRowMajor, metric_arg); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/src/metrics/pairwise_distance_canberra.cu b/cpp/src/metrics/pairwise_distance_canberra.cu index cd2a2f3b19..4e55d285de 100644 --- a/cpp/src/metrics/pairwise_distance_canberra.cu +++ b/cpp/src/metrics/pairwise_distance_canberra.cu @@ -29,7 +29,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -37,13 +36,8 @@ void pairwise_distance_canberra(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::Canberra: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_canberra(const raft::handle_t& handle, @@ -53,7 +47,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -61,13 +54,8 @@ void pairwise_distance_canberra(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::Canberra: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_canberra.cuh b/cpp/src/metrics/pairwise_distance_canberra.cuh index 3d1454cfcc..24bba4906f 100644 --- a/cpp/src/metrics/pairwise_distance_canberra.cuh +++ b/cpp/src/metrics/pairwise_distance_canberra.cuh @@ -30,7 +30,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -41,7 +40,6 @@ void pairwise_distance_canberra(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cu b/cpp/src/metrics/pairwise_distance_chebyshev.cu index 2a30aa8e5c..c9c0cf0119 100644 --- a/cpp/src/metrics/pairwise_distance_chebyshev.cu +++ b/cpp/src/metrics/pairwise_distance_chebyshev.cu @@ -28,20 +28,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::Linf: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_chebyshev(const raft::handle_t& handle, @@ -51,20 +45,14 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::Linf: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_chebyshev.cuh b/cpp/src/metrics/pairwise_distance_chebyshev.cuh index 6f95dbba30..d8b385808f 100644 --- a/cpp/src/metrics/pairwise_distance_chebyshev.cuh +++ b/cpp/src/metrics/pairwise_distance_chebyshev.cuh @@ -28,7 +28,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -39,7 +38,6 @@ void pairwise_distance_chebyshev(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_correlation.cu b/cpp/src/metrics/pairwise_distance_correlation.cu index 393ea1fd95..6b1a6a770d 100644 --- a/cpp/src/metrics/pairwise_distance_correlation.cu +++ b/cpp/src/metrics/pairwise_distance_correlation.cu @@ -29,7 +29,6 @@ void pairwise_distance_correlation(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -37,14 +36,9 @@ void pairwise_distance_correlation(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::CorrelationExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_correlation(const raft::handle_t& handle, @@ -54,7 +48,6 @@ void pairwise_distance_correlation(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -62,14 +55,9 @@ void pairwise_distance_correlation(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::CorrelationExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_correlation.cuh b/cpp/src/metrics/pairwise_distance_correlation.cuh index ffa39a98a5..8db0d59556 100644 --- a/cpp/src/metrics/pairwise_distance_correlation.cuh +++ b/cpp/src/metrics/pairwise_distance_correlation.cuh @@ -30,7 +30,6 @@ void pairwise_distance_correlation(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -41,7 +40,6 @@ void pairwise_distance_correlation(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_cosine.cu b/cpp/src/metrics/pairwise_distance_cosine.cu index de21d9a3b4..1ada740ec8 100644 --- a/cpp/src/metrics/pairwise_distance_cosine.cu +++ b/cpp/src/metrics/pairwise_distance_cosine.cu @@ -29,7 +29,6 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -37,14 +36,8 @@ void pairwise_distance_cosine(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_cosine(const raft::handle_t& handle, @@ -54,20 +47,14 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_cosine.cuh b/cpp/src/metrics/pairwise_distance_cosine.cuh index 04f07e7de7..58388ea4a9 100644 --- a/cpp/src/metrics/pairwise_distance_cosine.cuh +++ b/cpp/src/metrics/pairwise_distance_cosine.cuh @@ -29,7 +29,6 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -40,7 +39,6 @@ void pairwise_distance_cosine(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_hamming.cu b/cpp/src/metrics/pairwise_distance_hamming.cu index cc09a9f676..7e6019b71c 100644 --- a/cpp/src/metrics/pairwise_distance_hamming.cu +++ b/cpp/src/metrics/pairwise_distance_hamming.cu @@ -29,7 +29,6 @@ void pairwise_distance_hamming(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -37,14 +36,9 @@ void pairwise_distance_hamming(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::HammingUnexpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_hamming(const raft::handle_t& handle, @@ -54,7 +48,6 @@ void pairwise_distance_hamming(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -62,14 +55,9 @@ void pairwise_distance_hamming(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::HammingUnexpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_hamming.cuh b/cpp/src/metrics/pairwise_distance_hamming.cuh index 182fc6e2c6..59b6aad019 100644 --- a/cpp/src/metrics/pairwise_distance_hamming.cuh +++ b/cpp/src/metrics/pairwise_distance_hamming.cuh @@ -30,7 +30,6 @@ void pairwise_distance_hamming(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -41,7 +40,6 @@ void pairwise_distance_hamming(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cu b/cpp/src/metrics/pairwise_distance_hellinger.cu index 9b2528af83..456a24f4f0 100644 --- a/cpp/src/metrics/pairwise_distance_hellinger.cu +++ b/cpp/src/metrics/pairwise_distance_hellinger.cu @@ -29,21 +29,15 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::HellingerExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_hellinger(const raft::handle_t& handle, @@ -53,21 +47,15 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::HellingerExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_hellinger.cuh b/cpp/src/metrics/pairwise_distance_hellinger.cuh index 70521b6578..92b820a6a0 100644 --- a/cpp/src/metrics/pairwise_distance_hellinger.cuh +++ b/cpp/src/metrics/pairwise_distance_hellinger.cuh @@ -29,7 +29,6 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -40,7 +39,6 @@ void pairwise_distance_hellinger(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu index 51d718c678..f38b9b6062 100644 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu @@ -29,7 +29,6 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -37,14 +36,8 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::JensenShannon: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_jensen_shannon(const raft::handle_t& handle, @@ -54,7 +47,6 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -62,14 +54,8 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::JensenShannon: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh index 63dd22a5bb..4f6f55af35 100644 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cuh @@ -30,7 +30,6 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -41,7 +40,6 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cu b/cpp/src/metrics/pairwise_distance_kl_divergence.cu index 20c7678810..4434b00b32 100644 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cu +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cu @@ -29,7 +29,6 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -37,14 +36,8 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::KLDivergence: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_kl_divergence(const raft::handle_t& handle, @@ -54,7 +47,6 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -62,14 +54,8 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::KLDivergence: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh index 89679e64f8..80125c710b 100644 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cuh +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cuh @@ -30,7 +30,6 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -41,7 +40,6 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_l1.cu b/cpp/src/metrics/pairwise_distance_l1.cu index cdde31d2a5..6ae62a18d0 100644 --- a/cpp/src/metrics/pairwise_distance_l1.cu +++ b/cpp/src/metrics/pairwise_distance_l1.cu @@ -29,20 +29,14 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::L1: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_l1(const raft::handle_t& handle, @@ -52,20 +46,14 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::L1: - raft::distance::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_l1.cuh b/cpp/src/metrics/pairwise_distance_l1.cuh index f451df5cc8..f93de2bb2d 100644 --- a/cpp/src/metrics/pairwise_distance_l1.cuh +++ b/cpp/src/metrics/pairwise_distance_l1.cuh @@ -28,7 +28,6 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -39,7 +38,6 @@ void pairwise_distance_l1(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cu b/cpp/src/metrics/pairwise_distance_minkowski.cu index 7816bcb253..c1152ffcc2 100644 --- a/cpp/src/metrics/pairwise_distance_minkowski.cu +++ b/cpp/src/metrics/pairwise_distance_minkowski.cu @@ -29,21 +29,14 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::LpUnexpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); } void pairwise_distance_minkowski(const raft::handle_t& handle, @@ -53,21 +46,14 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { // Allocate workspace raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::LpUnexpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_minkowski.cuh b/cpp/src/metrics/pairwise_distance_minkowski.cuh index 013205e67b..dd0ff59b25 100644 --- a/cpp/src/metrics/pairwise_distance_minkowski.cuh +++ b/cpp/src/metrics/pairwise_distance_minkowski.cuh @@ -29,7 +29,6 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -40,7 +39,6 @@ void pairwise_distance_minkowski(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cu b/cpp/src/metrics/pairwise_distance_russell_rao.cu index 2d8ef9d95b..34a777c8b3 100644 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cu +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cu @@ -29,7 +29,6 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg) { @@ -37,14 +36,9 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::RusselRaoExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } void pairwise_distance_russell_rao(const raft::handle_t& handle, @@ -54,7 +48,6 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg) { @@ -62,14 +55,9 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); // Call the distance function - switch (metric) { - case raft::distance::DistanceType::RusselRaoExpanded: - raft::distance:: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - } + raft::distance:: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); } } // namespace Metrics diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cuh b/cpp/src/metrics/pairwise_distance_russell_rao.cuh index 0caf3f89df..1d25194f42 100644 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cuh +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cuh @@ -30,7 +30,6 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, double metric_arg); @@ -41,7 +40,6 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, int m, int n, int k, - raft::distance::DistanceType metric, bool isRowMajor, float metric_arg); From 675b6610a8226c7d2158cc4dfb1c10684e67764f Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 31 Aug 2021 22:39:52 +0530 Subject: [PATCH 15/16] Add udevice_vector changes to new distances --- cpp/src/metrics/pairwise_distance_correlation.cu | 5 +++-- cpp/src/metrics/pairwise_distance_hamming.cu | 5 +++-- cpp/src/metrics/pairwise_distance_jensen_shannon.cu | 5 +++-- cpp/src/metrics/pairwise_distance_kl_divergence.cu | 5 +++-- cpp/src/metrics/pairwise_distance_russell_rao.cu | 5 +++-- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/cpp/src/metrics/pairwise_distance_correlation.cu b/cpp/src/metrics/pairwise_distance_correlation.cu index 6b1a6a770d..5e972553e4 100644 --- a/cpp/src/metrics/pairwise_distance_correlation.cu +++ b/cpp/src/metrics/pairwise_distance_correlation.cu @@ -17,6 +17,7 @@ #include #include +#include #include "pairwise_distance_correlation.cuh" namespace ML { @@ -33,7 +34,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle, double metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance:: @@ -52,7 +53,7 @@ void pairwise_distance_correlation(const raft::handle_t& handle, float metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance:: diff --git a/cpp/src/metrics/pairwise_distance_hamming.cu b/cpp/src/metrics/pairwise_distance_hamming.cu index 7e6019b71c..c99cda5479 100644 --- a/cpp/src/metrics/pairwise_distance_hamming.cu +++ b/cpp/src/metrics/pairwise_distance_hamming.cu @@ -17,6 +17,7 @@ #include #include +#include #include "pairwise_distance_hamming.cuh" namespace ML { @@ -33,7 +34,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle, double metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance:: @@ -52,7 +53,7 @@ void pairwise_distance_hamming(const raft::handle_t& handle, float metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance:: diff --git a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu index f38b9b6062..c78a52ffbf 100644 --- a/cpp/src/metrics/pairwise_distance_jensen_shannon.cu +++ b/cpp/src/metrics/pairwise_distance_jensen_shannon.cu @@ -17,6 +17,7 @@ #include #include +#include #include "pairwise_distance_jensen_shannon.cuh" namespace ML { @@ -33,7 +34,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, double metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance::pairwise_distance_impl( @@ -51,7 +52,7 @@ void pairwise_distance_jensen_shannon(const raft::handle_t& handle, float metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance::pairwise_distance_impl( diff --git a/cpp/src/metrics/pairwise_distance_kl_divergence.cu b/cpp/src/metrics/pairwise_distance_kl_divergence.cu index 4434b00b32..2a734145e6 100644 --- a/cpp/src/metrics/pairwise_distance_kl_divergence.cu +++ b/cpp/src/metrics/pairwise_distance_kl_divergence.cu @@ -17,6 +17,7 @@ #include #include +#include #include "pairwise_distance_kl_divergence.cuh" namespace ML { @@ -33,7 +34,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, double metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance::pairwise_distance_impl( @@ -51,7 +52,7 @@ void pairwise_distance_kl_divergence(const raft::handle_t& handle, float metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 1); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance::pairwise_distance_impl( diff --git a/cpp/src/metrics/pairwise_distance_russell_rao.cu b/cpp/src/metrics/pairwise_distance_russell_rao.cu index 34a777c8b3..3b73a89c01 100644 --- a/cpp/src/metrics/pairwise_distance_russell_rao.cu +++ b/cpp/src/metrics/pairwise_distance_russell_rao.cu @@ -17,6 +17,7 @@ #include #include +#include #include "pairwise_distance_russell_rao.cuh" namespace ML { @@ -33,7 +34,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, double metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance:: @@ -52,7 +53,7 @@ void pairwise_distance_russell_rao(const raft::handle_t& handle, float metric_arg) { // Allocate workspace - raft::mr::device::buffer workspace(handle.get_device_allocator(), handle.get_stream(), 0); + rmm::device_uvector workspace(1, handle.get_stream()); // Call the distance function raft::distance:: From ac4f0671e113c592e1f90f6d5d984d9a2ebb78b3 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 31 Aug 2021 23:16:09 +0530 Subject: [PATCH 16/16] fix clang format issues --- cpp/src/metrics/pairwise_distance_canberra.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/metrics/pairwise_distance_canberra.cu b/cpp/src/metrics/pairwise_distance_canberra.cu index 4ca5e1d911..4e4ac45857 100644 --- a/cpp/src/metrics/pairwise_distance_canberra.cu +++ b/cpp/src/metrics/pairwise_distance_canberra.cu @@ -17,9 +17,8 @@ #include #include -#include "pairwise_distance_canberra.cuh" #include - +#include "pairwise_distance_canberra.cuh" namespace ML {