From e40b27f7a15e1b8b6c63c892aabc37c53f54f3f6 Mon Sep 17 00:00:00 2001 From: Sameh Gobriel <75963591+s-gobriel@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:55:09 -0700 Subject: [PATCH] adding faisshnswflat indexing (#212) * adding faisshnswflat indexing * address comments * Automated updates: Format and/or coverage --------- Co-authored-by: sys_vdms --- .../coverage/cpp.develop.coverage_report.txt | 8 +- .../coverage/cpp.develop.coverage_value.txt | 2 +- client/cpp/DescriptorSetQueryParser.h | 3 +- include/vcl/DescriptorSet.h | 3 +- src/DescriptorsCommand.cc | 2 + src/vcl/DescriptorSet.cc | 4 + src/vcl/FaissDescriptorSet.cc | 75 ++++++++++- src/vcl/FaissDescriptorSet.h | 13 ++ tests/unit_tests/DescriptorSetAdd_test.cc | 125 ++++++++++++++++++ tests/unit_tests/helpers.cc | 1 + 10 files changed, 228 insertions(+), 8 deletions(-) diff --git a/.github/coverage/cpp.develop.coverage_report.txt b/.github/coverage/cpp.develop.coverage_report.txt index c60a3f56..94b3f776 100644 --- a/.github/coverage/cpp.develop.coverage_report.txt +++ b/.github/coverage/cpp.develop.coverage_report.txt @@ -11,7 +11,7 @@ src/BackendNeo4j.cc 121 0 0% 4,6-17,20,24,2 src/BlobCommand.cc 87 66 75% 76,130-132,136-139,145,147,165,186-189,192-196,202 src/BoundingBoxCommand.cc 180 4 2% 45,49,51,53-54,56-59,62,64-67,70-73,76,83,87,90-91,93-97,101,103,105,114,118,122-123,125-132,137-138,140-144,147-150,152,154-160,162-165,167-169,171-173,176-177,179-181,183-184,186-187,190,193,196-197,199,201-204,206-210,213,215-219,222-223,225-227,229-237,240-244,246,251-256,259-261,263,265-266,268,270,272-274,276-277,281-283,286,288,292-294,296,298,300-303,307-308,310-313,316-319,321-326,329-330,335,338-339,341 src/CommunicationManager.cc 46 0 0% 42-43,46-47,49-50,52-54,57,61-66,68-71,73-81,84,86-88,93,96-97,100-101,105,107-108,110,113-116 -src/DescriptorsCommand.cc 668 107 16% 56,63-68,73,75-79,81-85,87,89,92-93,96-98,101,103-104,107-113,115,118,121,168-170,174,188-192,232-243,253,267-269,273,288-295,297,309-312,317,322-326,343,351,353-356,359,362-363,365-369,372,375,377-378,381-383,385-386,388-389,392-393,395-397,403,408-409,411,413-414,417,419-422,425-430,436-437,440,442-443,445,447-448,451,454,456,458,461,463-468,470-474,476,479,482-483,486-487,498,501,506,512-513,518-519,521-524,527,533-534,539-540,543-549,552,554-555,557,559-560,562-563,567,569-574,577-582,585,588-593,595,621-622,625-626,629,632,639-640,642-644,646,649-650,652,655,658,660-662,670,674,676,679,681-684,687-692,694-698,700,703,705,707,710,713-714,716,718,720-725,728,730-732,735-736,738,740,742-743,745,747,749-751,753,755-760,765-766,769,771,773,780-781,784,788,790,793,795-798,801-805,807-811,813,815-818,820-822,825,829-838,843-844,847,853,855,858-859,862,866,871,873,876,879-880,884,889,891,893,896,899-900,902-905,909,913-914,916,918-920,922,924-925,927,929-931,933-935,940-942,945-946,948,951-955,959,962,966,968-969,971-972,974,976-977,979-981,983,988,990-991,993-995,997-998,1002,1004-1007,1009-1012,1017,1020-1022,1024,1026,1028-1031,1033-1037,1040-1041,1044,1046,1048,1050-1054,1057-1058,1061-1064,1066,1068-1075,1080,1082,1084-1085,1088-1091,1093,1095,1097-1104,1108-1113,1119,1122,1124,1126-1129,1132,1134,1137-1140,1142-1146,1152,1154,1156-1159,1164,1166,1168-1169,1172-1174,1176,1178,1180-1183,1186-1187,1189-1190,1192,1194-1195,1197-1203,1207-1208,1212-1213,1215-1216,1224-1225,1228-1229,1231,1233-1240,1244,1248-1249,1253-1257,1262-1263,1265 +src/DescriptorsCommand.cc 670 107 16% 56,63-68,73,75-79,81-85,87,89,92-93,96-98,101,103-104,107-113,115,118,121,168-170,174,188-192,232-243,253,267-269,273,288-297,299,311-314,319,324-328,345,353,355-358,361,364-365,367-371,374,377,379-380,383-385,387-388,390-391,394-395,397-399,405,410-411,413,415-416,419,421-424,427-432,438-439,442,444-445,447,449-450,453,456,458,460,463,465-470,472-476,478,481,484-485,488-489,500,503,508,514-515,520-521,523-526,529,535-536,541-542,545-551,554,556-557,559,561-562,564-565,569,571-576,579-584,587,590-595,597,623-624,627-628,631,634,641-642,644-646,648,651-652,654,657,660,662-664,672,676,678,681,683-686,689-694,696-700,702,705,707,709,712,715-716,718,720,722-727,730,732-734,737-738,740,742,744-745,747,749,751-753,755,757-762,767-768,771,773,775,782-783,786,790,792,795,797-800,803-807,809-813,815,817-820,822-824,827,831-840,845-846,849,855,857,860-861,864,868,873,875,878,881-882,886,891,893,895,898,901-902,904-907,911,915-916,918,920-922,924,926-927,929,931-933,935-937,942-944,947-948,950,953-957,961,964,968,970-971,973-974,976,978-979,981-983,985,990,992-993,995-997,999-1000,1004,1006-1009,1011-1014,1019,1022-1024,1026,1028,1030-1033,1035-1039,1042-1043,1046,1048,1050,1052-1056,1059-1060,1063-1066,1068,1070-1077,1082,1084,1086-1087,1090-1093,1095,1097,1099-1106,1110-1115,1121,1124,1126,1128-1131,1134,1136,1139-1142,1144-1148,1154,1156,1158-1161,1166,1168,1170-1171,1174-1176,1178,1180,1182-1185,1188-1189,1191-1192,1194,1196-1197,1199-1205,1209-1210,1214-1215,1217-1218,1226-1227,1230-1231,1233,1235-1242,1246,1250-1251,1255-1259,1264-1265,1267 src/DescriptorsManager.cc 24 19 79% 49-50,57-58,73 src/ExceptionsCommand.cc 6 0 0% 35-40 src/ImageCommand.cc 322 157 48% 55,59,63,65,67-69,71,73-76,78,81,86,88-89,97,99,106,109,111-112,114-115,117-118,120-121,124,151,162-163,174-175,177,182-185,195-196,198,203-206,221-229,231-233,246-247,257-267,269-270,272-273,278,286,297,304,308,311,313,315,337,339-340,343-348,350,352,374-376,379-381,385-388,394,396,403-406,420,427,433-436,440-441,452-455,458-463,468-470,481-484,489-493,498-499,501-502,504-508,511,513-517,520-523,526-527,530,532,537 @@ -31,10 +31,10 @@ src/RSCommand.cc 144 105 72% 65-67,73-74,98 src/SearchExpression.cc 99 38 38% 59,132-133,135,137-139,143,146,148-153,157,160,168-170,177,180-181,183-185,188,192-195,197,201,217-222,224-225,227,235-240,243,247-249,252-256,263,276,284-285 src/Server.cc 138 0 0% 57-58,60,64,68-70,72-73,77-78,80,85,88,90,92,95,97-98,103-106,108-109,112,116,118,122,125,128,131,133-134,136-138,140,142,145-152,154-155,159,162-167,170,172-173,176-177,181,183,185-186,188-190,192,194,197-199,203-206,208-209,212-216,218-219,221,223,225,228,231-232,236-237,239-240,242,246-250,253,255-257,260-263,265-268,272-273,276,280-281,284-287,292,294-301,304-309 src/vcl/CustomVCL.cc 51 22 43% 55,57-58,60-63,66,69-70,72,74,76-78,82-83,89-93,95,98-99,102,104-105,110 -src/vcl/DescriptorSet.cc 205 150 73% 64-65,67-68,89-90,111-112,129,131,133,186-189,192,217-218,220-221,224-227,236-240,252,264,315-316,319,321-324,327,341-342,344-346,350-352,354,361-363,365-366,369-370 +src/vcl/DescriptorSet.cc 209 155 74% 65,69-70,93-94,115-116,133,135,137,190-193,196,221-222,224-225,228-231,240-244,256,268,319-320,323,325-328,331,345-346,348-350,354-356,358,365-367,369-370,373-374 src/vcl/DescriptorSetData.cc 55 47 85% 48,58,64,67,114,116-118 src/vcl/Exception.cc 7 6 85% 38 -src/vcl/FaissDescriptorSet.cc 182 157 86% 83,115-116,132,167,187-188,204-205,224-225,238-239,245,258-259,261,272-273,279,299-300,302-303,305 +src/vcl/FaissDescriptorSet.cc 206 177 85% 83,115-116,132,167,187-188,204-205,224-225,238-239,245,258-259,261,272-273,279,303-304,306-307,309,372-373,379,397 src/vcl/FlinngDescriptorSet.cc 149 109 73% 60-66,89,109-111,113-114,118-121,124,126,128,130,132,134-137,140-141,143-144,170-171,176-177,182,206,208,228,248,279 src/vcl/Image.cc 910 689 75% 62,73-74,76-78,81-84,86,92,101,122-123,125,132-133,135,147,165,170,193,196-199,223,246,249-252,264,273,276-279,291,323,326-329,341,347,349-352,360-362,369,393-396,415,417,425,427,432,436,441,445,459,462,467-468,471-472,474,490,500,513,531,553-556,594,605-606,608,615,619,624,627-630,658-660,712,757-758,809,838-842,844-850,852,854-855,896,899-900,939-940,944-945,966,985-986,988,1028-1030,1032-1036,1038-1042,1044-1048,1050-1054,1056-1060,1062-1065,1088,1109,1128-1136,1147-1148,1167-1186,1198-1199,1207,1218,1220-1222,1224-1226,1228,1242,1246-1247,1249,1254-1255,1257,1278,1282,1285,1292,1307,1313,1322,1336,1361,1379,1462,1481 src/vcl/KeyFrame.cc 303 244 80% 58,62,86,90,95,97,102,105-107,109-111,113,119,139,148,154,172,186,190,216,220,224,235,239,249,255,274,284,288,307,315,341,345,347,359,367,369,394,396,405,430,442,449,465,469,478,483,495,500,507,514,518,525,541,547,557,563 @@ -57,5 +57,5 @@ utils/src/comm/Exception.cc 6 0 0% 35-40 utils/src/stats/SystemStats.cc 250 249 99% 453 utils/src/timers/TimerMap.cc 82 75 91% 126,151,153,155-158 ------------------------------------------------------------------------------ -TOTAL 10142 6488 64% +TOTAL 10172 6513 64% ------------------------------------------------------------------------------ diff --git a/.github/coverage/cpp.develop.coverage_value.txt b/.github/coverage/cpp.develop.coverage_value.txt index 50b05238..e525df55 100644 --- a/.github/coverage/cpp.develop.coverage_value.txt +++ b/.github/coverage/cpp.develop.coverage_value.txt @@ -1 +1 @@ -63.9716 +64.0287 diff --git a/client/cpp/DescriptorSetQueryParser.h b/client/cpp/DescriptorSetQueryParser.h index 27640b14..dbe2b554 100644 --- a/client/cpp/DescriptorSetQueryParser.h +++ b/client/cpp/DescriptorSetQueryParser.h @@ -50,5 +50,6 @@ bool VDMS::DescriptorSetQueryParser::isValidMetric(string &metric) { bool VDMS::DescriptorSetQueryParser::isValidEngine(string &engine) { return (engine == "TileDBDense" || engine == "TileDBSparse" || - engine == "FaissFlat" || engine == "FaissIVFFlat"); + engine == "FaissFlat" || engine == "FaissIVFFlat" || + engine == "Flinng" || engine == "FaissHNSWFlat"); } diff --git a/include/vcl/DescriptorSet.h b/include/vcl/DescriptorSet.h index be4cfb5f..c35820a0 100644 --- a/include/vcl/DescriptorSet.h +++ b/include/vcl/DescriptorSet.h @@ -53,7 +53,8 @@ enum DescriptorSetEngine { FaissIVFFlat, TileDBDense, TileDBSparse, - Flinng + Flinng, + FaissHNSWFlat }; enum DistanceMetric { L2, IP }; diff --git a/src/DescriptorsCommand.cc b/src/DescriptorsCommand.cc index df92b699..23907185 100644 --- a/src/DescriptorsCommand.cc +++ b/src/DescriptorsCommand.cc @@ -293,6 +293,8 @@ Json::Value AddDescriptorSet::construct_responses( _eng = VCL::TileDBSparse; else if (eng_str == "Flinng") _eng = VCL::Flinng; + else if (eng_str == "FaissHNSWFlat") + _eng = VCL::FaissHNSWFlat; else throw ExceptionCommand(DescriptorSetError, "Engine not supported"); diff --git a/src/vcl/DescriptorSet.cc b/src/vcl/DescriptorSet.cc index 72a62a5b..e8f5cadd 100644 --- a/src/vcl/DescriptorSet.cc +++ b/src/vcl/DescriptorSet.cc @@ -63,6 +63,8 @@ DescriptorSet::DescriptorSet(const std::string &set_path) { _set = new TDBSparseDescriptorSet(set_path); else if (_eng == DescriptorSetEngine(Flinng)) _set = new FlinngDescriptorSet(set_path); + else if (_eng == DescriptorSetEngine(FaissHNSWFlat)) + _set = new FaissHNSWFlatDescriptorSet(set_path); else { std::cerr << "Index Not supported" << std::endl; throw VCLException(UnsupportedIndex, "Index not supported"); @@ -85,6 +87,8 @@ DescriptorSet::DescriptorSet(const std::string &set_path, unsigned dim, _set = new TDBSparseDescriptorSet(set_path, dim, metric); else if (eng == DescriptorSetEngine(Flinng)) _set = new FlinngDescriptorSet(set_path, dim, metric, param); + else if (eng == DescriptorSetEngine(FaissHNSWFlat)) + _set = new FaissHNSWFlatDescriptorSet(set_path, dim, metric); else { std::cerr << "Index Not supported" << std::endl; throw VCLException(UnsupportedIndex, "Index not supported"); diff --git a/src/vcl/FaissDescriptorSet.cc b/src/vcl/FaissDescriptorSet.cc index 22c1f33f..4adb5e52 100644 --- a/src/vcl/FaissDescriptorSet.cc +++ b/src/vcl/FaissDescriptorSet.cc @@ -289,7 +289,11 @@ FaissIVFFlatDescriptorSet::FaissIVFFlatDescriptorSet( // TODO: Revise nlist param for future optimizations. // 4 is a suggested value by faiss for the IVFFlat index, // that's why we leave it for now. - int nlist = 4; + // int nlist = 4; + + // default value of 4 is too low for any sizeable dataset + + int nlist = 16; if (metric == L2) { faiss::IndexFlatL2 *quantizer = new faiss::IndexFlatL2(_dimensions); @@ -339,3 +343,72 @@ long FaissIVFFlatDescriptorSet::add(float *descriptors, unsigned n, return id_first; } + +// FaissHNSWFlat +// Note: +// setting value of hnsw m= 48 +// M is number of connections each vertex will have +// i.e. number of nearest neighbors that each vertex will connect to. +// Total memory usage is (dim * 4 + M * 2 * 4) bytes per vector. + +// Rule of thumb to support 1M entries +// according to +// https://github.com/facebookresearch/faiss/wiki/Indexing-1M-vectors HNSW_M= 48 +// efConstruction = 2 x m +// efSearch = 32 +// Default values from hnsw.h +// int efConstruction = 40; +// int efSearch = 16; +// efsearch will be defined in the search params when IndexHNSW::search() is +// called (can dynamically change during runtime) efconstruction is a compile +// time paramerter will be defined in IndexHNSW.hnsw struct + +FaissHNSWFlatDescriptorSet::FaissHNSWFlatDescriptorSet( + const std::string &set_path) + : FaissDescriptorSet(set_path) { + try { + _index = faiss::read_index(_faiss_file.c_str()); + + } catch (faiss::FaissException &e) { + throw VCLException(OpenFailed, "Problem reading: " + _faiss_file); + } + + // Faiss will sometimes throw, or sometimes set _index = NULL, + // we check both just in case. + if (!_index) { + throw VCLException(OpenFailed, "Problem reading: " + _faiss_file); + } + + _dimensions = _index->d; + _n_total = _index->ntotal; +} + +FaissHNSWFlatDescriptorSet::FaissHNSWFlatDescriptorSet( + const std::string &set_path, unsigned dim, DistanceMetric metric) + : FaissDescriptorSet(set_path, dim) { + + int hnsw_M = 48; + if (metric == L2) { + _index = new faiss::IndexHNSWFlat(dim, hnsw_M, faiss::METRIC_L2); + ((faiss::IndexHNSWFlat *)_index)->hnsw.efConstruction = 96; + } else { + // only metric L2 is supported for HNSWFLAT for FAISS v1.7.4 + // newer version of Faiss e.g. V1.8.0 supports I.P. metric for HNSW + throw VCLException(UnsupportedIndex, "Metric Not implemented"); + } +} + +void FaissHNSWFlatDescriptorSet::search(float *query, unsigned n_queries, + unsigned k, long *descriptors, + float *distances) { + ((faiss::IndexHNSWFlat *)_index)->hnsw.efSearch = 64; + // set according to + // https://github.com/facebookresearch/faiss/wiki/Indexing-1M-vectors for R@1 + // accuracy of 0.9779 + // The higher the value the slower the search is but better accuracy + // efsearch is a runtime parameter. + // ToDO - VDMS should expose an API to set runtime parameters to users of the + // different indices + + _index->search(n_queries, query, k, distances, descriptors); +} diff --git a/src/vcl/FaissDescriptorSet.h b/src/vcl/FaissDescriptorSet.h index d12badcf..d484969e 100644 --- a/src/vcl/FaissDescriptorSet.h +++ b/src/vcl/FaissDescriptorSet.h @@ -44,6 +44,7 @@ #include "DescriptorSetData.h" #include +#include #include namespace VCL { @@ -109,4 +110,16 @@ class FaissIVFFlatDescriptorSet : public FaissDescriptorSet { long add(float *descriptors, unsigned n_descriptors, long *classes); }; + +class FaissHNSWFlatDescriptorSet : public FaissDescriptorSet { + +public: + FaissHNSWFlatDescriptorSet(const std::string &set_path); + FaissHNSWFlatDescriptorSet(const std::string &set_path, unsigned dim, + DistanceMetric metric); + + void search(float *query, unsigned n_queries, unsigned k, long *descriptors, + float *distances); +}; + }; // namespace VCL diff --git a/tests/unit_tests/DescriptorSetAdd_test.cc b/tests/unit_tests/DescriptorSetAdd_test.cc index 958aafbc..a90ebac5 100644 --- a/tests/unit_tests/DescriptorSetAdd_test.cc +++ b/tests/unit_tests/DescriptorSetAdd_test.cc @@ -218,6 +218,131 @@ TEST(Descriptors_Add, add_flatl2_100d_2add) { delete[] xb; } +// HNSW Tests +TEST(Descriptors_Add, add_hnswflatl2_100d) { + + // test to add 1K descriptors of 100D each + // descriptors are created by incrementing a random initial value as follows + // init init ... init (D times) + // init+1 init+1 ... init+1 (D times) + // ... + // init+nb-1 init+nb-1 ... init+nb-1 (d times) + // hence, nearest neigbor of any query descriptor are the IDs that is around + // the query ID + + int d = 100; + int nb = 10000; + float *xb = generate_desc_linear_increase(d, nb); + + std::string index_filename = "dbs/add_hnswflatl2_100d"; + VCL::DescriptorSet index(index_filename, unsigned(d), VCL::FaissHNSWFlat); + + std::vector classes(nb); + + for (auto &str : classes) { + str = 1; + } + + index.add(xb, nb, classes); + + std::vector distances; + std::vector desc_ids; + index.search(xb, 1, 4, desc_ids, distances); + + int exp = 0; + for (auto &desc : desc_ids) { + EXPECT_EQ(desc, exp++); + } + + float results[] = {float(std::pow(0, 2) * d), float(std::pow(1, 2) * d), + float(std::pow(2, 2) * d), float(std::pow(3, 2) * d)}; + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(distances[i], results[i]); + } + + index.store(); + + delete[] xb; +} + +TEST(Descriptors_Add, add_recons_hnswflatl2_100d) { + + // test to add 1K descriptors of 100D each + // Same as last previous test case but addes classes as labels + // classes will be searched and checked of nearest neighbors + + int d = 100; + int nb = 10000; + float *xb = generate_desc_linear_increase(d, nb); + + std::string index_filename = "dbs/add_recons_hnswflatl2_100d"; + VCL::DescriptorSet index(index_filename, unsigned(d), VCL::FaissHNSWFlat); + + std::vector classes(nb); + + for (auto &cl : classes) { + cl = 1; + } + + index.add(xb, nb, classes); + + std::vector distances; + std::vector desc_ids; + index.search(xb, 1, 4, desc_ids, distances); + desc_ids.clear(); + + float *recons = new float[d * nb]; + for (int i = 0; i < nb; ++i) { + desc_ids.push_back(i); + } + + index.get_descriptors(desc_ids, recons); + + for (int i = 0; i < nb * d; ++i) { + EXPECT_EQ(xb[i], recons[i]); + } + + index.store(); + + delete[] xb; +} + +TEST(Descriptors_Add, add_hnswflatl2_100d_2add) { + // test to add 2K descriptors of 100D each + // this is done in 2 steps + // first 1K and then the index is stored to a file + // second 1K are added after the index is read from a file + // the test case is to test file i/o of the index + + int d = 100; + int nb = 10000; + float *xb = generate_desc_linear_increase(d, nb); + + std::string index_filename = "dbs/add_hnswflatl2_100d_2add"; + VCL::DescriptorSet index(index_filename, unsigned(d), VCL::FaissHNSWFlat); + + index.add(xb, nb); + + generate_desc_linear_increase(d, nb, xb, .6); + + index.add(xb, nb); + + generate_desc_linear_increase(d, 4, xb, 0); + + std::vector distances; + std::vector desc_ids; + index.search(xb, 1, 4, desc_ids, distances); + + float results[] = {float(std::pow(0, 2) * d), float(std::pow(.6, 2) * d), + float(std::pow(1, 2) * d), float(std::pow(1.6, 2) * d)}; + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(std::round(distances[i]), std::round(results[i])); + } + + index.store(); + delete[] xb; +} + // Flinng Tests TEST(Descriptors_Add, add_flinngIP_100d) { diff --git a/tests/unit_tests/helpers.cc b/tests/unit_tests/helpers.cc index 2644019c..bb4ffdb3 100644 --- a/tests/unit_tests/helpers.cc +++ b/tests/unit_tests/helpers.cc @@ -334,6 +334,7 @@ std::vector get_engines() { std::vector engs; engs.push_back(VCL::FaissFlat); engs.push_back(VCL::FaissIVFFlat); + engs.push_back(VCL::FaissHNSWFlat); engs.push_back(VCL::TileDBDense); engs.push_back(VCL::TileDBSparse); // engs.push_back(VCL::Flinng);