diff --git a/docs/changelog/117994.yaml b/docs/changelog/117994.yaml new file mode 100644 index 0000000000000..603f2d855a11a --- /dev/null +++ b/docs/changelog/117994.yaml @@ -0,0 +1,5 @@ +pr: 117994 +summary: Even better(er) binary quantization +area: Vector Search +type: enhancement +issues: [] diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index e2af894eb0939..7347d9c1312dd 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -67,4 +67,6 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task -> task.skipTest("logsdb/20_source_mapping/include/exclude is supported with stored _source", "no longer serialize source_mode") task.skipTest("logsdb/20_source_mapping/synthetic _source is default", "no longer serialize source_mode") task.skipTest("search/520_fetch_fields/fetch _seq_no via fields", "error code is changed from 5xx to 400 in 9.0") + task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions") + task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions") }) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 188c155e4a836..5767c895fbe7e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -11,20 +11,11 @@ setup: number_of_shards: 1 mappings: properties: - name: - type: keyword vector: type: dense_vector dims: 64 index: true - similarity: l2_norm - index_options: - type: bbq_hnsw - another_vector: - type: dense_vector - dims: 64 - index: true - similarity: l2_norm + similarity: max_inner_product index_options: type: bbq_hnsw @@ -33,9 +24,14 @@ setup: index: bbq_hnsw id: "1" body: - name: cow.jpg - vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] - another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] # Flush in order to provoke a merge later - do: indices.flush: @@ -46,9 +42,14 @@ setup: index: bbq_hnsw id: "2" body: - name: moose.jpg - vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] - another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] # Flush in order to provoke a merge later - do: indices.flush: @@ -60,8 +61,14 @@ setup: id: "3" body: name: rabbit.jpg - vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] - another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] # Flush in order to provoke a merge later - do: indices.flush: @@ -73,20 +80,33 @@ setup: max_num_segments: 1 --- "Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" - do: search: index: bbq_hnsw body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 - # Depending on how things are distributed, docs 2 and 3 might be swapped - # here we verify that are last hit is always the worst one - - match: { hits.hits.2._id: "1" } - + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } --- "Test bad quantization parameters": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index ed7a8dd5df65d..dcdae04aeabb4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -11,20 +11,11 @@ setup: number_of_shards: 1 mappings: properties: - name: - type: keyword vector: type: dense_vector dims: 64 index: true - similarity: l2_norm - index_options: - type: bbq_flat - another_vector: - type: dense_vector - dims: 64 - index: true - similarity: l2_norm + similarity: max_inner_product index_options: type: bbq_flat @@ -33,9 +24,14 @@ setup: index: bbq_flat id: "1" body: - name: cow.jpg - vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] - another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] # Flush in order to provoke a merge later - do: indices.flush: @@ -46,9 +42,14 @@ setup: index: bbq_flat id: "2" body: - name: moose.jpg - vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] - another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] # Flush in order to provoke a merge later - do: indices.flush: @@ -59,9 +60,14 @@ setup: index: bbq_flat id: "3" body: - name: rabbit.jpg - vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] - another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] # Flush in order to provoke a merge later - do: indices.flush: @@ -73,19 +79,33 @@ setup: max_num_segments: 1 --- "Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" - do: search: index: bbq_flat body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 - # Depending on how things are distributed, docs 2 and 3 might be swapped - # here we verify that are last hit is always the worst one - - match: { hits.hits.2._id: "1" } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } --- "Test bad parameters": - do: diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 331a2bc0dddac..ff902dbede007 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -459,7 +459,9 @@ org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat, org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat, org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java index 5201e57179cc7..1aff06a175967 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java @@ -40,6 +40,20 @@ public static boolean isUnitVector(float[] v) { return Math.abs(l1norm - 1.0d) <= EPSILON; } + public static void packAsBinary(byte[] vector, byte[] packed) { + for (int i = 0; i < vector.length;) { + byte result = 0; + for (int j = 7; j >= 0 && i < vector.length; j--) { + assert vector[i] == 0 || vector[i] == 1; + result |= (byte) ((vector[i] & 1) << j); + ++i; + } + int index = ((i + 7) / 8) - 1; + assert index < packed.length; + packed[index] = result; + } + } + public static int discretize(int value, int bucket) { return ((value + (bucket - 1)) / bucket) * bucket; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java index 445bdadab2354..e85079e998c61 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java @@ -48,13 +48,8 @@ class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer { public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues - ) throws IOException { - if (vectorValues instanceof BinarizedByteVectorValues) { - throw new UnsupportedOperationException( - "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" - ); - } - return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + ) { + throw new UnsupportedOperationException(); } @Override @@ -90,61 +85,11 @@ public RandomVectorScorer getRandomVectorScorer( return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } - RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, - ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, - BinarizedByteVectorValues targetVectors - ) { - return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); - } - @Override public String toString() { return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; } - /** Vector scorer supplier over binarized vector values */ - static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; - private final BinarizedByteVectorValues targetVectors; - private final VectorSimilarityFunction similarityFunction; - - BinarizedRandomVectorScorerSupplier( - ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, - BinarizedByteVectorValues targetVectors, - VectorSimilarityFunction similarityFunction - ) { - this.queryVectors = queryVectors; - this.targetVectors = targetVectors; - this.similarityFunction = similarityFunction; - } - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - byte[] vector = queryVectors.vectorValue(ord); - int quantizedSum = queryVectors.sumQuantizedValues(ord); - float distanceToCentroid = queryVectors.getCentroidDistance(ord); - float lower = queryVectors.getLower(ord); - float width = queryVectors.getWidth(ord); - float normVmC = 0f; - float vDotC = 0f; - if (similarityFunction != EUCLIDEAN) { - normVmC = queryVectors.getNormVmC(ord); - vDotC = queryVectors.getVDotC(ord); - } - BinaryQueryVector binaryQueryVector = new BinaryQueryVector( - vector, - new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC) - ); - return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); - } - } - /** A binarized query representing its quantized form along with factors */ record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java index d864ec5dee8c5..61b6edc474d1f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java @@ -62,7 +62,7 @@ public ES816BinaryQuantizedVectorsFormat() { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + throw new UnsupportedOperationException(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java index 52f9f14b7bf97..1dbb4e432b188 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java @@ -25,10 +25,8 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.hnsw.HnswGraph; import java.io.IOException; @@ -52,21 +50,18 @@ public class ES816HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. */ - private final int maxConn; + protected final int maxConn; /** * The number of candidate neighbors to track while searching the graph for each newly inserted * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} * for details. */ - private final int beamWidth; + protected final int beamWidth; /** The format for storing, reading, merging vectors on disk */ private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedVectorsFormat(); - private final int numMergeWorkers; - private final TaskExecutor mergeExec; - /** Constructs a format using default graph construction parameters */ public ES816HnswBinaryQuantizedVectorsFormat() { this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); @@ -109,17 +104,11 @@ public ES816HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int num if (numMergeWorkers == 1 && mergeExec != null) { throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); } - this.numMergeWorkers = numMergeWorkers; - if (mergeExec != null) { - this.mergeExec = new TaskExecutor(mergeExec); - } else { - this.mergeExec = null; - } } @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + throw new UnsupportedOperationException(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java new file mode 100644 index 0000000000000..cc1f7b85e0f78 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java @@ -0,0 +1,87 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +abstract class BinarizedByteVectorValues extends ByteVectorValues { + + /** + * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of + * distances, the corrective terms are, in order + * + * + * + * For euclidean: + * + * + * + * @param vectorOrd the vector ordinal + * @return the corrective terms + * @throws IOException if an I/O error occurs + */ + public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract OptimizedScalarQuantizer getQuantizer(); + + public abstract float[] getCentroid() throws IOException; + + int discretizedDimensions() { + return BQVectorUtils.discretize(dimension(), 64); + } + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ + public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract BinarizedByteVectorValues copy() throws IOException; + + float getCentroidDP() throws IOException { + // this only gets executed on-merge + float[] centroid = getCentroid(); + return VectorUtil.dotProduct(centroid, centroid); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java new file mode 100644 index 0000000000000..7c7e470909eb3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -0,0 +1,188 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Vector scorer over binarized vector values */ +public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + public ES818BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" + ); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; + byte[] initial = new byte[target.length]; + byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; + OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid); + BQSpaceUtils.transposeHalfByte(initial, quantized); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors + ) { + return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + OptimizedScalarQuantizer.QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord); + BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + public record BinaryQueryVector(byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {} + + /** Vector scorer over binarized vector values */ + public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + public BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + OptimizedScalarQuantizer.QuantizationResult queryCorrections = queryVector.quantizationResult(); + OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd); + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = indexCorrections.upperInterval() - ax; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - targetVectors.getCentroidDP(); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..1dee9599f985f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java @@ -0,0 +1,132 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + * Codec for encoding/decoding binary quantized vectors The binary quantization format used here + * is a per-vector optimized scalar quantization. Also see {@link + * org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are: + * + * + * + * The format is stored in two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

+ */ +public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat { + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "ES818BinaryQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES818BinaryQuantizedVectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES818BinaryQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java new file mode 100644 index 0000000000000..8036b8314cdc1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -0,0 +1,412 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.SuppressForbidden; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES818BinaryQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final ES818BinaryFlatVectorsScorer vectorScorer; + + ES818BinaryQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + ES818BinaryFlatVectorsScorer vectorsScorer + ) throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.META_EXTENSION + ); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_START, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = openDataInput( + state, + versionMeta, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withReadAdvice(ReadAdvice.RANDOM) + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " + fieldEntry.dimension + ); + } + + int binaryDims = BQVectorUtils.discretize(dimension, 64) / 8; + long numQuantizedVectorBytes = Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "Binarized vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (binaryBytes=" + + binaryDims + + " + 14" + + ") = " + + numQuantizedVectorBytes + ); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " + VectorEncoding.FLOAT32 + ); + } + OffHeapBinarizedVectorValues bvv = OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ); + return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfMap(fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES818BinaryQuantizedVectorsFormat.VERSION_START, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + int descritizedDimension, + long vectorDataOffset, + long vectorDataLength, + int size, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration + ) { + + static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + if (size > 0) { + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + BQVectorUtils.discretize(dimension, 64), + vectorDataOffset, + vectorDataLength, + size, + centroid, + centroidDP, + conf + ); + } + } + + /** Binarized vector values holding row and quantized vector values */ + protected static final class BinarizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final BinarizedByteVectorValues quantizedVectorValues; + + BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public BinarizedVectorValues copy() throws IOException { + return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { + return quantizedVectorValues; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java new file mode 100644 index 0000000000000..02dda6a4a9da1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -0,0 +1,944 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT; +import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(ES818BinaryQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, binarizedVectorData; + private final FlatVectorsWriter rawVectorDelegate; + private final ES818BinaryFlatVectorsScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + protected ES818BinaryQuantizedVectorsWriter( + ES818BinaryFlatVectorsScorer vectorsScorer, + FlatVectorsWriter rawVectorDelegate, + SegmentWriteState state + ) throws IOException { + super(vectorsScorer); + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES818BinaryQuantizedVectorsFormat.META_EXTENSION + ); + + String binarizedVectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION + ); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + binarizedVectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + binarizedVectorData, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField(FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeBinarizedVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet() + ); + } + + private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, quantizedVectorLength, clusterCenter, centroidDp, newDocsWithField); + } + + private void writeSortedBinarizedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField + ) throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + final ByteBuffer buffer = ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, + meta, + binarizedVectorData, + count, + maxDoc, + docsWithField + ); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (binarizedVectorData != null) { + CodecUtil.writeFooter(binarizedVectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + BinarizedFloatVectorValues binarizedVectorValues = new BinarizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + centroid + ); + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField + ); + } else { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + } + } + + static DocsWithFieldSet writeBinarizedVectorAndQueryData( + IndexOutput binarizedVectorData, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + float[] centroid, + OptimizedScalarQuantizer binaryQuantizer + ) throws IOException { + int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64); + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()]; + byte[] toIndex = new byte[discretizedDimension / 8]; + byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY]; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + OptimizedScalarQuantizer.QuantizationResult[] r = binaryQuantizer.multiScalarQuantize( + floatVectorValues.vectorValue(iterator.index()), + quantizationScratch, + new byte[] { 1, 4 }, + centroid + ); + // pack and store document bit vector + BQVectorUtils.packAsBinary(quantizationScratch[0], toIndex); + binarizedVectorData.writeBytes(toIndex, toIndex.length); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection())); + assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) r[0].quantizedComponentSum()); + docsWithField.add(docV); + + // pack and store the 4bit query vector + BQSpaceUtils.transposeHalfByte(quantizationScratch[1], toQuery); + binarizedQueryData.writeBytes(toQuery, toQuery.length); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection())); + assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff; + binarizedQueryData.writeShort((short) r[1].quantizedComponentSum()); + } + return docsWithField; + } + + static DocsWithFieldSet writeBinarizedVectorData(IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) + throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + output.writeShort((short) corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC + ) throws IOException { + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + final IndexOutput tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "temp", + segmentWriteState.context + ); + final IndexOutput tempScoreQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "score_temp", + segmentWriteState.context + ); + IndexInput binarizedDataInput = null; + IndexInput binarizedScoreDataInput = null; + boolean success = false; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + try { + FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = writeBinarizedVectorAndQueryData( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + floatVectorValues, + centroid, + quantizer + ); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + binarizedDataInput = segmentWriteState.directory.openInput(tempQuantizedVectorData.getName(), segmentWriteState.context); + binarizedVectorData.copyBytes(binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(binarizedDataInput); + CodecUtil.writeFooter(tempScoreQuantizedVectorData); + IOUtils.close(tempScoreQuantizedVectorData); + binarizedScoreDataInput = segmentWriteState.directory.openInput( + tempScoreQuantizedVectorData.getName(), + segmentWriteState.context + ); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField + ); + success = true; + final IndexInput finalBinarizedDataInput = binarizedDataInput; + final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput; + OffHeapBinarizedVectorValues vectorValues = new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalBinarizedDataInput + ); + RandomVectorScorerSupplier scorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapBinarizedQueryVectorValues( + finalBinarizedScoreDataInput, + fieldInfo.getVectorDimension(), + docsWithField.cardinality() + ), + vectorValues + ); + return new BinarizedCloseableRandomVectorScorerSupplier(scorerSupplier, vectorValues, () -> { + IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName() + ); + }); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + binarizedDataInput, + binarizedScoreDataInput + ); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName() + ); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, binarizedVectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof ES818BinaryQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids(MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + // TODO Panama sum + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + // TODO Panama div + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + // When accessing vectorValue method, targerOrd here means a row ordinal. + static class OffHeapBinarizedQueryVectorValues { + private final IndexInput slice; + private final int dimension; + private final int size; + protected final byte[] binaryValue; + protected final ByteBuffer byteBuffer; + private final int byteSize; + protected final float[] correctiveValues; + private int lastOrd = -1; + private int quantizedComponentSum; + + OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { + this.slice = data; + this.dimension = dimension; + this.size = size; + // 4x the quantized binary dimensions + int binaryDimensions = (BQVectorUtils.discretize(dimension, 64) / 8) * BQSpaceUtils.B_QUERY; + this.byteBuffer = ByteBuffer.allocate(binaryDimensions); + this.binaryValue = byteBuffer.array(); + // + 1 for the quantized sum + this.correctiveValues = new float[3]; + this.byteSize = binaryDimensions + Float.BYTES * 3 + Short.BYTES; + } + + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + vectorValue(targetOrd); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + public int size() { + return size; + } + + public int dimension() { + return dimension; + } + + public OffHeapBinarizedQueryVectorValues copy() throws IOException { + return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size); + } + + public IndexInput getSlice() { + return slice; + } + + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(binaryValue, 0, binaryValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + } + + static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] binarized; + private final byte[] initQuantized; + private final float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + + private int lastOrd = -1; + + BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8]; + this.initQuantized = new byte[delegate.dimension()]; + this.centroid = centroid; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd + ); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BinarizedByteVectorValues copy() throws IOException { + return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid); + } + + private void binarize(int ord) throws IOException { + corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, (byte) 1, centroid); + BQVectorUtils.packAsBinary(initQuantized, binarized); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier supplier; + private final KnnVectorValues vectorValues; + private final Closeable onClose; + + BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + this.supplier = supplier; + this.onClose = onClose; + this.vectorValues = vectorValues; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return supplier.scorer(ord); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return vectorValues.size(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..56942017c3cef --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java @@ -0,0 +1,145 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES818HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "ES818HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat(); + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public ES818HnswBinaryQuantizedVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java new file mode 100644 index 0000000000000..72333169b39b5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java @@ -0,0 +1,371 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** Binarized vector values loaded from off-heap */ +abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { + + final int dimension; + final int size; + final int numBytes; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final byte[] binaryValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final OptimizedScalarQuantizer binaryQuantizer; + final float[] centroid; + final float centroidDp; + private final int discretizedDimensions; + + OffHeapBinarizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.numBytes = BQVectorUtils.discretize(dimension, 64) / 8; + this.correctiveValues = new float[3]; + this.byteSize = numBytes + (Float.BYTES * 3) + Short.BYTES; + this.byteBuffer = ByteBuffer.allocate(numBytes); + this.binaryValue = byteBuffer.array(); + this.binaryQuantizer = quantizer; + this.discretizedDimensions = BQVectorUtils.discretize(dimension, 64); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + + @Override + public int discretizedDimensions() { + return discretizedDimensions; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return binaryQuantizer; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + + static OffHeapBinarizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = vectorData.slice("quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + bytesSlice + ); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice + ); + } + } + + /** Dense off-heap binarized vector values */ + static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) { + super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap binarized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) throws IOException { + super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues { + EmptyOffHeapVectorValues(int dimension, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer) { + super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java new file mode 100644 index 0000000000000..d5ed38cb5a0e1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java @@ -0,0 +1,246 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; + +class OptimizedScalarQuantizer { + // The initial interval is set to the minimum MSE grid for each number of bits + // these starting points are derived from the optimal MSE grid for a uniform distribution + static final float[][] MINIMUM_MSE_GRID = new float[][] { + { -0.798f, 0.798f }, + { -1.493f, 1.493f }, + { -2.051f, 2.051f }, + { -2.514f, 2.514f }, + { -2.916f, 2.916f }, + { -3.278f, 3.278f }, + { -3.611f, 3.611f }, + { -3.922f, 3.922f } }; + private static final float DEFAULT_LAMBDA = 0.1f; + private static final int DEFAULT_ITERS = 5; + private final VectorSimilarityFunction similarityFunction; + private final float lambda; + private final int iters; + + OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) { + this.similarityFunction = similarityFunction; + this.lambda = lambda; + this.iters = iters; + } + + OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) { + this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS); + } + + public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {} + + public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert bits.length == destinations.length; + float[] intervalScratch = new float[2]; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + QuantizationResult[] results = new QuantizationResult[bits.length]; + for (int i = 0; i < bits.length; ++i) { + assert bits[i] > 0 && bits[i] <= 8; + int points = (1 << bits[i]); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds + intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits[i]) - 1); + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + // Now we have the optimized intervals, quantize the vector + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destinations[i][h] = (byte) assignment; + } + results[i] = new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery + ); + } + return results; + } + + public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert vector.length <= destination.length; + assert bits > 0 && bits <= 8; + float[] intervalScratch = new float[2]; + int points = 1 << bits; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds + intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits) - 1); + // Now we have the optimized intervals, quantize the vector + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destination[h] = (byte) assignment; + } + return new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery + ); + } + + /** + * Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw + * vector. + * @param vector raw vector + * @param interval interval to quantize the vector + * @param points number of quantization points + * @param norm2 squared norm of the vector + * @return the loss + */ + private double loss(float[] vector, float[] interval, int points, float norm2) { + double a = interval[0]; + double b = interval[1]; + double step = ((b - a) / (points - 1.0F)); + double stepInv = 1.0 / step; + double xe = 0.0; + double e = 0.0; + for (double xi : vector) { + // this is quantizing and then dequantizing the vector + double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv)); + // how much does the de-quantized value differ from the original value + xe += xi * (xi - xiq); + e += (xi - xiq) * (xi - xiq); + } + return (1.0 - lambda) * xe * xe / norm2 + lambda * e; + } + + /** + * Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization + * loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the + * loss increases. + * @param initInterval initial interval, the optimized interval will be stored here + * @param vector raw vector + * @param norm2 squared norm of the vector + * @param points number of quantization points + */ + private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) { + double initialLoss = loss(vector, initInterval, points, norm2); + final float scale = (1.0f - lambda) / norm2; + if (Float.isFinite(scale) == false) { + return; + } + for (int i = 0; i < iters; ++i) { + float a = initInterval[0]; + float b = initInterval[1]; + float stepInv = (points - 1.0f) / (b - a); + // calculate the grid points for coordinate descent + double daa = 0; + double dab = 0; + double dbb = 0; + double dax = 0; + double dbx = 0; + for (float xi : vector) { + float k = Math.round((clamp(xi, a, b) - a) * stepInv); + float s = k / (points - 1); + daa += (1.0 - s) * (1.0 - s); + dab += (1.0 - s) * s; + dbb += s * s; + dax += xi * (1.0 - s); + dbx += xi * s; + } + double m0 = scale * dax * dax + lambda * daa; + double m1 = scale * dax * dbx + lambda * dab; + double m2 = scale * dbx * dbx + lambda * dbb; + // its possible that the determinant is 0, in which case we can't update the interval + double det = m0 * m2 - m1 * m1; + if (det == 0) { + return; + } + float aOpt = (float) ((m2 * dax - m1 * dbx) / det); + float bOpt = (float) ((m0 * dbx - m1 * dax) / det); + // If there is no change in the interval, we can stop + if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) { + return; + } + double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2); + // If the new loss is worse, don't update the interval and exit + // This optimization, unlike kMeans, does not always converge to better loss + // So exit if we are getting worse + if (newLoss > initialLoss) { + return; + } + // Update the interval and go again + initInterval[0] = aOpt; + initInterval[1] = bOpt; + initialLoss = newLoss; + } + } + + private static double clamp(double x, double a, double b) { + return Math.min(Math.max(x, a), b); + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0a6a24f727572..d780faad96f2d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -46,8 +46,8 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; -import org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -1788,7 +1788,7 @@ static class BBQHnswIndexOptions extends IndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES816HnswBinaryQuantizedVectorsFormat(m, efConstruction); + return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); } @Override @@ -1836,7 +1836,7 @@ static class BBQFlatIndexOptions extends IndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES816BinaryQuantizedVectorsFormat(); + return new ES818BinaryQuantizedVectorsFormat(); } @Override diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 794b30aa5aab2..57980321bdc3d 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -44,6 +44,7 @@ private SearchCapabilities() {} private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim_with_bugfix"; private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; + private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; public static final Set CAPABILITIES; static { @@ -55,6 +56,7 @@ private SearchCapabilities() {} capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER); capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); + capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS); diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 389555e60b43b..cef8d09980814 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -5,3 +5,5 @@ org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java index 9f9114c70b6db..270ad54e9a962 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java @@ -38,6 +38,32 @@ public static int popcount(byte[] a, int aOffset, byte[] b, int length) { private static float DELTA = Float.MIN_VALUE; + public void testPackAsBinary() { + // 5 bits + byte[] toPack = new byte[] { 1, 1, 0, 0, 1 }; + byte[] packed = new byte[1]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed); + + // 8 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 }; + packed = new byte[1]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed); + + // 10 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 }; + packed = new byte[2]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed); + + // 16 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 }; + packed = new byte[2]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed); + } + public void testPadFloat() { assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 4), DELTA); assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 3), DELTA); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java new file mode 100644 index 0000000000000..0bebe16f468ce --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java @@ -0,0 +1,256 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Vector scorer over binarized vector values */ +class ES816BinaryFlatRWVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + ES816BinaryFlatRWVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" + ); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + BinaryQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + // FIXME: precompute this once? + int discretizedDimensions = BQVectorUtils.discretize(target.length, 64); + if (similarityFunction == COSINE) { + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + VectorUtil.l2normalize(copy); + target = copy; + } + byte[] quantized = new byte[BQSpaceUtils.B_QUERY * discretizedDimensions / 8]; + BinaryQuantizer.QueryFactors factors = quantizer.quantizeForQuery(target, quantized, centroid); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, factors); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors + ) { + return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + int quantizedSum = queryVectors.sumQuantizedValues(ord); + float distanceToCentroid = queryVectors.getCentroidDistance(ord); + float lower = queryVectors.getLower(ord); + float width = queryVectors.getWidth(ord); + float normVmC = 0f; + float vDotC = 0f; + if (similarityFunction != EUCLIDEAN) { + normVmC = queryVectors.getNormVmC(ord); + vDotC = queryVectors.getVDotC(ord); + } + BinaryQueryVector binaryQueryVector = new BinaryQueryVector( + vector, + new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC) + ); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} + + /** Vector scorer over binarized vector values */ + static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + private final float sqrtDimensions; + private final float maxX1; + + BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + // FIXME: precompute this once? + this.sqrtDimensions = targetVectors.sqrtDimensions(); + this.maxX1 = targetVectors.maxX1(); + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + int quantizedSum = queryVector.factors().quantizedSum(); + float lower = queryVector.factors().lower(); + float width = queryVector.factors().width(); + float distanceToCentroid = queryVector.factors().distToC(); + if (similarityFunction == EUCLIDEAN) { + return euclideanScore(targetOrd, sqrtDimensions, quantizedQuery, distanceToCentroid, lower, quantizedSum, width); + } + + float vmC = queryVector.factors().normVmC(); + float vDotC = queryVector.factors().vDotC(); + float cDotC = targetVectors.getCentroidDP(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float ooq = targetVectors.getOOQ(targetOrd); + float normOC = targetVectors.getNormOC(targetOrd); + float oDotC = targetVectors.getODotC(targetOrd); + + float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + + // FIXME: pre-compute these only once for each target vector + // ... pull this out or use a similar cache mechanism as do in score + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + final float dist; + // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away + // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector + if (normOC == 0 || ooq == 0) { + dist = oDotC + vDotC - cDotC; + } else { + // If ||o-c|| != 0, we should assume that `ooq` is finite + assert Float.isFinite(ooq); + float estimatedDot = (2 * width / sqrtDimensions * qcDist + 2 * lower / sqrtDimensions * xbSum - width / sqrtDimensions + * quantizedSum - sqrtDimensions * lower) / ooq; + dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC; + } + assert Float.isFinite(dist); + + float ooqSqr = (float) Math.pow(ooq, 2); + float errorBound = (float) (vmC * normOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr))); + float score = Float.isFinite(errorBound) ? dist - errorBound : dist; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + + private float euclideanScore( + int targetOrd, + float sqrtDimensions, + byte[] quantizedQuery, + float distanceToCentroid, + float lower, + int quantizedSum, + float width + ) throws IOException { + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + + // FIXME: pre-compute these only once for each target vector + // .. not sure how to enumerate the target ordinals but that's what we did in PoC + float targetDistToC = targetVectors.getCentroidDistance(targetOrd); + float x0 = targetVectors.getVectorMagnitude(targetOrd); + float sqrX = targetDistToC * targetDistToC; + double xX0 = targetDistToC / x0; + + // TODO maybe store? + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + float factorPPC = (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension())); + float factorIP = (float) (-2.0 / sqrtDimensions * xX0); + + long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + float score = sqrX + distanceToCentroid + factorPPC * lower + (qcDist * 2 - quantizedSum) * factorIP * width; + float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC); + float error = 2.0f * maxX1 * projectionDist; + float y = (float) Math.sqrt(distanceToCentroid); + float errorBound = y * error; + if (Float.isFinite(errorBound)) { + score = score + errorBound; + } + return Math.max(1 / (1f + score), 0); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java index a75b9bc6064d1..ffe007be9799d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java @@ -59,7 +59,7 @@ public void testScore() throws IOException { short quantizedSum = (short) random().nextInt(0, 4097); float normVmC = random().nextFloat(-1000f, 1000f); float vDotC = random().nextFloat(-1000f, 1000f); - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); @@ -134,7 +134,7 @@ public int dimension() { } }; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction @@ -217,7 +217,7 @@ public void testScoreEuclidean() throws IOException { float vl = -57.883f; float width = 9.972266f; short quantizedSum = 795; - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, 0f, 0f) ); @@ -420,7 +420,7 @@ public int dimension() { VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction @@ -824,7 +824,7 @@ public void testScoreMIP() throws IOException { float normVmC = 9.766797f; float vDotC = 133.56123f; float cDotC = 132.20227f; - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); @@ -1768,7 +1768,7 @@ public int dimension() { VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java new file mode 100644 index 0000000000000..c54903a94b54f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java @@ -0,0 +1,52 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES816BinaryQuantizedRWVectorsFormat extends ES816BinaryQuantizedVectorsFormat { + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES816BinaryFlatRWVectorsScorer scorer = new ES816BinaryFlatRWVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES816BinaryQuantizedRWVectorsFormat() { + super(); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java index 681f615653d40..48ba566353f5d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java @@ -63,7 +63,7 @@ protected Codec getCodec() { return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new ES816BinaryQuantizedVectorsFormat(); + return new ES816BinaryQuantizedRWVectorsFormat(); } }; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java similarity index 99% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java index 31ae977e81118..4d97235c5fae5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java @@ -77,7 +77,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { private final List fields = new ArrayList<>(); private final IndexOutput meta, binarizedVectorData; private final FlatVectorsWriter rawVectorDelegate; - private final ES816BinaryFlatVectorsScorer vectorsScorer; + private final ES816BinaryFlatRWVectorsScorer vectorsScorer; private boolean finished; /** @@ -86,7 +86,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { * @param vectorsScorer the scorer to use for scoring vectors */ protected ES816BinaryQuantizedVectorsWriter( - ES816BinaryFlatVectorsScorer vectorsScorer, + ES816BinaryFlatRWVectorsScorer vectorsScorer, FlatVectorsWriter rawVectorDelegate, SegmentWriteState state ) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java new file mode 100644 index 0000000000000..e9bace72b591c --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java @@ -0,0 +1,55 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + +class ES816HnswBinaryQuantizedRWVectorsFormat extends ES816HnswBinaryQuantizedVectorsFormat { + + private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedRWVectorsFormat(); + + /** Constructs a format using default graph construction parameters */ + ES816HnswBinaryQuantizedRWVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH); + } + + ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(maxConn, beamWidth, numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java index a25fa2836ee34..03aa847f3a5d4 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java @@ -59,7 +59,7 @@ protected Codec getCodec() { return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new ES816HnswBinaryQuantizedVectorsFormat(); + return new ES816HnswBinaryQuantizedRWVectorsFormat(); } }; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..397cc472592b6 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java @@ -0,0 +1,181 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected Codec getCodec() { + return new Lucene100Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES818BinaryQuantizedVectorsFormat(); + } + }; + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES818BinaryQuantizedVectorsFormat(); + } + }; + String expectedPattern = "ES818BinaryQuantizedVectorsFormat(" + + "name=ES818BinaryQuantizedVectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + BinarizedByteVectorValues qvectorValues = ((ES818BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + byte[] quantizedVector = new byte[dims]; + byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + quantizedVector, + (byte) 1, + centroid + ); + BQVectorUtils.packAsBinary(quantizedVector, expectedVector); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); + } + } + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..b6ae3199bb896 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java @@ -0,0 +1,132 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.util.SameThreadExecutorService; +import org.elasticsearch.common.logging.LogConfigurator; + +import java.util.Arrays; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES818HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected Codec getCodec() { + return new Lucene100Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES818HnswBinaryQuantizedVectorsFormat(); + } + }; + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES818HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES818BinaryQuantizedVectorsFormat(name=ES818BinaryQuantizedVectorsFormat," + + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true score + assertEquals(trueScore, td.scoreDocs[0].score, 0.0001f); + } + } + } + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(-1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(0, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 0)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, -1)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) + ); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java new file mode 100644 index 0000000000000..e3e2d6caafe0e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID; + +public class OptimizedScalarQuantizerTests extends ESTestCase { + + static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + + public void testAbusiveEdgeCases() { + // large zero array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + continue; + } + float[] vector = new float[4096]; + float[] centroid = new float[4096]; + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (byte[] destination : destinations) { + assertArrayEquals(new byte[4096], destination); + } + byte[] destination = new byte[4096]; + for (byte bit : ALL_BITS) { + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertArrayEquals(new byte[4096], destination); + } + } + + // single value array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + float[] vector = new float[] { randomFloat() }; + float[] centroid = new float[] { randomFloat() }; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + vector = new float[] { randomFloat() }; + centroid = new float[] { randomFloat() }; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + byte[] destination = new byte[1]; + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + + } + + public void testMathematicalConsistency() { + int dims = randomIntBetween(1, 4096); + float[] vector = new float[dims]; + for (int i = 0; i < dims; ++i) { + vector[i] = randomFloat(); + } + float[] centroid = new float[dims]; + for (int i = 0; i < dims; ++i) { + centroid[i] = randomFloat(); + } + float[] copy = new float[dims]; + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + // copy the vector to avoid modifying it + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + byte[] destination = new byte[dims]; + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(copy, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + } + + static void assertValidQuantizedRange(byte[] quantized, byte bits) { + for (byte b : quantized) { + if (bits < 8) { + assertTrue(b >= 0); + } + assertTrue(b < 1 << bits); + } + } + + static void assertValidResults(OptimizedScalarQuantizer.QuantizationResult... results) { + for (OptimizedScalarQuantizer.QuantizationResult result : results) { + assertTrue(Float.isFinite(result.lowerInterval())); + assertTrue(Float.isFinite(result.upperInterval())); + assertTrue(result.lowerInterval() <= result.upperInterval()); + assertTrue(Float.isFinite(result.additionalCorrection())); + assertTrue(result.quantizedComponentSum() >= 0); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index de084cd4582e2..c043b9ffb381a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1970,13 +1970,13 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "ES816HnswBinaryQuantizedVectorsFormat(name=ES816HnswBinaryQuantizedVectorsFormat, maxConn=" + String expectedString = "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + m + ", beamWidth=" + efConstruction - + ", flatVectorFormat=ES816BinaryQuantizedVectorsFormat(" - + "name=ES816BinaryQuantizedVectorsFormat, " - + "flatVectorScorer=ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; + + ", flatVectorFormat=ES818BinaryQuantizedVectorsFormat(" + + "name=ES818BinaryQuantizedVectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; assertEquals(expectedString, knnVectorsFormat.toString()); }