Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Additional distance metrics for ANN #3533

Merged
9 changes: 6 additions & 3 deletions cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@

namespace ML {
struct knnIndex {
faiss::gpu::StandardGpuResources *gpu_res;
faiss::gpu::GpuIndex *index;
raft::distance::DistanceType metric;
float metricArg;

faiss::gpu::StandardGpuResources *gpu_res;
int device;
~knnIndex() {
delete index;
Expand Down Expand Up @@ -99,8 +102,8 @@ void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
raft::distance::DistanceType metric,
float metricArg, float *index_items, int n);

void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k,
float *distances, int64_t *labels);
void approx_knn_search(raft::handle_t &handle, ML::knnIndex *index, int n,
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
float *x, int k, float *distances, int64_t *labels);

/**
* @brief Flat C++ API function to perform a knn classification using a
Expand Down
11 changes: 6 additions & 5 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params, int D,
raft::distance::DistanceType metric,
float metricArg, float *index_items, int n) {
MLCommon::Selection::approx_knn_build_index(
index, params, D, metric, metricArg, index_items, n, handle.get_stream());
MLCommon::Selection::approx_knn_build_index(handle, index, params, D, metric,
metricArg, index_items, n);
}

void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k,
float *distances, int64_t *labels) {
MLCommon::Selection::approx_knn_search(index, n, x, k, distances, labels);
void approx_knn_search(raft::handle_t &handle, ML::knnIndex *index, int n,
float *x, int k, float *distances, int64_t *labels) {
MLCommon::Selection::approx_knn_search(handle, index, n, x, k, distances,
labels);
}

void knn_classify(raft::handle_t &handle, int *out, int64_t *knn_indices,
Expand Down
76 changes: 57 additions & 19 deletions cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params,
faiss::gpu::GpuIndexIVFPQConfig config;
config.device = index->device;
config.usePrecomputedTables = params->usePrecomputedTables;
config.interleavedLayout = params->n_bits != 8;
faiss::MetricType faiss_metric = build_faiss_metric(metric);
faiss::gpu::GpuIndexIVFPQ *faiss_index =
new faiss::gpu::GpuIndexIVFPQ(index->gpu_res, D, params->nlist, params->M,
Expand Down Expand Up @@ -292,48 +293,85 @@ void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params,
}

template <typename IntType = int>
void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params,
IntType D, raft::distance::DistanceType metric,
float metricArg, float *index_items, IntType n,
cudaStream_t userStream) {
void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params, IntType D,
raft::distance::DistanceType metric,
float metricArg, float *index_items, IntType n) {
int device;
CUDA_CHECK(cudaGetDevice(&device));

faiss::gpu::StandardGpuResources *gpu_res =
new faiss::gpu::StandardGpuResources();
gpu_res->noTempMemory();
gpu_res->setDefaultStream(device, userStream);
gpu_res->setDefaultStream(device, handle.get_stream());
index->gpu_res = gpu_res;
index->device = device;
index->index = nullptr;
index->metric = metric;
index->metricArg = metricArg;

// perform preprocessing
std::unique_ptr<MetricProcessor<float>> query_metric_processor =
// k set to 0 (unused during preprocessing / revertion)
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
create_processor<float>(metric, n, D, 0, false, handle.get_stream(),
handle.get_device_allocator());

query_metric_processor->preprocess(index_items);

if (dynamic_cast<ML::IVFFlatParam *>(params)) {
ML::IVFFlatParam *IVFFlat_param = dynamic_cast<ML::IVFFlatParam *>(params);
approx_knn_ivfflat_build_index(index, IVFFlat_param, D, metric, n);
std::vector<float> h_index_items(n * D);
raft::update_host(h_index_items.data(), index_items, h_index_items.size(),
userStream);
handle.get_stream());
query_metric_processor->revert(index_items);
index->index->train(n, h_index_items.data());
index->index->add(n, h_index_items.data());
return;
} else if (dynamic_cast<ML::IVFPQParam *>(params)) {
ML::IVFPQParam *IVFPQ_param = dynamic_cast<ML::IVFPQParam *>(params);
approx_knn_ivfpq_build_index(index, IVFPQ_param, D, metric, n);
} else if (dynamic_cast<ML::IVFSQParam *>(params)) {
ML::IVFSQParam *IVFSQ_param = dynamic_cast<ML::IVFSQParam *>(params);
approx_knn_ivfsq_build_index(index, IVFSQ_param, D, metric, n);
} else {
ASSERT(index->index, "KNN index could not be initialized");
}
if (dynamic_cast<ML::IVFPQParam *>(params)) {
ML::IVFPQParam *IVFPQ_param = dynamic_cast<ML::IVFPQParam *>(params);
approx_knn_ivfpq_build_index(index, IVFPQ_param, D, metric, n);
} else if (dynamic_cast<ML::IVFSQParam *>(params)) {
ML::IVFSQParam *IVFSQ_param = dynamic_cast<ML::IVFSQParam *>(params);
approx_knn_ivfsq_build_index(index, IVFSQ_param, D, metric, n);
} else {
ASSERT(index->index, "KNN index could not be initialized");
}

index->index->train(n, index_items);
index->index->add(n, index_items);
index->index->train(n, index_items);
index->index->add(n, index_items);
query_metric_processor->revert(index_items);
}
}

template <typename IntType = int>
void approx_knn_search(ML::knnIndex *index, IntType n, const float *x,
IntType k, float *distances, int64_t *labels) {
void approx_knn_search(raft::handle_t &handle, ML::knnIndex *index, IntType n,
float *x, IntType k, float *distances, int64_t *labels) {
// perform preprocessing
std::unique_ptr<MetricProcessor<float>> query_metric_processor =
create_processor<float>(index->metric, n, index->index->d, k, false,
handle.get_stream(), handle.get_device_allocator());

query_metric_processor->preprocess(x);
index->index->search(n, x, k, distances, labels);
query_metric_processor->revert(x);

// Perform necessary post-processing
if (index->metric == raft::distance::DistanceType::L2SqrtExpanded ||
index->metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
index->metric == raft::distance::DistanceType::LpUnexpanded) {
/**
* post-processing
*/
float p = 0.5; // standard l2
if (index->metric == raft::distance::DistanceType::LpUnexpanded)
p = 1.0 / index->metricArg;
raft::linalg::unaryOp<float>(
distances, distances, n * k,
[p] __device__(float input) { return powf(input, p); },
handle.get_stream());
}
query_metric_processor->postprocess(distances);
}

/**
Expand Down
15 changes: 12 additions & 3 deletions python/cuml/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@
"inner_product", "sqeuclidean",
"haversine"
]),
"ivfflat": set(["l2", "euclidean"]),
"ivfpq": set(["l2", "euclidean"]),
"ivfsq": set(["l2", "euclidean"])
"ivfflat": set([
"l2", "euclidean", "sqeuclidean",
"inner_product", "cosine", "correlation"
]),
"ivfpq": set([
"l2", "euclidean", "sqeuclidean",
"inner_product", "cosine", "correlation"
]),
"ivfsq": set([
"l2", "euclidean", "sqeuclidean",
"inner_product", "cosine", "correlation"
])
}

VALID_METRICS_SPARSE = {
Expand Down
9 changes: 5 additions & 4 deletions python/cuml/neighbors/ann.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ cdef build_ivfflat_algo_params(params, automated):
cdef build_ivfpq_algo_params(params, automated, additional_info):
cdef IVFPQParam* algo_params = new IVFPQParam()
if automated:
allowedSubquantizers = [1, 2, 3, 4, 8, 12, 16, 20, 24, 28, 32]
allowedSubquantizers = [1, 2, 3, 4, 8, 12, 16, 20, 24, 28, 32, 40, 48]
allowedSubDimSize = {1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32}
N = additional_info['n_samples']
D = additional_info['n_features']
Expand All @@ -75,10 +75,11 @@ cdef build_ivfpq_algo_params(params, automated, additional_info):
params['M'] = n_subq
break

for i in reversed(range(1, 4)):
min_train_points = (2 ** i) * 39
params['n_bits'] = 4
viclafargue marked this conversation as resolved.
Show resolved Hide resolved
for n_bits in [5, 6, 8]:
min_train_points = (2 ** n_bits) * 39
if N >= min_train_points:
params['n_bits'] = i
params['n_bits'] = n_bits
break

algo_params.nlist = <int> params['nlist']
Expand Down
2 changes: 2 additions & 0 deletions python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML":
) except +

void approx_knn_search(
handle_t &handle,
knnIndex* index,
int n,
const float *x,
Expand Down Expand Up @@ -675,6 +676,7 @@ class NearestNeighbors(Base,
else:
knn_index = <knnIndex*><uintptr_t> self.knn_index
approx_knn_search(
handle_[0],
<knnIndex*>knn_index,
<int>N,
<float*><uintptr_t>X_m.ptr,
Expand Down
25 changes: 25 additions & 0 deletions python/cuml/test/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,31 @@ def test_ivfsq_pred(qtype, encodeResidual, nrows, ncols, n_neighbors, nlist):
assert array_equal(labels, y)


@pytest.mark.parametrize("algo", ["brute", "ivfflat", "ivfpq", "ivfsq"])
@pytest.mark.parametrize("metric", set([
"l2", "euclidean", "sqeuclidean",
"cosine", "correlation"
]))
def test_ann_distances_metrics(algo, metric):
X, y = make_blobs(n_samples=500, centers=2,
n_features=128, random_state=0)

cu_knn = cuKNN(algorithm=algo, metric=metric)
cu_knn.fit(X)
cu_dist, cu_ind = cu_knn.kneighbors(X, n_neighbors=10,
return_distance=True)
del cu_knn
gc.collect()

X = X.get()
sk_knn = skKNN(metric=metric)
sk_knn.fit(X)
sk_dist, sk_ind = sk_knn.kneighbors(X, n_neighbors=10,
return_distance=True)

return array_equal(sk_dist, cu_dist)


def test_return_dists():
n_samples = 50
n_feats = 50
Expand Down