diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index cc767847..8adaa591 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -1703,6 +1703,9 @@ namespace hnswlib // Do nothing for the first element enterpoint_node_ = 0; maxlevel_ = curlevel; + + // mark cur_c as dirty + markElementToPersist(cur_c); } // Releasing lock for the maximum level diff --git a/tests/cpp/persistent_test.cpp b/tests/cpp/persistent_test.cpp index 666999b5..d93e66af 100644 --- a/tests/cpp/persistent_test.cpp +++ b/tests/cpp/persistent_test.cpp @@ -300,6 +300,231 @@ namespace } } +void test_persist_empty() { + int d = 1; + idx_t n = 0; + idx_t nq = 1; + + std::vector data(n * d); + std::vector query(nq * d); + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; i++) + { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) + { + query[i] = distrib(rng); + } + + hnswlib::InnerProductSpace space(d); + hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, n, 16, 200, 100, false, false, true, "."); + + alg_hnsw->persistDirty(); + + hnswlib::HierarchicalNSW *alg_hnsw2 = new hnswlib::HierarchicalNSW(&space, ".", false, n, false, false, true); + // query and expect no result + std::priority_queue> result = alg_hnsw2->searchKnn(query.data(), 10); + assert(result.size() == 0); +} + +void test_persist_size(int n) { + int d = 1; + idx_t nq = 1; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + + for (idx_t i = 0; i < n * d; i++) + { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) + { + query[i] = distrib(rng); + } + + hnswlib::InnerProductSpace space(d); + hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, n, 16, 200, 100, false, false, true, "."); + + for (size_t i = 0; i < n; i++) + { + alg_hnsw->addPoint(data.data() + d * i, i); + alg_hnsw->persistDirty(); + } + alg_hnsw->persistDirty(); + + hnswlib::HierarchicalNSW *alg_hnsw2 = new hnswlib::HierarchicalNSW(&space, ".", false, n, false, false, true); + + + // Check that all data is the same + for (size_t i = 0; i < n; i++) + { + std::vector actual = alg_hnsw2->template getDataByLabel(i); + for (size_t j = 0; j < d; j++) + { + // Check that abs difference is less than 1e-6 + if (!(std::abs(actual[j] - data[d * i + j]) < 1e-6)) + { + std::cout << "actual: " << actual[j] << " expected: " << data[d * i + j] << std::endl; + } + assert(std::abs(actual[j] - data[d * i + j]) < 1e-6); + } + } + + // Compare to in-memory index + for (size_t j = 0; j < nq; ++j) + { + const void *p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k); + auto res = alg_hnsw2->searchKnn(p, k); + assert(gd.size() == res.size()); + int missed = 0; + for (size_t i = 0; i < gd.size(); i++) + { + assert(std::abs(gd.top().first - res.top().first) < 1e-6); + assert(gd.top().second == res.top().second); + gd.pop(); + res.pop(); + } + } +} + +void test_persist_then_delete_size(int n) { + int d = 1; + idx_t nq = 1; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; i++) + { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) + { + query[i] = distrib(rng); + } + + hnswlib::InnerProductSpace space(d); + hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, n, 16, 200, 100, false, false, true, "."); + + for (size_t i = 0; i < n; i++) + { + alg_hnsw->addPoint(data.data() + d * i, i); + alg_hnsw->persistDirty(); + } + alg_hnsw->persistDirty(); + + // Delete the inserted data and then persist + for (size_t i = 0; i < n; i++) + { + alg_hnsw->markDelete(i); + alg_hnsw->persistDirty(); + } + alg_hnsw->persistDirty(); + + hnswlib::HierarchicalNSW *alg_hnsw2 = new hnswlib::HierarchicalNSW(&space, ".", false, n, false, false, true); + + // query and expect no result + std::priority_queue> result = alg_hnsw2->searchKnn(query.data(), 10); + assert(result.size() == 0); +} + +void test_persist_size_then_add(int n, int second_n) { + int d = 1536; + idx_t nq = 1; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + + for (idx_t i = 0; i < n * d; i++) + { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) + { + query[i] = distrib(rng); + } + + hnswlib::InnerProductSpace space(d); + hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, n, 16, 200, 100, false, false, true, "."); + + for (size_t i = 0; i < n; i++) + { + alg_hnsw->addPoint(data.data() + d * i, i); + alg_hnsw->persistDirty(); + } + alg_hnsw->persistDirty(); + + hnswlib::HierarchicalNSW *alg_hnsw2 = new hnswlib::HierarchicalNSW(&space, ".", false, n + second_n, false, false, true); + + std::vector data2(second_n * d); + for (idx_t i = 0; i < second_n * d; i++) + { + data2[i] = distrib(rng); + } + + for (size_t i = n; i < n + second_n; i++) + { + alg_hnsw2->addPoint(data2.data() + d * (i - n), i); + alg_hnsw2->persistDirty(); + } + alg_hnsw2->persistDirty(); + + // Load alg_hnsw3 + hnswlib::HierarchicalNSW *alg_hnsw3 = new hnswlib::HierarchicalNSW(&space, ".", false, n + second_n, false, false, true); + + // Check that all data is the same + for (size_t i = 0; i < n + second_n; i++) + { + std::vector actual = alg_hnsw3->template getDataByLabel(i); + for (size_t j = 0; j < d; j++) + { + // Check that abs difference is less than 1e-6 + if (!(std::abs(actual[j] - (i < n ? data[d * i + j] : data2[d * (i - n) + j])) < 1e-6)) + { + std::cout << "actual: " << actual[j] << " expected: " << (i < n ? data[d * i + j] : data2[d * (i - n) + j]) << std::endl; + } + assert(std::abs(actual[j] - (i < n ? data[d * i + j] : data2[d * (i - n) + j]) < 1e-6)); + } + } + + // Compare to in-memory index + for (size_t j = 0; j < nq; ++j) + { + const void *p = query.data() + j * d; + auto gd = alg_hnsw2->searchKnn(p, k); + auto res = alg_hnsw3->searchKnn(p, k); + assert(gd.size() == res.size()); + int missed = 0; + for (size_t i = 0; i < gd.size(); i++) + { + assert(std::abs(gd.top().first - res.top().first) < 1e-6); + assert(gd.top().second == res.top().second); + gd.pop(); + res.pop(); + } + } +} + int main() { std::cout << "Testing ..." << std::endl; @@ -311,5 +536,25 @@ int main() std::cout << "Test testAddUpdatePersistentIndex ok" << std::endl; testDeletePersistentIndex(); std::cout << "Test testDeletePersistentIndex ok" << std::endl; + test_persist_empty(); + std::cout << "Test test_persist_empty ok" << std::endl; + test_persist_size(1); + std::cout << "Test test_persist_size(1) ok" << std::endl; + test_persist_size(2); + std::cout << "Test test_persist_size(2) ok" << std::endl; + test_persist_size(3); + std::cout << "Test test_persist_size(3) ok" << std::endl; + test_persist_then_delete_size(1); + std::cout << "Test test_persist_then_delete_size(1) ok" << std::endl; + test_persist_then_delete_size(2); + std::cout << "Test test_persist_then_delete_size(2) ok" << std::endl; + test_persist_size_then_add(1, 1); + std::cout << "Test test_persist_size_then_add(1, 1) ok" << std::endl; + test_persist_size_then_add(2, 1); + std::cout << "Test test_persist_size_then_add(2, 1) ok" << std::endl; + test_persist_size_then_add(1, 1000); + std::cout << "Test test_persist_size_then_add(1, 1000) ok" << std::endl; + test_persist_size_then_add(2, 1000); + std::cout << "Test test_persist_size_then_add(2, 1000) ok" << std::endl; return 0; }