diff --git a/docs/changelog/117575.yaml b/docs/changelog/117575.yaml new file mode 100644 index 0000000000000..781444ae97be5 --- /dev/null +++ b/docs/changelog/117575.yaml @@ -0,0 +1,5 @@ +pr: 117575 +summary: Fix enrich cache size setting name +area: Ingest Node +type: bug +issues: [] diff --git a/docs/changelog/117994.yaml b/docs/changelog/117994.yaml new file mode 100644 index 0000000000000..603f2d855a11a --- /dev/null +++ b/docs/changelog/117994.yaml @@ -0,0 +1,5 @@ +pr: 117994 +summary: Even better(er) binary quantization +area: Vector Search +type: enhancement +issues: [] diff --git a/docs/changelog/118192.yaml b/docs/changelog/118192.yaml new file mode 100644 index 0000000000000..03542048761d3 --- /dev/null +++ b/docs/changelog/118192.yaml @@ -0,0 +1,11 @@ +pr: 118192 +summary: Remove `client.type` setting +area: Infra/Core +type: breaking +issues: [104574] +breaking: + title: Remove `client.type` setting + area: Cluster and node setting + details: The node setting `client.type` has been ignored since the node client was removed in 8.0. The setting is now removed. + impact: Remove the `client.type` setting from `elasticsearch.yml` + notable: false diff --git a/docs/reference/migration/migrate_9_0/transforms-migration-guide.asciidoc b/docs/reference/migration/migrate_9_0/transforms-migration-guide.asciidoc new file mode 100644 index 0000000000000..d41c524d68d5c --- /dev/null +++ b/docs/reference/migration/migrate_9_0/transforms-migration-guide.asciidoc @@ -0,0 +1,9 @@ +[[transforms-migration-guide]] +== {transforms-cap} migration guide +This migration guide helps you upgrade your {transforms} to work with the 9.0 release. Each section outlines a breaking change and any manual steps needed to upgrade your {transforms} to be compatible with 9.0. + + +=== Updating deprecated {transform} roles (`data_frame_transforms_admin` and `data_frame_transforms_user`) +If you have existing {transforms} that use deprecated {transform} roles (`data_frame_transforms_admin` or `data_frame_transforms_user`) you must update them to use the new equivalent {transform} roles (`transform_admin` or `transform_user`). To update your {transform} roles: +1. Switch to a user with the `transform_admin` role (to replace `data_frame_transforms_admin`) or the `transform_user` role (to replace `data_frame_transforms_user`). +2. Call the <> with that user. diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java index 73936d82fc204..08bdc2051b9e3 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java @@ -144,7 +144,7 @@ public DeleteResult delete(OperationPurpose purpose) throws IOException { @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, new Iterator<>() { + blobStore.deleteBlobs(purpose, new Iterator<>() { @Override public boolean hasNext() { return blobNames.hasNext(); diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java index e4f973fb73a4e..3cac0dc4bb6db 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java @@ -264,8 +264,7 @@ public DeleteResult deleteBlobDirectory(OperationPurpose purpose, String path) t return new DeleteResult(blobsDeleted.get(), bytesDeleted.get()); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(OperationPurpose purpose, Iterator blobNames) { if (blobNames.hasNext() == false) { return; } diff --git a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java index f6e97187222e7..8979507230bdd 100644 --- a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java +++ b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java @@ -72,7 +72,7 @@ public void testRetriesAndOperationsAreTrackedSeparately() throws IOException { false ); case LIST_BLOBS -> blobStore.listBlobsByPrefix(purpose, randomIdentifier(), randomIdentifier()); - case BLOB_BATCH -> blobStore.deleteBlobsIgnoringIfNotExists( + case BLOB_BATCH -> blobStore.deleteBlobs( purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator() ); @@ -113,7 +113,7 @@ public void testOperationPurposeIsReflectedInBlobStoreStats() throws IOException os.flush(); }); // BLOB_BATCH - blobStore.deleteBlobsIgnoringIfNotExists(purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator()); + blobStore.deleteBlobs(purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator()); Map stats = blobStore.stats(); String statsMapString = stats.toString(); @@ -148,10 +148,7 @@ public void testOperationPurposeIsNotReflectedInBlobStoreStatsWhenNotServerless( os.flush(); }); // BLOB_BATCH - blobStore.deleteBlobsIgnoringIfNotExists( - purpose, - List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator() - ); + blobStore.deleteBlobs(purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator()); } Map stats = blobStore.stats(); diff --git a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java index 047549cc893ed..edcf03580da09 100644 --- a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java +++ b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java @@ -114,12 +114,12 @@ public void writeBlobAtomic(OperationPurpose purpose, String blobName, BytesRefe @Override public DeleteResult delete(OperationPurpose purpose) throws IOException { - return blobStore.deleteDirectory(purpose, path().buildAsString()); + return blobStore.deleteDirectory(path().buildAsString()); } @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, new Iterator<>() { + blobStore.deleteBlobs(new Iterator<>() { @Override public boolean hasNext() { return blobNames.hasNext(); diff --git a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java index 9cbf64e7e0146..c68217a1a3738 100644 --- a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java +++ b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java @@ -29,7 +29,6 @@ import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.blobstore.BlobStoreActionStats; import org.elasticsearch.common.blobstore.DeleteResult; -import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.common.blobstore.OptionalBytesReference; import org.elasticsearch.common.blobstore.support.BlobContainerUtils; import org.elasticsearch.common.blobstore.support.BlobMetadata; @@ -491,10 +490,9 @@ private void writeBlobMultipart(BlobInfo blobInfo, byte[] buffer, int offset, in /** * Deletes the given path and all its children. * - * @param purpose The purpose of the delete operation * @param pathStr Name of path to delete */ - DeleteResult deleteDirectory(OperationPurpose purpose, String pathStr) throws IOException { + DeleteResult deleteDirectory(String pathStr) throws IOException { return SocketAccess.doPrivilegedIOException(() -> { DeleteResult deleteResult = DeleteResult.ZERO; Page page = client().list(bucketName, BlobListOption.prefix(pathStr)); @@ -502,7 +500,7 @@ DeleteResult deleteDirectory(OperationPurpose purpose, String pathStr) throws IO final AtomicLong blobsDeleted = new AtomicLong(0L); final AtomicLong bytesDeleted = new AtomicLong(0L); final Iterator blobs = page.getValues().iterator(); - deleteBlobsIgnoringIfNotExists(purpose, new Iterator<>() { + deleteBlobs(new Iterator<>() { @Override public boolean hasNext() { return blobs.hasNext(); @@ -526,11 +524,9 @@ public String next() { /** * Deletes multiple blobs from the specific bucket using a batch request * - * @param purpose the purpose of the delete operation * @param blobNames names of the blobs to delete */ - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(Iterator blobNames) throws IOException { if (blobNames.hasNext() == false) { return; } diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java index e13cc40dd3e0f..f527dcd42814c 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java @@ -342,10 +342,10 @@ public DeleteResult delete(OperationPurpose purpose) throws IOException { return summary.getKey(); }); if (list.isTruncated()) { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, blobNameIterator); + blobStore.deleteBlobs(purpose, blobNameIterator); prevListing = list; } else { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, Iterators.concat(blobNameIterator, Iterators.single(keyPath))); + blobStore.deleteBlobs(purpose, Iterators.concat(blobNameIterator, Iterators.single(keyPath))); break; } } @@ -357,7 +357,7 @@ public DeleteResult delete(OperationPurpose purpose) throws IOException { @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, Iterators.map(blobNames, this::buildKey)); + blobStore.deleteBlobs(purpose, Iterators.map(blobNames, this::buildKey)); } @Override diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java index 4f2b0f213e448..4bd54aa37077f 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java @@ -340,8 +340,7 @@ public BlobContainer blobContainer(BlobPath path) { return new S3BlobContainer(path, this); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(OperationPurpose purpose, Iterator blobNames) throws IOException { if (blobNames.hasNext() == false) { return; } diff --git a/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java b/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java index 8538d2ba673bc..0e9c735b22fd6 100644 --- a/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java +++ b/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.blobstore.BlobStoreException; -import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.common.blobstore.url.http.HttpURLBlobContainer; import org.elasticsearch.common.blobstore.url.http.URLHttpClient; import org.elasticsearch.common.blobstore.url.http.URLHttpClientSettings; @@ -23,10 +22,8 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.CheckedFunction; -import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; -import java.util.Iterator; import java.util.List; /** @@ -109,11 +106,6 @@ public BlobContainer blobContainer(BlobPath blobPath) { } } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - throw new UnsupportedOperationException("Bulk deletes are not supported in URL repositories"); - } - @Override public void close() { // nothing to do here... diff --git a/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java b/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java index eaf2429ae6258..e817384d95c04 100644 --- a/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java +++ b/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java @@ -16,10 +16,8 @@ import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; -import org.elasticsearch.common.blobstore.OperationPurpose; import java.io.IOException; -import java.util.Iterator; final class HdfsBlobStore implements BlobStore { @@ -72,11 +70,6 @@ public BlobContainer blobContainer(BlobPath path) { return new HdfsBlobContainer(path, this, buildHdfsPath(path), bufferSize, securityContext, replicationFactor); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - throw new UnsupportedOperationException("Bulk deletes are not supported in Hdfs repositories"); - } - private Path buildHdfsPath(BlobPath blobPath) { final Path path = translateToHdfsPath(blobPath); if (readOnly == false) { diff --git a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java index 17927b02a08dc..3e1c112a4d9f7 100644 --- a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java +++ b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java @@ -46,11 +46,6 @@ public void testSnapshotAndRestore() throws Exception { testSnapshotAndRestore(false); } - @Override - public void testBlobStoreBulkDeletion() throws Exception { - // HDFS does not implement bulk deletion from different BlobContainers - } - @Override protected Collection> nodePlugins() { return Collections.singletonList(HdfsPlugin.class); diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index e2af894eb0939..7347d9c1312dd 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -67,4 +67,6 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task -> task.skipTest("logsdb/20_source_mapping/include/exclude is supported with stored _source", "no longer serialize source_mode") task.skipTest("logsdb/20_source_mapping/synthetic _source is default", "no longer serialize source_mode") task.skipTest("search/520_fetch_fields/fetch _seq_no via fields", "error code is changed from 5xx to 400 in 9.0") + task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions") + task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions") }) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 188c155e4a836..5767c895fbe7e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -11,20 +11,11 @@ setup: number_of_shards: 1 mappings: properties: - name: - type: keyword vector: type: dense_vector dims: 64 index: true - similarity: l2_norm - index_options: - type: bbq_hnsw - another_vector: - type: dense_vector - dims: 64 - index: true - similarity: l2_norm + similarity: max_inner_product index_options: type: bbq_hnsw @@ -33,9 +24,14 @@ setup: index: bbq_hnsw id: "1" body: - name: cow.jpg - vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] - another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] # Flush in order to provoke a merge later - do: indices.flush: @@ -46,9 +42,14 @@ setup: index: bbq_hnsw id: "2" body: - name: moose.jpg - vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] - another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] # Flush in order to provoke a merge later - do: indices.flush: @@ -60,8 +61,14 @@ setup: id: "3" body: name: rabbit.jpg - vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] - another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] # Flush in order to provoke a merge later - do: indices.flush: @@ -73,20 +80,33 @@ setup: max_num_segments: 1 --- "Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" - do: search: index: bbq_hnsw body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 - # Depending on how things are distributed, docs 2 and 3 might be swapped - # here we verify that are last hit is always the worst one - - match: { hits.hits.2._id: "1" } - + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } --- "Test bad quantization parameters": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index ed7a8dd5df65d..dcdae04aeabb4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -11,20 +11,11 @@ setup: number_of_shards: 1 mappings: properties: - name: - type: keyword vector: type: dense_vector dims: 64 index: true - similarity: l2_norm - index_options: - type: bbq_flat - another_vector: - type: dense_vector - dims: 64 - index: true - similarity: l2_norm + similarity: max_inner_product index_options: type: bbq_flat @@ -33,9 +24,14 @@ setup: index: bbq_flat id: "1" body: - name: cow.jpg - vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] - another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] # Flush in order to provoke a merge later - do: indices.flush: @@ -46,9 +42,14 @@ setup: index: bbq_flat id: "2" body: - name: moose.jpg - vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] - another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] # Flush in order to provoke a merge later - do: indices.flush: @@ -59,9 +60,14 @@ setup: index: bbq_flat id: "3" body: - name: rabbit.jpg - vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] - another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] # Flush in order to provoke a merge later - do: indices.flush: @@ -73,19 +79,33 @@ setup: max_num_segments: 1 --- "Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" - do: search: index: bbq_flat body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 - # Depending on how things are distributed, docs 2 and 3 might be swapped - # here we verify that are last hit is always the worst one - - match: { hits.hits.2._id: "1" } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } --- "Test bad parameters": - do: diff --git a/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java b/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java index c0a2c83f7fe1e..b2e02b2f4c271 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java @@ -36,7 +36,6 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.Collection; -import java.util.Iterator; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -136,11 +135,6 @@ public BlobContainer blobContainer(BlobPath path) { return new AssertingBlobContainer(delegateBlobStore.blobContainer(path)); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - delegateBlobStore.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegateBlobStore.close(); diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 331a2bc0dddac..ff902dbede007 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -459,7 +459,9 @@ org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat, org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat, org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/client/internal/Client.java b/server/src/main/java/org/elasticsearch/client/internal/Client.java index 4158bbfb27cda..2d1cbe0cce7f7 100644 --- a/server/src/main/java/org/elasticsearch/client/internal/Client.java +++ b/server/src/main/java/org/elasticsearch/client/internal/Client.java @@ -52,8 +52,6 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.action.update.UpdateResponse; -import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.transport.RemoteClusterService; @@ -74,14 +72,6 @@ */ public interface Client extends ElasticsearchClient { - // Note: This setting is registered only for bwc. The value is never read. - Setting CLIENT_TYPE_SETTING_S = new Setting<>("client.type", "node", (s) -> { - return switch (s) { - case "node", "transport" -> s; - default -> throw new IllegalArgumentException("Can't parse [client.type] must be one of [node, transport]"); - }; - }, Property.NodeScope, Property.Deprecated); - /** * The admin client that can be used to perform administrative operations. */ diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java b/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java index d67c034fd3e27..f1fe028f60f6e 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java @@ -9,9 +9,7 @@ package org.elasticsearch.common.blobstore; import java.io.Closeable; -import java.io.IOException; import java.util.Collections; -import java.util.Iterator; import java.util.Map; /** @@ -28,14 +26,6 @@ public interface BlobStore extends Closeable { */ BlobContainer blobContainer(BlobPath path); - /** - * Delete all the provided blobs from the blob store. Each blob could belong to a different {@code BlobContainer} - * - * @param purpose the purpose of the delete operation - * @param blobNames the blobs to be deleted - */ - void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException; - /** * Returns statistics on the count of operations that have been performed on this blob store */ diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java index b5118d8a289a9..7d40008231292 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java @@ -177,7 +177,7 @@ public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IO @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, Iterators.map(blobNames, blobName -> path.resolve(blobName).toString())); + blobStore.deleteBlobs(Iterators.map(blobNames, blobName -> path.resolve(blobName).toString())); } @Override diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java index 53e3b4b4796dc..9a368483d46c0 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; -import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.core.IOUtils; import java.io.IOException; @@ -70,8 +69,7 @@ public BlobContainer blobContainer(BlobPath path) { return new FsBlobContainer(this, path, f); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(Iterator blobNames) throws IOException { IOException ioe = null; long suppressedExceptions = 0; while (blobNames.hasNext()) { diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java index a9a9411de8e1f..16af7ca2915d4 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.bootstrap.BootstrapSettings; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.InternalClusterInfoService; @@ -483,7 +482,6 @@ public void apply(Settings value, Settings current, Settings previous) { AutoCreateIndex.AUTO_CREATE_INDEX_SETTING, BaseRestHandler.MULTI_ALLOW_EXPLICIT_INDEX, ClusterName.CLUSTER_NAME_SETTING, - Client.CLIENT_TYPE_SETTING_S, ClusterModule.SHARDS_ALLOCATOR_TYPE_SETTING, EsExecutors.NODE_PROCESSORS_SETTING, ThreadContext.DEFAULT_HEADERS_SETTING, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java index 5201e57179cc7..1aff06a175967 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java @@ -40,6 +40,20 @@ public static boolean isUnitVector(float[] v) { return Math.abs(l1norm - 1.0d) <= EPSILON; } + public static void packAsBinary(byte[] vector, byte[] packed) { + for (int i = 0; i < vector.length;) { + byte result = 0; + for (int j = 7; j >= 0 && i < vector.length; j--) { + assert vector[i] == 0 || vector[i] == 1; + result |= (byte) ((vector[i] & 1) << j); + ++i; + } + int index = ((i + 7) / 8) - 1; + assert index < packed.length; + packed[index] = result; + } + } + public static int discretize(int value, int bucket) { return ((value + (bucket - 1)) / bucket) * bucket; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java index 445bdadab2354..e85079e998c61 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java @@ -48,13 +48,8 @@ class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer { public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues - ) throws IOException { - if (vectorValues instanceof BinarizedByteVectorValues) { - throw new UnsupportedOperationException( - "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" - ); - } - return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + ) { + throw new UnsupportedOperationException(); } @Override @@ -90,61 +85,11 @@ public RandomVectorScorer getRandomVectorScorer( return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } - RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, - ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, - BinarizedByteVectorValues targetVectors - ) { - return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); - } - @Override public String toString() { return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; } - /** Vector scorer supplier over binarized vector values */ - static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; - private final BinarizedByteVectorValues targetVectors; - private final VectorSimilarityFunction similarityFunction; - - BinarizedRandomVectorScorerSupplier( - ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, - BinarizedByteVectorValues targetVectors, - VectorSimilarityFunction similarityFunction - ) { - this.queryVectors = queryVectors; - this.targetVectors = targetVectors; - this.similarityFunction = similarityFunction; - } - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - byte[] vector = queryVectors.vectorValue(ord); - int quantizedSum = queryVectors.sumQuantizedValues(ord); - float distanceToCentroid = queryVectors.getCentroidDistance(ord); - float lower = queryVectors.getLower(ord); - float width = queryVectors.getWidth(ord); - float normVmC = 0f; - float vDotC = 0f; - if (similarityFunction != EUCLIDEAN) { - normVmC = queryVectors.getNormVmC(ord); - vDotC = queryVectors.getVDotC(ord); - } - BinaryQueryVector binaryQueryVector = new BinaryQueryVector( - vector, - new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC) - ); - return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); - } - } - /** A binarized query representing its quantized form along with factors */ record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java index d864ec5dee8c5..61b6edc474d1f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java @@ -62,7 +62,7 @@ public ES816BinaryQuantizedVectorsFormat() { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + throw new UnsupportedOperationException(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java index 52f9f14b7bf97..1dbb4e432b188 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java @@ -25,10 +25,8 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.hnsw.HnswGraph; import java.io.IOException; @@ -52,21 +50,18 @@ public class ES816HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. */ - private final int maxConn; + protected final int maxConn; /** * The number of candidate neighbors to track while searching the graph for each newly inserted * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} * for details. */ - private final int beamWidth; + protected final int beamWidth; /** The format for storing, reading, merging vectors on disk */ private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedVectorsFormat(); - private final int numMergeWorkers; - private final TaskExecutor mergeExec; - /** Constructs a format using default graph construction parameters */ public ES816HnswBinaryQuantizedVectorsFormat() { this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); @@ -109,17 +104,11 @@ public ES816HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int num if (numMergeWorkers == 1 && mergeExec != null) { throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); } - this.numMergeWorkers = numMergeWorkers; - if (mergeExec != null) { - this.mergeExec = new TaskExecutor(mergeExec); - } else { - this.mergeExec = null; - } } @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + throw new UnsupportedOperationException(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java new file mode 100644 index 0000000000000..cc1f7b85e0f78 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java @@ -0,0 +1,87 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +abstract class BinarizedByteVectorValues extends ByteVectorValues { + + /** + * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of + * distances, the corrective terms are, in order + * + *
    + *
  • the lower optimized interval + *
  • the upper optimized interval + *
  • the dot-product of the non-centered vector with the centroid + *
  • the sum of quantized components + *
+ * + * For euclidean: + * + *
    + *
  • the lower optimized interval + *
  • the upper optimized interval + *
  • the l2norm of the centered vector + *
  • the sum of quantized components + *
+ * + * @param vectorOrd the vector ordinal + * @return the corrective terms + * @throws IOException if an I/O error occurs + */ + public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract OptimizedScalarQuantizer getQuantizer(); + + public abstract float[] getCentroid() throws IOException; + + int discretizedDimensions() { + return BQVectorUtils.discretize(dimension(), 64); + } + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ + public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract BinarizedByteVectorValues copy() throws IOException; + + float getCentroidDP() throws IOException { + // this only gets executed on-merge + float[] centroid = getCentroid(); + return VectorUtil.dotProduct(centroid, centroid); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java new file mode 100644 index 0000000000000..7c7e470909eb3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -0,0 +1,188 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Vector scorer over binarized vector values */ +public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + public ES818BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" + ); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; + byte[] initial = new byte[target.length]; + byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; + OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid); + BQSpaceUtils.transposeHalfByte(initial, quantized); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors + ) { + return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + OptimizedScalarQuantizer.QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord); + BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + public record BinaryQueryVector(byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {} + + /** Vector scorer over binarized vector values */ + public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + public BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + OptimizedScalarQuantizer.QuantizationResult queryCorrections = queryVector.quantizationResult(); + OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd); + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = indexCorrections.upperInterval() - ax; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - targetVectors.getCentroidDP(); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..1dee9599f985f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java @@ -0,0 +1,132 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + * Codec for encoding/decoding binary quantized vectors The binary quantization format used here + * is a per-vector optimized scalar quantization. Also see {@link + * org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are: + * + *
    + *
  • Estimating the distance between two vectors using their centroid normalized distance. This + * requires some additional corrective factors, but allows for centroid normalization to occur. + *
  • Optimized scalar quantization to bit level of centroid normalized vectors. + *
  • Asymmetric quantization of vectors, where query vectors are quantized to half-byte + * precision (normalized to the centroid) and then compared directly against the single bit + * quantized vectors in the index. + *
  • Transforming the half-byte quantized query vectors in such a way that the comparison with + * single bit vectors can be done with bit arithmetic. + *
+ * + * The format is stored in two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

    + *
  • For each vector: + *
      + *
    • [byte] the binary quantized values, each byte holds 8 bits. + *
    • [float] the optimized quantiles and an additional similarity dependent corrective factor. + *
    • short the sum of the quantized components
    • + *
    + *
  • After the vectors, sparse vector information keeping track of monotonic blocks. + *
+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

    + *
  • int the field number + *
  • int the vector encoding ordinal + *
  • int the vector similarity ordinal + *
  • vint the vector dimensions + *
  • vlong the offset to the vector data in the .veb file + *
  • vlong the length of the vector data in the .veb file + *
  • vint the number of vectors + *
  • [float] the centroid
  • + *
  • float the centroid square magnitude
  • + *
  • The sparse vector information, if required, mapping vector ordinal to doc ID + *
+ */ +public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat { + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "ES818BinaryQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES818BinaryQuantizedVectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES818BinaryQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java new file mode 100644 index 0000000000000..8036b8314cdc1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -0,0 +1,412 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.SuppressForbidden; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES818BinaryQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final ES818BinaryFlatVectorsScorer vectorScorer; + + ES818BinaryQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + ES818BinaryFlatVectorsScorer vectorsScorer + ) throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.META_EXTENSION + ); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_START, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = openDataInput( + state, + versionMeta, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withReadAdvice(ReadAdvice.RANDOM) + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " + fieldEntry.dimension + ); + } + + int binaryDims = BQVectorUtils.discretize(dimension, 64) / 8; + long numQuantizedVectorBytes = Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "Binarized vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (binaryBytes=" + + binaryDims + + " + 14" + + ") = " + + numQuantizedVectorBytes + ); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " + VectorEncoding.FLOAT32 + ); + } + OffHeapBinarizedVectorValues bvv = OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ); + return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfMap(fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES818BinaryQuantizedVectorsFormat.VERSION_START, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + int descritizedDimension, + long vectorDataOffset, + long vectorDataLength, + int size, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration + ) { + + static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + if (size > 0) { + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + BQVectorUtils.discretize(dimension, 64), + vectorDataOffset, + vectorDataLength, + size, + centroid, + centroidDP, + conf + ); + } + } + + /** Binarized vector values holding row and quantized vector values */ + protected static final class BinarizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final BinarizedByteVectorValues quantizedVectorValues; + + BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public BinarizedVectorValues copy() throws IOException { + return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { + return quantizedVectorValues; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java new file mode 100644 index 0000000000000..02dda6a4a9da1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -0,0 +1,944 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT; +import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(ES818BinaryQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, binarizedVectorData; + private final FlatVectorsWriter rawVectorDelegate; + private final ES818BinaryFlatVectorsScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + protected ES818BinaryQuantizedVectorsWriter( + ES818BinaryFlatVectorsScorer vectorsScorer, + FlatVectorsWriter rawVectorDelegate, + SegmentWriteState state + ) throws IOException { + super(vectorsScorer); + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES818BinaryQuantizedVectorsFormat.META_EXTENSION + ); + + String binarizedVectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION + ); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + binarizedVectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + binarizedVectorData, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField(FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeBinarizedVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet() + ); + } + + private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, quantizedVectorLength, clusterCenter, centroidDp, newDocsWithField); + } + + private void writeSortedBinarizedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField + ) throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + final ByteBuffer buffer = ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, + meta, + binarizedVectorData, + count, + maxDoc, + docsWithField + ); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (binarizedVectorData != null) { + CodecUtil.writeFooter(binarizedVectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + BinarizedFloatVectorValues binarizedVectorValues = new BinarizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + centroid + ); + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField + ); + } else { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + } + } + + static DocsWithFieldSet writeBinarizedVectorAndQueryData( + IndexOutput binarizedVectorData, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + float[] centroid, + OptimizedScalarQuantizer binaryQuantizer + ) throws IOException { + int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64); + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()]; + byte[] toIndex = new byte[discretizedDimension / 8]; + byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY]; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + OptimizedScalarQuantizer.QuantizationResult[] r = binaryQuantizer.multiScalarQuantize( + floatVectorValues.vectorValue(iterator.index()), + quantizationScratch, + new byte[] { 1, 4 }, + centroid + ); + // pack and store document bit vector + BQVectorUtils.packAsBinary(quantizationScratch[0], toIndex); + binarizedVectorData.writeBytes(toIndex, toIndex.length); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection())); + assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) r[0].quantizedComponentSum()); + docsWithField.add(docV); + + // pack and store the 4bit query vector + BQSpaceUtils.transposeHalfByte(quantizationScratch[1], toQuery); + binarizedQueryData.writeBytes(toQuery, toQuery.length); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection())); + assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff; + binarizedQueryData.writeShort((short) r[1].quantizedComponentSum()); + } + return docsWithField; + } + + static DocsWithFieldSet writeBinarizedVectorData(IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) + throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + output.writeShort((short) corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC + ) throws IOException { + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + final IndexOutput tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "temp", + segmentWriteState.context + ); + final IndexOutput tempScoreQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "score_temp", + segmentWriteState.context + ); + IndexInput binarizedDataInput = null; + IndexInput binarizedScoreDataInput = null; + boolean success = false; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + try { + FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = writeBinarizedVectorAndQueryData( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + floatVectorValues, + centroid, + quantizer + ); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + binarizedDataInput = segmentWriteState.directory.openInput(tempQuantizedVectorData.getName(), segmentWriteState.context); + binarizedVectorData.copyBytes(binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(binarizedDataInput); + CodecUtil.writeFooter(tempScoreQuantizedVectorData); + IOUtils.close(tempScoreQuantizedVectorData); + binarizedScoreDataInput = segmentWriteState.directory.openInput( + tempScoreQuantizedVectorData.getName(), + segmentWriteState.context + ); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField + ); + success = true; + final IndexInput finalBinarizedDataInput = binarizedDataInput; + final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput; + OffHeapBinarizedVectorValues vectorValues = new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalBinarizedDataInput + ); + RandomVectorScorerSupplier scorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapBinarizedQueryVectorValues( + finalBinarizedScoreDataInput, + fieldInfo.getVectorDimension(), + docsWithField.cardinality() + ), + vectorValues + ); + return new BinarizedCloseableRandomVectorScorerSupplier(scorerSupplier, vectorValues, () -> { + IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName() + ); + }); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + binarizedDataInput, + binarizedScoreDataInput + ); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName() + ); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, binarizedVectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof ES818BinaryQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids(MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + // TODO Panama sum + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + // TODO Panama div + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + // When accessing vectorValue method, targerOrd here means a row ordinal. + static class OffHeapBinarizedQueryVectorValues { + private final IndexInput slice; + private final int dimension; + private final int size; + protected final byte[] binaryValue; + protected final ByteBuffer byteBuffer; + private final int byteSize; + protected final float[] correctiveValues; + private int lastOrd = -1; + private int quantizedComponentSum; + + OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { + this.slice = data; + this.dimension = dimension; + this.size = size; + // 4x the quantized binary dimensions + int binaryDimensions = (BQVectorUtils.discretize(dimension, 64) / 8) * BQSpaceUtils.B_QUERY; + this.byteBuffer = ByteBuffer.allocate(binaryDimensions); + this.binaryValue = byteBuffer.array(); + // + 1 for the quantized sum + this.correctiveValues = new float[3]; + this.byteSize = binaryDimensions + Float.BYTES * 3 + Short.BYTES; + } + + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + vectorValue(targetOrd); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + public int size() { + return size; + } + + public int dimension() { + return dimension; + } + + public OffHeapBinarizedQueryVectorValues copy() throws IOException { + return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size); + } + + public IndexInput getSlice() { + return slice; + } + + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(binaryValue, 0, binaryValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + } + + static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] binarized; + private final byte[] initQuantized; + private final float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + + private int lastOrd = -1; + + BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8]; + this.initQuantized = new byte[delegate.dimension()]; + this.centroid = centroid; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd + ); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BinarizedByteVectorValues copy() throws IOException { + return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid); + } + + private void binarize(int ord) throws IOException { + corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, (byte) 1, centroid); + BQVectorUtils.packAsBinary(initQuantized, binarized); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier supplier; + private final KnnVectorValues vectorValues; + private final Closeable onClose; + + BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + this.supplier = supplier; + this.onClose = onClose; + this.vectorValues = vectorValues; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return supplier.scorer(ord); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return vectorValues.size(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..56942017c3cef --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java @@ -0,0 +1,145 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES818HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "ES818HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat(); + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public ES818HnswBinaryQuantizedVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java new file mode 100644 index 0000000000000..72333169b39b5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java @@ -0,0 +1,371 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** Binarized vector values loaded from off-heap */ +abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { + + final int dimension; + final int size; + final int numBytes; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final byte[] binaryValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final OptimizedScalarQuantizer binaryQuantizer; + final float[] centroid; + final float centroidDp; + private final int discretizedDimensions; + + OffHeapBinarizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.numBytes = BQVectorUtils.discretize(dimension, 64) / 8; + this.correctiveValues = new float[3]; + this.byteSize = numBytes + (Float.BYTES * 3) + Short.BYTES; + this.byteBuffer = ByteBuffer.allocate(numBytes); + this.binaryValue = byteBuffer.array(); + this.binaryQuantizer = quantizer; + this.discretizedDimensions = BQVectorUtils.discretize(dimension, 64); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + + @Override + public int discretizedDimensions() { + return discretizedDimensions; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return binaryQuantizer; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + + static OffHeapBinarizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = vectorData.slice("quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + bytesSlice + ); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice + ); + } + } + + /** Dense off-heap binarized vector values */ + static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) { + super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap binarized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) throws IOException { + super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues { + EmptyOffHeapVectorValues(int dimension, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer) { + super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java new file mode 100644 index 0000000000000..d5ed38cb5a0e1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java @@ -0,0 +1,246 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; + +class OptimizedScalarQuantizer { + // The initial interval is set to the minimum MSE grid for each number of bits + // these starting points are derived from the optimal MSE grid for a uniform distribution + static final float[][] MINIMUM_MSE_GRID = new float[][] { + { -0.798f, 0.798f }, + { -1.493f, 1.493f }, + { -2.051f, 2.051f }, + { -2.514f, 2.514f }, + { -2.916f, 2.916f }, + { -3.278f, 3.278f }, + { -3.611f, 3.611f }, + { -3.922f, 3.922f } }; + private static final float DEFAULT_LAMBDA = 0.1f; + private static final int DEFAULT_ITERS = 5; + private final VectorSimilarityFunction similarityFunction; + private final float lambda; + private final int iters; + + OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) { + this.similarityFunction = similarityFunction; + this.lambda = lambda; + this.iters = iters; + } + + OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) { + this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS); + } + + public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {} + + public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert bits.length == destinations.length; + float[] intervalScratch = new float[2]; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + QuantizationResult[] results = new QuantizationResult[bits.length]; + for (int i = 0; i < bits.length; ++i) { + assert bits[i] > 0 && bits[i] <= 8; + int points = (1 << bits[i]); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds + intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits[i]) - 1); + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + // Now we have the optimized intervals, quantize the vector + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destinations[i][h] = (byte) assignment; + } + results[i] = new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery + ); + } + return results; + } + + public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert vector.length <= destination.length; + assert bits > 0 && bits <= 8; + float[] intervalScratch = new float[2]; + int points = 1 << bits; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds + intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits) - 1); + // Now we have the optimized intervals, quantize the vector + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destination[h] = (byte) assignment; + } + return new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery + ); + } + + /** + * Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw + * vector. + * @param vector raw vector + * @param interval interval to quantize the vector + * @param points number of quantization points + * @param norm2 squared norm of the vector + * @return the loss + */ + private double loss(float[] vector, float[] interval, int points, float norm2) { + double a = interval[0]; + double b = interval[1]; + double step = ((b - a) / (points - 1.0F)); + double stepInv = 1.0 / step; + double xe = 0.0; + double e = 0.0; + for (double xi : vector) { + // this is quantizing and then dequantizing the vector + double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv)); + // how much does the de-quantized value differ from the original value + xe += xi * (xi - xiq); + e += (xi - xiq) * (xi - xiq); + } + return (1.0 - lambda) * xe * xe / norm2 + lambda * e; + } + + /** + * Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization + * loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the + * loss increases. + * @param initInterval initial interval, the optimized interval will be stored here + * @param vector raw vector + * @param norm2 squared norm of the vector + * @param points number of quantization points + */ + private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) { + double initialLoss = loss(vector, initInterval, points, norm2); + final float scale = (1.0f - lambda) / norm2; + if (Float.isFinite(scale) == false) { + return; + } + for (int i = 0; i < iters; ++i) { + float a = initInterval[0]; + float b = initInterval[1]; + float stepInv = (points - 1.0f) / (b - a); + // calculate the grid points for coordinate descent + double daa = 0; + double dab = 0; + double dbb = 0; + double dax = 0; + double dbx = 0; + for (float xi : vector) { + float k = Math.round((clamp(xi, a, b) - a) * stepInv); + float s = k / (points - 1); + daa += (1.0 - s) * (1.0 - s); + dab += (1.0 - s) * s; + dbb += s * s; + dax += xi * (1.0 - s); + dbx += xi * s; + } + double m0 = scale * dax * dax + lambda * daa; + double m1 = scale * dax * dbx + lambda * dab; + double m2 = scale * dbx * dbx + lambda * dbb; + // its possible that the determinant is 0, in which case we can't update the interval + double det = m0 * m2 - m1 * m1; + if (det == 0) { + return; + } + float aOpt = (float) ((m2 * dax - m1 * dbx) / det); + float bOpt = (float) ((m0 * dbx - m1 * dax) / det); + // If there is no change in the interval, we can stop + if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) { + return; + } + double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2); + // If the new loss is worse, don't update the interval and exit + // This optimization, unlike kMeans, does not always converge to better loss + // So exit if we are getting worse + if (newLoss > initialLoss) { + return; + } + // Update the interval and go again + initInterval[0] = aOpt; + initInterval[1] = bOpt; + initialLoss = newLoss; + } + } + + private static double clamp(double x, double a, double b) { + return Math.min(Math.max(x, a), b); + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0a6a24f727572..d780faad96f2d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -46,8 +46,8 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; -import org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -1788,7 +1788,7 @@ static class BBQHnswIndexOptions extends IndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES816HnswBinaryQuantizedVectorsFormat(m, efConstruction); + return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); } @Override @@ -1836,7 +1836,7 @@ static class BBQFlatIndexOptions extends IndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES816BinaryQuantizedVectorsFormat(); + return new ES818BinaryQuantizedVectorsFormat(); } @Override diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 794b30aa5aab2..57980321bdc3d 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -44,6 +44,7 @@ private SearchCapabilities() {} private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim_with_bugfix"; private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; + private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; public static final Set CAPABILITIES; static { @@ -55,6 +56,7 @@ private SearchCapabilities() {} capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER); capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); + capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS); diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 389555e60b43b..cef8d09980814 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -5,3 +5,5 @@ org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java index 9f9114c70b6db..270ad54e9a962 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java @@ -38,6 +38,32 @@ public static int popcount(byte[] a, int aOffset, byte[] b, int length) { private static float DELTA = Float.MIN_VALUE; + public void testPackAsBinary() { + // 5 bits + byte[] toPack = new byte[] { 1, 1, 0, 0, 1 }; + byte[] packed = new byte[1]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed); + + // 8 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 }; + packed = new byte[1]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed); + + // 10 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 }; + packed = new byte[2]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed); + + // 16 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 }; + packed = new byte[2]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed); + } + public void testPadFloat() { assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 4), DELTA); assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 3), DELTA); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java new file mode 100644 index 0000000000000..0bebe16f468ce --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java @@ -0,0 +1,256 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Vector scorer over binarized vector values */ +class ES816BinaryFlatRWVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + ES816BinaryFlatRWVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" + ); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + BinaryQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + // FIXME: precompute this once? + int discretizedDimensions = BQVectorUtils.discretize(target.length, 64); + if (similarityFunction == COSINE) { + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + VectorUtil.l2normalize(copy); + target = copy; + } + byte[] quantized = new byte[BQSpaceUtils.B_QUERY * discretizedDimensions / 8]; + BinaryQuantizer.QueryFactors factors = quantizer.quantizeForQuery(target, quantized, centroid); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, factors); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors + ) { + return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + int quantizedSum = queryVectors.sumQuantizedValues(ord); + float distanceToCentroid = queryVectors.getCentroidDistance(ord); + float lower = queryVectors.getLower(ord); + float width = queryVectors.getWidth(ord); + float normVmC = 0f; + float vDotC = 0f; + if (similarityFunction != EUCLIDEAN) { + normVmC = queryVectors.getNormVmC(ord); + vDotC = queryVectors.getVDotC(ord); + } + BinaryQueryVector binaryQueryVector = new BinaryQueryVector( + vector, + new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC) + ); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} + + /** Vector scorer over binarized vector values */ + static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + private final float sqrtDimensions; + private final float maxX1; + + BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + // FIXME: precompute this once? + this.sqrtDimensions = targetVectors.sqrtDimensions(); + this.maxX1 = targetVectors.maxX1(); + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + int quantizedSum = queryVector.factors().quantizedSum(); + float lower = queryVector.factors().lower(); + float width = queryVector.factors().width(); + float distanceToCentroid = queryVector.factors().distToC(); + if (similarityFunction == EUCLIDEAN) { + return euclideanScore(targetOrd, sqrtDimensions, quantizedQuery, distanceToCentroid, lower, quantizedSum, width); + } + + float vmC = queryVector.factors().normVmC(); + float vDotC = queryVector.factors().vDotC(); + float cDotC = targetVectors.getCentroidDP(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float ooq = targetVectors.getOOQ(targetOrd); + float normOC = targetVectors.getNormOC(targetOrd); + float oDotC = targetVectors.getODotC(targetOrd); + + float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + + // FIXME: pre-compute these only once for each target vector + // ... pull this out or use a similar cache mechanism as do in score + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + final float dist; + // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away + // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector + if (normOC == 0 || ooq == 0) { + dist = oDotC + vDotC - cDotC; + } else { + // If ||o-c|| != 0, we should assume that `ooq` is finite + assert Float.isFinite(ooq); + float estimatedDot = (2 * width / sqrtDimensions * qcDist + 2 * lower / sqrtDimensions * xbSum - width / sqrtDimensions + * quantizedSum - sqrtDimensions * lower) / ooq; + dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC; + } + assert Float.isFinite(dist); + + float ooqSqr = (float) Math.pow(ooq, 2); + float errorBound = (float) (vmC * normOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr))); + float score = Float.isFinite(errorBound) ? dist - errorBound : dist; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + + private float euclideanScore( + int targetOrd, + float sqrtDimensions, + byte[] quantizedQuery, + float distanceToCentroid, + float lower, + int quantizedSum, + float width + ) throws IOException { + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + + // FIXME: pre-compute these only once for each target vector + // .. not sure how to enumerate the target ordinals but that's what we did in PoC + float targetDistToC = targetVectors.getCentroidDistance(targetOrd); + float x0 = targetVectors.getVectorMagnitude(targetOrd); + float sqrX = targetDistToC * targetDistToC; + double xX0 = targetDistToC / x0; + + // TODO maybe store? + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + float factorPPC = (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension())); + float factorIP = (float) (-2.0 / sqrtDimensions * xX0); + + long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + float score = sqrX + distanceToCentroid + factorPPC * lower + (qcDist * 2 - quantizedSum) * factorIP * width; + float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC); + float error = 2.0f * maxX1 * projectionDist; + float y = (float) Math.sqrt(distanceToCentroid); + float errorBound = y * error; + if (Float.isFinite(errorBound)) { + score = score + errorBound; + } + return Math.max(1 / (1f + score), 0); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java index a75b9bc6064d1..ffe007be9799d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java @@ -59,7 +59,7 @@ public void testScore() throws IOException { short quantizedSum = (short) random().nextInt(0, 4097); float normVmC = random().nextFloat(-1000f, 1000f); float vDotC = random().nextFloat(-1000f, 1000f); - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); @@ -134,7 +134,7 @@ public int dimension() { } }; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction @@ -217,7 +217,7 @@ public void testScoreEuclidean() throws IOException { float vl = -57.883f; float width = 9.972266f; short quantizedSum = 795; - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, 0f, 0f) ); @@ -420,7 +420,7 @@ public int dimension() { VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction @@ -824,7 +824,7 @@ public void testScoreMIP() throws IOException { float normVmC = 9.766797f; float vDotC = 133.56123f; float cDotC = 132.20227f; - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); @@ -1768,7 +1768,7 @@ public int dimension() { VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java new file mode 100644 index 0000000000000..c54903a94b54f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java @@ -0,0 +1,52 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES816BinaryQuantizedRWVectorsFormat extends ES816BinaryQuantizedVectorsFormat { + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES816BinaryFlatRWVectorsScorer scorer = new ES816BinaryFlatRWVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES816BinaryQuantizedRWVectorsFormat() { + super(); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java index 681f615653d40..48ba566353f5d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java @@ -63,7 +63,7 @@ protected Codec getCodec() { return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new ES816BinaryQuantizedVectorsFormat(); + return new ES816BinaryQuantizedRWVectorsFormat(); } }; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java similarity index 99% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java index 31ae977e81118..4d97235c5fae5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java @@ -77,7 +77,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { private final List fields = new ArrayList<>(); private final IndexOutput meta, binarizedVectorData; private final FlatVectorsWriter rawVectorDelegate; - private final ES816BinaryFlatVectorsScorer vectorsScorer; + private final ES816BinaryFlatRWVectorsScorer vectorsScorer; private boolean finished; /** @@ -86,7 +86,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { * @param vectorsScorer the scorer to use for scoring vectors */ protected ES816BinaryQuantizedVectorsWriter( - ES816BinaryFlatVectorsScorer vectorsScorer, + ES816BinaryFlatRWVectorsScorer vectorsScorer, FlatVectorsWriter rawVectorDelegate, SegmentWriteState state ) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java new file mode 100644 index 0000000000000..e9bace72b591c --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java @@ -0,0 +1,55 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + +class ES816HnswBinaryQuantizedRWVectorsFormat extends ES816HnswBinaryQuantizedVectorsFormat { + + private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedRWVectorsFormat(); + + /** Constructs a format using default graph construction parameters */ + ES816HnswBinaryQuantizedRWVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH); + } + + ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(maxConn, beamWidth, numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java index a25fa2836ee34..03aa847f3a5d4 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java @@ -59,7 +59,7 @@ protected Codec getCodec() { return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new ES816HnswBinaryQuantizedVectorsFormat(); + return new ES816HnswBinaryQuantizedRWVectorsFormat(); } }; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..397cc472592b6 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java @@ -0,0 +1,181 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected Codec getCodec() { + return new Lucene100Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES818BinaryQuantizedVectorsFormat(); + } + }; + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES818BinaryQuantizedVectorsFormat(); + } + }; + String expectedPattern = "ES818BinaryQuantizedVectorsFormat(" + + "name=ES818BinaryQuantizedVectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + BinarizedByteVectorValues qvectorValues = ((ES818BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + byte[] quantizedVector = new byte[dims]; + byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + quantizedVector, + (byte) 1, + centroid + ); + BQVectorUtils.packAsBinary(quantizedVector, expectedVector); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); + } + } + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..b6ae3199bb896 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java @@ -0,0 +1,132 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.util.SameThreadExecutorService; +import org.elasticsearch.common.logging.LogConfigurator; + +import java.util.Arrays; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES818HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected Codec getCodec() { + return new Lucene100Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES818HnswBinaryQuantizedVectorsFormat(); + } + }; + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES818HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES818BinaryQuantizedVectorsFormat(name=ES818BinaryQuantizedVectorsFormat," + + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true score + assertEquals(trueScore, td.scoreDocs[0].score, 0.0001f); + } + } + } + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(-1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(0, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 0)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, -1)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) + ); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java new file mode 100644 index 0000000000000..e3e2d6caafe0e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID; + +public class OptimizedScalarQuantizerTests extends ESTestCase { + + static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + + public void testAbusiveEdgeCases() { + // large zero array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + continue; + } + float[] vector = new float[4096]; + float[] centroid = new float[4096]; + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (byte[] destination : destinations) { + assertArrayEquals(new byte[4096], destination); + } + byte[] destination = new byte[4096]; + for (byte bit : ALL_BITS) { + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertArrayEquals(new byte[4096], destination); + } + } + + // single value array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + float[] vector = new float[] { randomFloat() }; + float[] centroid = new float[] { randomFloat() }; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + vector = new float[] { randomFloat() }; + centroid = new float[] { randomFloat() }; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + byte[] destination = new byte[1]; + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + + } + + public void testMathematicalConsistency() { + int dims = randomIntBetween(1, 4096); + float[] vector = new float[dims]; + for (int i = 0; i < dims; ++i) { + vector[i] = randomFloat(); + } + float[] centroid = new float[dims]; + for (int i = 0; i < dims; ++i) { + centroid[i] = randomFloat(); + } + float[] copy = new float[dims]; + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + // copy the vector to avoid modifying it + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + byte[] destination = new byte[dims]; + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(copy, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + } + + static void assertValidQuantizedRange(byte[] quantized, byte bits) { + for (byte b : quantized) { + if (bits < 8) { + assertTrue(b >= 0); + } + assertTrue(b < 1 << bits); + } + } + + static void assertValidResults(OptimizedScalarQuantizer.QuantizationResult... results) { + for (OptimizedScalarQuantizer.QuantizationResult result : results) { + assertTrue(Float.isFinite(result.lowerInterval())); + assertTrue(Float.isFinite(result.upperInterval())); + assertTrue(result.lowerInterval() <= result.upperInterval()); + assertTrue(Float.isFinite(result.additionalCorrection())); + assertTrue(result.quantizedComponentSum() >= 0); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index de084cd4582e2..c043b9ffb381a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1970,13 +1970,13 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "ES816HnswBinaryQuantizedVectorsFormat(name=ES816HnswBinaryQuantizedVectorsFormat, maxConn=" + String expectedString = "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + m + ", beamWidth=" + efConstruction - + ", flatVectorFormat=ES816BinaryQuantizedVectorsFormat(" - + "name=ES816BinaryQuantizedVectorsFormat, " - + "flatVectorScorer=ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; + + ", flatVectorFormat=ES818BinaryQuantizedVectorsFormat(" + + "name=ES818BinaryQuantizedVectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; assertEquals(expectedString, knnVectorsFormat.toString()); } diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java index 0b5999b614050..4facaa391ec24 100644 --- a/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java @@ -35,7 +35,6 @@ import java.io.OutputStream; import java.util.Collection; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -100,11 +99,6 @@ public BlobContainer blobContainer(BlobPath path) { return new ConcurrencyLimitingBlobContainer(delegate.blobContainer(path), activeIndices, countDownLatch); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - delegate.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegate.close(); diff --git a/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java b/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java index f360a6c012cb7..cd2812a95cfac 100644 --- a/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java +++ b/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java @@ -24,7 +24,6 @@ import java.io.IOException; import java.io.InputStream; -import java.util.Iterator; class LatencySimulatingBlobStoreRepository extends FsRepository { @@ -53,11 +52,6 @@ public BlobContainer blobContainer(BlobPath path) { return new LatencySimulatingBlobContainer(blobContainer); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - fsBlobStore.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { fsBlobStore.close(); diff --git a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java index b85ee970664e2..c982f36e5ccb3 100644 --- a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java @@ -48,7 +48,6 @@ import java.io.IOException; import java.io.InputStream; import java.nio.file.NoSuchFileException; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; @@ -70,7 +69,6 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; @@ -524,39 +522,6 @@ public void testIndicesDeletedFromRepository() throws Exception { assertAcked(clusterAdmin().prepareDeleteSnapshot(TEST_REQUEST_TIMEOUT, repoName, "test-snap2").get()); } - public void testBlobStoreBulkDeletion() throws Exception { - Map> expectedBlobsPerContainer = new HashMap<>(); - try (BlobStore store = newBlobStore()) { - List blobsToDelete = new ArrayList<>(); - int numberOfContainers = randomIntBetween(2, 5); - for (int i = 0; i < numberOfContainers; i++) { - BlobPath containerPath = BlobPath.EMPTY.add(randomIdentifier()); - final BlobContainer container = store.blobContainer(containerPath); - int numberOfBlobsPerContainer = randomIntBetween(5, 10); - for (int j = 0; j < numberOfBlobsPerContainer; j++) { - byte[] bytes = randomBytes(randomInt(100)); - String blobName = randomAlphaOfLength(10); - container.writeBlob(randomPurpose(), blobName, new BytesArray(bytes), false); - if (randomBoolean()) { - blobsToDelete.add(containerPath.buildAsString() + blobName); - } else { - expectedBlobsPerContainer.computeIfAbsent(containerPath, unused -> new ArrayList<>()).add(blobName); - } - } - } - - store.deleteBlobsIgnoringIfNotExists(randomPurpose(), blobsToDelete.iterator()); - for (var containerEntry : expectedBlobsPerContainer.entrySet()) { - BlobContainer blobContainer = store.blobContainer(containerEntry.getKey()); - Map blobsInContainer = blobContainer.listBlobs(randomPurpose()); - for (String expectedBlob : containerEntry.getValue()) { - assertThat(blobsInContainer, hasKey(expectedBlob)); - } - blobContainer.delete(randomPurpose()); - } - } - } - public void testDanglingShardLevelBlobCleanup() throws Exception { final var repoName = createRepository(randomRepositoryName()); final var client = client(); diff --git a/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java b/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java index 5803c2a825671..54af75fc584d6 100644 --- a/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java +++ b/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java @@ -12,15 +12,13 @@ import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.blobstore.BlobStoreActionStats; -import org.elasticsearch.common.blobstore.OperationPurpose; import java.io.IOException; -import java.util.Iterator; import java.util.Map; public class BlobStoreWrapper implements BlobStore { - private BlobStore delegate; + private final BlobStore delegate; public BlobStoreWrapper(BlobStore delegate) { this.delegate = delegate; @@ -31,11 +29,6 @@ public BlobContainer blobContainer(BlobPath path) { return delegate.blobContainer(path); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - delegate.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegate.close(); diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java index 1a68ada60b6f1..d46639d700420 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java @@ -14,6 +14,8 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.logging.DeprecationCategory; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Setting; @@ -23,6 +25,7 @@ import org.elasticsearch.common.unit.MemorySizeValue; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.ingest.Processor; @@ -74,6 +77,8 @@ public class EnrichPlugin extends Plugin implements SystemIndexPlugin, IngestPlugin { + private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(EnrichPlugin.class); + static final Setting ENRICH_FETCH_SIZE_SETTING = Setting.intSetting( "enrich.fetch_size", 10000, @@ -126,9 +131,9 @@ public class EnrichPlugin extends Plugin implements SystemIndexPlugin, IngestPlu return String.valueOf(maxConcurrentRequests * maxLookupsPerRequest); }, val -> Setting.parseInt(val, 1, Integer.MAX_VALUE, QUEUE_CAPACITY_SETTING_NAME), Setting.Property.NodeScope); - public static final String CACHE_SIZE_SETTING_NAME = "enrich.cache.size"; + public static final String CACHE_SIZE_SETTING_NAME = "enrich.cache_size"; public static final Setting CACHE_SIZE = new Setting<>( - "enrich.cache.size", + CACHE_SIZE_SETTING_NAME, (String) null, (String s) -> FlatNumberOrByteSizeValue.parse( s, @@ -138,16 +143,59 @@ public class EnrichPlugin extends Plugin implements SystemIndexPlugin, IngestPlu Setting.Property.NodeScope ); + /** + * This setting solely exists because the original setting was accidentally renamed in + * https://github.com/elastic/elasticsearch/pull/111412. + */ + @UpdateForV10(owner = UpdateForV10.Owner.DATA_MANAGEMENT) + public static final String CACHE_SIZE_SETTING_BWC_NAME = "enrich.cache.size"; + public static final Setting CACHE_SIZE_BWC = new Setting<>( + CACHE_SIZE_SETTING_BWC_NAME, + (String) null, + (String s) -> FlatNumberOrByteSizeValue.parse( + s, + CACHE_SIZE_SETTING_BWC_NAME, + new FlatNumberOrByteSizeValue(ByteSizeValue.ofBytes((long) (0.01 * JvmInfo.jvmInfo().getConfiguredMaxHeapSize()))) + ), + Setting.Property.NodeScope, + Setting.Property.Deprecated + ); + private final Settings settings; private final EnrichCache enrichCache; + private final long maxCacheSize; public EnrichPlugin(final Settings settings) { this.settings = settings; - FlatNumberOrByteSizeValue maxSize = CACHE_SIZE.get(settings); + FlatNumberOrByteSizeValue maxSize; + if (settings.hasValue(CACHE_SIZE_SETTING_BWC_NAME)) { + if (settings.hasValue(CACHE_SIZE_SETTING_NAME)) { + throw new IllegalArgumentException( + Strings.format( + "Both [{}] and [{}] are set, please use [{}]", + CACHE_SIZE_SETTING_NAME, + CACHE_SIZE_SETTING_BWC_NAME, + CACHE_SIZE_SETTING_NAME + ) + ); + } + deprecationLogger.warn( + DeprecationCategory.SETTINGS, + "enrich_cache_size_name", + "The [{}] setting is deprecated and will be removed in a future version. Please use [{}] instead.", + CACHE_SIZE_SETTING_BWC_NAME, + CACHE_SIZE_SETTING_NAME + ); + maxSize = CACHE_SIZE_BWC.get(settings); + } else { + maxSize = CACHE_SIZE.get(settings); + } if (maxSize.byteSizeValue() != null) { this.enrichCache = new EnrichCache(maxSize.byteSizeValue()); + this.maxCacheSize = maxSize.byteSizeValue().getBytes(); } else { this.enrichCache = new EnrichCache(maxSize.flatNumber()); + this.maxCacheSize = maxSize.flatNumber(); } } @@ -286,6 +334,11 @@ public String getFeatureDescription() { return "Manages data related to Enrich policies"; } + // Visible for testing + long getMaxCacheSize() { + return maxCacheSize; + } + /** * A class that specifies either a flat (unit-less) number or a byte size value. */ diff --git a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPluginTests.java b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPluginTests.java new file mode 100644 index 0000000000000..07de0e0967448 --- /dev/null +++ b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPluginTests.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.enrich; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +public class EnrichPluginTests extends ESTestCase { + + public void testConstructWithByteSize() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_NAME, size + "b").build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithFlatNumber() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_NAME, size).build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithByteSizeBwc() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_BWC_NAME, size + "b").build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithFlatNumberBwc() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_BWC_NAME, size).build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithBothSettings() { + Settings settings = Settings.builder() + .put(EnrichPlugin.CACHE_SIZE_SETTING_NAME, randomNonNegativeInt()) + .put(EnrichPlugin.CACHE_SIZE_SETTING_BWC_NAME, randomNonNegativeInt()) + .build(); + assertThrows(IllegalArgumentException.class, () -> new EnrichPlugin(settings)); + } + + @Override + protected List filteredWarnings() { + final var warnings = super.filteredWarnings(); + warnings.add("[enrich.cache.size] setting was deprecated in Elasticsearch and will be removed in a future release."); + warnings.add( + "The [enrich.cache.size] setting is deprecated and will be removed in a future version. Please use [enrich.cache_size] instead." + ); + return warnings; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java index af7c706c807ec..b8908ee139c29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java @@ -12,5 +12,7 @@ import java.util.List; public interface Chunker { - List chunk(String input, ChunkingSettings chunkingSettings); + record ChunkOffset(int start, int end) {}; + + List chunk(String input, ChunkingSettings chunkingSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index c5897f32d6eb8..2aef54e56f4b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -68,7 +68,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private final EmbeddingType embeddingType; private final ChunkingSettings chunkingSettings; - private List> chunkedInputs; + private List chunkedOffsets; private List>> floatResults; private List>> byteResults; private List>> sparseResults; @@ -109,7 +109,7 @@ public EmbeddingRequestChunker( } private void splitIntoBatchedRequests(List inputs) { - Function> chunkFunction; + Function> chunkFunction; if (chunkingSettings != null) { var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); chunkFunction = input -> chunker.chunk(input, chunkingSettings); @@ -118,7 +118,7 @@ private void splitIntoBatchedRequests(List inputs) { chunkFunction = input -> chunker.chunk(input, wordsPerChunk, chunkOverlap); } - chunkedInputs = new ArrayList<>(inputs.size()); + chunkedOffsets = new ArrayList<>(inputs.size()); switch (embeddingType) { case FLOAT -> floatResults = new ArrayList<>(inputs.size()); case BYTE -> byteResults = new ArrayList<>(inputs.size()); @@ -128,18 +128,19 @@ private void splitIntoBatchedRequests(List inputs) { for (int i = 0; i < inputs.size(); i++) { var chunks = chunkFunction.apply(inputs.get(i)); - int numberOfSubBatches = addToBatches(chunks, i); + var offSetsAndInput = new ChunkOffsetsAndInput(chunks, inputs.get(i)); + int numberOfSubBatches = addToBatches(offSetsAndInput, i); // size the results array with the expected number of request/responses switch (embeddingType) { case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches)); case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches)); case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches)); } - chunkedInputs.add(chunks); + chunkedOffsets.add(offSetsAndInput); } } - private int addToBatches(List chunks, int inputIndex) { + private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) { BatchRequest lastBatch; if (batchedRequests.isEmpty()) { lastBatch = new BatchRequest(new ArrayList<>()); @@ -157,16 +158,24 @@ private int addToBatches(List chunks, int inputIndex) { if (freeSpace > 0) { // use any free space in the previous batch before creating new batches - int toAdd = Math.min(freeSpace, chunks.size()); - lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd))); + int toAdd = Math.min(freeSpace, chunk.offsets().size()); + lastBatch.addSubBatch( + new SubBatch( + new ChunkOffsetsAndInput(chunk.offsets().subList(0, toAdd), chunk.input()), + new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd) + ) + ); } int start = freeSpace; - while (start < chunks.size()) { - int toAdd = Math.min(maxNumberOfInputsPerBatch, chunks.size() - start); + while (start < chunk.offsets().size()) { + int toAdd = Math.min(maxNumberOfInputsPerBatch, chunk.offsets().size() - start); var batch = new BatchRequest(new ArrayList<>()); batch.addSubBatch( - new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)) + new SubBatch( + new ChunkOffsetsAndInput(chunk.offsets().subList(start, start + toAdd), chunk.input()), + new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd) + ) ); batchedRequests.add(batch); start += toAdd; @@ -333,8 +342,8 @@ public void onFailure(Exception e) { } private void sendResponse() { - var response = new ArrayList(chunkedInputs.size()); - for (int i = 0; i < chunkedInputs.size(); i++) { + var response = new ArrayList(chunkedOffsets.size()); + for (int i = 0; i < chunkedOffsets.size(); i++) { if (errors.get(i) != null) { response.add(errors.get(i)); } else { @@ -348,9 +357,9 @@ private void sendResponse() { private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) { return switch (embeddingType) { - case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex)); - case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex)); - case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex)); + case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex)); + case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex)); + case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex)); }; } @@ -428,7 +437,7 @@ public void addSubBatch(SubBatch sb) { } public List inputs() { - return subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList()); + return subBatches.stream().flatMap(s -> s.requests().toChunkText().stream()).collect(Collectors.toList()); } } @@ -441,9 +450,15 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener requests, SubBatchPositionsAndCount positions) { - public int size() { - return requests.size(); + record SubBatch(ChunkOffsetsAndInput requests, SubBatchPositionsAndCount positions) { + int size() { + return requests.offsets().size(); + } + } + + record ChunkOffsetsAndInput(List offsets, String input) { + List toChunkText() { + return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index 5df940d6a3fba..b2d6c83b89211 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -34,7 +34,6 @@ public class SentenceBoundaryChunker implements Chunker { public SentenceBoundaryChunker() { sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - } /** @@ -45,7 +44,7 @@ public SentenceBoundaryChunker() { * @return The input text chunked */ @Override - public List chunk(String input, ChunkingSettings chunkingSettings) { + public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) { return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0); } else { @@ -65,8 +64,8 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * @param maxNumberWordsPerChunk Maximum size of the chunk * @return The input text chunked */ - public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { - var chunks = new ArrayList(); + public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { + var chunks = new ArrayList(); sentenceIterator.setText(input); wordIterator.setText(input); @@ -91,7 +90,7 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl int nextChunkWordCount = wordsInSentenceCount; if (chunkWordCount > 0) { // add a new chunk containing all the input up to this sentence - chunks.add(input.substring(chunkStart, chunkEnd)); + chunks.add(new ChunkOffset(chunkStart, chunkEnd)); if (includePrecedingSentence) { if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) { @@ -127,12 +126,17 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl for (; i < sentenceSplits.size() - 1; i++) { // Because the substring was passed to splitLongSentence() // the returned positions need to be offset by chunkStart - chunks.add(input.substring(chunkStart + sentenceSplits.get(i).start(), chunkStart + sentenceSplits.get(i).end())); + chunks.add( + new ChunkOffset( + chunkStart + sentenceSplits.get(i).offsets().start(), + chunkStart + sentenceSplits.get(i).offsets().end() + ) + ); } // The final split is partially filled. // Set the next chunk start to the beginning of the // final split of the long sentence. - chunkStart = chunkStart + sentenceSplits.get(i).start(); // start pos needs to be offset by chunkStart + chunkStart = chunkStart + sentenceSplits.get(i).offsets().start(); // start pos needs to be offset by chunkStart chunkWordCount = sentenceSplits.get(i).wordCount(); } } else { @@ -151,7 +155,7 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl } if (chunkWordCount > 0) { - chunks.add(input.substring(chunkStart)); + chunks.add(new ChunkOffset(chunkStart, input.length())); } return chunks; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java index c9c752b9aabbc..b15e2134f4cf7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; /** * Breaks text into smaller strings or chunks on Word boundaries. @@ -35,7 +36,7 @@ public WordBoundaryChunker() { wordIterator = BreakIterator.getWordInstance(Locale.ROOT); } - record ChunkPosition(int start, int end, int wordCount) {} + record ChunkPosition(ChunkOffset offsets, int wordCount) {} /** * Break the input text into small chunks as dictated @@ -45,7 +46,7 @@ record ChunkPosition(int start, int end, int wordCount) {} * @return List of chunked text */ @Override - public List chunk(String input, ChunkingSettings chunkingSettings) { + public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) { return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap); } else { @@ -64,18 +65,9 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * Can be 0 but must be non-negative. * @return List of chunked text */ - public List chunk(String input, int chunkSize, int overlap) { - - if (input.isEmpty()) { - return List.of(""); - } - + public List chunk(String input, int chunkSize, int overlap) { var chunkPositions = chunkPositions(input, chunkSize, overlap); - var chunks = new ArrayList(chunkPositions.size()); - for (var pos : chunkPositions) { - chunks.add(input.substring(pos.start, pos.end)); - } - return chunks; + return chunkPositions.stream().map(ChunkPosition::offsets).collect(Collectors.toList()); } /** @@ -127,7 +119,7 @@ List chunkPositions(String input, int chunkSize, int overlap) { wordsSinceStartWindowWasMarked++; if (wordsInChunkCountIncludingOverlap >= chunkSize) { - chunkPositions.add(new ChunkPosition(windowStart, boundary, wordsInChunkCountIncludingOverlap)); + chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, boundary), wordsInChunkCountIncludingOverlap)); wordsInChunkCountIncludingOverlap = overlap; if (overlap == 0) { @@ -149,7 +141,7 @@ List chunkPositions(String input, int chunkSize, int overlap) { // if it ends on a boundary than the count should equal overlap in which case // we can ignore it, unless this is the first chunk in which case we want to add it if (wordsInChunkCountIncludingOverlap > overlap || chunkPositions.isEmpty()) { - chunkPositions.add(new ChunkPosition(windowStart, input.length(), wordsInChunkCountIncludingOverlap)); + chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, input.length()), wordsInChunkCountIncludingOverlap)); } return chunkPositions; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 4fdf254101d3e..a82d2f474ca4a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -62,7 +62,7 @@ public void testMultipleShortInputsAreSingleBatch() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < inputs.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(i, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -102,7 +102,7 @@ public void testManyInputsMakeManyBatches() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < batches.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(inputIndex, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -146,7 +146,7 @@ public void testChunkingSettingsProvided() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < batches.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(inputIndex, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -184,17 +184,17 @@ public void testLongInputChunkedOverMultipleBatches() { assertEquals(0, subBatch.positions().inputIndex()); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(1, subBatch.positions().embeddingCount()); - assertThat(subBatch.requests(), contains("1st small")); + assertThat(subBatch.requests().toChunkText(), contains("1st small")); } { var subBatch = batch.subBatches().get(1); assertEquals(1, subBatch.positions().inputIndex()); // 2nd input assertEquals(0, subBatch.positions().chunkIndex()); // 1st part of the 2nd input assertEquals(4, subBatch.positions().embeddingCount()); // 4 chunks - assertThat(subBatch.requests().get(0), startsWith("passage_input0 ")); - assertThat(subBatch.requests().get(1), startsWith(" passage_input20 ")); - assertThat(subBatch.requests().get(2), startsWith(" passage_input40 ")); - assertThat(subBatch.requests().get(3), startsWith(" passage_input60 ")); + assertThat(subBatch.requests().toChunkText().get(0), startsWith("passage_input0 ")); + assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input20 ")); + assertThat(subBatch.requests().toChunkText().get(2), startsWith(" passage_input40 ")); + assertThat(subBatch.requests().toChunkText().get(3), startsWith(" passage_input60 ")); } } { @@ -207,22 +207,22 @@ public void testLongInputChunkedOverMultipleBatches() { assertEquals(1, subBatch.positions().inputIndex()); // 2nd input assertEquals(1, subBatch.positions().chunkIndex()); // 2nd part of the 2nd input assertEquals(2, subBatch.positions().embeddingCount()); - assertThat(subBatch.requests().get(0), startsWith(" passage_input80 ")); - assertThat(subBatch.requests().get(1), startsWith(" passage_input100 ")); + assertThat(subBatch.requests().toChunkText().get(0), startsWith(" passage_input80 ")); + assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input100 ")); } { var subBatch = batch.subBatches().get(1); assertEquals(2, subBatch.positions().inputIndex()); // 3rd input assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk - assertThat(subBatch.requests(), contains("2nd small")); + assertThat(subBatch.requests().toChunkText(), contains("2nd small")); } { var subBatch = batch.subBatches().get(2); assertEquals(3, subBatch.positions().inputIndex()); // 4th input assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk - assertThat(subBatch.requests(), contains("3rd small")); + assertThat(subBatch.requests().toChunkText(), contains("3rd small")); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index afce8c57e0350..de943f7f57ab8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -15,7 +15,9 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT; import static org.hamcrest.Matchers.containsString; @@ -27,10 +29,24 @@ public class SentenceBoundaryChunkerTests extends ESTestCase { + /** + * Utility method for testing. + * Use the chunk functions that return offsets where possible + */ + private List textChunks( + SentenceBoundaryChunker chunker, + String input, + int maxNumberWordsPerChunk, + boolean includePrecedingSentence + ) { + var chunkPositions = chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence); + return chunkPositions.stream().map(offset -> input.substring(offset.start(), offset.end())).collect(Collectors.toList()); + } + public void testChunkSplitLargeChunkSizes() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); @@ -48,7 +64,7 @@ public void testChunkSplitLargeChunkSizes_withOverlap() { boolean overlap = true; for (int maxWordsPerChunk : new int[] { 70, 80, 100, 120, 150, 200 }) { var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, overlap); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, overlap); int[] overlaps = chunkOverlaps(sentenceSizes(TEST_TEXT), maxWordsPerChunk, overlap); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(overlaps.length)); @@ -107,7 +123,7 @@ public void testWithOverlap_SentencesFitInChunks() { } var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(sb.toString(), chunkSize, true); + var chunks = textChunks(chunker, sb.toString(), chunkSize, true); assertThat(chunks, hasSize(numChunks)); for (int i = 0; i < numChunks; i++) { assertThat("num sentences " + numSentences, chunks.get(i), startsWith("SStart" + sentenceStartIndexes[i])); @@ -128,10 +144,10 @@ private String makeSentence(int numWords, int sentenceIndex) { public void testChunk_ChunkSizeLargerThanText() { int maxWordsPerChunk = 500; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); assertEquals(chunks.get(0), TEST_TEXT); - chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); + chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true); assertEquals(chunks.get(0), TEST_TEXT); } @@ -142,7 +158,7 @@ public void testChunkSplit_SentencesLongerThanChunkSize() { for (int i = 0; i < chunkSizes.length; i++) { int maxWordsPerChunk = chunkSizes[i]; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(expectedNumberOFChunks[i])); for (var chunk : chunks) { @@ -171,7 +187,7 @@ public void testChunkSplit_SentencesLongerThanChunkSize_WithOverlap() { for (int i = 0; i < chunkSizes.length; i++) { int maxWordsPerChunk = chunkSizes[i]; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true); assertThat(chunks.get(0), containsString("Word segmentation is the problem of dividing")); assertThat(chunks.get(chunks.size() - 1), containsString(", with solidification being a stronger norm.")); } @@ -190,7 +206,7 @@ public void testShortLongShortSentences_WithOverlap() { } var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(sb.toString(), maxWordsPerChunk, true); + var chunks = textChunks(chunker, sb.toString(), maxWordsPerChunk, true); assertThat(chunks, hasSize(5)); assertTrue(chunks.get(0).trim().startsWith("SStart0")); // Entire sentence assertTrue(chunks.get(0).trim().endsWith(".")); // Entire sentence @@ -303,7 +319,7 @@ public void testChunkSplitLargeChunkSizesWithChunkingSettings() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk, 0); - var chunks = chunker.chunk(TEST_TEXT, chunkingSettings); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index ef643a4b36fdc..2ef28f2cf2e77 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; @@ -65,9 +66,22 @@ public class WordBoundaryChunkerTests extends ESTestCase { NUM_WORDS_IN_TEST_TEXT = wordCount; } + /** + * Utility method for testing. + * Use the chunk functions that return offsets where possible + */ + List textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) { + if (input.isEmpty()) { + return List.of(""); + } + + var chunkPositions = chunker.chunk(input, chunkSize, overlap); + return chunkPositions.stream().map(p -> input.substring(p.start(), p.end())).collect(Collectors.toList()); + } + public void testSingleSplit() { var chunker = new WordBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, 10_000, 0); + var chunks = textChunks(chunker, TEST_TEXT, 10_000, 0); assertThat(chunks, hasSize(1)); assertEquals(TEST_TEXT, chunks.get(0)); } @@ -168,11 +182,11 @@ public void testWindowSpanningWords() { } var whiteSpacedText = input.toString().stripTrailing(); - var chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 20, 10); + var chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 20, 10); assertChunkContents(chunks, numWords, 20, 10); - chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 10, 4); + chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 10, 4); assertChunkContents(chunks, numWords, 10, 4); - chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 15, 3); + chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 15, 3); assertChunkContents(chunks, numWords, 15, 3); } @@ -217,28 +231,28 @@ public void testWindowSpanning_TextShorterThanWindow() { } public void testEmptyString() { - var chunks = new WordBoundaryChunker().chunk("", 10, 5); - assertThat(chunks, contains("")); + var chunks = textChunks(new WordBoundaryChunker(), "", 10, 5); + assertThat(chunks.toString(), chunks, contains("")); } public void testWhitespace() { - var chunks = new WordBoundaryChunker().chunk(" ", 10, 5); + var chunks = textChunks(new WordBoundaryChunker(), " ", 10, 5); assertThat(chunks, contains(" ")); } public void testPunctuation() { int chunkSize = 1; - var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0); + var chunks = textChunks(new WordBoundaryChunker(), "Comma, separated", chunkSize, 0); assertThat(chunks, contains("Comma", ", separated")); - chunks = new WordBoundaryChunker().chunk("Mme. Thénardier", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Mme. Thénardier", chunkSize, 0); assertThat(chunks, contains("Mme", ". Thénardier")); - chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0); assertThat(chunks, contains("Won't", " you", " chunk")); chunkSize = 10; - chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0); assertThat(chunks, contains("Won't you chunk")); } diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java index ab38a89870500..c955457b78d60 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java @@ -67,7 +67,6 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -466,12 +465,6 @@ public BlobContainer blobContainer(BlobPath path) { return new TrackingFilesBlobContainer(delegate.blobContainer(path)); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) - throws IOException { - delegate.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegate.close(); diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java index b8acd9808a35e..6a638f53a6330 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java @@ -569,11 +569,6 @@ public BlobContainer blobContainer(BlobPath path) { } } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) { - assertPurpose(purpose); - } - private void deleteContainer(DisruptableBlobContainer container) { blobContainer = null; } diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java index 1f8b247e76176..c24a254d34ace 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java @@ -287,11 +287,6 @@ private void deleteContainer(AssertingBlobContainer container) { } } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) { - assertPurpose(purpose); - } - @Override public void close() {}