Skip to content

Commit

Permalink
Add: Exact search shortcut
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Aug 5, 2023
1 parent 05e908f commit a005084
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 24 deletions.
17 changes: 16 additions & 1 deletion include/usearch/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,21 @@ template <typename allocator_at = std::allocator<byte_t>> class visits_bitset_gt
}

#endif

class lock_t {
visits_bitset_gt& bitset_;
std::size_t bit_offset_;

public:
inline ~lock_t() noexcept { bitset_.atomic_reset(bit_offset_); }
inline lock_t(visits_bitset_gt& bitset, std::size_t bit_offset) noexcept
: bitset_(bitset), bit_offset_(bit_offset) {
while (bitset_.atomic_set(bit_offset_))
;
}
};

inline lock_t lock(std::size_t i) noexcept { return {*this, i}; }
};

using visits_bitset_t = visits_bitset_gt<>;
Expand Down Expand Up @@ -2010,7 +2025,7 @@ class index_gt {
std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t));
keys[offset] = result.member.key;
distances[offset] = result.distance;
merged_count = (std::min)(merged_count + 1u, max_count);
merged_count += merged_count != max_count;
}
return merged_count;
}
Expand Down
140 changes: 134 additions & 6 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ static void search_typed( //
if (!threads)
threads = std::thread::hardware_concurrency();

std::vector<std::mutex> vectors_mutexes(static_cast<std::size_t>(vectors_count));
std::vector<std::mutex> query_mutexes(static_cast<std::size_t>(vectors_count));
executor_default_t{threads}.execute_bulk(indexes.shards_.size(), [&](std::size_t, std::size_t task_idx) {
dense_index_py_t& index = *indexes.shards_[task_idx].get();

Expand All @@ -318,7 +318,7 @@ static void search_typed( //
dense_search_result_t result = index.search(vector, wanted, config);
result.error.raise();
{
std::unique_lock<std::mutex> lock(vectors_mutexes[vector_idx]);
std::unique_lock<std::mutex> lock(query_mutexes[vector_idx]);
counts_py1d(vector_idx) = static_cast<Py_ssize_t>(result.merge_into( //
&keys_py2d(vector_idx, 0), //
&distances_py2d(vector_idx, 0), //
Expand Down Expand Up @@ -348,7 +348,7 @@ static py::tuple search_many_in_index( //
index_at& index, py::buffer vectors, std::size_t wanted, bool exact, std::size_t threads) {

if (wanted == 0)
return py::tuple(3);
return py::tuple(5);

if (index.limits().threads_search < threads)
throw std::invalid_argument("Can't use that many threads!");
Expand Down Expand Up @@ -388,6 +388,123 @@ static py::tuple search_many_in_index( //
return results;
}

template <typename scalar_at>
static void search_typed_brute_force( //
py::buffer_info& dataset_info, py::buffer_info& queries_info, //
std::size_t wanted, std::size_t threads, metric_t const& metric, //
py::array_t<key_t>& keys_py, py::array_t<distance_t>& distances_py, py::array_t<Py_ssize_t>& counts_py) {

auto keys_py2d = keys_py.template mutable_unchecked<2>();
auto distances_py2d = distances_py.template mutable_unchecked<2>();
auto counts_py1d = counts_py.template mutable_unchecked<1>();

std::size_t dataset_count = static_cast<std::size_t>(dataset_info.shape[0]);
std::size_t queries_count = static_cast<std::size_t>(queries_info.shape[0]);
std::size_t dimensions = static_cast<std::size_t>(dataset_info.shape[1]);

byte_t const* dataset_data = reinterpret_cast<byte_t const*>(dataset_info.ptr);
byte_t const* queries_data = reinterpret_cast<byte_t const*>(queries_info.ptr);
for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx)
counts_py1d(query_idx) = 0;

if (!threads)
threads = std::thread::hardware_concurrency();

std::size_t tasks_count = static_cast<std::size_t>(dataset_count * queries_count);
visits_bitset_t query_mutexes(static_cast<std::size_t>(queries_count));
if (!query_mutexes)
throw std::bad_alloc();

executor_default_t{threads}.execute_bulk(tasks_count, [&](std::size_t, std::size_t task_idx) {
//
std::size_t dataset_idx = task_idx / queries_count;
std::size_t query_idx = task_idx % queries_count;

byte_t const* dataset = dataset_data + dataset_idx * dataset_info.strides[0];
byte_t const* query = queries_data + query_idx * queries_info.strides[0];
distance_t distance = metric(dataset, query);

{
auto lock = query_mutexes.lock(query_idx);
key_t* keys = &keys_py2d(query_idx, 0);
distance_t* distances = &distances_py2d(query_idx, 0);
std::size_t& matches = reinterpret_cast<std::size_t&>(counts_py1d(query_idx));
if (matches == wanted)
if (distances[wanted - 1] <= distance)
return;

std::size_t offset = std::lower_bound(distances, distances + matches, distance) - distances;

std::size_t count_worse = matches - offset - (wanted == matches);
std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(key_t));
std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t));
keys[offset] = static_cast<key_t>(dataset_idx);
distances[offset] = distance;
matches += matches != wanted;
}

if (PyErr_CheckSignals() != 0)
throw py::error_already_set();
});
}

static py::tuple search_many_brute_force( //
py::buffer dataset, py::buffer queries, //
std::size_t wanted, std::size_t threads, //
metric_kind_t metric_kind, //
metric_signature_t metric_signature, //
std::uintptr_t metric_uintptr) {

if (wanted == 0)
return py::tuple(5);

py::buffer_info dataset_info = dataset.request();
py::buffer_info queries_info = queries.request();
if (dataset_info.ndim != 2 || queries_info.ndim != 2)
throw std::invalid_argument("Expects a matrix of dataset to add!");

Py_ssize_t dataset_count = dataset_info.shape[0];
Py_ssize_t dataset_dimensions = dataset_info.shape[1];
Py_ssize_t queries_count = queries_info.shape[0];
Py_ssize_t queries_dimensions = queries_info.shape[1];
if (dataset_dimensions != queries_dimensions)
throw std::invalid_argument("The number of vector dimensions doesn't match!");

scalar_kind_t dataset_kind = numpy_string_to_kind(dataset_info.format);
scalar_kind_t queries_kind = numpy_string_to_kind(queries_info.format);
if (dataset_kind != queries_kind)
throw std::invalid_argument("The types of vectors don't match!");

py::array_t<key_t> keys_py({dataset_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<distance_t> distances_py({dataset_count, static_cast<Py_ssize_t>(wanted)});
py::array_t<Py_ssize_t> counts_py(dataset_count);

std::size_t dimensions = static_cast<std::size_t>(queries_dimensions);
metric_t metric = //
metric_uintptr //
? udf(metric_kind, metric_signature, metric_uintptr, queries_kind, dimensions)
: metric_t(dimensions, metric_kind, queries_kind);

// clang-format off
switch (dataset_kind) {
case scalar_kind_t::b1x8_k: search_typed_brute_force<b1x8_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
case scalar_kind_t::i8_k: search_typed_brute_force<i8_bits_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
case scalar_kind_t::f16_k: search_typed_brute_force<f16_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
case scalar_kind_t::f32_k: search_typed_brute_force<f32_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
case scalar_kind_t::f64_k: search_typed_brute_force<f64_t>(dataset_info, queries_info, wanted, threads, metric, keys_py, distances_py, counts_py); break;
default: throw std::invalid_argument("Incompatible vector types: " + dataset_info.format);
}
// clang-format on

py::tuple results(5);
results[0] = keys_py;
results[1] = distances_py;
results[2] = counts_py;
results[3] = 0;
results[4] = static_cast<std::size_t>(dataset_count * queries_count);
return results;
}

static std::unordered_map<key_t, key_t> join_index( //
dense_index_py_t const& a, dense_index_py_t const& b, //
std::size_t max_proposals, bool exact) {
Expand Down Expand Up @@ -505,7 +622,7 @@ PYBIND11_MODULE(compiled, m) {
py::enum_<metric_kind_t>(m, "MetricKind")
.value("Unknown", metric_kind_t::unknown_k)

.value("IP", metric_kind_t::ip_k)
.value("IP", metric_kind_t::cos_k)
.value("Cos", metric_kind_t::cos_k)
.value("L2sq", metric_kind_t::l2sq_k)

Expand All @@ -517,7 +634,7 @@ PYBIND11_MODULE(compiled, m) {
.value("Sorensen", metric_kind_t::sorensen_k)

.value("Cosine", metric_kind_t::cos_k)
.value("InnerProduct", metric_kind_t::ip_k);
.value("InnerProduct", metric_kind_t::cos_k);

py::enum_<scalar_kind_t>(m, "ScalarKind")
.value("Unknown", scalar_kind_t::unknown_k)
Expand Down Expand Up @@ -562,13 +679,24 @@ PYBIND11_MODULE(compiled, m) {
return result;
});

m.def("exact_search", &search_many_brute_force, //
py::arg("dataset"), //
py::arg("queries"), //
py::arg("count") = 10, //
py::kw_only(), //
py::arg("threads") = 0, //
py::arg("metric_kind") = metric_kind_t::cos_k, //
py::arg("metric_signature") = metric_signature_t::array_array_k, //
py::arg("metric_pointer") = 0 //
);

auto i = py::class_<dense_index_py_t, std::shared_ptr<dense_index_py_t>>(m, "Index");

i.def(py::init(&make_index), //
py::kw_only(), //
py::arg("ndim") = 0, //
py::arg("dtype") = scalar_kind_t::f32_k, //
py::arg("metric_kind") = metric_kind_t::ip_k, //
py::arg("metric_kind") = metric_kind_t::cos_k, //
py::arg("connectivity") = default_connectivity(), //
py::arg("expansion_add") = default_expansion_add(), //
py::arg("expansion_search") = default_expansion_search(), //
Expand Down
22 changes: 22 additions & 0 deletions python/scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from usearch.io import load_matrix, save_matrix
from usearch.eval import random_vectors
from usearch.index import search

from usearch.index import (
Index,
Expand Down Expand Up @@ -81,6 +82,27 @@ def test_serializing_ibin_matrix(rows: int, cols: int):
os.remove(temporary_filename + ".ibin")


@pytest.mark.parametrize("rows", batch_sizes)
@pytest.mark.parametrize("cols", dimensions)
def test_exact_search(rows: int, cols: int):
"""
Test exact search.
:param int rows: The number of rows in the matrix.
:param int cols: The number of columns in the matrix.
"""
original = np.random.rand(rows, cols)
matches: BatchMatches = search(original, original, 10, exact=True)
top_matches = (
[int(m.keys[0]) for m in matches] if rows > 1 else int(matches.keys[0])
)
assert np.all(top_matches == np.arange(rows))

matches: Matches = search(original, original[0], 10, exact=True)
top_match = int(matches.keys[0])
assert top_match == 0


@pytest.mark.parametrize("ndim", dimensions)
@pytest.mark.parametrize("metric", continuous_metrics)
@pytest.mark.parametrize("index_type", index_types)
Expand Down
Loading

0 comments on commit a005084

Please sign in to comment.