Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Persist single item #17

Merged
merged 3 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,9 @@ namespace hnswlib
// Do nothing for the first element
enterpoint_node_ = 0;
maxlevel_ = curlevel;

// mark cur_c as dirty
markElementToPersist(cur_c);
}

// Releasing lock for the maximum level
Expand Down
245 changes: 245 additions & 0 deletions tests/cpp/persistent_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,231 @@ namespace
}
}

void test_persist_empty() {
int d = 1;
idx_t n = 0;
idx_t nq = 1;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n, false, false, true);
// query and expect no result
std::priority_queue<std::pair<float, idx_t>> result = alg_hnsw2->searchKnn(query.data(), 10);
assert(result.size() == 0);
}

void test_persist_size(int n) {
int d = 1;
idx_t nq = 1;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;


for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

for (size_t i = 0; i < n; i++)
{
alg_hnsw->addPoint(data.data() + d * i, i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n, false, false, true);


// Check that all data is the same
for (size_t i = 0; i < n; i++)
{
std::vector<float> actual = alg_hnsw2->template getDataByLabel<float>(i);
for (size_t j = 0; j < d; j++)
{
// Check that abs difference is less than 1e-6
if (!(std::abs(actual[j] - data[d * i + j]) < 1e-6))
{
std::cout << "actual: " << actual[j] << " expected: " << data[d * i + j] << std::endl;
}
assert(std::abs(actual[j] - data[d * i + j]) < 1e-6);
}
}

// Compare to in-memory index
for (size_t j = 0; j < nq; ++j)
{
const void *p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k);
auto res = alg_hnsw2->searchKnn(p, k);
assert(gd.size() == res.size());
int missed = 0;
for (size_t i = 0; i < gd.size(); i++)
{
assert(std::abs(gd.top().first - res.top().first) < 1e-6);
assert(gd.top().second == res.top().second);
gd.pop();
res.pop();
}
}
}

void test_persist_then_delete_size(int n) {
int d = 1;
idx_t nq = 1;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

for (size_t i = 0; i < n; i++)
{
alg_hnsw->addPoint(data.data() + d * i, i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

// Delete the inserted data and then persist
for (size_t i = 0; i < n; i++)
{
alg_hnsw->markDelete(i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n, false, false, true);

// query and expect no result
std::priority_queue<std::pair<float, idx_t>> result = alg_hnsw2->searchKnn(query.data(), 10);
assert(result.size() == 0);
}

void test_persist_size_then_add(int n, int second_n) {
int d = 1536;
idx_t nq = 1;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;


for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

for (size_t i = 0; i < n; i++)
{
alg_hnsw->addPoint(data.data() + d * i, i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n + second_n, false, false, true);

std::vector<float> data2(second_n * d);
for (idx_t i = 0; i < second_n * d; i++)
{
data2[i] = distrib(rng);
}

for (size_t i = n; i < n + second_n; i++)
{
alg_hnsw2->addPoint(data2.data() + d * (i - n), i);
alg_hnsw2->persistDirty();
}
alg_hnsw2->persistDirty();

// Load alg_hnsw3
hnswlib::HierarchicalNSW<float> *alg_hnsw3 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n + second_n, false, false, true);

// Check that all data is the same
for (size_t i = 0; i < n + second_n; i++)
{
std::vector<float> actual = alg_hnsw3->template getDataByLabel<float>(i);
for (size_t j = 0; j < d; j++)
{
// Check that abs difference is less than 1e-6
if (!(std::abs(actual[j] - (i < n ? data[d * i + j] : data2[d * (i - n) + j])) < 1e-6))
{
std::cout << "actual: " << actual[j] << " expected: " << (i < n ? data[d * i + j] : data2[d * (i - n) + j]) << std::endl;
}
assert(std::abs(actual[j] - (i < n ? data[d * i + j] : data2[d * (i - n) + j]) < 1e-6));
}
}

// Compare to in-memory index
for (size_t j = 0; j < nq; ++j)
{
const void *p = query.data() + j * d;
auto gd = alg_hnsw2->searchKnn(p, k);
auto res = alg_hnsw3->searchKnn(p, k);
assert(gd.size() == res.size());
int missed = 0;
for (size_t i = 0; i < gd.size(); i++)
{
assert(std::abs(gd.top().first - res.top().first) < 1e-6);
assert(gd.top().second == res.top().second);
gd.pop();
res.pop();
}
}
}

int main()
{
std::cout << "Testing ..." << std::endl;
Expand All @@ -311,5 +536,25 @@ int main()
std::cout << "Test testAddUpdatePersistentIndex ok" << std::endl;
testDeletePersistentIndex();
std::cout << "Test testDeletePersistentIndex ok" << std::endl;
test_persist_empty();
std::cout << "Test test_persist_empty ok" << std::endl;
test_persist_size(1);
std::cout << "Test test_persist_size(1) ok" << std::endl;
test_persist_size(2);
std::cout << "Test test_persist_size(2) ok" << std::endl;
test_persist_size(3);
std::cout << "Test test_persist_size(3) ok" << std::endl;
test_persist_then_delete_size(1);
std::cout << "Test test_persist_then_delete_size(1) ok" << std::endl;
test_persist_then_delete_size(2);
std::cout << "Test test_persist_then_delete_size(2) ok" << std::endl;
test_persist_size_then_add(1, 1);
std::cout << "Test test_persist_size_then_add(1, 1) ok" << std::endl;
test_persist_size_then_add(2, 1);
std::cout << "Test test_persist_size_then_add(2, 1) ok" << std::endl;
test_persist_size_then_add(1, 1000);
std::cout << "Test test_persist_size_then_add(1, 1000) ok" << std::endl;
test_persist_size_then_add(2, 1000);
std::cout << "Test test_persist_size_then_add(2, 1000) ok" << std::endl;
return 0;
}
Loading