Skip to content

Commit

Permalink
Merge pull request #17 from chroma-core/hammad/persist_single
Browse files Browse the repository at this point in the history
[BUG] Persist single item
  • Loading branch information
HammadB authored Jul 3, 2024
2 parents 7f51170 + daea153 commit 408c5d1
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 0 deletions.
3 changes: 3 additions & 0 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1703,6 +1703,9 @@ namespace hnswlib
// Do nothing for the first element
enterpoint_node_ = 0;
maxlevel_ = curlevel;

// mark cur_c as dirty
markElementToPersist(cur_c);
}

// Releasing lock for the maximum level
Expand Down
245 changes: 245 additions & 0 deletions tests/cpp/persistent_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,231 @@ namespace
}
}

void test_persist_empty() {
int d = 1;
idx_t n = 0;
idx_t nq = 1;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n, false, false, true);
// query and expect no result
std::priority_queue<std::pair<float, idx_t>> result = alg_hnsw2->searchKnn(query.data(), 10);
assert(result.size() == 0);
}

void test_persist_size(int n) {
int d = 1;
idx_t nq = 1;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;


for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

for (size_t i = 0; i < n; i++)
{
alg_hnsw->addPoint(data.data() + d * i, i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n, false, false, true);


// Check that all data is the same
for (size_t i = 0; i < n; i++)
{
std::vector<float> actual = alg_hnsw2->template getDataByLabel<float>(i);
for (size_t j = 0; j < d; j++)
{
// Check that abs difference is less than 1e-6
if (!(std::abs(actual[j] - data[d * i + j]) < 1e-6))
{
std::cout << "actual: " << actual[j] << " expected: " << data[d * i + j] << std::endl;
}
assert(std::abs(actual[j] - data[d * i + j]) < 1e-6);
}
}

// Compare to in-memory index
for (size_t j = 0; j < nq; ++j)
{
const void *p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k);
auto res = alg_hnsw2->searchKnn(p, k);
assert(gd.size() == res.size());
int missed = 0;
for (size_t i = 0; i < gd.size(); i++)
{
assert(std::abs(gd.top().first - res.top().first) < 1e-6);
assert(gd.top().second == res.top().second);
gd.pop();
res.pop();
}
}
}

void test_persist_then_delete_size(int n) {
int d = 1;
idx_t nq = 1;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

for (size_t i = 0; i < n; i++)
{
alg_hnsw->addPoint(data.data() + d * i, i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

// Delete the inserted data and then persist
for (size_t i = 0; i < n; i++)
{
alg_hnsw->markDelete(i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n, false, false, true);

// query and expect no result
std::priority_queue<std::pair<float, idx_t>> result = alg_hnsw2->searchKnn(query.data(), 10);
assert(result.size() == 0);
}

void test_persist_size_then_add(int n, int second_n) {
int d = 1536;
idx_t nq = 1;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;


for (idx_t i = 0; i < n * d; i++)
{
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i)
{
query[i] = distrib(rng);
}

hnswlib::InnerProductSpace space(d);
hnswlib::HierarchicalNSW<float> *alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n, 16, 200, 100, false, false, true, ".");

for (size_t i = 0; i < n; i++)
{
alg_hnsw->addPoint(data.data() + d * i, i);
alg_hnsw->persistDirty();
}
alg_hnsw->persistDirty();

hnswlib::HierarchicalNSW<float> *alg_hnsw2 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n + second_n, false, false, true);

std::vector<float> data2(second_n * d);
for (idx_t i = 0; i < second_n * d; i++)
{
data2[i] = distrib(rng);
}

for (size_t i = n; i < n + second_n; i++)
{
alg_hnsw2->addPoint(data2.data() + d * (i - n), i);
alg_hnsw2->persistDirty();
}
alg_hnsw2->persistDirty();

// Load alg_hnsw3
hnswlib::HierarchicalNSW<float> *alg_hnsw3 = new hnswlib::HierarchicalNSW<float>(&space, ".", false, n + second_n, false, false, true);

// Check that all data is the same
for (size_t i = 0; i < n + second_n; i++)
{
std::vector<float> actual = alg_hnsw3->template getDataByLabel<float>(i);
for (size_t j = 0; j < d; j++)
{
// Check that abs difference is less than 1e-6
if (!(std::abs(actual[j] - (i < n ? data[d * i + j] : data2[d * (i - n) + j])) < 1e-6))
{
std::cout << "actual: " << actual[j] << " expected: " << (i < n ? data[d * i + j] : data2[d * (i - n) + j]) << std::endl;
}
assert(std::abs(actual[j] - (i < n ? data[d * i + j] : data2[d * (i - n) + j]) < 1e-6));
}
}

// Compare to in-memory index
for (size_t j = 0; j < nq; ++j)
{
const void *p = query.data() + j * d;
auto gd = alg_hnsw2->searchKnn(p, k);
auto res = alg_hnsw3->searchKnn(p, k);
assert(gd.size() == res.size());
int missed = 0;
for (size_t i = 0; i < gd.size(); i++)
{
assert(std::abs(gd.top().first - res.top().first) < 1e-6);
assert(gd.top().second == res.top().second);
gd.pop();
res.pop();
}
}
}

int main()
{
std::cout << "Testing ..." << std::endl;
Expand All @@ -311,5 +536,25 @@ int main()
std::cout << "Test testAddUpdatePersistentIndex ok" << std::endl;
testDeletePersistentIndex();
std::cout << "Test testDeletePersistentIndex ok" << std::endl;
test_persist_empty();
std::cout << "Test test_persist_empty ok" << std::endl;
test_persist_size(1);
std::cout << "Test test_persist_size(1) ok" << std::endl;
test_persist_size(2);
std::cout << "Test test_persist_size(2) ok" << std::endl;
test_persist_size(3);
std::cout << "Test test_persist_size(3) ok" << std::endl;
test_persist_then_delete_size(1);
std::cout << "Test test_persist_then_delete_size(1) ok" << std::endl;
test_persist_then_delete_size(2);
std::cout << "Test test_persist_then_delete_size(2) ok" << std::endl;
test_persist_size_then_add(1, 1);
std::cout << "Test test_persist_size_then_add(1, 1) ok" << std::endl;
test_persist_size_then_add(2, 1);
std::cout << "Test test_persist_size_then_add(2, 1) ok" << std::endl;
test_persist_size_then_add(1, 1000);
std::cout << "Test test_persist_size_then_add(1, 1000) ok" << std::endl;
test_persist_size_then_add(2, 1000);
std::cout << "Test test_persist_size_then_add(2, 1000) ok" << std::endl;
return 0;
}

0 comments on commit 408c5d1

Please sign in to comment.