diff --git a/docs/changelog/117994.yaml b/docs/changelog/117994.yaml
new file mode 100644
index 0000000000000..603f2d855a11a
--- /dev/null
+++ b/docs/changelog/117994.yaml
@@ -0,0 +1,5 @@
+pr: 117994
+summary: Even better(er) binary quantization
+area: Vector Search
+type: enhancement
+issues: []
diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle
index e2af894eb0939..7347d9c1312dd 100644
--- a/rest-api-spec/build.gradle
+++ b/rest-api-spec/build.gradle
@@ -67,4 +67,6 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task ->
task.skipTest("logsdb/20_source_mapping/include/exclude is supported with stored _source", "no longer serialize source_mode")
task.skipTest("logsdb/20_source_mapping/synthetic _source is default", "no longer serialize source_mode")
task.skipTest("search/520_fetch_fields/fetch _seq_no via fields", "error code is changed from 5xx to 400 in 9.0")
+ task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions")
+ task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions")
})
diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml
index 188c155e4a836..5767c895fbe7e 100644
--- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml
+++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml
@@ -11,20 +11,11 @@ setup:
number_of_shards: 1
mappings:
properties:
- name:
- type: keyword
vector:
type: dense_vector
dims: 64
index: true
- similarity: l2_norm
- index_options:
- type: bbq_hnsw
- another_vector:
- type: dense_vector
- dims: 64
- index: true
- similarity: l2_norm
+ similarity: max_inner_product
index_options:
type: bbq_hnsw
@@ -33,9 +24,14 @@ setup:
index: bbq_hnsw
id: "1"
body:
- name: cow.jpg
- vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0]
- another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0]
+ vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313,
+ 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272,
+ 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132,
+ -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265,
+ -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475,
+ -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242,
+ -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45,
+ -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176]
# Flush in order to provoke a merge later
- do:
indices.flush:
@@ -46,9 +42,14 @@ setup:
index: bbq_hnsw
id: "2"
body:
- name: moose.jpg
- vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0]
- another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120]
+ vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348,
+ -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048,
+ 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438,
+ -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138,
+ -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429,
+ -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166,
+ 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569,
+ -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013]
# Flush in order to provoke a merge later
- do:
indices.flush:
@@ -60,8 +61,14 @@ setup:
id: "3"
body:
name: rabbit.jpg
- vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0]
- another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0]
+ vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 ,
+ 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049,
+ 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246,
+ -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058,
+ -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399,
+ -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017,
+ 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615,
+ -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058]
# Flush in order to provoke a merge later
- do:
indices.flush:
@@ -73,20 +80,33 @@ setup:
max_num_segments: 1
---
"Test knn search":
+ - requires:
+ capabilities:
+ - method: POST
+ path: /_search
+ capabilities: [ optimized_scalar_quantization_bbq ]
+ test_runner_features: capabilities
+ reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq"
- do:
search:
index: bbq_hnsw
body:
knn:
field: vector
- query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0]
+ query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393,
+ 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015,
+ 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259,
+ -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 ,
+ -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232,
+ -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034,
+ -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582,
+ -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158]
k: 3
num_candidates: 3
- # Depending on how things are distributed, docs 2 and 3 might be swapped
- # here we verify that are last hit is always the worst one
- - match: { hits.hits.2._id: "1" }
-
+ - match: { hits.hits.0._id: "1" }
+ - match: { hits.hits.1._id: "3" }
+ - match: { hits.hits.2._id: "2" }
---
"Test bad quantization parameters":
- do:
diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml
index ed7a8dd5df65d..dcdae04aeabb4 100644
--- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml
+++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml
@@ -11,20 +11,11 @@ setup:
number_of_shards: 1
mappings:
properties:
- name:
- type: keyword
vector:
type: dense_vector
dims: 64
index: true
- similarity: l2_norm
- index_options:
- type: bbq_flat
- another_vector:
- type: dense_vector
- dims: 64
- index: true
- similarity: l2_norm
+ similarity: max_inner_product
index_options:
type: bbq_flat
@@ -33,9 +24,14 @@ setup:
index: bbq_flat
id: "1"
body:
- name: cow.jpg
- vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0]
- another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0]
+ vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313,
+ 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272,
+ 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132,
+ -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265,
+ -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475,
+ -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242,
+ -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45,
+ -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176]
# Flush in order to provoke a merge later
- do:
indices.flush:
@@ -46,9 +42,14 @@ setup:
index: bbq_flat
id: "2"
body:
- name: moose.jpg
- vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0]
- another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120]
+ vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348,
+ -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048,
+ 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438,
+ -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138,
+ -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429,
+ -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166,
+ 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569,
+ -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013]
# Flush in order to provoke a merge later
- do:
indices.flush:
@@ -59,9 +60,14 @@ setup:
index: bbq_flat
id: "3"
body:
- name: rabbit.jpg
- vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0]
- another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0]
+ vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 ,
+ 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049,
+ 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246,
+ -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058,
+ -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399,
+ -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017,
+ 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615,
+ -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058]
# Flush in order to provoke a merge later
- do:
indices.flush:
@@ -73,19 +79,33 @@ setup:
max_num_segments: 1
---
"Test knn search":
+ - requires:
+ capabilities:
+ - method: POST
+ path: /_search
+ capabilities: [ optimized_scalar_quantization_bbq ]
+ test_runner_features: capabilities
+ reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq"
- do:
search:
index: bbq_flat
body:
knn:
field: vector
- query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0]
+ query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393,
+ 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015,
+ 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259,
+ -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 ,
+ -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232,
+ -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034,
+ -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582,
+ -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158]
k: 3
num_candidates: 3
- # Depending on how things are distributed, docs 2 and 3 might be swapped
- # here we verify that are last hit is always the worst one
- - match: { hits.hits.2._id: "1" }
+ - match: { hits.hits.0._id: "1" }
+ - match: { hits.hits.1._id: "3" }
+ - match: { hits.hits.2._id: "2" }
---
"Test bad parameters":
- do:
diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java
index 331a2bc0dddac..ff902dbede007 100644
--- a/server/src/main/java/module-info.java
+++ b/server/src/main/java/module-info.java
@@ -459,7 +459,9 @@
org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat,
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat,
org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat,
- org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat;
+ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat,
+ org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat,
+ org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
provides org.apache.lucene.codecs.Codec
with
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java
index 5201e57179cc7..1aff06a175967 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java
@@ -40,6 +40,20 @@ public static boolean isUnitVector(float[] v) {
return Math.abs(l1norm - 1.0d) <= EPSILON;
}
+ public static void packAsBinary(byte[] vector, byte[] packed) {
+ for (int i = 0; i < vector.length;) {
+ byte result = 0;
+ for (int j = 7; j >= 0 && i < vector.length; j--) {
+ assert vector[i] == 0 || vector[i] == 1;
+ result |= (byte) ((vector[i] & 1) << j);
+ ++i;
+ }
+ int index = ((i + 7) / 8) - 1;
+ assert index < packed.length;
+ packed[index] = result;
+ }
+ }
+
public static int discretize(int value, int bucket) {
return ((value + (bucket - 1)) / bucket) * bucket;
}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java
index 445bdadab2354..e85079e998c61 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java
@@ -48,13 +48,8 @@ class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer {
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
KnnVectorValues vectorValues
- ) throws IOException {
- if (vectorValues instanceof BinarizedByteVectorValues) {
- throw new UnsupportedOperationException(
- "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format"
- );
- }
- return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ ) {
+ throw new UnsupportedOperationException();
}
@Override
@@ -90,61 +85,11 @@ public RandomVectorScorer getRandomVectorScorer(
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
}
- RandomVectorScorerSupplier getRandomVectorScorerSupplier(
- VectorSimilarityFunction similarityFunction,
- ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
- BinarizedByteVectorValues targetVectors
- ) {
- return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction);
- }
-
@Override
public String toString() {
return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")";
}
- /** Vector scorer supplier over binarized vector values */
- static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
- private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors;
- private final BinarizedByteVectorValues targetVectors;
- private final VectorSimilarityFunction similarityFunction;
-
- BinarizedRandomVectorScorerSupplier(
- ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
- BinarizedByteVectorValues targetVectors,
- VectorSimilarityFunction similarityFunction
- ) {
- this.queryVectors = queryVectors;
- this.targetVectors = targetVectors;
- this.similarityFunction = similarityFunction;
- }
-
- @Override
- public RandomVectorScorer scorer(int ord) throws IOException {
- byte[] vector = queryVectors.vectorValue(ord);
- int quantizedSum = queryVectors.sumQuantizedValues(ord);
- float distanceToCentroid = queryVectors.getCentroidDistance(ord);
- float lower = queryVectors.getLower(ord);
- float width = queryVectors.getWidth(ord);
- float normVmC = 0f;
- float vDotC = 0f;
- if (similarityFunction != EUCLIDEAN) {
- normVmC = queryVectors.getNormVmC(ord);
- vDotC = queryVectors.getVDotC(ord);
- }
- BinaryQueryVector binaryQueryVector = new BinaryQueryVector(
- vector,
- new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC)
- );
- return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
- }
-
- @Override
- public RandomVectorScorerSupplier copy() throws IOException {
- return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction);
- }
- }
-
/** A binarized query representing its quantized form along with factors */
record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java
index d864ec5dee8c5..61b6edc474d1f 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java
@@ -62,7 +62,7 @@ public ES816BinaryQuantizedVectorsFormat() {
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
- return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state);
+ throw new UnsupportedOperationException();
}
@Override
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java
index 52f9f14b7bf97..1dbb4e432b188 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java
@@ -25,10 +25,8 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
-import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
-import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.hnsw.HnswGraph;
import java.io.IOException;
@@ -52,21 +50,18 @@ public class ES816HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat {
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/
- private final int maxConn;
+ protected final int maxConn;
/**
* The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
* for details.
*/
- private final int beamWidth;
+ protected final int beamWidth;
/** The format for storing, reading, merging vectors on disk */
private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedVectorsFormat();
- private final int numMergeWorkers;
- private final TaskExecutor mergeExec;
-
/** Constructs a format using default graph construction parameters */
public ES816HnswBinaryQuantizedVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
@@ -109,17 +104,11 @@ public ES816HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int num
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
}
- this.numMergeWorkers = numMergeWorkers;
- if (mergeExec != null) {
- this.mergeExec = new TaskExecutor(mergeExec);
- } else {
- this.mergeExec = null;
- }
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
- return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
+ throw new UnsupportedOperationException();
}
@Override
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java
new file mode 100644
index 0000000000000..cc1f7b85e0f78
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java
@@ -0,0 +1,87 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.index.codec.vectors.BQVectorUtils;
+
+import java.io.IOException;
+
+/**
+ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
+ */
+abstract class BinarizedByteVectorValues extends ByteVectorValues {
+
+ /**
+ * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of
+ * distances, the corrective terms are, in order
+ *
+ *
+ * - the lower optimized interval
+ *
- the upper optimized interval
+ *
- the dot-product of the non-centered vector with the centroid
+ *
- the sum of quantized components
+ *
+ *
+ * For euclidean:
+ *
+ *
+ * - the lower optimized interval
+ *
- the upper optimized interval
+ *
- the l2norm of the centered vector
+ *
- the sum of quantized components
+ *
+ *
+ * @param vectorOrd the vector ordinal
+ * @return the corrective terms
+ * @throws IOException if an I/O error occurs
+ */
+ public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) throws IOException;
+
+ /**
+ * @return the quantizer used to quantize the vectors
+ */
+ public abstract OptimizedScalarQuantizer getQuantizer();
+
+ public abstract float[] getCentroid() throws IOException;
+
+ int discretizedDimensions() {
+ return BQVectorUtils.discretize(dimension(), 64);
+ }
+
+ /**
+ * Return a {@link VectorScorer} for the given query vector.
+ *
+ * @param query the query vector
+ * @return a {@link VectorScorer} instance or null
+ */
+ public abstract VectorScorer scorer(float[] query) throws IOException;
+
+ @Override
+ public abstract BinarizedByteVectorValues copy() throws IOException;
+
+ float getCentroidDP() throws IOException {
+ // this only gets executed on-merge
+ float[] centroid = getCentroid();
+ return VectorUtil.dotProduct(centroid, centroid);
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java
new file mode 100644
index 0000000000000..7c7e470909eb3
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java
@@ -0,0 +1,188 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
+import org.elasticsearch.simdvec.ESVectorUtil;
+
+import java.io.IOException;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Vector scorer over binarized vector values */
+public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer {
+ private final FlatVectorsScorer nonQuantizedDelegate;
+ private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
+
+ public ES818BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
+ this.nonQuantizedDelegate = nonQuantizedDelegate;
+ }
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ KnnVectorValues vectorValues
+ ) throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues) {
+ throw new UnsupportedOperationException(
+ "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format"
+ );
+ }
+ return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction,
+ KnnVectorValues vectorValues,
+ float[] target
+ ) throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
+ OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer();
+ float[] centroid = binarizedVectors.getCentroid();
+ // We make a copy as the quantization process mutates the input
+ float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
+ if (similarityFunction == COSINE) {
+ VectorUtil.l2normalize(copy);
+ }
+ target = copy;
+ byte[] initial = new byte[target.length];
+ byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid);
+ BQSpaceUtils.transposeHalfByte(initial, quantized);
+ BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections);
+ return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction);
+ }
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction,
+ KnnVectorValues vectorValues,
+ byte[] target
+ ) throws IOException {
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
+ BinarizedByteVectorValues targetVectors
+ ) {
+ return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public String toString() {
+ return "ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")";
+ }
+
+ /** Vector scorer supplier over binarized vector values */
+ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
+ private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ BinarizedRandomVectorScorerSupplier(
+ ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction
+ ) {
+ this.queryVectors = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ byte[] vector = queryVectors.vectorValue(ord);
+ OptimizedScalarQuantizer.QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord);
+ BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms);
+ return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction);
+ }
+ }
+
+ /** A binarized query representing its quantized form along with factors */
+ public record BinaryQueryVector(byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {}
+
+ /** Vector scorer over binarized vector values */
+ public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
+ private final BinaryQueryVector queryVector;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ public BinarizedRandomVectorScorer(
+ BinaryQueryVector queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction
+ ) {
+ super(targetVectors);
+ this.queryVector = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public float score(int targetOrd) throws IOException {
+ byte[] quantizedQuery = queryVector.vector();
+ byte[] binaryCode = targetVectors.vectorValue(targetOrd);
+ float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections = queryVector.quantizationResult();
+ OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd);
+ float x1 = indexCorrections.quantizedComponentSum();
+ float ax = indexCorrections.lowerInterval();
+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
+ float lx = indexCorrections.upperInterval() - ax;
+ float ay = queryCorrections.lowerInterval();
+ float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+ float y1 = queryCorrections.quantizedComponentSum();
+ float score = ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist;
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score;
+ return Math.max(1 / (1f + score), 0);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - targetVectors.getCentroidDP();
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
+ }
+ return Math.max((1f + score) / 2f, 0);
+ }
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java
new file mode 100644
index 0000000000000..1dee9599f985f
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java
@@ -0,0 +1,132 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+
+import java.io.IOException;
+
+import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
+
+/**
+ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
+ * Codec for encoding/decoding binary quantized vectors The binary quantization format used here
+ * is a per-vector optimized scalar quantization. Also see {@link
+ * org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are:
+ *
+ *
+ * - Estimating the distance between two vectors using their centroid normalized distance. This
+ * requires some additional corrective factors, but allows for centroid normalization to occur.
+ *
- Optimized scalar quantization to bit level of centroid normalized vectors.
+ *
- Asymmetric quantization of vectors, where query vectors are quantized to half-byte
+ * precision (normalized to the centroid) and then compared directly against the single bit
+ * quantized vectors in the index.
+ *
- Transforming the half-byte quantized query vectors in such a way that the comparison with
+ * single bit vectors can be done with bit arithmetic.
+ *
+ *
+ * The format is stored in two files:
+ *
+ * .veb (vector data) file
+ *
+ * Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's
+ * corrective factors. At the end of the file, additional information is stored for vector ordinal
+ * to centroid ordinal mapping and sparse vector information.
+ *
+ *
+ * - For each vector:
+ *
+ * - [byte] the binary quantized values, each byte holds 8 bits.
+ *
- [float] the optimized quantiles and an additional similarity dependent corrective factor.
+ *
- short the sum of the quantized components
+ *
+ * - After the vectors, sparse vector information keeping track of monotonic blocks.
+ *
+ *
+ * .vemb (vector metadata) file
+ *
+ * Stores the metadata for the vectors. This includes the number of vectors, the number of
+ * dimensions, and file offset information.
+ *
+ *
+ * - int the field number
+ *
- int the vector encoding ordinal
+ *
- int the vector similarity ordinal
+ *
- vint the vector dimensions
+ *
- vlong the offset to the vector data in the .veb file
+ *
- vlong the length of the vector data in the .veb file
+ *
- vint the number of vectors
+ *
- [float] the centroid
+ * - float the centroid square magnitude
+ * - The sparse vector information, if required, mapping vector ordinal to doc ID
+ *
+ */
+public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
+
+ public static final String BINARIZED_VECTOR_COMPONENT = "BVEC";
+ public static final String NAME = "ES818BinaryQuantizedVectorsFormat";
+
+ static final int VERSION_START = 0;
+ static final int VERSION_CURRENT = VERSION_START;
+ static final String META_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatMeta";
+ static final String VECTOR_DATA_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatData";
+ static final String META_EXTENSION = "vemb";
+ static final String VECTOR_DATA_EXTENSION = "veb";
+ static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
+
+ private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
+ FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
+ );
+
+ private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
+ FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
+ );
+
+ /** Creates a new instance with the default number of vectors per cluster. */
+ public ES818BinaryQuantizedVectorsFormat() {
+ super(NAME);
+ }
+
+ @Override
+ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state);
+ }
+
+ @Override
+ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer);
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return MAX_DIMS_COUNT;
+ }
+
+ @Override
+ public String toString() {
+ return "ES818BinaryQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")";
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java
new file mode 100644
index 0000000000000..8036b8314cdc1
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java
@@ -0,0 +1,412 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.ByteVectorValues;
+import org.apache.lucene.index.CorruptIndexException;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FieldInfos;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.KnnCollector;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.ChecksumIndexInput;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.ReadAdvice;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.SuppressForbidden;
+import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.elasticsearch.index.codec.vectors.BQVectorUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding;
+
+/**
+ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
+ */
+@SuppressForbidden(reason = "Lucene classes")
+class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader {
+
+ private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES818BinaryQuantizedVectorsReader.class);
+
+ private final Map fields = new HashMap<>();
+ private final IndexInput quantizedVectorData;
+ private final FlatVectorsReader rawVectorsReader;
+ private final ES818BinaryFlatVectorsScorer vectorScorer;
+
+ ES818BinaryQuantizedVectorsReader(
+ SegmentReadState state,
+ FlatVectorsReader rawVectorsReader,
+ ES818BinaryFlatVectorsScorer vectorsScorer
+ ) throws IOException {
+ super(vectorsScorer);
+ this.vectorScorer = vectorsScorer;
+ this.rawVectorsReader = rawVectorsReader;
+ int versionMeta = -1;
+ String metaFileName = IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.META_EXTENSION
+ );
+ boolean success = false;
+ try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
+ Throwable priorE = null;
+ try {
+ versionMeta = CodecUtil.checkIndexHeader(
+ meta,
+ ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME,
+ ES818BinaryQuantizedVectorsFormat.VERSION_START,
+ ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix
+ );
+ readFields(meta, state.fieldInfos);
+ } catch (Throwable exception) {
+ priorE = exception;
+ } finally {
+ CodecUtil.checkFooter(meta, priorE);
+ }
+ quantizedVectorData = openDataInput(
+ state,
+ versionMeta,
+ ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION,
+ ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ // Quantized vectors are accessed randomly from their node ID stored in the HNSW
+ // graph.
+ state.context.withReadAdvice(ReadAdvice.RANDOM)
+ );
+ success = true;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(this);
+ }
+ }
+ }
+
+ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
+ for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
+ FieldInfo info = infos.fieldInfo(fieldNumber);
+ if (info == null) {
+ throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
+ }
+ FieldEntry fieldEntry = readField(meta, info);
+ validateFieldEntry(info, fieldEntry);
+ fields.put(info.name, fieldEntry);
+ }
+ }
+
+ static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
+ int dimension = info.getVectorDimension();
+ if (dimension != fieldEntry.dimension) {
+ throw new IllegalStateException(
+ "Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " + fieldEntry.dimension
+ );
+ }
+
+ int binaryDims = BQVectorUtils.discretize(dimension, 64) / 8;
+ long numQuantizedVectorBytes = Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size);
+ if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
+ throw new IllegalStateException(
+ "Binarized vector data length "
+ + fieldEntry.vectorDataLength
+ + " not matching size = "
+ + fieldEntry.size
+ + " * (binaryBytes="
+ + binaryDims
+ + " + 14"
+ + ") = "
+ + numQuantizedVectorBytes
+ );
+ }
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ return vectorScorer.getRandomVectorScorer(
+ fi.similarityFunction,
+ OffHeapBinarizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new OptimizedScalarQuantizer(fi.similarityFunction),
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData
+ ),
+ target
+ );
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
+ return rawVectorsReader.getRandomVectorScorer(field, target);
+ }
+
+ @Override
+ public void checkIntegrity() throws IOException {
+ rawVectorsReader.checkIntegrity();
+ CodecUtil.checksumEntireFile(quantizedVectorData);
+ }
+
+ @Override
+ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
+ FieldEntry fi = fields.get(field);
+ if (fi == null) {
+ return null;
+ }
+ if (fi.vectorEncoding != VectorEncoding.FLOAT32) {
+ throw new IllegalArgumentException(
+ "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " + VectorEncoding.FLOAT32
+ );
+ }
+ OffHeapBinarizedVectorValues bvv = OffHeapBinarizedVectorValues.load(
+ fi.ordToDocDISIReaderConfiguration,
+ fi.dimension,
+ fi.size,
+ new OptimizedScalarQuantizer(fi.similarityFunction),
+ fi.similarityFunction,
+ vectorScorer,
+ fi.centroid,
+ fi.centroidDP,
+ fi.vectorDataOffset,
+ fi.vectorDataLength,
+ quantizedVectorData
+ );
+ return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv);
+ }
+
+ @Override
+ public ByteVectorValues getByteVectorValues(String field) throws IOException {
+ return rawVectorsReader.getByteVectorValues(field);
+ }
+
+ @Override
+ public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
+ rawVectorsReader.search(field, target, knnCollector, acceptDocs);
+ }
+
+ @Override
+ public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
+ if (knnCollector.k() == 0) return;
+ final RandomVectorScorer scorer = getRandomVectorScorer(field, target);
+ if (scorer == null) return;
+ OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
+ Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs);
+ for (int i = 0; i < scorer.maxOrd(); i++) {
+ if (acceptedOrds == null || acceptedOrds.get(i)) {
+ collector.collect(i, scorer.score(i));
+ collector.incVisitedCount(1);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(quantizedVectorData, rawVectorsReader);
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size += RamUsageEstimator.sizeOfMap(fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
+ size += rawVectorsReader.ramBytesUsed();
+ return size;
+ }
+
+ public float[] getCentroid(String field) {
+ FieldEntry fieldEntry = fields.get(field);
+ if (fieldEntry != null) {
+ return fieldEntry.centroid;
+ }
+ return null;
+ }
+
+ private static IndexInput openDataInput(
+ SegmentReadState state,
+ int versionMeta,
+ String fileExtension,
+ String codecName,
+ IOContext context
+ ) throws IOException {
+ String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
+ IndexInput in = state.directory.openInput(fileName, context);
+ boolean success = false;
+ try {
+ int versionVectorData = CodecUtil.checkIndexHeader(
+ in,
+ codecName,
+ ES818BinaryQuantizedVectorsFormat.VERSION_START,
+ ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix
+ );
+ if (versionMeta != versionVectorData) {
+ throw new CorruptIndexException(
+ "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData,
+ in
+ );
+ }
+ CodecUtil.retrieveChecksum(in);
+ success = true;
+ return in;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(in);
+ }
+ }
+ }
+
+ private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException {
+ VectorEncoding vectorEncoding = readVectorEncoding(input);
+ VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
+ if (similarityFunction != info.getVectorSimilarityFunction()) {
+ throw new IllegalStateException(
+ "Inconsistent vector similarity function for field=\""
+ + info.name
+ + "\"; "
+ + similarityFunction
+ + " != "
+ + info.getVectorSimilarityFunction()
+ );
+ }
+ return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction());
+ }
+
+ private record FieldEntry(
+ VectorSimilarityFunction similarityFunction,
+ VectorEncoding vectorEncoding,
+ int dimension,
+ int descritizedDimension,
+ long vectorDataOffset,
+ long vectorDataLength,
+ int size,
+ float[] centroid,
+ float centroidDP,
+ OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration
+ ) {
+
+ static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction)
+ throws IOException {
+ int dimension = input.readVInt();
+ long vectorDataOffset = input.readVLong();
+ long vectorDataLength = input.readVLong();
+ int size = input.readVInt();
+ final float[] centroid;
+ float centroidDP = 0;
+ if (size > 0) {
+ centroid = new float[dimension];
+ input.readFloats(centroid, 0, dimension);
+ centroidDP = Float.intBitsToFloat(input.readInt());
+ } else {
+ centroid = null;
+ }
+ OrdToDocDISIReaderConfiguration conf = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
+ return new FieldEntry(
+ similarityFunction,
+ vectorEncoding,
+ dimension,
+ BQVectorUtils.discretize(dimension, 64),
+ vectorDataOffset,
+ vectorDataLength,
+ size,
+ centroid,
+ centroidDP,
+ conf
+ );
+ }
+ }
+
+ /** Binarized vector values holding row and quantized vector values */
+ protected static final class BinarizedVectorValues extends FloatVectorValues {
+ private final FloatVectorValues rawVectorValues;
+ private final BinarizedByteVectorValues quantizedVectorValues;
+
+ BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) {
+ this.rawVectorValues = rawVectorValues;
+ this.quantizedVectorValues = quantizedVectorValues;
+ }
+
+ @Override
+ public int dimension() {
+ return rawVectorValues.dimension();
+ }
+
+ @Override
+ public int size() {
+ return rawVectorValues.size();
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ return rawVectorValues.vectorValue(ord);
+ }
+
+ @Override
+ public BinarizedVectorValues copy() throws IOException {
+ return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy());
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return rawVectorValues.getAcceptOrds(acceptDocs);
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return rawVectorValues.ordToDoc(ord);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return rawVectorValues.iterator();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] query) throws IOException {
+ return quantizedVectorValues.scorer(query);
+ }
+
+ BinarizedByteVectorValues getQuantizedVectorValues() throws IOException {
+ return quantizedVectorValues;
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java
new file mode 100644
index 0000000000000..02dda6a4a9da1
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java
@@ -0,0 +1,944 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.CodecUtil;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
+import org.apache.lucene.index.DocsWithFieldSet;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexFileNames;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.MergeState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorEncoding;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.internal.hppc.FloatArrayList;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.IOUtils;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.elasticsearch.core.SuppressForbidden;
+import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
+import org.elasticsearch.index.codec.vectors.BQVectorUtils;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
+import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT;
+import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
+
+/**
+ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
+ */
+@SuppressForbidden(reason = "Lucene classes")
+public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
+ private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(ES818BinaryQuantizedVectorsWriter.class);
+
+ private final SegmentWriteState segmentWriteState;
+ private final List fields = new ArrayList<>();
+ private final IndexOutput meta, binarizedVectorData;
+ private final FlatVectorsWriter rawVectorDelegate;
+ private final ES818BinaryFlatVectorsScorer vectorsScorer;
+ private boolean finished;
+
+ /**
+ * Sole constructor
+ *
+ * @param vectorsScorer the scorer to use for scoring vectors
+ */
+ protected ES818BinaryQuantizedVectorsWriter(
+ ES818BinaryFlatVectorsScorer vectorsScorer,
+ FlatVectorsWriter rawVectorDelegate,
+ SegmentWriteState state
+ ) throws IOException {
+ super(vectorsScorer);
+ this.vectorsScorer = vectorsScorer;
+ this.segmentWriteState = state;
+ String metaFileName = IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ ES818BinaryQuantizedVectorsFormat.META_EXTENSION
+ );
+
+ String binarizedVectorDataFileName = IndexFileNames.segmentFileName(
+ state.segmentInfo.name,
+ state.segmentSuffix,
+ ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION
+ );
+ this.rawVectorDelegate = rawVectorDelegate;
+ boolean success = false;
+ try {
+ meta = state.directory.createOutput(metaFileName, state.context);
+ binarizedVectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context);
+
+ CodecUtil.writeIndexHeader(
+ meta,
+ ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME,
+ ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix
+ );
+ CodecUtil.writeIndexHeader(
+ binarizedVectorData,
+ ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
+ ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT,
+ state.segmentInfo.getId(),
+ state.segmentSuffix
+ );
+ success = true;
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(this);
+ }
+ }
+ }
+
+ @Override
+ public FlatFieldVectorsWriter> addField(FieldInfo fieldInfo) throws IOException {
+ FlatFieldVectorsWriter> rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo);
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ @SuppressWarnings("unchecked")
+ FieldWriter fieldWriter = new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate);
+ fields.add(fieldWriter);
+ return fieldWriter;
+ }
+ return rawVectorDelegate;
+ }
+
+ @Override
+ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
+ rawVectorDelegate.flush(maxDoc, sortMap);
+ for (FieldWriter field : fields) {
+ // after raw vectors are written, normalize vectors for clustering and quantization
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ field.normalizeVectors();
+ }
+ final float[] clusterCenter;
+ int vectorCount = field.flatFieldVectorsWriter.getVectors().size();
+ clusterCenter = new float[field.dimensionSums.length];
+ if (vectorCount > 0) {
+ for (int i = 0; i < field.dimensionSums.length; i++) {
+ clusterCenter[i] = field.dimensionSums[i] / vectorCount;
+ }
+ if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) {
+ VectorUtil.l2normalize(clusterCenter);
+ }
+ }
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction());
+ if (sortMap == null) {
+ writeField(field, clusterCenter, maxDoc, quantizer);
+ } else {
+ writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer);
+ }
+ field.finish();
+ }
+ }
+
+ private void writeField(FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer)
+ throws IOException {
+ // write vector values
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ writeBinarizedVectors(fieldData, clusterCenter, quantizer);
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp = fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0;
+
+ writeMeta(
+ fieldData.fieldInfo,
+ maxDoc,
+ vectorDataOffset,
+ vectorDataLength,
+ clusterCenter,
+ centroidDp,
+ fieldData.getDocsWithFieldSet()
+ );
+ }
+
+ private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer)
+ throws IOException {
+ int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
+ byte[] quantizationScratch = new byte[discreteDims];
+ byte[] vector = new byte[discreteDims / 8];
+ for (int i = 0; i < fieldData.getVectors().size(); i++) {
+ float[] v = fieldData.getVectors().get(i);
+ OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize(
+ v,
+ quantizationScratch,
+ (byte) 1,
+ clusterCenter
+ );
+ BQVectorUtils.packAsBinary(quantizationScratch, vector);
+ binarizedVectorData.writeBytes(vector, vector.length);
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
+ binarizedVectorData.writeShort((short) corrections.quantizedComponentSum());
+ }
+ }
+
+ private void writeSortingField(
+ FieldWriter fieldData,
+ float[] clusterCenter,
+ int maxDoc,
+ Sorter.DocMap sortMap,
+ OptimizedScalarQuantizer scalarQuantizer
+ ) throws IOException {
+ final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord
+
+ DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
+ mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField);
+
+ // write vector values
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer);
+ long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+
+ float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter);
+ writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, quantizedVectorLength, clusterCenter, centroidDp, newDocsWithField);
+ }
+
+ private void writeSortedBinarizedVectors(
+ FieldWriter fieldData,
+ float[] clusterCenter,
+ int[] ordMap,
+ OptimizedScalarQuantizer scalarQuantizer
+ ) throws IOException {
+ int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64);
+ byte[] quantizationScratch = new byte[discreteDims];
+ byte[] vector = new byte[discreteDims / 8];
+ for (int ordinal : ordMap) {
+ float[] v = fieldData.getVectors().get(ordinal);
+ OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize(
+ v,
+ quantizationScratch,
+ (byte) 1,
+ clusterCenter
+ );
+ BQVectorUtils.packAsBinary(quantizationScratch, vector);
+ binarizedVectorData.writeBytes(vector, vector.length);
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
+ binarizedVectorData.writeShort((short) corrections.quantizedComponentSum());
+ }
+ }
+
+ private void writeMeta(
+ FieldInfo field,
+ int maxDoc,
+ long vectorDataOffset,
+ long vectorDataLength,
+ float[] clusterCenter,
+ float centroidDp,
+ DocsWithFieldSet docsWithField
+ ) throws IOException {
+ meta.writeInt(field.number);
+ meta.writeInt(field.getVectorEncoding().ordinal());
+ meta.writeInt(field.getVectorSimilarityFunction().ordinal());
+ meta.writeVInt(field.getVectorDimension());
+ meta.writeVLong(vectorDataOffset);
+ meta.writeVLong(vectorDataLength);
+ int count = docsWithField.cardinality();
+ meta.writeVInt(count);
+ if (count > 0) {
+ final ByteBuffer buffer = ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
+ buffer.asFloatBuffer().put(clusterCenter);
+ meta.writeBytes(buffer.array(), buffer.array().length);
+ meta.writeInt(Float.floatToIntBits(centroidDp));
+ }
+ OrdToDocDISIReaderConfiguration.writeStoredMeta(
+ DIRECT_MONOTONIC_BLOCK_SHIFT,
+ meta,
+ binarizedVectorData,
+ count,
+ maxDoc,
+ docsWithField
+ );
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ throw new IllegalStateException("already finished");
+ }
+ finished = true;
+ rawVectorDelegate.finish();
+ if (meta != null) {
+ // write end of fields marker
+ meta.writeInt(-1);
+ CodecUtil.writeFooter(meta);
+ }
+ if (binarizedVectorData != null) {
+ CodecUtil.writeFooter(binarizedVectorData);
+ }
+ }
+
+ @Override
+ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ final float[] centroid;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ BinarizedFloatVectorValues binarizedVectorValues = new BinarizedFloatVectorValues(
+ floatVectorValues,
+ new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()),
+ centroid
+ );
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ DocsWithFieldSet docsWithField = writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues);
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ float centroidDp = docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ centroidDp,
+ docsWithField
+ );
+ } else {
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ }
+ }
+
+ static DocsWithFieldSet writeBinarizedVectorAndQueryData(
+ IndexOutput binarizedVectorData,
+ IndexOutput binarizedQueryData,
+ FloatVectorValues floatVectorValues,
+ float[] centroid,
+ OptimizedScalarQuantizer binaryQuantizer
+ ) throws IOException {
+ int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64);
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()];
+ byte[] toIndex = new byte[discretizedDimension / 8];
+ byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY];
+ KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ // write index vector
+ OptimizedScalarQuantizer.QuantizationResult[] r = binaryQuantizer.multiScalarQuantize(
+ floatVectorValues.vectorValue(iterator.index()),
+ quantizationScratch,
+ new byte[] { 1, 4 },
+ centroid
+ );
+ // pack and store document bit vector
+ BQVectorUtils.packAsBinary(quantizationScratch[0], toIndex);
+ binarizedVectorData.writeBytes(toIndex, toIndex.length);
+ binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval()));
+ binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection()));
+ assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff;
+ binarizedVectorData.writeShort((short) r[0].quantizedComponentSum());
+ docsWithField.add(docV);
+
+ // pack and store the 4bit query vector
+ BQSpaceUtils.transposeHalfByte(quantizationScratch[1], toQuery);
+ binarizedQueryData.writeBytes(toQuery, toQuery.length);
+ binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval()));
+ binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval()));
+ binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection()));
+ assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff;
+ binarizedQueryData.writeShort((short) r[1].quantizedComponentSum());
+ }
+ return docsWithField;
+ }
+
+ static DocsWithFieldSet writeBinarizedVectorData(IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues)
+ throws IOException {
+ DocsWithFieldSet docsWithField = new DocsWithFieldSet();
+ KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator();
+ for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
+ // write vector
+ byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index());
+ output.writeBytes(binaryValue, binaryValue.length);
+ OptimizedScalarQuantizer.QuantizationResult corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index());
+ output.writeInt(Float.floatToIntBits(corrections.lowerInterval()));
+ output.writeInt(Float.floatToIntBits(corrections.upperInterval()));
+ output.writeInt(Float.floatToIntBits(corrections.additionalCorrection()));
+ assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff;
+ output.writeShort((short) corrections.quantizedComponentSum());
+ docsWithField.add(docV);
+ }
+ return docsWithField;
+ }
+
+ @Override
+ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
+ if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
+ final float[] centroid;
+ final float cDotC;
+ final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()];
+ int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid);
+
+ // Don't need access to the random vectors, we can just use the merged
+ rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
+ centroid = mergedCentroid;
+ cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0;
+ if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) {
+ segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount);
+ }
+ return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC);
+ }
+ return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState);
+ }
+
+ private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
+ SegmentWriteState segmentWriteState,
+ FieldInfo fieldInfo,
+ MergeState mergeState,
+ float[] centroid,
+ float cDotC
+ ) throws IOException {
+ long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES);
+ final IndexOutput tempQuantizedVectorData = segmentWriteState.directory.createTempOutput(
+ binarizedVectorData.getName(),
+ "temp",
+ segmentWriteState.context
+ );
+ final IndexOutput tempScoreQuantizedVectorData = segmentWriteState.directory.createTempOutput(
+ binarizedVectorData.getName(),
+ "score_temp",
+ segmentWriteState.context
+ );
+ IndexInput binarizedDataInput = null;
+ IndexInput binarizedScoreDataInput = null;
+ boolean success = false;
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
+ try {
+ FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
+ }
+ DocsWithFieldSet docsWithField = writeBinarizedVectorAndQueryData(
+ tempQuantizedVectorData,
+ tempScoreQuantizedVectorData,
+ floatVectorValues,
+ centroid,
+ quantizer
+ );
+ CodecUtil.writeFooter(tempQuantizedVectorData);
+ IOUtils.close(tempQuantizedVectorData);
+ binarizedDataInput = segmentWriteState.directory.openInput(tempQuantizedVectorData.getName(), segmentWriteState.context);
+ binarizedVectorData.copyBytes(binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength());
+ long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset;
+ CodecUtil.retrieveChecksum(binarizedDataInput);
+ CodecUtil.writeFooter(tempScoreQuantizedVectorData);
+ IOUtils.close(tempScoreQuantizedVectorData);
+ binarizedScoreDataInput = segmentWriteState.directory.openInput(
+ tempScoreQuantizedVectorData.getName(),
+ segmentWriteState.context
+ );
+ writeMeta(
+ fieldInfo,
+ segmentWriteState.segmentInfo.maxDoc(),
+ vectorDataOffset,
+ vectorDataLength,
+ centroid,
+ cDotC,
+ docsWithField
+ );
+ success = true;
+ final IndexInput finalBinarizedDataInput = binarizedDataInput;
+ final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput;
+ OffHeapBinarizedVectorValues vectorValues = new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues(
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality(),
+ centroid,
+ cDotC,
+ quantizer,
+ fieldInfo.getVectorSimilarityFunction(),
+ vectorsScorer,
+ finalBinarizedDataInput
+ );
+ RandomVectorScorerSupplier scorerSupplier = vectorsScorer.getRandomVectorScorerSupplier(
+ fieldInfo.getVectorSimilarityFunction(),
+ new OffHeapBinarizedQueryVectorValues(
+ finalBinarizedScoreDataInput,
+ fieldInfo.getVectorDimension(),
+ docsWithField.cardinality()
+ ),
+ vectorValues
+ );
+ return new BinarizedCloseableRandomVectorScorerSupplier(scorerSupplier, vectorValues, () -> {
+ IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput);
+ IOUtils.deleteFilesIgnoringExceptions(
+ segmentWriteState.directory,
+ tempQuantizedVectorData.getName(),
+ tempScoreQuantizedVectorData.getName()
+ );
+ });
+ } finally {
+ if (success == false) {
+ IOUtils.closeWhileHandlingException(
+ tempQuantizedVectorData,
+ tempScoreQuantizedVectorData,
+ binarizedDataInput,
+ binarizedScoreDataInput
+ );
+ IOUtils.deleteFilesIgnoringExceptions(
+ segmentWriteState.directory,
+ tempQuantizedVectorData.getName(),
+ tempScoreQuantizedVectorData.getName()
+ );
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(meta, binarizedVectorData, rawVectorDelegate);
+ }
+
+ static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) {
+ if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
+ vectorsReader = candidateReader.getFieldReader(fieldName);
+ }
+ if (vectorsReader instanceof ES818BinaryQuantizedVectorsReader reader) {
+ return reader.getCentroid(fieldName);
+ }
+ return null;
+ }
+
+ static int mergeAndRecalculateCentroids(MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException {
+ boolean recalculate = false;
+ int totalVectorCount = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) {
+ continue;
+ }
+ float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name);
+ int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size();
+ if (vectorCount == 0) {
+ continue;
+ }
+ totalVectorCount += vectorCount;
+ // If there aren't centroids, or previously clustered with more than one cluster
+ // or if there are deleted docs, we must recalculate the centroid
+ if (centroid == null || mergeState.liveDocs[i] != null) {
+ recalculate = true;
+ break;
+ }
+ for (int j = 0; j < centroid.length; j++) {
+ mergedCentroid[j] += centroid[j] * vectorCount;
+ }
+ }
+ if (recalculate) {
+ return calculateCentroid(mergeState, fieldInfo, mergedCentroid);
+ } else {
+ for (int j = 0; j < mergedCentroid.length; j++) {
+ mergedCentroid[j] = mergedCentroid[j] / totalVectorCount;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(mergedCentroid);
+ }
+ return totalVectorCount;
+ }
+ }
+
+ static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) throws IOException {
+ assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32);
+ // clear out the centroid
+ Arrays.fill(centroid, 0);
+ int count = 0;
+ for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
+ KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
+ if (knnVectorsReader == null) continue;
+ FloatVectorValues vectorValues = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
+ if (vectorValues == null) {
+ continue;
+ }
+ KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
+ for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
+ ++count;
+ float[] vector = vectorValues.vectorValue(iterator.index());
+ // TODO Panama sum
+ for (int j = 0; j < vector.length; j++) {
+ centroid[j] += vector[j];
+ }
+ }
+ }
+ if (count == 0) {
+ return count;
+ }
+ // TODO Panama div
+ for (int i = 0; i < centroid.length; i++) {
+ centroid[i] /= count;
+ }
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ VectorUtil.l2normalize(centroid);
+ }
+ return count;
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long total = SHALLOW_RAM_BYTES_USED;
+ for (FieldWriter field : fields) {
+ // the field tracks the delegate field usage
+ total += field.ramBytesUsed();
+ }
+ return total;
+ }
+
+ static class FieldWriter extends FlatFieldVectorsWriter {
+ private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
+ private final FieldInfo fieldInfo;
+ private boolean finished;
+ private final FlatFieldVectorsWriter flatFieldVectorsWriter;
+ private final float[] dimensionSums;
+ private final FloatArrayList magnitudes = new FloatArrayList();
+
+ FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) {
+ this.fieldInfo = fieldInfo;
+ this.flatFieldVectorsWriter = flatFieldVectorsWriter;
+ this.dimensionSums = new float[fieldInfo.getVectorDimension()];
+ }
+
+ @Override
+ public List getVectors() {
+ return flatFieldVectorsWriter.getVectors();
+ }
+
+ public void normalizeVectors() {
+ for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) {
+ float[] vector = flatFieldVectorsWriter.getVectors().get(i);
+ float magnitude = magnitudes.get(i);
+ for (int j = 0; j < vector.length; j++) {
+ vector[j] /= magnitude;
+ }
+ }
+ }
+
+ @Override
+ public DocsWithFieldSet getDocsWithFieldSet() {
+ return flatFieldVectorsWriter.getDocsWithFieldSet();
+ }
+
+ @Override
+ public void finish() throws IOException {
+ if (finished) {
+ return;
+ }
+ assert flatFieldVectorsWriter.isFinished();
+ finished = true;
+ }
+
+ @Override
+ public boolean isFinished() {
+ return finished && flatFieldVectorsWriter.isFinished();
+ }
+
+ @Override
+ public void addValue(int docID, float[] vectorValue) throws IOException {
+ flatFieldVectorsWriter.addValue(docID, vectorValue);
+ if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
+ float dp = VectorUtil.dotProduct(vectorValue, vectorValue);
+ float divisor = (float) Math.sqrt(dp);
+ magnitudes.add(divisor);
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += (vectorValue[i] / divisor);
+ }
+ } else {
+ for (int i = 0; i < vectorValue.length; i++) {
+ dimensionSums[i] += vectorValue[i];
+ }
+ }
+ }
+
+ @Override
+ public float[] copyValue(float[] vectorValue) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long ramBytesUsed() {
+ long size = SHALLOW_SIZE;
+ size += flatFieldVectorsWriter.ramBytesUsed();
+ size += magnitudes.ramBytesUsed();
+ return size;
+ }
+ }
+
+ // When accessing vectorValue method, targerOrd here means a row ordinal.
+ static class OffHeapBinarizedQueryVectorValues {
+ private final IndexInput slice;
+ private final int dimension;
+ private final int size;
+ protected final byte[] binaryValue;
+ protected final ByteBuffer byteBuffer;
+ private final int byteSize;
+ protected final float[] correctiveValues;
+ private int lastOrd = -1;
+ private int quantizedComponentSum;
+
+ OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) {
+ this.slice = data;
+ this.dimension = dimension;
+ this.size = size;
+ // 4x the quantized binary dimensions
+ int binaryDimensions = (BQVectorUtils.discretize(dimension, 64) / 8) * BQSpaceUtils.B_QUERY;
+ this.byteBuffer = ByteBuffer.allocate(binaryDimensions);
+ this.binaryValue = byteBuffer.array();
+ // + 1 for the quantized sum
+ this.correctiveValues = new float[3];
+ this.byteSize = binaryDimensions + Float.BYTES * 3 + Short.BYTES;
+ }
+
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0],
+ correctiveValues[1],
+ correctiveValues[2],
+ quantizedComponentSum
+ );
+ }
+ vectorValue(targetOrd);
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0],
+ correctiveValues[1],
+ correctiveValues[2],
+ quantizedComponentSum
+ );
+ }
+
+ public int size() {
+ return size;
+ }
+
+ public int dimension() {
+ return dimension;
+ }
+
+ public OffHeapBinarizedQueryVectorValues copy() throws IOException {
+ return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size);
+ }
+
+ public IndexInput getSlice() {
+ return slice;
+ }
+
+ public byte[] vectorValue(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return binaryValue;
+ }
+ slice.seek((long) targetOrd * byteSize);
+ slice.readBytes(binaryValue, 0, binaryValue.length);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = Short.toUnsignedInt(slice.readShort());
+ lastOrd = targetOrd;
+ return binaryValue;
+ }
+ }
+
+ static class BinarizedFloatVectorValues extends BinarizedByteVectorValues {
+ private OptimizedScalarQuantizer.QuantizationResult corrections;
+ private final byte[] binarized;
+ private final byte[] initQuantized;
+ private final float[] centroid;
+ private final FloatVectorValues values;
+ private final OptimizedScalarQuantizer quantizer;
+
+ private int lastOrd = -1;
+
+ BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) {
+ this.values = delegate;
+ this.quantizer = quantizer;
+ this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8];
+ this.initQuantized = new byte[delegate.dimension()];
+ this.centroid = centroid;
+ }
+
+ @Override
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
+ if (ord != lastOrd) {
+ throw new IllegalStateException(
+ "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd
+ );
+ }
+ return corrections;
+ }
+
+ @Override
+ public byte[] vectorValue(int ord) throws IOException {
+ if (ord != lastOrd) {
+ binarize(ord);
+ lastOrd = ord;
+ }
+ return binarized;
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public OptimizedScalarQuantizer getQuantizer() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float[] getCentroid() throws IOException {
+ return centroid;
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public BinarizedByteVectorValues copy() throws IOException {
+ return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid);
+ }
+
+ private void binarize(int ord) throws IOException {
+ corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, (byte) 1, centroid);
+ BQVectorUtils.packAsBinary(initQuantized, binarized);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+ }
+
+ static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier {
+ private final RandomVectorScorerSupplier supplier;
+ private final KnnVectorValues vectorValues;
+ private final Closeable onClose;
+
+ BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) {
+ this.supplier = supplier;
+ this.onClose = onClose;
+ this.vectorValues = vectorValues;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ return supplier.scorer(ord);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return supplier.copy();
+ }
+
+ @Override
+ public void close() throws IOException {
+ onClose.close();
+ }
+
+ @Override
+ public int totalVectorCount() {
+ return vectorValues.size();
+ }
+ }
+
+ static final class NormalizedFloatVectorValues extends FloatVectorValues {
+ private final FloatVectorValues values;
+ private final float[] normalizedVector;
+
+ NormalizedFloatVectorValues(FloatVectorValues values) {
+ this.values = values;
+ this.normalizedVector = new float[values.dimension()];
+ }
+
+ @Override
+ public int dimension() {
+ return values.dimension();
+ }
+
+ @Override
+ public int size() {
+ return values.size();
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return values.ordToDoc(ord);
+ }
+
+ @Override
+ public float[] vectorValue(int ord) throws IOException {
+ System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
+ VectorUtil.l2normalize(normalizedVector);
+ return normalizedVector;
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return values.iterator();
+ }
+
+ @Override
+ public NormalizedFloatVectorValues copy() throws IOException {
+ return new NormalizedFloatVectorValues(values.copy());
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java
new file mode 100644
index 0000000000000..56942017c3cef
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java
@@ -0,0 +1,145 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
+import org.apache.lucene.index.SegmentReadState;
+import org.apache.lucene.index.SegmentWriteState;
+import org.apache.lucene.search.TaskExecutor;
+import org.apache.lucene.util.hnsw.HnswGraph;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
+import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
+
+/**
+ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
+ */
+public class ES818HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat {
+
+ public static final String NAME = "ES818HnswBinaryQuantizedVectorsFormat";
+
+ /**
+ * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
+ * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
+ */
+ private final int maxConn;
+
+ /**
+ * The number of candidate neighbors to track while searching the graph for each newly inserted
+ * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
+ * for details.
+ */
+ private final int beamWidth;
+
+ /** The format for storing, reading, merging vectors on disk */
+ private static final FlatVectorsFormat flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat();
+
+ private final int numMergeWorkers;
+ private final TaskExecutor mergeExec;
+
+ /** Constructs a format using default graph construction parameters */
+ public ES818HnswBinaryQuantizedVectorsFormat() {
+ this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ */
+ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) {
+ this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ /**
+ * Constructs a format using the given graph construction parameters and scalar quantization.
+ *
+ * @param maxConn the maximum number of connections to a node in the HNSW graph
+ * @param beamWidth the size of the queue maintained during graph construction.
+ * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
+ * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
+ * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
+ * generated by this format to do the merge
+ */
+ public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
+ super(NAME);
+ if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
+ throw new IllegalArgumentException(
+ "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
+ );
+ }
+ if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
+ throw new IllegalArgumentException(
+ "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
+ );
+ }
+ this.maxConn = maxConn;
+ this.beamWidth = beamWidth;
+ if (numMergeWorkers == 1 && mergeExec != null) {
+ throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
+ }
+ this.numMergeWorkers = numMergeWorkers;
+ if (mergeExec != null) {
+ this.mergeExec = new TaskExecutor(mergeExec);
+ } else {
+ this.mergeExec = null;
+ }
+ }
+
+ @Override
+ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
+ }
+
+ @Override
+ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
+ return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
+ }
+
+ @Override
+ public int getMaxDimensions(String fieldName) {
+ return MAX_DIMS_COUNT;
+ }
+
+ @Override
+ public String toString() {
+ return "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn="
+ + maxConn
+ + ", beamWidth="
+ + beamWidth
+ + ", flatVectorFormat="
+ + flatVectorsFormat
+ + ")";
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java
new file mode 100644
index 0000000000000..72333169b39b5
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java
@@ -0,0 +1,371 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.codecs.lucene90.IndexedDISI;
+import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.VectorScorer;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.packed.DirectMonotonicReader;
+import org.elasticsearch.index.codec.vectors.BQVectorUtils;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/** Binarized vector values loaded from off-heap */
+abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues {
+
+ final int dimension;
+ final int size;
+ final int numBytes;
+ final VectorSimilarityFunction similarityFunction;
+ final FlatVectorsScorer vectorsScorer;
+
+ final IndexInput slice;
+ final byte[] binaryValue;
+ final ByteBuffer byteBuffer;
+ final int byteSize;
+ private int lastOrd = -1;
+ final float[] correctiveValues;
+ int quantizedComponentSum;
+ final OptimizedScalarQuantizer binaryQuantizer;
+ final float[] centroid;
+ final float centroidDp;
+ private final int discretizedDimensions;
+
+ OffHeapBinarizedVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer quantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice
+ ) {
+ this.dimension = dimension;
+ this.size = size;
+ this.similarityFunction = similarityFunction;
+ this.vectorsScorer = vectorsScorer;
+ this.slice = slice;
+ this.centroid = centroid;
+ this.centroidDp = centroidDp;
+ this.numBytes = BQVectorUtils.discretize(dimension, 64) / 8;
+ this.correctiveValues = new float[3];
+ this.byteSize = numBytes + (Float.BYTES * 3) + Short.BYTES;
+ this.byteBuffer = ByteBuffer.allocate(numBytes);
+ this.binaryValue = byteBuffer.array();
+ this.binaryQuantizer = quantizer;
+ this.discretizedDimensions = BQVectorUtils.discretize(dimension, 64);
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public byte[] vectorValue(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return binaryValue;
+ }
+ slice.seek((long) targetOrd * byteSize);
+ slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = Short.toUnsignedInt(slice.readShort());
+ lastOrd = targetOrd;
+ return binaryValue;
+ }
+
+ @Override
+ public int discretizedDimensions() {
+ return discretizedDimensions;
+ }
+
+ @Override
+ public float getCentroidDP() {
+ return centroidDp;
+ }
+
+ @Override
+ public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException {
+ if (lastOrd == targetOrd) {
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0],
+ correctiveValues[1],
+ correctiveValues[2],
+ quantizedComponentSum
+ );
+ }
+ slice.seek(((long) targetOrd * byteSize) + numBytes);
+ slice.readFloats(correctiveValues, 0, 3);
+ quantizedComponentSum = Short.toUnsignedInt(slice.readShort());
+ return new OptimizedScalarQuantizer.QuantizationResult(
+ correctiveValues[0],
+ correctiveValues[1],
+ correctiveValues[2],
+ quantizedComponentSum
+ );
+ }
+
+ @Override
+ public OptimizedScalarQuantizer getQuantizer() {
+ return binaryQuantizer;
+ }
+
+ @Override
+ public float[] getCentroid() {
+ return centroid;
+ }
+
+ @Override
+ public int getVectorByteLength() {
+ return numBytes;
+ }
+
+ static OffHeapBinarizedVectorValues load(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ OptimizedScalarQuantizer binaryQuantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ float[] centroid,
+ float centroidDp,
+ long quantizedVectorDataOffset,
+ long quantizedVectorDataLength,
+ IndexInput vectorData
+ ) throws IOException {
+ if (configuration.isEmpty()) {
+ return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer);
+ }
+ assert centroid != null;
+ IndexInput bytesSlice = vectorData.slice("quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
+ if (configuration.isDense()) {
+ return new DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice
+ );
+ } else {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ vectorData,
+ similarityFunction,
+ vectorsScorer,
+ bytesSlice
+ );
+ }
+ }
+
+ /** Dense off-heap binarized vector values */
+ static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ DenseOffHeapVectorValues(
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer binaryQuantizer,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice
+ ) {
+ super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice);
+ }
+
+ @Override
+ public DenseOffHeapVectorValues copy() throws IOException {
+ return new DenseOffHeapVectorValues(
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone()
+ );
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return acceptDocs;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ DenseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+ }
+
+ /** Sparse off-heap binarized vector values */
+ private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ private final DirectMonotonicReader ordToDoc;
+ private final IndexedDISI disi;
+ // dataIn was used to init a new IndexedDIS for #randomAccess()
+ private final IndexInput dataIn;
+ private final OrdToDocDISIReaderConfiguration configuration;
+
+ SparseOffHeapVectorValues(
+ OrdToDocDISIReaderConfiguration configuration,
+ int dimension,
+ int size,
+ float[] centroid,
+ float centroidDp,
+ OptimizedScalarQuantizer binaryQuantizer,
+ IndexInput dataIn,
+ VectorSimilarityFunction similarityFunction,
+ FlatVectorsScorer vectorsScorer,
+ IndexInput slice
+ ) throws IOException {
+ super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice);
+ this.configuration = configuration;
+ this.dataIn = dataIn;
+ this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
+ this.disi = configuration.getIndexedDISI(dataIn);
+ }
+
+ @Override
+ public SparseOffHeapVectorValues copy() throws IOException {
+ return new SparseOffHeapVectorValues(
+ configuration,
+ dimension,
+ size,
+ centroid,
+ centroidDp,
+ binaryQuantizer,
+ dataIn,
+ similarityFunction,
+ vectorsScorer,
+ slice.clone()
+ );
+ }
+
+ @Override
+ public int ordToDoc(int ord) {
+ return (int) ordToDoc.get(ord);
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ if (acceptDocs == null) {
+ return null;
+ }
+ return new Bits() {
+ @Override
+ public boolean get(int index) {
+ return acceptDocs.get(ordToDoc(index));
+ }
+
+ @Override
+ public int length() {
+ return size;
+ }
+ };
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return IndexedDISI.asDocIndexIterator(disi);
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) throws IOException {
+ SparseOffHeapVectorValues copy = copy();
+ DocIndexIterator iterator = copy.iterator();
+ RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
+ return new VectorScorer() {
+ @Override
+ public float score() throws IOException {
+ return scorer.score(iterator.index());
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return iterator;
+ }
+ };
+ }
+ }
+
+ private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues {
+ EmptyOffHeapVectorValues(int dimension, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer) {
+ super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null);
+ }
+
+ @Override
+ public DocIndexIterator iterator() {
+ return createDenseIterator();
+ }
+
+ @Override
+ public DenseOffHeapVectorValues copy() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Bits getAcceptOrds(Bits acceptDocs) {
+ return null;
+ }
+
+ @Override
+ public VectorScorer scorer(float[] target) {
+ return null;
+ }
+ }
+}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java
new file mode 100644
index 0000000000000..d5ed38cb5a0e1
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java
@@ -0,0 +1,246 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.VectorUtil;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+
+class OptimizedScalarQuantizer {
+ // The initial interval is set to the minimum MSE grid for each number of bits
+ // these starting points are derived from the optimal MSE grid for a uniform distribution
+ static final float[][] MINIMUM_MSE_GRID = new float[][] {
+ { -0.798f, 0.798f },
+ { -1.493f, 1.493f },
+ { -2.051f, 2.051f },
+ { -2.514f, 2.514f },
+ { -2.916f, 2.916f },
+ { -3.278f, 3.278f },
+ { -3.611f, 3.611f },
+ { -3.922f, 3.922f } };
+ private static final float DEFAULT_LAMBDA = 0.1f;
+ private static final int DEFAULT_ITERS = 5;
+ private final VectorSimilarityFunction similarityFunction;
+ private final float lambda;
+ private final int iters;
+
+ OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
+ this.similarityFunction = similarityFunction;
+ this.lambda = lambda;
+ this.iters = iters;
+ }
+
+ OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
+ this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS);
+ }
+
+ public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {}
+
+ public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) {
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+ assert bits.length == destinations.length;
+ float[] intervalScratch = new float[2];
+ double vecMean = 0;
+ double vecVar = 0;
+ float norm2 = 0;
+ float centroidDot = 0;
+ float min = Float.MAX_VALUE;
+ float max = -Float.MAX_VALUE;
+ for (int i = 0; i < vector.length; ++i) {
+ if (similarityFunction != EUCLIDEAN) {
+ centroidDot += vector[i] * centroid[i];
+ }
+ vector[i] = vector[i] - centroid[i];
+ min = Math.min(min, vector[i]);
+ max = Math.max(max, vector[i]);
+ norm2 += (vector[i] * vector[i]);
+ double delta = vector[i] - vecMean;
+ vecMean += delta / (i + 1);
+ vecVar += delta * (vector[i] - vecMean);
+ }
+ vecVar /= vector.length;
+ double vecStd = Math.sqrt(vecVar);
+ QuantizationResult[] results = new QuantizationResult[bits.length];
+ for (int i = 0; i < bits.length; ++i) {
+ assert bits[i] > 0 && bits[i] <= 8;
+ int points = (1 << bits[i]);
+ // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
+ intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max);
+ intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max);
+ optimizeIntervals(intervalScratch, vector, norm2, points);
+ float nSteps = ((1 << bits[i]) - 1);
+ float a = intervalScratch[0];
+ float b = intervalScratch[1];
+ float step = (b - a) / nSteps;
+ int sumQuery = 0;
+ // Now we have the optimized intervals, quantize the vector
+ for (int h = 0; h < vector.length; h++) {
+ float xi = (float) clamp(vector[h], a, b);
+ int assignment = Math.round((xi - a) / step);
+ sumQuery += assignment;
+ destinations[i][h] = (byte) assignment;
+ }
+ results[i] = new QuantizationResult(
+ intervalScratch[0],
+ intervalScratch[1],
+ similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
+ sumQuery
+ );
+ }
+ return results;
+ }
+
+ public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
+ assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
+ assert vector.length <= destination.length;
+ assert bits > 0 && bits <= 8;
+ float[] intervalScratch = new float[2];
+ int points = 1 << bits;
+ double vecMean = 0;
+ double vecVar = 0;
+ float norm2 = 0;
+ float centroidDot = 0;
+ float min = Float.MAX_VALUE;
+ float max = -Float.MAX_VALUE;
+ for (int i = 0; i < vector.length; ++i) {
+ if (similarityFunction != EUCLIDEAN) {
+ centroidDot += vector[i] * centroid[i];
+ }
+ vector[i] = vector[i] - centroid[i];
+ min = Math.min(min, vector[i]);
+ max = Math.max(max, vector[i]);
+ norm2 += (vector[i] * vector[i]);
+ double delta = vector[i] - vecMean;
+ vecMean += delta / (i + 1);
+ vecVar += delta * (vector[i] - vecMean);
+ }
+ vecVar /= vector.length;
+ double vecStd = Math.sqrt(vecVar);
+ // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
+ intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max);
+ intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max);
+ optimizeIntervals(intervalScratch, vector, norm2, points);
+ float nSteps = ((1 << bits) - 1);
+ // Now we have the optimized intervals, quantize the vector
+ float a = intervalScratch[0];
+ float b = intervalScratch[1];
+ float step = (b - a) / nSteps;
+ int sumQuery = 0;
+ for (int h = 0; h < vector.length; h++) {
+ float xi = (float) clamp(vector[h], a, b);
+ int assignment = Math.round((xi - a) / step);
+ sumQuery += assignment;
+ destination[h] = (byte) assignment;
+ }
+ return new QuantizationResult(
+ intervalScratch[0],
+ intervalScratch[1],
+ similarityFunction == EUCLIDEAN ? norm2 : centroidDot,
+ sumQuery
+ );
+ }
+
+ /**
+ * Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw
+ * vector.
+ * @param vector raw vector
+ * @param interval interval to quantize the vector
+ * @param points number of quantization points
+ * @param norm2 squared norm of the vector
+ * @return the loss
+ */
+ private double loss(float[] vector, float[] interval, int points, float norm2) {
+ double a = interval[0];
+ double b = interval[1];
+ double step = ((b - a) / (points - 1.0F));
+ double stepInv = 1.0 / step;
+ double xe = 0.0;
+ double e = 0.0;
+ for (double xi : vector) {
+ // this is quantizing and then dequantizing the vector
+ double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv));
+ // how much does the de-quantized value differ from the original value
+ xe += xi * (xi - xiq);
+ e += (xi - xiq) * (xi - xiq);
+ }
+ return (1.0 - lambda) * xe * xe / norm2 + lambda * e;
+ }
+
+ /**
+ * Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization
+ * loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the
+ * loss increases.
+ * @param initInterval initial interval, the optimized interval will be stored here
+ * @param vector raw vector
+ * @param norm2 squared norm of the vector
+ * @param points number of quantization points
+ */
+ private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
+ double initialLoss = loss(vector, initInterval, points, norm2);
+ final float scale = (1.0f - lambda) / norm2;
+ if (Float.isFinite(scale) == false) {
+ return;
+ }
+ for (int i = 0; i < iters; ++i) {
+ float a = initInterval[0];
+ float b = initInterval[1];
+ float stepInv = (points - 1.0f) / (b - a);
+ // calculate the grid points for coordinate descent
+ double daa = 0;
+ double dab = 0;
+ double dbb = 0;
+ double dax = 0;
+ double dbx = 0;
+ for (float xi : vector) {
+ float k = Math.round((clamp(xi, a, b) - a) * stepInv);
+ float s = k / (points - 1);
+ daa += (1.0 - s) * (1.0 - s);
+ dab += (1.0 - s) * s;
+ dbb += s * s;
+ dax += xi * (1.0 - s);
+ dbx += xi * s;
+ }
+ double m0 = scale * dax * dax + lambda * daa;
+ double m1 = scale * dax * dbx + lambda * dab;
+ double m2 = scale * dbx * dbx + lambda * dbb;
+ // its possible that the determinant is 0, in which case we can't update the interval
+ double det = m0 * m2 - m1 * m1;
+ if (det == 0) {
+ return;
+ }
+ float aOpt = (float) ((m2 * dax - m1 * dbx) / det);
+ float bOpt = (float) ((m0 * dbx - m1 * dax) / det);
+ // If there is no change in the interval, we can stop
+ if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) {
+ return;
+ }
+ double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2);
+ // If the new loss is worse, don't update the interval and exit
+ // This optimization, unlike kMeans, does not always converge to better loss
+ // So exit if we are getting worse
+ if (newLoss > initialLoss) {
+ return;
+ }
+ // Update the interval and go again
+ initInterval[0] = aOpt;
+ initInterval[1] = bOpt;
+ initialLoss = newLoss;
+ }
+ }
+
+ private static double clamp(double x, double a, double b) {
+ return Math.min(Math.max(x, a), b);
+ }
+
+}
diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
index 0a6a24f727572..d780faad96f2d 100644
--- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
+++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
@@ -46,8 +46,8 @@
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat;
-import org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat;
-import org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat;
+import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
+import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
import org.elasticsearch.index.fielddata.FieldDataContext;
import org.elasticsearch.index.fielddata.IndexFieldData;
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
@@ -1788,7 +1788,7 @@ static class BBQHnswIndexOptions extends IndexOptions {
@Override
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
- return new ES816HnswBinaryQuantizedVectorsFormat(m, efConstruction);
+ return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction);
}
@Override
@@ -1836,7 +1836,7 @@ static class BBQFlatIndexOptions extends IndexOptions {
@Override
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
assert elementType == ElementType.FLOAT;
- return new ES816BinaryQuantizedVectorsFormat();
+ return new ES818BinaryQuantizedVectorsFormat();
}
@Override
diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java
index 794b30aa5aab2..57980321bdc3d 100644
--- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java
+++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java
@@ -44,6 +44,7 @@ private SearchCapabilities() {}
private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim_with_bugfix";
private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";
+ private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq";
public static final Set CAPABILITIES;
static {
@@ -55,6 +56,7 @@ private SearchCapabilities() {}
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);
capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS);
+ capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ);
if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) {
capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER);
capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS);
diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
index 389555e60b43b..cef8d09980814 100644
--- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
+++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat
@@ -5,3 +5,5 @@ org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat
org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat
org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat
org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat
+org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat
+org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java
index 9f9114c70b6db..270ad54e9a962 100644
--- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java
@@ -38,6 +38,32 @@ public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
private static float DELTA = Float.MIN_VALUE;
+ public void testPackAsBinary() {
+ // 5 bits
+ byte[] toPack = new byte[] { 1, 1, 0, 0, 1 };
+ byte[] packed = new byte[1];
+ BQVectorUtils.packAsBinary(toPack, packed);
+ assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed);
+
+ // 8 bits
+ toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 };
+ packed = new byte[1];
+ BQVectorUtils.packAsBinary(toPack, packed);
+ assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed);
+
+ // 10 bits
+ toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 };
+ packed = new byte[2];
+ BQVectorUtils.packAsBinary(toPack, packed);
+ assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed);
+
+ // 16 bits
+ toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 };
+ packed = new byte[2];
+ BQVectorUtils.packAsBinary(toPack, packed);
+ assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed);
+ }
+
public void testPadFloat() {
assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 4), DELTA);
assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 3), DELTA);
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java
new file mode 100644
index 0000000000000..0bebe16f468ce
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java
@@ -0,0 +1,256 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es816;
+
+import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.ArrayUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.hnsw.RandomVectorScorer;
+import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
+import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
+import org.elasticsearch.index.codec.vectors.BQVectorUtils;
+import org.elasticsearch.simdvec.ESVectorUtil;
+
+import java.io.IOException;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Vector scorer over binarized vector values */
+class ES816BinaryFlatRWVectorsScorer implements FlatVectorsScorer {
+ private final FlatVectorsScorer nonQuantizedDelegate;
+
+ ES816BinaryFlatRWVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
+ this.nonQuantizedDelegate = nonQuantizedDelegate;
+ }
+
+ @Override
+ public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ KnnVectorValues vectorValues
+ ) throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues) {
+ throw new UnsupportedOperationException(
+ "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format"
+ );
+ }
+ return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction,
+ KnnVectorValues vectorValues,
+ float[] target
+ ) throws IOException {
+ if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) {
+ BinaryQuantizer quantizer = binarizedVectors.getQuantizer();
+ float[] centroid = binarizedVectors.getCentroid();
+ // FIXME: precompute this once?
+ int discretizedDimensions = BQVectorUtils.discretize(target.length, 64);
+ if (similarityFunction == COSINE) {
+ float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length);
+ VectorUtil.l2normalize(copy);
+ target = copy;
+ }
+ byte[] quantized = new byte[BQSpaceUtils.B_QUERY * discretizedDimensions / 8];
+ BinaryQuantizer.QueryFactors factors = quantizer.quantizeForQuery(target, quantized, centroid);
+ BinaryQueryVector queryVector = new BinaryQueryVector(quantized, factors);
+ return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction);
+ }
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ @Override
+ public RandomVectorScorer getRandomVectorScorer(
+ VectorSimilarityFunction similarityFunction,
+ KnnVectorValues vectorValues,
+ byte[] target
+ ) throws IOException {
+ return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
+ }
+
+ RandomVectorScorerSupplier getRandomVectorScorerSupplier(
+ VectorSimilarityFunction similarityFunction,
+ ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors,
+ BinarizedByteVectorValues targetVectors
+ ) {
+ return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public String toString() {
+ return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")";
+ }
+
+ /** Vector scorer supplier over binarized vector values */
+ static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
+ private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ BinarizedRandomVectorScorerSupplier(
+ ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction
+ ) {
+ this.queryVectors = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ }
+
+ @Override
+ public RandomVectorScorer scorer(int ord) throws IOException {
+ byte[] vector = queryVectors.vectorValue(ord);
+ int quantizedSum = queryVectors.sumQuantizedValues(ord);
+ float distanceToCentroid = queryVectors.getCentroidDistance(ord);
+ float lower = queryVectors.getLower(ord);
+ float width = queryVectors.getWidth(ord);
+ float normVmC = 0f;
+ float vDotC = 0f;
+ if (similarityFunction != EUCLIDEAN) {
+ normVmC = queryVectors.getNormVmC(ord);
+ vDotC = queryVectors.getVDotC(ord);
+ }
+ BinaryQueryVector binaryQueryVector = new BinaryQueryVector(
+ vector,
+ new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC)
+ );
+ return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction);
+ }
+
+ @Override
+ public RandomVectorScorerSupplier copy() throws IOException {
+ return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction);
+ }
+ }
+
+ /** A binarized query representing its quantized form along with factors */
+ record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {}
+
+ /** Vector scorer over binarized vector values */
+ static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
+ private final BinaryQueryVector queryVector;
+ private final BinarizedByteVectorValues targetVectors;
+ private final VectorSimilarityFunction similarityFunction;
+
+ private final float sqrtDimensions;
+ private final float maxX1;
+
+ BinarizedRandomVectorScorer(
+ BinaryQueryVector queryVectors,
+ BinarizedByteVectorValues targetVectors,
+ VectorSimilarityFunction similarityFunction
+ ) {
+ super(targetVectors);
+ this.queryVector = queryVectors;
+ this.targetVectors = targetVectors;
+ this.similarityFunction = similarityFunction;
+ // FIXME: precompute this once?
+ this.sqrtDimensions = targetVectors.sqrtDimensions();
+ this.maxX1 = targetVectors.maxX1();
+ }
+
+ @Override
+ public float score(int targetOrd) throws IOException {
+ byte[] quantizedQuery = queryVector.vector();
+ int quantizedSum = queryVector.factors().quantizedSum();
+ float lower = queryVector.factors().lower();
+ float width = queryVector.factors().width();
+ float distanceToCentroid = queryVector.factors().distToC();
+ if (similarityFunction == EUCLIDEAN) {
+ return euclideanScore(targetOrd, sqrtDimensions, quantizedQuery, distanceToCentroid, lower, quantizedSum, width);
+ }
+
+ float vmC = queryVector.factors().normVmC();
+ float vDotC = queryVector.factors().vDotC();
+ float cDotC = targetVectors.getCentroidDP();
+ byte[] binaryCode = targetVectors.vectorValue(targetOrd);
+ float ooq = targetVectors.getOOQ(targetOrd);
+ float normOC = targetVectors.getNormOC(targetOrd);
+ float oDotC = targetVectors.getODotC(targetOrd);
+
+ float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
+
+ // FIXME: pre-compute these only once for each target vector
+ // ... pull this out or use a similar cache mechanism as do in score
+ float xbSum = (float) BQVectorUtils.popcount(binaryCode);
+ final float dist;
+ // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away
+ // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector
+ if (normOC == 0 || ooq == 0) {
+ dist = oDotC + vDotC - cDotC;
+ } else {
+ // If ||o-c|| != 0, we should assume that `ooq` is finite
+ assert Float.isFinite(ooq);
+ float estimatedDot = (2 * width / sqrtDimensions * qcDist + 2 * lower / sqrtDimensions * xbSum - width / sqrtDimensions
+ * quantizedSum - sqrtDimensions * lower) / ooq;
+ dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC;
+ }
+ assert Float.isFinite(dist);
+
+ float ooqSqr = (float) Math.pow(ooq, 2);
+ float errorBound = (float) (vmC * normOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr)));
+ float score = Float.isFinite(errorBound) ? dist - errorBound : dist;
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
+ }
+ return Math.max((1f + score) / 2f, 0);
+ }
+
+ private float euclideanScore(
+ int targetOrd,
+ float sqrtDimensions,
+ byte[] quantizedQuery,
+ float distanceToCentroid,
+ float lower,
+ int quantizedSum,
+ float width
+ ) throws IOException {
+ byte[] binaryCode = targetVectors.vectorValue(targetOrd);
+
+ // FIXME: pre-compute these only once for each target vector
+ // .. not sure how to enumerate the target ordinals but that's what we did in PoC
+ float targetDistToC = targetVectors.getCentroidDistance(targetOrd);
+ float x0 = targetVectors.getVectorMagnitude(targetOrd);
+ float sqrX = targetDistToC * targetDistToC;
+ double xX0 = targetDistToC / x0;
+
+ // TODO maybe store?
+ float xbSum = (float) BQVectorUtils.popcount(binaryCode);
+ float factorPPC = (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension()));
+ float factorIP = (float) (-2.0 / sqrtDimensions * xX0);
+
+ long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode);
+ float score = sqrX + distanceToCentroid + factorPPC * lower + (qcDist * 2 - quantizedSum) * factorIP * width;
+ float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC);
+ float error = 2.0f * maxX1 * projectionDist;
+ float y = (float) Math.sqrt(distanceToCentroid);
+ float errorBound = y * error;
+ if (Float.isFinite(errorBound)) {
+ score = score + errorBound;
+ }
+ return Math.max(1 / (1f + score), 0);
+ }
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java
index a75b9bc6064d1..ffe007be9799d 100644
--- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java
@@ -59,7 +59,7 @@ public void testScore() throws IOException {
short quantizedSum = (short) random().nextInt(0, 4097);
float normVmC = random().nextFloat(-1000f, 1000f);
float vDotC = random().nextFloat(-1000f, 1000f);
- ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector(
+ ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector(
vector,
new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC)
);
@@ -134,7 +134,7 @@ public int dimension() {
}
};
- ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer(
+ ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer(
queryVector,
targetVectors,
similarityFunction
@@ -217,7 +217,7 @@ public void testScoreEuclidean() throws IOException {
float vl = -57.883f;
float width = 9.972266f;
short quantizedSum = 795;
- ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector(
+ ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector(
vector,
new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, 0f, 0f)
);
@@ -420,7 +420,7 @@ public int dimension() {
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
- ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer(
+ ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer(
queryVector,
targetVectors,
similarityFunction
@@ -824,7 +824,7 @@ public void testScoreMIP() throws IOException {
float normVmC = 9.766797f;
float vDotC = 133.56123f;
float cDotC = 132.20227f;
- ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector(
+ ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector(
vector,
new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC)
);
@@ -1768,7 +1768,7 @@ public int dimension() {
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
- ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer(
+ ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer(
queryVector,
targetVectors,
similarityFunction
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java
new file mode 100644
index 0000000000000..c54903a94b54f
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java
@@ -0,0 +1,52 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es816;
+
+import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
+import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
+import org.apache.lucene.index.SegmentWriteState;
+
+import java.io.IOException;
+
+/**
+ * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
+ */
+public class ES816BinaryQuantizedRWVectorsFormat extends ES816BinaryQuantizedVectorsFormat {
+
+ private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
+ FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
+ );
+
+ private static final ES816BinaryFlatRWVectorsScorer scorer = new ES816BinaryFlatRWVectorsScorer(
+ FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
+ );
+
+ /** Creates a new instance with the default number of vectors per cluster. */
+ public ES816BinaryQuantizedRWVectorsFormat() {
+ super();
+ }
+
+ @Override
+ public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state);
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java
index 681f615653d40..48ba566353f5d 100644
--- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java
@@ -63,7 +63,7 @@ protected Codec getCodec() {
return new Lucene100Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new ES816BinaryQuantizedVectorsFormat();
+ return new ES816BinaryQuantizedRWVectorsFormat();
}
};
}
diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
similarity index 99%
rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
index 31ae977e81118..4d97235c5fae5 100644
--- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java
@@ -77,7 +77,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
private final List fields = new ArrayList<>();
private final IndexOutput meta, binarizedVectorData;
private final FlatVectorsWriter rawVectorDelegate;
- private final ES816BinaryFlatVectorsScorer vectorsScorer;
+ private final ES816BinaryFlatRWVectorsScorer vectorsScorer;
private boolean finished;
/**
@@ -86,7 +86,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter {
* @param vectorsScorer the scorer to use for scoring vectors
*/
protected ES816BinaryQuantizedVectorsWriter(
- ES816BinaryFlatVectorsScorer vectorsScorer,
+ ES816BinaryFlatRWVectorsScorer vectorsScorer,
FlatVectorsWriter rawVectorDelegate,
SegmentWriteState state
) throws IOException {
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java
new file mode 100644
index 0000000000000..e9bace72b591c
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java
@@ -0,0 +1,55 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es816;
+
+import org.apache.lucene.codecs.KnnVectorsWriter;
+import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
+import org.apache.lucene.index.SegmentWriteState;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
+
+class ES816HnswBinaryQuantizedRWVectorsFormat extends ES816HnswBinaryQuantizedVectorsFormat {
+
+ private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedRWVectorsFormat();
+
+ /** Constructs a format using default graph construction parameters */
+ ES816HnswBinaryQuantizedRWVectorsFormat() {
+ this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
+ }
+
+ ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth) {
+ this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
+ }
+
+ ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
+ super(maxConn, beamWidth, numMergeWorkers, mergeExec);
+ }
+
+ @Override
+ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
+ return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java
index a25fa2836ee34..03aa847f3a5d4 100644
--- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java
@@ -59,7 +59,7 @@ protected Codec getCodec() {
return new Lucene100Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
- return new ES816HnswBinaryQuantizedVectorsFormat();
+ return new ES816HnswBinaryQuantizedRWVectorsFormat();
}
};
}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java
new file mode 100644
index 0000000000000..397cc472592b6
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java
@@ -0,0 +1,181 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.lucene100.Lucene100Codec;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.KnnFloatVectorQuery;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.TotalHits;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.index.codec.vectors.BQVectorUtils;
+
+import java.io.IOException;
+import java.util.Locale;
+
+import static java.lang.String.format;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
+
+ static {
+ LogConfigurator.loadLog4jPlugins();
+ LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+ }
+
+ @Override
+ protected Codec getCodec() {
+ return new Lucene100Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new ES818BinaryQuantizedVectorsFormat();
+ }
+ };
+ }
+
+ public void testSearch() throws Exception {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ IndexWriterConfig iwc = newIndexWriterConfig();
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, iwc)) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ }
+ w.commit();
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ final int k = random().nextInt(5, 50);
+ float[] queryVector = randomVector(dims);
+ Query q = new KnnFloatVectorQuery(fieldName, queryVector, k);
+ TopDocs collectedDocs = searcher.search(q, k);
+ assertEquals(k, collectedDocs.totalHits.value());
+ assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation());
+ }
+ }
+ }
+ }
+
+ public void testToString() {
+ FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new ES818BinaryQuantizedVectorsFormat();
+ }
+ };
+ String expectedPattern = "ES818BinaryQuantizedVectorsFormat("
+ + "name=ES818BinaryQuantizedVectorsFormat, "
+ + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))";
+ var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ @Override
+ public void testRandomWithUpdatesAndGraph() {
+ // graph not supported
+ }
+
+ @Override
+ public void testSearchWithVisitedLimit() {
+ // visited limit is not respected, as it is brute force search
+ }
+
+ public void testQuantizedVectorsWriteAndRead() throws IOException {
+ String fieldName = "field";
+ int numVectors = random().nextInt(99, 500);
+ int dims = random().nextInt(4, 65);
+
+ float[] vector = randomVector(dims);
+ VectorSimilarityFunction similarityFunction = randomSimilarity();
+ KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction);
+ try (Directory dir = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ for (int i = 0; i < numVectors; i++) {
+ Document doc = new Document();
+ knnField.setVectorValue(randomVector(dims));
+ doc.add(knnField);
+ w.addDocument(doc);
+ if (i % 101 == 0) {
+ w.commit();
+ }
+ }
+ w.commit();
+ w.forceMerge(1);
+
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
+ assertEquals(vectorValues.size(), numVectors);
+ BinarizedByteVectorValues qvectorValues = ((ES818BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues)
+ .getQuantizedVectorValues();
+ float[] centroid = qvectorValues.getCentroid();
+ assertEquals(centroid.length, dims);
+
+ OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
+ byte[] quantizedVector = new byte[dims];
+ byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8];
+ if (similarityFunction == VectorSimilarityFunction.COSINE) {
+ vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues);
+ }
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ OptimizedScalarQuantizer.QuantizationResult corrections = quantizer.scalarQuantize(
+ vectorValues.vectorValue(docIndexIterator.index()),
+ quantizedVector,
+ (byte) 1,
+ centroid
+ );
+ BQVectorUtils.packAsBinary(quantizedVector, expectedVector);
+ assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index()));
+ assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index()));
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java
new file mode 100644
index 0000000000000..b6ae3199bb896
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java
@@ -0,0 +1,132 @@
+/*
+ * @notice
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * Modifications copyright (C) 2024 Elasticsearch B.V.
+ */
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.codecs.FilterCodec;
+import org.apache.lucene.codecs.KnnVectorsFormat;
+import org.apache.lucene.codecs.lucene100.Lucene100Codec;
+import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.KnnFloatVectorField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.FloatVectorValues;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.KnnVectorValues;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
+import org.apache.lucene.util.SameThreadExecutorService;
+import org.elasticsearch.common.logging.LogConfigurator;
+
+import java.util.Arrays;
+import java.util.Locale;
+
+import static java.lang.String.format;
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+public class ES818HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
+
+ static {
+ LogConfigurator.loadLog4jPlugins();
+ LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+ }
+
+ @Override
+ protected Codec getCodec() {
+ return new Lucene100Codec() {
+ @Override
+ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
+ return new ES818HnswBinaryQuantizedVectorsFormat();
+ }
+ };
+ }
+
+ public void testToString() {
+ FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) {
+ @Override
+ public KnnVectorsFormat knnVectorsFormat() {
+ return new ES818HnswBinaryQuantizedVectorsFormat(10, 20, 1, null);
+ }
+ };
+ String expectedPattern =
+ "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20,"
+ + " flatVectorFormat=ES818BinaryQuantizedVectorsFormat(name=ES818BinaryQuantizedVectorsFormat,"
+ + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))";
+
+ var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
+ var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
+ assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
+ }
+
+ public void testSingleVectorCase() throws Exception {
+ float[] vector = randomVector(random().nextInt(12, 500));
+ for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
+ try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
+ Document doc = new Document();
+ doc.add(new KnnFloatVectorField("f", vector, similarityFunction));
+ w.addDocument(doc);
+ w.commit();
+ try (IndexReader reader = DirectoryReader.open(w)) {
+ LeafReader r = getOnlyLeafReader(reader);
+ FloatVectorValues vectorValues = r.getFloatVectorValues("f");
+ KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
+ assert (vectorValues.size() == 1);
+ while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
+ assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f);
+ }
+ float[] randomVector = randomVector(vector.length);
+ float trueScore = similarityFunction.compare(vector, randomVector);
+ TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE);
+ assertEquals(1, td.totalHits.value());
+ assertTrue(td.scoreDocs[0].score >= 0);
+ // When it's the only vector in a segment, the score should be very close to the true score
+ assertEquals(trueScore, td.scoreDocs[0].score, 0.0001f);
+ }
+ }
+ }
+ }
+
+ public void testLimits() {
+ expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(-1, 20));
+ expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(0, 20));
+ expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 0));
+ expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, -1));
+ expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(512 + 1, 20));
+ expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 3201));
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService())
+ );
+ }
+
+ // Ensures that all expected vector similarity functions are translatable in the format.
+ public void testVectorSimilarityFuncs() {
+ // This does not necessarily have to be all similarity functions, but
+ // differences should be considered carefully.
+ var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
+ assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues);
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java
new file mode 100644
index 0000000000000..e3e2d6caafe0e
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java
@@ -0,0 +1,136 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.index.codec.vectors.es818;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.test.ESTestCase;
+
+import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID;
+
+public class OptimizedScalarQuantizerTests extends ESTestCase {
+
+ static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
+
+ public void testAbusiveEdgeCases() {
+ // large zero array
+ for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ continue;
+ }
+ float[] vector = new float[4096];
+ float[] centroid = new float[4096];
+ OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
+ byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096];
+ OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
+ assertEquals(MINIMUM_MSE_GRID.length, results.length);
+ assertValidResults(results);
+ for (byte[] destination : destinations) {
+ assertArrayEquals(new byte[4096], destination);
+ }
+ byte[] destination = new byte[4096];
+ for (byte bit : ALL_BITS) {
+ OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
+ assertValidResults(result);
+ assertArrayEquals(new byte[4096], destination);
+ }
+ }
+
+ // single value array
+ for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
+ float[] vector = new float[] { randomFloat() };
+ float[] centroid = new float[] { randomFloat() };
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(vector);
+ VectorUtil.l2normalize(centroid);
+ }
+ OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
+ byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1];
+ OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid);
+ assertEquals(MINIMUM_MSE_GRID.length, results.length);
+ assertValidResults(results);
+ for (int i = 0; i < ALL_BITS.length; i++) {
+ assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
+ }
+ for (byte bit : ALL_BITS) {
+ vector = new float[] { randomFloat() };
+ centroid = new float[] { randomFloat() };
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(vector);
+ VectorUtil.l2normalize(centroid);
+ }
+ byte[] destination = new byte[1];
+ OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid);
+ assertValidResults(result);
+ assertValidQuantizedRange(destination, bit);
+ }
+ }
+
+ }
+
+ public void testMathematicalConsistency() {
+ int dims = randomIntBetween(1, 4096);
+ float[] vector = new float[dims];
+ for (int i = 0; i < dims; ++i) {
+ vector[i] = randomFloat();
+ }
+ float[] centroid = new float[dims];
+ for (int i = 0; i < dims; ++i) {
+ centroid[i] = randomFloat();
+ }
+ float[] copy = new float[dims];
+ for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) {
+ // copy the vector to avoid modifying it
+ System.arraycopy(vector, 0, copy, 0, dims);
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(copy);
+ VectorUtil.l2normalize(centroid);
+ }
+ OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction);
+ byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims];
+ OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid);
+ assertEquals(MINIMUM_MSE_GRID.length, results.length);
+ assertValidResults(results);
+ for (int i = 0; i < ALL_BITS.length; i++) {
+ assertValidQuantizedRange(destinations[i], ALL_BITS[i]);
+ }
+ for (byte bit : ALL_BITS) {
+ byte[] destination = new byte[dims];
+ System.arraycopy(vector, 0, copy, 0, dims);
+ if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
+ VectorUtil.l2normalize(copy);
+ VectorUtil.l2normalize(centroid);
+ }
+ OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(copy, destination, bit, centroid);
+ assertValidResults(result);
+ assertValidQuantizedRange(destination, bit);
+ }
+ }
+ }
+
+ static void assertValidQuantizedRange(byte[] quantized, byte bits) {
+ for (byte b : quantized) {
+ if (bits < 8) {
+ assertTrue(b >= 0);
+ }
+ assertTrue(b < 1 << bits);
+ }
+ }
+
+ static void assertValidResults(OptimizedScalarQuantizer.QuantizationResult... results) {
+ for (OptimizedScalarQuantizer.QuantizationResult result : results) {
+ assertTrue(Float.isFinite(result.lowerInterval()));
+ assertTrue(Float.isFinite(result.upperInterval()));
+ assertTrue(result.lowerInterval() <= result.upperInterval());
+ assertTrue(Float.isFinite(result.additionalCorrection()));
+ assertTrue(result.quantizedComponentSum() >= 0);
+ }
+ }
+}
diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java
index de084cd4582e2..c043b9ffb381a 100644
--- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java
+++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java
@@ -1970,13 +1970,13 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException {
assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class));
knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field");
}
- String expectedString = "ES816HnswBinaryQuantizedVectorsFormat(name=ES816HnswBinaryQuantizedVectorsFormat, maxConn="
+ String expectedString = "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn="
+ m
+ ", beamWidth="
+ efConstruction
- + ", flatVectorFormat=ES816BinaryQuantizedVectorsFormat("
- + "name=ES816BinaryQuantizedVectorsFormat, "
- + "flatVectorScorer=ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))";
+ + ", flatVectorFormat=ES818BinaryQuantizedVectorsFormat("
+ + "name=ES818BinaryQuantizedVectorsFormat, "
+ + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))";
assertEquals(expectedString, knnVectorsFormat.toString());
}