From f40dc99f9101e97f4100df2a63e73757be52ffa5 Mon Sep 17 00:00:00 2001 From: Dan Rubinstein Date: Mon, 9 Dec 2024 10:14:32 -0500 Subject: [PATCH 1/6] Adding transforms migration guide for 9.0 (#117353) * Adding transforms migration guide for 9.0 * Adding shared transform attribute and simplifying wording --------- Co-authored-by: Elastic Machine --- .../migrate_9_0/transforms-migration-guide.asciidoc | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 docs/reference/migration/migrate_9_0/transforms-migration-guide.asciidoc diff --git a/docs/reference/migration/migrate_9_0/transforms-migration-guide.asciidoc b/docs/reference/migration/migrate_9_0/transforms-migration-guide.asciidoc new file mode 100644 index 0000000000000..d41c524d68d5c --- /dev/null +++ b/docs/reference/migration/migrate_9_0/transforms-migration-guide.asciidoc @@ -0,0 +1,9 @@ +[[transforms-migration-guide]] +== {transforms-cap} migration guide +This migration guide helps you upgrade your {transforms} to work with the 9.0 release. Each section outlines a breaking change and any manual steps needed to upgrade your {transforms} to be compatible with 9.0. + + +=== Updating deprecated {transform} roles (`data_frame_transforms_admin` and `data_frame_transforms_user`) +If you have existing {transforms} that use deprecated {transform} roles (`data_frame_transforms_admin` or `data_frame_transforms_user`) you must update them to use the new equivalent {transform} roles (`transform_admin` or `transform_user`). To update your {transform} roles: +1. Switch to a user with the `transform_admin` role (to replace `data_frame_transforms_admin`) or the `transform_user` role (to replace `data_frame_transforms_user`). +2. Call the <> with that user. From 2ecf981f24ca5d480466a4fcc669aeb52d063657 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 9 Dec 2024 15:22:48 +0000 Subject: [PATCH 2/6] [ML] Refactor the Chunker classes to return offsets (#117977) --- .../xpack/inference/chunking/Chunker.java | 4 +- .../chunking/EmbeddingRequestChunker.java | 55 ++++++++++++------- .../chunking/SentenceBoundaryChunker.java | 20 ++++--- .../chunking/WordBoundaryChunker.java | 22 +++----- .../EmbeddingRequestChunkerTests.java | 24 ++++---- .../SentenceBoundaryChunkerTests.java | 34 +++++++++--- .../chunking/WordBoundaryChunkerTests.java | 36 ++++++++---- 7 files changed, 119 insertions(+), 76 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java index af7c706c807ec..b8908ee139c29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/Chunker.java @@ -12,5 +12,7 @@ import java.util.List; public interface Chunker { - List chunk(String input, ChunkingSettings chunkingSettings); + record ChunkOffset(int start, int end) {}; + + List chunk(String input, ChunkingSettings chunkingSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index c5897f32d6eb8..2aef54e56f4b9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -68,7 +68,7 @@ public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.El private final EmbeddingType embeddingType; private final ChunkingSettings chunkingSettings; - private List> chunkedInputs; + private List chunkedOffsets; private List>> floatResults; private List>> byteResults; private List>> sparseResults; @@ -109,7 +109,7 @@ public EmbeddingRequestChunker( } private void splitIntoBatchedRequests(List inputs) { - Function> chunkFunction; + Function> chunkFunction; if (chunkingSettings != null) { var chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); chunkFunction = input -> chunker.chunk(input, chunkingSettings); @@ -118,7 +118,7 @@ private void splitIntoBatchedRequests(List inputs) { chunkFunction = input -> chunker.chunk(input, wordsPerChunk, chunkOverlap); } - chunkedInputs = new ArrayList<>(inputs.size()); + chunkedOffsets = new ArrayList<>(inputs.size()); switch (embeddingType) { case FLOAT -> floatResults = new ArrayList<>(inputs.size()); case BYTE -> byteResults = new ArrayList<>(inputs.size()); @@ -128,18 +128,19 @@ private void splitIntoBatchedRequests(List inputs) { for (int i = 0; i < inputs.size(); i++) { var chunks = chunkFunction.apply(inputs.get(i)); - int numberOfSubBatches = addToBatches(chunks, i); + var offSetsAndInput = new ChunkOffsetsAndInput(chunks, inputs.get(i)); + int numberOfSubBatches = addToBatches(offSetsAndInput, i); // size the results array with the expected number of request/responses switch (embeddingType) { case FLOAT -> floatResults.add(new AtomicArray<>(numberOfSubBatches)); case BYTE -> byteResults.add(new AtomicArray<>(numberOfSubBatches)); case SPARSE -> sparseResults.add(new AtomicArray<>(numberOfSubBatches)); } - chunkedInputs.add(chunks); + chunkedOffsets.add(offSetsAndInput); } } - private int addToBatches(List chunks, int inputIndex) { + private int addToBatches(ChunkOffsetsAndInput chunk, int inputIndex) { BatchRequest lastBatch; if (batchedRequests.isEmpty()) { lastBatch = new BatchRequest(new ArrayList<>()); @@ -157,16 +158,24 @@ private int addToBatches(List chunks, int inputIndex) { if (freeSpace > 0) { // use any free space in the previous batch before creating new batches - int toAdd = Math.min(freeSpace, chunks.size()); - lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd))); + int toAdd = Math.min(freeSpace, chunk.offsets().size()); + lastBatch.addSubBatch( + new SubBatch( + new ChunkOffsetsAndInput(chunk.offsets().subList(0, toAdd), chunk.input()), + new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd) + ) + ); } int start = freeSpace; - while (start < chunks.size()) { - int toAdd = Math.min(maxNumberOfInputsPerBatch, chunks.size() - start); + while (start < chunk.offsets().size()) { + int toAdd = Math.min(maxNumberOfInputsPerBatch, chunk.offsets().size() - start); var batch = new BatchRequest(new ArrayList<>()); batch.addSubBatch( - new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)) + new SubBatch( + new ChunkOffsetsAndInput(chunk.offsets().subList(start, start + toAdd), chunk.input()), + new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd) + ) ); batchedRequests.add(batch); start += toAdd; @@ -333,8 +342,8 @@ public void onFailure(Exception e) { } private void sendResponse() { - var response = new ArrayList(chunkedInputs.size()); - for (int i = 0; i < chunkedInputs.size(); i++) { + var response = new ArrayList(chunkedOffsets.size()); + for (int i = 0; i < chunkedOffsets.size(); i++) { if (errors.get(i) != null) { response.add(errors.get(i)); } else { @@ -348,9 +357,9 @@ private void sendResponse() { private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) { return switch (embeddingType) { - case FLOAT -> mergeFloatResultsWithInputs(chunkedInputs.get(resultIndex), floatResults.get(resultIndex)); - case BYTE -> mergeByteResultsWithInputs(chunkedInputs.get(resultIndex), byteResults.get(resultIndex)); - case SPARSE -> mergeSparseResultsWithInputs(chunkedInputs.get(resultIndex), sparseResults.get(resultIndex)); + case FLOAT -> mergeFloatResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), floatResults.get(resultIndex)); + case BYTE -> mergeByteResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), byteResults.get(resultIndex)); + case SPARSE -> mergeSparseResultsWithInputs(chunkedOffsets.get(resultIndex).toChunkText(), sparseResults.get(resultIndex)); }; } @@ -428,7 +437,7 @@ public void addSubBatch(SubBatch sb) { } public List inputs() { - return subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList()); + return subBatches.stream().flatMap(s -> s.requests().toChunkText().stream()).collect(Collectors.toList()); } } @@ -441,9 +450,15 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener requests, SubBatchPositionsAndCount positions) { - public int size() { - return requests.size(); + record SubBatch(ChunkOffsetsAndInput requests, SubBatchPositionsAndCount positions) { + int size() { + return requests.offsets().size(); + } + } + + record ChunkOffsetsAndInput(List offsets, String input) { + List toChunkText() { + return offsets.stream().map(o -> input.substring(o.start(), o.end())).collect(Collectors.toList()); } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java index 5df940d6a3fba..b2d6c83b89211 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunker.java @@ -34,7 +34,6 @@ public class SentenceBoundaryChunker implements Chunker { public SentenceBoundaryChunker() { sentenceIterator = BreakIterator.getSentenceInstance(Locale.ROOT); wordIterator = BreakIterator.getWordInstance(Locale.ROOT); - } /** @@ -45,7 +44,7 @@ public SentenceBoundaryChunker() { * @return The input text chunked */ @Override - public List chunk(String input, ChunkingSettings chunkingSettings) { + public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof SentenceBoundaryChunkingSettings sentenceBoundaryChunkingSettings) { return chunk(input, sentenceBoundaryChunkingSettings.maxChunkSize, sentenceBoundaryChunkingSettings.sentenceOverlap > 0); } else { @@ -65,8 +64,8 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * @param maxNumberWordsPerChunk Maximum size of the chunk * @return The input text chunked */ - public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { - var chunks = new ArrayList(); + public List chunk(String input, int maxNumberWordsPerChunk, boolean includePrecedingSentence) { + var chunks = new ArrayList(); sentenceIterator.setText(input); wordIterator.setText(input); @@ -91,7 +90,7 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl int nextChunkWordCount = wordsInSentenceCount; if (chunkWordCount > 0) { // add a new chunk containing all the input up to this sentence - chunks.add(input.substring(chunkStart, chunkEnd)); + chunks.add(new ChunkOffset(chunkStart, chunkEnd)); if (includePrecedingSentence) { if (wordsInPrecedingSentenceCount + wordsInSentenceCount > maxNumberWordsPerChunk) { @@ -127,12 +126,17 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl for (; i < sentenceSplits.size() - 1; i++) { // Because the substring was passed to splitLongSentence() // the returned positions need to be offset by chunkStart - chunks.add(input.substring(chunkStart + sentenceSplits.get(i).start(), chunkStart + sentenceSplits.get(i).end())); + chunks.add( + new ChunkOffset( + chunkStart + sentenceSplits.get(i).offsets().start(), + chunkStart + sentenceSplits.get(i).offsets().end() + ) + ); } // The final split is partially filled. // Set the next chunk start to the beginning of the // final split of the long sentence. - chunkStart = chunkStart + sentenceSplits.get(i).start(); // start pos needs to be offset by chunkStart + chunkStart = chunkStart + sentenceSplits.get(i).offsets().start(); // start pos needs to be offset by chunkStart chunkWordCount = sentenceSplits.get(i).wordCount(); } } else { @@ -151,7 +155,7 @@ public List chunk(String input, int maxNumberWordsPerChunk, boolean incl } if (chunkWordCount > 0) { - chunks.add(input.substring(chunkStart)); + chunks.add(new ChunkOffset(chunkStart, input.length())); } return chunks; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java index c9c752b9aabbc..b15e2134f4cf7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunker.java @@ -15,6 +15,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; /** * Breaks text into smaller strings or chunks on Word boundaries. @@ -35,7 +36,7 @@ public WordBoundaryChunker() { wordIterator = BreakIterator.getWordInstance(Locale.ROOT); } - record ChunkPosition(int start, int end, int wordCount) {} + record ChunkPosition(ChunkOffset offsets, int wordCount) {} /** * Break the input text into small chunks as dictated @@ -45,7 +46,7 @@ record ChunkPosition(int start, int end, int wordCount) {} * @return List of chunked text */ @Override - public List chunk(String input, ChunkingSettings chunkingSettings) { + public List chunk(String input, ChunkingSettings chunkingSettings) { if (chunkingSettings instanceof WordBoundaryChunkingSettings wordBoundaryChunkerSettings) { return chunk(input, wordBoundaryChunkerSettings.maxChunkSize, wordBoundaryChunkerSettings.overlap); } else { @@ -64,18 +65,9 @@ public List chunk(String input, ChunkingSettings chunkingSettings) { * Can be 0 but must be non-negative. * @return List of chunked text */ - public List chunk(String input, int chunkSize, int overlap) { - - if (input.isEmpty()) { - return List.of(""); - } - + public List chunk(String input, int chunkSize, int overlap) { var chunkPositions = chunkPositions(input, chunkSize, overlap); - var chunks = new ArrayList(chunkPositions.size()); - for (var pos : chunkPositions) { - chunks.add(input.substring(pos.start, pos.end)); - } - return chunks; + return chunkPositions.stream().map(ChunkPosition::offsets).collect(Collectors.toList()); } /** @@ -127,7 +119,7 @@ List chunkPositions(String input, int chunkSize, int overlap) { wordsSinceStartWindowWasMarked++; if (wordsInChunkCountIncludingOverlap >= chunkSize) { - chunkPositions.add(new ChunkPosition(windowStart, boundary, wordsInChunkCountIncludingOverlap)); + chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, boundary), wordsInChunkCountIncludingOverlap)); wordsInChunkCountIncludingOverlap = overlap; if (overlap == 0) { @@ -149,7 +141,7 @@ List chunkPositions(String input, int chunkSize, int overlap) { // if it ends on a boundary than the count should equal overlap in which case // we can ignore it, unless this is the first chunk in which case we want to add it if (wordsInChunkCountIncludingOverlap > overlap || chunkPositions.isEmpty()) { - chunkPositions.add(new ChunkPosition(windowStart, input.length(), wordsInChunkCountIncludingOverlap)); + chunkPositions.add(new ChunkPosition(new ChunkOffset(windowStart, input.length()), wordsInChunkCountIncludingOverlap)); } return chunkPositions; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 4fdf254101d3e..a82d2f474ca4a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -62,7 +62,7 @@ public void testMultipleShortInputsAreSingleBatch() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < inputs.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(i, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -102,7 +102,7 @@ public void testManyInputsMakeManyBatches() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < batches.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(inputIndex, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -146,7 +146,7 @@ public void testChunkingSettingsProvided() { var subBatches = batches.get(0).batch().subBatches(); for (int i = 0; i < batches.size(); i++) { var subBatch = subBatches.get(i); - assertThat(subBatch.requests(), contains(inputs.get(i))); + assertThat(subBatch.requests().toChunkText(), contains(inputs.get(i))); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(inputIndex, subBatch.positions().inputIndex()); assertEquals(1, subBatch.positions().embeddingCount()); @@ -184,17 +184,17 @@ public void testLongInputChunkedOverMultipleBatches() { assertEquals(0, subBatch.positions().inputIndex()); assertEquals(0, subBatch.positions().chunkIndex()); assertEquals(1, subBatch.positions().embeddingCount()); - assertThat(subBatch.requests(), contains("1st small")); + assertThat(subBatch.requests().toChunkText(), contains("1st small")); } { var subBatch = batch.subBatches().get(1); assertEquals(1, subBatch.positions().inputIndex()); // 2nd input assertEquals(0, subBatch.positions().chunkIndex()); // 1st part of the 2nd input assertEquals(4, subBatch.positions().embeddingCount()); // 4 chunks - assertThat(subBatch.requests().get(0), startsWith("passage_input0 ")); - assertThat(subBatch.requests().get(1), startsWith(" passage_input20 ")); - assertThat(subBatch.requests().get(2), startsWith(" passage_input40 ")); - assertThat(subBatch.requests().get(3), startsWith(" passage_input60 ")); + assertThat(subBatch.requests().toChunkText().get(0), startsWith("passage_input0 ")); + assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input20 ")); + assertThat(subBatch.requests().toChunkText().get(2), startsWith(" passage_input40 ")); + assertThat(subBatch.requests().toChunkText().get(3), startsWith(" passage_input60 ")); } } { @@ -207,22 +207,22 @@ public void testLongInputChunkedOverMultipleBatches() { assertEquals(1, subBatch.positions().inputIndex()); // 2nd input assertEquals(1, subBatch.positions().chunkIndex()); // 2nd part of the 2nd input assertEquals(2, subBatch.positions().embeddingCount()); - assertThat(subBatch.requests().get(0), startsWith(" passage_input80 ")); - assertThat(subBatch.requests().get(1), startsWith(" passage_input100 ")); + assertThat(subBatch.requests().toChunkText().get(0), startsWith(" passage_input80 ")); + assertThat(subBatch.requests().toChunkText().get(1), startsWith(" passage_input100 ")); } { var subBatch = batch.subBatches().get(1); assertEquals(2, subBatch.positions().inputIndex()); // 3rd input assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk - assertThat(subBatch.requests(), contains("2nd small")); + assertThat(subBatch.requests().toChunkText(), contains("2nd small")); } { var subBatch = batch.subBatches().get(2); assertEquals(3, subBatch.positions().inputIndex()); // 4th input assertEquals(0, subBatch.positions().chunkIndex()); // 1st and only part assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk - assertThat(subBatch.requests(), contains("3rd small")); + assertThat(subBatch.requests().toChunkText(), contains("3rd small")); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java index afce8c57e0350..de943f7f57ab8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkerTests.java @@ -15,7 +15,9 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkerTests.TEST_TEXT; import static org.hamcrest.Matchers.containsString; @@ -27,10 +29,24 @@ public class SentenceBoundaryChunkerTests extends ESTestCase { + /** + * Utility method for testing. + * Use the chunk functions that return offsets where possible + */ + private List textChunks( + SentenceBoundaryChunker chunker, + String input, + int maxNumberWordsPerChunk, + boolean includePrecedingSentence + ) { + var chunkPositions = chunker.chunk(input, maxNumberWordsPerChunk, includePrecedingSentence); + return chunkPositions.stream().map(offset -> input.substring(offset.start(), offset.end())).collect(Collectors.toList()); + } + public void testChunkSplitLargeChunkSizes() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); @@ -48,7 +64,7 @@ public void testChunkSplitLargeChunkSizes_withOverlap() { boolean overlap = true; for (int maxWordsPerChunk : new int[] { 70, 80, 100, 120, 150, 200 }) { var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, overlap); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, overlap); int[] overlaps = chunkOverlaps(sentenceSizes(TEST_TEXT), maxWordsPerChunk, overlap); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(overlaps.length)); @@ -107,7 +123,7 @@ public void testWithOverlap_SentencesFitInChunks() { } var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(sb.toString(), chunkSize, true); + var chunks = textChunks(chunker, sb.toString(), chunkSize, true); assertThat(chunks, hasSize(numChunks)); for (int i = 0; i < numChunks; i++) { assertThat("num sentences " + numSentences, chunks.get(i), startsWith("SStart" + sentenceStartIndexes[i])); @@ -128,10 +144,10 @@ private String makeSentence(int numWords, int sentenceIndex) { public void testChunk_ChunkSizeLargerThanText() { int maxWordsPerChunk = 500; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); assertEquals(chunks.get(0), TEST_TEXT); - chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); + chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true); assertEquals(chunks.get(0), TEST_TEXT); } @@ -142,7 +158,7 @@ public void testChunkSplit_SentencesLongerThanChunkSize() { for (int i = 0; i < chunkSizes.length; i++) { int maxWordsPerChunk = chunkSizes[i]; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, false); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(expectedNumberOFChunks[i])); for (var chunk : chunks) { @@ -171,7 +187,7 @@ public void testChunkSplit_SentencesLongerThanChunkSize_WithOverlap() { for (int i = 0; i < chunkSizes.length; i++) { int maxWordsPerChunk = chunkSizes[i]; var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, maxWordsPerChunk, true); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, true); assertThat(chunks.get(0), containsString("Word segmentation is the problem of dividing")); assertThat(chunks.get(chunks.size() - 1), containsString(", with solidification being a stronger norm.")); } @@ -190,7 +206,7 @@ public void testShortLongShortSentences_WithOverlap() { } var chunker = new SentenceBoundaryChunker(); - var chunks = chunker.chunk(sb.toString(), maxWordsPerChunk, true); + var chunks = textChunks(chunker, sb.toString(), maxWordsPerChunk, true); assertThat(chunks, hasSize(5)); assertTrue(chunks.get(0).trim().startsWith("SStart0")); // Entire sentence assertTrue(chunks.get(0).trim().endsWith(".")); // Entire sentence @@ -303,7 +319,7 @@ public void testChunkSplitLargeChunkSizesWithChunkingSettings() { for (int maxWordsPerChunk : new int[] { 100, 200 }) { var chunker = new SentenceBoundaryChunker(); SentenceBoundaryChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(maxWordsPerChunk, 0); - var chunks = chunker.chunk(TEST_TEXT, chunkingSettings); + var chunks = textChunks(chunker, TEST_TEXT, maxWordsPerChunk, false); int numChunks = expectedNumberOfChunks(sentenceSizes(TEST_TEXT), maxWordsPerChunk); assertThat("words per chunk " + maxWordsPerChunk, chunks, hasSize(numChunks)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java index ef643a4b36fdc..2ef28f2cf2e77 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkerTests.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Locale; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; @@ -65,9 +66,22 @@ public class WordBoundaryChunkerTests extends ESTestCase { NUM_WORDS_IN_TEST_TEXT = wordCount; } + /** + * Utility method for testing. + * Use the chunk functions that return offsets where possible + */ + List textChunks(WordBoundaryChunker chunker, String input, int chunkSize, int overlap) { + if (input.isEmpty()) { + return List.of(""); + } + + var chunkPositions = chunker.chunk(input, chunkSize, overlap); + return chunkPositions.stream().map(p -> input.substring(p.start(), p.end())).collect(Collectors.toList()); + } + public void testSingleSplit() { var chunker = new WordBoundaryChunker(); - var chunks = chunker.chunk(TEST_TEXT, 10_000, 0); + var chunks = textChunks(chunker, TEST_TEXT, 10_000, 0); assertThat(chunks, hasSize(1)); assertEquals(TEST_TEXT, chunks.get(0)); } @@ -168,11 +182,11 @@ public void testWindowSpanningWords() { } var whiteSpacedText = input.toString().stripTrailing(); - var chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 20, 10); + var chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 20, 10); assertChunkContents(chunks, numWords, 20, 10); - chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 10, 4); + chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 10, 4); assertChunkContents(chunks, numWords, 10, 4); - chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 15, 3); + chunks = textChunks(new WordBoundaryChunker(), whiteSpacedText, 15, 3); assertChunkContents(chunks, numWords, 15, 3); } @@ -217,28 +231,28 @@ public void testWindowSpanning_TextShorterThanWindow() { } public void testEmptyString() { - var chunks = new WordBoundaryChunker().chunk("", 10, 5); - assertThat(chunks, contains("")); + var chunks = textChunks(new WordBoundaryChunker(), "", 10, 5); + assertThat(chunks.toString(), chunks, contains("")); } public void testWhitespace() { - var chunks = new WordBoundaryChunker().chunk(" ", 10, 5); + var chunks = textChunks(new WordBoundaryChunker(), " ", 10, 5); assertThat(chunks, contains(" ")); } public void testPunctuation() { int chunkSize = 1; - var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0); + var chunks = textChunks(new WordBoundaryChunker(), "Comma, separated", chunkSize, 0); assertThat(chunks, contains("Comma", ", separated")); - chunks = new WordBoundaryChunker().chunk("Mme. Thénardier", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Mme. Thénardier", chunkSize, 0); assertThat(chunks, contains("Mme", ". Thénardier")); - chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0); assertThat(chunks, contains("Won't", " you", " chunk")); chunkSize = 10; - chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0); + chunks = textChunks(new WordBoundaryChunker(), "Won't you chunk", chunkSize, 0); assertThat(chunks, contains("Won't you chunk")); } From 22a392f1b69224aac3894dd6a05b892ccbb6a75d Mon Sep 17 00:00:00 2001 From: Ryan Ernst Date: Mon, 9 Dec 2024 07:33:11 -0800 Subject: [PATCH 3/6] Remove client.type setting (#118192) The client.type setting is a holdover from the node client which was removed in 8.0. The setting has been a noop since 8.0. This commit removes the setting. relates #104574 --- docs/changelog/118192.yaml | 11 +++++++++++ .../org/elasticsearch/client/internal/Client.java | 10 ---------- .../common/settings/ClusterSettings.java | 2 -- 3 files changed, 11 insertions(+), 12 deletions(-) create mode 100644 docs/changelog/118192.yaml diff --git a/docs/changelog/118192.yaml b/docs/changelog/118192.yaml new file mode 100644 index 0000000000000..03542048761d3 --- /dev/null +++ b/docs/changelog/118192.yaml @@ -0,0 +1,11 @@ +pr: 118192 +summary: Remove `client.type` setting +area: Infra/Core +type: breaking +issues: [104574] +breaking: + title: Remove `client.type` setting + area: Cluster and node setting + details: The node setting `client.type` has been ignored since the node client was removed in 8.0. The setting is now removed. + impact: Remove the `client.type` setting from `elasticsearch.yml` + notable: false diff --git a/server/src/main/java/org/elasticsearch/client/internal/Client.java b/server/src/main/java/org/elasticsearch/client/internal/Client.java index 4158bbfb27cda..2d1cbe0cce7f7 100644 --- a/server/src/main/java/org/elasticsearch/client/internal/Client.java +++ b/server/src/main/java/org/elasticsearch/client/internal/Client.java @@ -52,8 +52,6 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.action.update.UpdateResponse; -import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.transport.RemoteClusterService; @@ -74,14 +72,6 @@ */ public interface Client extends ElasticsearchClient { - // Note: This setting is registered only for bwc. The value is never read. - Setting CLIENT_TYPE_SETTING_S = new Setting<>("client.type", "node", (s) -> { - return switch (s) { - case "node", "transport" -> s; - default -> throw new IllegalArgumentException("Can't parse [client.type] must be one of [node, transport]"); - }; - }, Property.NodeScope, Property.Deprecated); - /** * The admin client that can be used to perform administrative operations. */ diff --git a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java index a9a9411de8e1f..16af7ca2915d4 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java +++ b/server/src/main/java/org/elasticsearch/common/settings/ClusterSettings.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.action.support.replication.TransportReplicationAction; import org.elasticsearch.bootstrap.BootstrapSettings; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterModule; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.InternalClusterInfoService; @@ -483,7 +482,6 @@ public void apply(Settings value, Settings current, Settings previous) { AutoCreateIndex.AUTO_CREATE_INDEX_SETTING, BaseRestHandler.MULTI_ALLOW_EXPLICIT_INDEX, ClusterName.CLUSTER_NAME_SETTING, - Client.CLIENT_TYPE_SETTING_S, ClusterModule.SHARDS_ALLOCATOR_TYPE_SETTING, EsExecutors.NODE_PROCESSORS_SETTING, ThreadContext.DEFAULT_HEADERS_SETTING, From 0586cbfb34c7201a996578db60d12fea8594261c Mon Sep 17 00:00:00 2001 From: David Turner Date: Mon, 9 Dec 2024 15:46:22 +0000 Subject: [PATCH 4/6] Remove unused `BlobStore#deleteBlobsIgnoringIfNotExists` (#118245) This method is never called against a general `BlobStore`, we only use it in certain implementations for which a bulk delete at the `BlobStore` level makes sense. This commit removes the unused interface method. --- .../azure/AzureBlobContainer.java | 2 +- .../repositories/azure/AzureBlobStore.java | 3 +- .../azure/AzureBlobContainerStatsTests.java | 9 ++--- .../gcs/GoogleCloudStorageBlobContainer.java | 4 +-- .../gcs/GoogleCloudStorageBlobStore.java | 10 ++---- .../repositories/s3/S3BlobContainer.java | 6 ++-- .../repositories/s3/S3BlobStore.java | 3 +- .../common/blobstore/url/URLBlobStore.java | 8 ----- .../repositories/hdfs/HdfsBlobStore.java | 7 ---- .../hdfs/HdfsBlobStoreRepositoryTests.java | 5 --- ...BlobStoreRepositoryOperationPurposeIT.java | 6 ---- .../common/blobstore/BlobStore.java | 10 ------ .../common/blobstore/fs/FsBlobContainer.java | 2 +- .../common/blobstore/fs/FsBlobStore.java | 4 +-- ...bStoreRepositoryDeleteThrottlingTests.java | 6 ---- .../LatencySimulatingBlobStoreRepository.java | 6 ---- .../ESBlobStoreRepositoryIntegTestCase.java | 35 ------------------- .../snapshots/mockstore/BlobStoreWrapper.java | 9 +---- ...archableSnapshotsPrewarmingIntegTests.java | 7 ---- .../analyze/RepositoryAnalysisFailureIT.java | 5 --- .../analyze/RepositoryAnalysisSuccessIT.java | 5 --- 21 files changed, 17 insertions(+), 135 deletions(-) diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java index 73936d82fc204..08bdc2051b9e3 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java @@ -144,7 +144,7 @@ public DeleteResult delete(OperationPurpose purpose) throws IOException { @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, new Iterator<>() { + blobStore.deleteBlobs(purpose, new Iterator<>() { @Override public boolean hasNext() { return blobNames.hasNext(); diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java index e4f973fb73a4e..3cac0dc4bb6db 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java @@ -264,8 +264,7 @@ public DeleteResult deleteBlobDirectory(OperationPurpose purpose, String path) t return new DeleteResult(blobsDeleted.get(), bytesDeleted.get()); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(OperationPurpose purpose, Iterator blobNames) { if (blobNames.hasNext() == false) { return; } diff --git a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java index f6e97187222e7..8979507230bdd 100644 --- a/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java +++ b/modules/repository-azure/src/test/java/org/elasticsearch/repositories/azure/AzureBlobContainerStatsTests.java @@ -72,7 +72,7 @@ public void testRetriesAndOperationsAreTrackedSeparately() throws IOException { false ); case LIST_BLOBS -> blobStore.listBlobsByPrefix(purpose, randomIdentifier(), randomIdentifier()); - case BLOB_BATCH -> blobStore.deleteBlobsIgnoringIfNotExists( + case BLOB_BATCH -> blobStore.deleteBlobs( purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator() ); @@ -113,7 +113,7 @@ public void testOperationPurposeIsReflectedInBlobStoreStats() throws IOException os.flush(); }); // BLOB_BATCH - blobStore.deleteBlobsIgnoringIfNotExists(purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator()); + blobStore.deleteBlobs(purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator()); Map stats = blobStore.stats(); String statsMapString = stats.toString(); @@ -148,10 +148,7 @@ public void testOperationPurposeIsNotReflectedInBlobStoreStatsWhenNotServerless( os.flush(); }); // BLOB_BATCH - blobStore.deleteBlobsIgnoringIfNotExists( - purpose, - List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator() - ); + blobStore.deleteBlobs(purpose, List.of(randomIdentifier(), randomIdentifier(), randomIdentifier()).iterator()); } Map stats = blobStore.stats(); diff --git a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java index 047549cc893ed..edcf03580da09 100644 --- a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java +++ b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobContainer.java @@ -114,12 +114,12 @@ public void writeBlobAtomic(OperationPurpose purpose, String blobName, BytesRefe @Override public DeleteResult delete(OperationPurpose purpose) throws IOException { - return blobStore.deleteDirectory(purpose, path().buildAsString()); + return blobStore.deleteDirectory(path().buildAsString()); } @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, new Iterator<>() { + blobStore.deleteBlobs(new Iterator<>() { @Override public boolean hasNext() { return blobNames.hasNext(); diff --git a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java index 9cbf64e7e0146..c68217a1a3738 100644 --- a/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java +++ b/modules/repository-gcs/src/main/java/org/elasticsearch/repositories/gcs/GoogleCloudStorageBlobStore.java @@ -29,7 +29,6 @@ import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.blobstore.BlobStoreActionStats; import org.elasticsearch.common.blobstore.DeleteResult; -import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.common.blobstore.OptionalBytesReference; import org.elasticsearch.common.blobstore.support.BlobContainerUtils; import org.elasticsearch.common.blobstore.support.BlobMetadata; @@ -491,10 +490,9 @@ private void writeBlobMultipart(BlobInfo blobInfo, byte[] buffer, int offset, in /** * Deletes the given path and all its children. * - * @param purpose The purpose of the delete operation * @param pathStr Name of path to delete */ - DeleteResult deleteDirectory(OperationPurpose purpose, String pathStr) throws IOException { + DeleteResult deleteDirectory(String pathStr) throws IOException { return SocketAccess.doPrivilegedIOException(() -> { DeleteResult deleteResult = DeleteResult.ZERO; Page page = client().list(bucketName, BlobListOption.prefix(pathStr)); @@ -502,7 +500,7 @@ DeleteResult deleteDirectory(OperationPurpose purpose, String pathStr) throws IO final AtomicLong blobsDeleted = new AtomicLong(0L); final AtomicLong bytesDeleted = new AtomicLong(0L); final Iterator blobs = page.getValues().iterator(); - deleteBlobsIgnoringIfNotExists(purpose, new Iterator<>() { + deleteBlobs(new Iterator<>() { @Override public boolean hasNext() { return blobs.hasNext(); @@ -526,11 +524,9 @@ public String next() { /** * Deletes multiple blobs from the specific bucket using a batch request * - * @param purpose the purpose of the delete operation * @param blobNames names of the blobs to delete */ - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(Iterator blobNames) throws IOException { if (blobNames.hasNext() == false) { return; } diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java index e13cc40dd3e0f..f527dcd42814c 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java @@ -342,10 +342,10 @@ public DeleteResult delete(OperationPurpose purpose) throws IOException { return summary.getKey(); }); if (list.isTruncated()) { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, blobNameIterator); + blobStore.deleteBlobs(purpose, blobNameIterator); prevListing = list; } else { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, Iterators.concat(blobNameIterator, Iterators.single(keyPath))); + blobStore.deleteBlobs(purpose, Iterators.concat(blobNameIterator, Iterators.single(keyPath))); break; } } @@ -357,7 +357,7 @@ public DeleteResult delete(OperationPurpose purpose) throws IOException { @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, Iterators.map(blobNames, this::buildKey)); + blobStore.deleteBlobs(purpose, Iterators.map(blobNames, this::buildKey)); } @Override diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java index 4f2b0f213e448..4bd54aa37077f 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobStore.java @@ -340,8 +340,7 @@ public BlobContainer blobContainer(BlobPath path) { return new S3BlobContainer(path, this); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(OperationPurpose purpose, Iterator blobNames) throws IOException { if (blobNames.hasNext() == false) { return; } diff --git a/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java b/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java index 8538d2ba673bc..0e9c735b22fd6 100644 --- a/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java +++ b/modules/repository-url/src/main/java/org/elasticsearch/common/blobstore/url/URLBlobStore.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.blobstore.BlobStoreException; -import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.common.blobstore.url.http.HttpURLBlobContainer; import org.elasticsearch.common.blobstore.url.http.URLHttpClient; import org.elasticsearch.common.blobstore.url.http.URLHttpClientSettings; @@ -23,10 +22,8 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.CheckedFunction; -import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; -import java.util.Iterator; import java.util.List; /** @@ -109,11 +106,6 @@ public BlobContainer blobContainer(BlobPath blobPath) { } } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - throw new UnsupportedOperationException("Bulk deletes are not supported in URL repositories"); - } - @Override public void close() { // nothing to do here... diff --git a/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java b/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java index eaf2429ae6258..e817384d95c04 100644 --- a/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java +++ b/plugins/repository-hdfs/src/main/java/org/elasticsearch/repositories/hdfs/HdfsBlobStore.java @@ -16,10 +16,8 @@ import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; -import org.elasticsearch.common.blobstore.OperationPurpose; import java.io.IOException; -import java.util.Iterator; final class HdfsBlobStore implements BlobStore { @@ -72,11 +70,6 @@ public BlobContainer blobContainer(BlobPath path) { return new HdfsBlobContainer(path, this, buildHdfsPath(path), bufferSize, securityContext, replicationFactor); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - throw new UnsupportedOperationException("Bulk deletes are not supported in Hdfs repositories"); - } - private Path buildHdfsPath(BlobPath blobPath) { final Path path = translateToHdfsPath(blobPath); if (readOnly == false) { diff --git a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java index 17927b02a08dc..3e1c112a4d9f7 100644 --- a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java +++ b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsBlobStoreRepositoryTests.java @@ -46,11 +46,6 @@ public void testSnapshotAndRestore() throws Exception { testSnapshotAndRestore(false); } - @Override - public void testBlobStoreBulkDeletion() throws Exception { - // HDFS does not implement bulk deletion from different BlobContainers - } - @Override protected Collection> nodePlugins() { return Collections.singletonList(HdfsPlugin.class); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java b/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java index c0a2c83f7fe1e..b2e02b2f4c271 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryOperationPurposeIT.java @@ -36,7 +36,6 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.Collection; -import java.util.Iterator; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -136,11 +135,6 @@ public BlobContainer blobContainer(BlobPath path) { return new AssertingBlobContainer(delegateBlobStore.blobContainer(path)); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - delegateBlobStore.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegateBlobStore.close(); diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java b/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java index d67c034fd3e27..f1fe028f60f6e 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/BlobStore.java @@ -9,9 +9,7 @@ package org.elasticsearch.common.blobstore; import java.io.Closeable; -import java.io.IOException; import java.util.Collections; -import java.util.Iterator; import java.util.Map; /** @@ -28,14 +26,6 @@ public interface BlobStore extends Closeable { */ BlobContainer blobContainer(BlobPath path); - /** - * Delete all the provided blobs from the blob store. Each blob could belong to a different {@code BlobContainer} - * - * @param purpose the purpose of the delete operation - * @param blobNames the blobs to be deleted - */ - void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException; - /** * Returns statistics on the count of operations that have been performed on this blob store */ diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java index b5118d8a289a9..7d40008231292 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobContainer.java @@ -177,7 +177,7 @@ public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IO @Override public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - blobStore.deleteBlobsIgnoringIfNotExists(purpose, Iterators.map(blobNames, blobName -> path.resolve(blobName).toString())); + blobStore.deleteBlobs(Iterators.map(blobNames, blobName -> path.resolve(blobName).toString())); } @Override diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java index 53e3b4b4796dc..9a368483d46c0 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/fs/FsBlobStore.java @@ -13,7 +13,6 @@ import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; -import org.elasticsearch.common.blobstore.OperationPurpose; import org.elasticsearch.core.IOUtils; import java.io.IOException; @@ -70,8 +69,7 @@ public BlobContainer blobContainer(BlobPath path) { return new FsBlobContainer(this, path, f); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { + void deleteBlobs(Iterator blobNames) throws IOException { IOException ioe = null; long suppressedExceptions = 0; while (blobNames.hasNext()) { diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java index 0b5999b614050..4facaa391ec24 100644 --- a/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/BlobStoreRepositoryDeleteThrottlingTests.java @@ -35,7 +35,6 @@ import java.io.OutputStream; import java.util.Collection; import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -100,11 +99,6 @@ public BlobContainer blobContainer(BlobPath path) { return new ConcurrencyLimitingBlobContainer(delegate.blobContainer(path), activeIndices, countDownLatch); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - delegate.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegate.close(); diff --git a/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java b/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java index f360a6c012cb7..cd2812a95cfac 100644 --- a/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java +++ b/test/external-modules/latency-simulating-directory/src/main/java/org/elasticsearch/test/simulatedlatencyrepo/LatencySimulatingBlobStoreRepository.java @@ -24,7 +24,6 @@ import java.io.IOException; import java.io.InputStream; -import java.util.Iterator; class LatencySimulatingBlobStoreRepository extends FsRepository { @@ -53,11 +52,6 @@ public BlobContainer blobContainer(BlobPath path) { return new LatencySimulatingBlobContainer(blobContainer); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - fsBlobStore.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { fsBlobStore.close(); diff --git a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java index b85ee970664e2..c982f36e5ccb3 100644 --- a/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/blobstore/ESBlobStoreRepositoryIntegTestCase.java @@ -48,7 +48,6 @@ import java.io.IOException; import java.io.InputStream; import java.nio.file.NoSuchFileException; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; @@ -70,7 +69,6 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertHitCount; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.hasKey; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; @@ -524,39 +522,6 @@ public void testIndicesDeletedFromRepository() throws Exception { assertAcked(clusterAdmin().prepareDeleteSnapshot(TEST_REQUEST_TIMEOUT, repoName, "test-snap2").get()); } - public void testBlobStoreBulkDeletion() throws Exception { - Map> expectedBlobsPerContainer = new HashMap<>(); - try (BlobStore store = newBlobStore()) { - List blobsToDelete = new ArrayList<>(); - int numberOfContainers = randomIntBetween(2, 5); - for (int i = 0; i < numberOfContainers; i++) { - BlobPath containerPath = BlobPath.EMPTY.add(randomIdentifier()); - final BlobContainer container = store.blobContainer(containerPath); - int numberOfBlobsPerContainer = randomIntBetween(5, 10); - for (int j = 0; j < numberOfBlobsPerContainer; j++) { - byte[] bytes = randomBytes(randomInt(100)); - String blobName = randomAlphaOfLength(10); - container.writeBlob(randomPurpose(), blobName, new BytesArray(bytes), false); - if (randomBoolean()) { - blobsToDelete.add(containerPath.buildAsString() + blobName); - } else { - expectedBlobsPerContainer.computeIfAbsent(containerPath, unused -> new ArrayList<>()).add(blobName); - } - } - } - - store.deleteBlobsIgnoringIfNotExists(randomPurpose(), blobsToDelete.iterator()); - for (var containerEntry : expectedBlobsPerContainer.entrySet()) { - BlobContainer blobContainer = store.blobContainer(containerEntry.getKey()); - Map blobsInContainer = blobContainer.listBlobs(randomPurpose()); - for (String expectedBlob : containerEntry.getValue()) { - assertThat(blobsInContainer, hasKey(expectedBlob)); - } - blobContainer.delete(randomPurpose()); - } - } - } - public void testDanglingShardLevelBlobCleanup() throws Exception { final var repoName = createRepository(randomRepositoryName()); final var client = client(); diff --git a/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java b/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java index 5803c2a825671..54af75fc584d6 100644 --- a/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java +++ b/test/framework/src/main/java/org/elasticsearch/snapshots/mockstore/BlobStoreWrapper.java @@ -12,15 +12,13 @@ import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; import org.elasticsearch.common.blobstore.BlobStoreActionStats; -import org.elasticsearch.common.blobstore.OperationPurpose; import java.io.IOException; -import java.util.Iterator; import java.util.Map; public class BlobStoreWrapper implements BlobStore { - private BlobStore delegate; + private final BlobStore delegate; public BlobStoreWrapper(BlobStore delegate) { this.delegate = delegate; @@ -31,11 +29,6 @@ public BlobContainer blobContainer(BlobPath path) { return delegate.blobContainer(path); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) throws IOException { - delegate.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegate.close(); diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java index ab38a89870500..c955457b78d60 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/cache/full/SearchableSnapshotsPrewarmingIntegTests.java @@ -67,7 +67,6 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -466,12 +465,6 @@ public BlobContainer blobContainer(BlobPath path) { return new TrackingFilesBlobContainer(delegate.blobContainer(path)); } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) - throws IOException { - delegate.deleteBlobsIgnoringIfNotExists(purpose, blobNames); - } - @Override public void close() throws IOException { delegate.close(); diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java index b8acd9808a35e..6a638f53a6330 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisFailureIT.java @@ -569,11 +569,6 @@ public BlobContainer blobContainer(BlobPath path) { } } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) { - assertPurpose(purpose); - } - private void deleteContainer(DisruptableBlobContainer container) { blobContainer = null; } diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java index 1f8b247e76176..c24a254d34ace 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/analyze/RepositoryAnalysisSuccessIT.java @@ -287,11 +287,6 @@ private void deleteContainer(AssertingBlobContainer container) { } } - @Override - public void deleteBlobsIgnoringIfNotExists(OperationPurpose purpose, Iterator blobNames) { - assertPurpose(purpose); - } - @Override public void close() {} From 5e859d9301ffe736548dfc2b6e72807a7f9006ff Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 9 Dec 2024 11:06:27 -0500 Subject: [PATCH 5/6] Even better(er) binary quantization (#117994) This measurably improves BBQ by adjusting the underlying algorithm to an optimized per vector scalar quantization. This is a brand new way to quantize vectors. Instead of there being a global set of upper and lower quantile bands, these are optimized and calculated per individual vector. Additionally, vectors are centered on a common centroid. This allows for an almost 32x reduction in memory, and even better recall than before at the cost of slightly increasing indexing time. Additionally, this new approach is easily generalizable to various other bit sizes (e.g. 2 bits, etc.). While not taken advantage of yet, we may update our scalar quantized indices in the future to use this new algorithm, giving significant boosts in recall. The recall gains spread from 2% to almost 10% for certain datasets with an additional 5-10% indexing cost when indexing with HNSW when compared with current BBQ. --- docs/changelog/117994.yaml | 5 + rest-api-spec/build.gradle | 2 + .../search.vectors/41_knn_search_bbq_hnsw.yml | 66 +- .../search.vectors/42_knn_search_bbq_flat.yml | 66 +- server/src/main/java/module-info.java | 4 +- .../index/codec/vectors/BQVectorUtils.java | 14 + .../es816/ES816BinaryFlatVectorsScorer.java | 59 +- .../ES816BinaryQuantizedVectorsFormat.java | 2 +- ...ES816HnswBinaryQuantizedVectorsFormat.java | 17 +- .../es818/BinarizedByteVectorValues.java | 87 ++ .../es818/ES818BinaryFlatVectorsScorer.java | 188 ++++ .../ES818BinaryQuantizedVectorsFormat.java | 132 +++ .../ES818BinaryQuantizedVectorsReader.java | 412 ++++++++ .../ES818BinaryQuantizedVectorsWriter.java | 944 ++++++++++++++++++ ...ES818HnswBinaryQuantizedVectorsFormat.java | 145 +++ .../es818/OffHeapBinarizedVectorValues.java | 371 +++++++ .../es818/OptimizedScalarQuantizer.java | 246 +++++ .../vectors/DenseVectorFieldMapper.java | 8 +- .../action/search/SearchCapabilities.java | 2 + .../org.apache.lucene.codecs.KnnVectorsFormat | 2 + .../codec/vectors/BQVectorUtilsTests.java | 26 + .../es816/ES816BinaryFlatRWVectorsScorer.java | 256 +++++ .../ES816BinaryFlatVectorsScorerTests.java | 12 +- .../ES816BinaryQuantizedRWVectorsFormat.java | 52 + ...S816BinaryQuantizedVectorsFormatTests.java | 2 +- .../ES816BinaryQuantizedVectorsWriter.java | 4 +- ...816HnswBinaryQuantizedRWVectorsFormat.java | 55 + ...HnswBinaryQuantizedVectorsFormatTests.java | 2 +- ...S818BinaryQuantizedVectorsFormatTests.java | 181 ++++ ...HnswBinaryQuantizedVectorsFormatTests.java | 132 +++ .../es818/OptimizedScalarQuantizerTests.java | 136 +++ .../vectors/DenseVectorFieldMapperTests.java | 8 +- 32 files changed, 3501 insertions(+), 137 deletions(-) create mode 100644 docs/changelog/117994.yaml create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java rename server/src/{main => test}/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java (99%) create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java diff --git a/docs/changelog/117994.yaml b/docs/changelog/117994.yaml new file mode 100644 index 0000000000000..603f2d855a11a --- /dev/null +++ b/docs/changelog/117994.yaml @@ -0,0 +1,5 @@ +pr: 117994 +summary: Even better(er) binary quantization +area: Vector Search +type: enhancement +issues: [] diff --git a/rest-api-spec/build.gradle b/rest-api-spec/build.gradle index e2af894eb0939..7347d9c1312dd 100644 --- a/rest-api-spec/build.gradle +++ b/rest-api-spec/build.gradle @@ -67,4 +67,6 @@ tasks.named("yamlRestCompatTestTransform").configure ({ task -> task.skipTest("logsdb/20_source_mapping/include/exclude is supported with stored _source", "no longer serialize source_mode") task.skipTest("logsdb/20_source_mapping/synthetic _source is default", "no longer serialize source_mode") task.skipTest("search/520_fetch_fields/fetch _seq_no via fields", "error code is changed from 5xx to 400 in 9.0") + task.skipTest("search.vectors/41_knn_search_bbq_hnsw/Test knn search", "Scoring has changed in latest versions") + task.skipTest("search.vectors/42_knn_search_bbq_flat/Test knn search", "Scoring has changed in latest versions") }) diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml index 188c155e4a836..5767c895fbe7e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_bbq_hnsw.yml @@ -11,20 +11,11 @@ setup: number_of_shards: 1 mappings: properties: - name: - type: keyword vector: type: dense_vector dims: 64 index: true - similarity: l2_norm - index_options: - type: bbq_hnsw - another_vector: - type: dense_vector - dims: 64 - index: true - similarity: l2_norm + similarity: max_inner_product index_options: type: bbq_hnsw @@ -33,9 +24,14 @@ setup: index: bbq_hnsw id: "1" body: - name: cow.jpg - vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] - another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] # Flush in order to provoke a merge later - do: indices.flush: @@ -46,9 +42,14 @@ setup: index: bbq_hnsw id: "2" body: - name: moose.jpg - vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] - another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] # Flush in order to provoke a merge later - do: indices.flush: @@ -60,8 +61,14 @@ setup: id: "3" body: name: rabbit.jpg - vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] - another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] # Flush in order to provoke a merge later - do: indices.flush: @@ -73,20 +80,33 @@ setup: max_num_segments: 1 --- "Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" - do: search: index: bbq_hnsw body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 - # Depending on how things are distributed, docs 2 and 3 might be swapped - # here we verify that are last hit is always the worst one - - match: { hits.hits.2._id: "1" } - + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } --- "Test bad quantization parameters": - do: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml index ed7a8dd5df65d..dcdae04aeabb4 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/42_knn_search_bbq_flat.yml @@ -11,20 +11,11 @@ setup: number_of_shards: 1 mappings: properties: - name: - type: keyword vector: type: dense_vector dims: 64 index: true - similarity: l2_norm - index_options: - type: bbq_flat - another_vector: - type: dense_vector - dims: 64 - index: true - similarity: l2_norm + similarity: max_inner_product index_options: type: bbq_flat @@ -33,9 +24,14 @@ setup: index: bbq_flat id: "1" body: - name: cow.jpg - vector: [300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0, 230.0, 300.33, -34.8988, 15.555, -200.0] - another_vector: [115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0, 130.0, 115.0, -1.02, 15.555, -100.0] + vector: [0.077, 0.32 , -0.205, 0.63 , 0.032, 0.201, 0.167, -0.313, + 0.176, 0.531, -0.375, 0.334, -0.046, 0.078, -0.349, 0.272, + 0.307, -0.083, 0.504, 0.255, -0.404, 0.289, -0.226, -0.132, + -0.216, 0.49 , 0.039, 0.507, -0.307, 0.107, 0.09 , -0.265, + -0.285, 0.336, -0.272, 0.369, -0.282, 0.086, -0.132, 0.475, + -0.224, 0.203, 0.439, 0.064, 0.246, -0.396, 0.297, 0.242, + -0.028, 0.321, -0.022, -0.009, -0.001 , 0.031, -0.533, 0.45, + -0.683, 1.331, 0.194, -0.157, -0.1 , -0.279, -0.098, -0.176] # Flush in order to provoke a merge later - do: indices.flush: @@ -46,9 +42,14 @@ setup: index: bbq_flat id: "2" body: - name: moose.jpg - vector: [100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0, -0.5, 100.0, -13, 14.8, -156.0] - another_vector: [50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120, -0.5, 50.0, -1, 1, 120] + vector: [0.196, 0.514, 0.039, 0.555, -0.042, 0.242, 0.463, -0.348, + -0.08 , 0.442, -0.067, -0.05 , -0.001, 0.298, -0.377, 0.048, + 0.307, 0.159, 0.278, 0.119, -0.057, 0.333, -0.289, -0.438, + -0.014, 0.361, -0.169, 0.292, -0.229, 0.123, 0.031, -0.138, + -0.139, 0.315, -0.216, 0.322, -0.445, -0.059, 0.071, 0.429, + -0.602, -0.142, 0.11 , 0.192, 0.259, -0.241, 0.181, -0.166, + 0.082, 0.107, -0.05 , 0.155, 0.011, 0.161, -0.486, 0.569, + -0.489, 0.901, 0.208, 0.011, -0.209, -0.153, -0.27 , -0.013] # Flush in order to provoke a merge later - do: indices.flush: @@ -59,9 +60,14 @@ setup: index: bbq_flat id: "3" body: - name: rabbit.jpg - vector: [111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0, 0.5, 111.3, -13.0, 14.8, -156.0] - another_vector: [11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0, -0.5, 11.0, 0, 12, 111.0] + vector: [0.139, 0.178, -0.117, 0.399, 0.014, -0.139, 0.347, -0.33 , + 0.139, 0.34 , -0.052, -0.052, -0.249, 0.327, -0.288, 0.049, + 0.464, 0.338, 0.516, 0.247, -0.104, 0.259, -0.209, -0.246, + -0.11 , 0.323, 0.091, 0.442, -0.254, 0.195, -0.109, -0.058, + -0.279, 0.402, -0.107, 0.308, -0.273, 0.019, 0.082, 0.399, + -0.658, -0.03 , 0.276, 0.041, 0.187, -0.331, 0.165, 0.017, + 0.171, -0.203, -0.198, 0.115, -0.007, 0.337, -0.444, 0.615, + -0.657, 1.285, 0.2 , -0.062, 0.038, 0.089, -0.068, -0.058] # Flush in order to provoke a merge later - do: indices.flush: @@ -73,19 +79,33 @@ setup: max_num_segments: 1 --- "Test knn search": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ optimized_scalar_quantization_bbq ] + test_runner_features: capabilities + reason: "BBQ scoring improved and changed with optimized_scalar_quantization_bbq" - do: search: index: bbq_flat body: knn: field: vector - query_vector: [ 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0, -0.5, 90.0, -10, 14.8, -156.0] + query_vector: [0.128, 0.067, -0.08 , 0.395, -0.11 , -0.259, 0.473, -0.393, + 0.292, 0.571, -0.491, 0.444, -0.288, 0.198, -0.343, 0.015, + 0.232, 0.088, 0.228, 0.151, -0.136, 0.236, -0.273, -0.259, + -0.217, 0.359, -0.207, 0.352, -0.142, 0.192, -0.061, -0.17 , + -0.343, 0.189, -0.221, 0.32 , -0.301, -0.1 , 0.005, 0.232, + -0.344, 0.136, 0.252, 0.157, -0.13 , -0.244, 0.193, -0.034, + -0.12 , -0.193, -0.102, 0.252, -0.185, -0.167, -0.575, 0.582, + -0.426, 0.983, 0.212, 0.204, 0.03 , -0.276, -0.425, -0.158] k: 3 num_candidates: 3 - # Depending on how things are distributed, docs 2 and 3 might be swapped - # here we verify that are last hit is always the worst one - - match: { hits.hits.2._id: "1" } + - match: { hits.hits.0._id: "1" } + - match: { hits.hits.1._id: "3" } + - match: { hits.hits.2._id: "2" } --- "Test bad parameters": - do: diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 331a2bc0dddac..ff902dbede007 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -459,7 +459,9 @@ org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat, org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat, org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java index 5201e57179cc7..1aff06a175967 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java @@ -40,6 +40,20 @@ public static boolean isUnitVector(float[] v) { return Math.abs(l1norm - 1.0d) <= EPSILON; } + public static void packAsBinary(byte[] vector, byte[] packed) { + for (int i = 0; i < vector.length;) { + byte result = 0; + for (int j = 7; j >= 0 && i < vector.length; j--) { + assert vector[i] == 0 || vector[i] == 1; + result |= (byte) ((vector[i] & 1) << j); + ++i; + } + int index = ((i + 7) / 8) - 1; + assert index < packed.length; + packed[index] = result; + } + } + public static int discretize(int value, int bucket) { return ((value + (bucket - 1)) / bucket) * bucket; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java index 445bdadab2354..e85079e998c61 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorer.java @@ -48,13 +48,8 @@ class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer { public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues - ) throws IOException { - if (vectorValues instanceof BinarizedByteVectorValues) { - throw new UnsupportedOperationException( - "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" - ); - } - return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + ) { + throw new UnsupportedOperationException(); } @Override @@ -90,61 +85,11 @@ public RandomVectorScorer getRandomVectorScorer( return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } - RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, - ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, - BinarizedByteVectorValues targetVectors - ) { - return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); - } - @Override public String toString() { return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; } - /** Vector scorer supplier over binarized vector values */ - static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; - private final BinarizedByteVectorValues targetVectors; - private final VectorSimilarityFunction similarityFunction; - - BinarizedRandomVectorScorerSupplier( - ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, - BinarizedByteVectorValues targetVectors, - VectorSimilarityFunction similarityFunction - ) { - this.queryVectors = queryVectors; - this.targetVectors = targetVectors; - this.similarityFunction = similarityFunction; - } - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - byte[] vector = queryVectors.vectorValue(ord); - int quantizedSum = queryVectors.sumQuantizedValues(ord); - float distanceToCentroid = queryVectors.getCentroidDistance(ord); - float lower = queryVectors.getLower(ord); - float width = queryVectors.getWidth(ord); - float normVmC = 0f; - float vDotC = 0f; - if (similarityFunction != EUCLIDEAN) { - normVmC = queryVectors.getNormVmC(ord); - vDotC = queryVectors.getVDotC(ord); - } - BinaryQueryVector binaryQueryVector = new BinaryQueryVector( - vector, - new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC) - ); - return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); - } - } - /** A binarized query representing its quantized form along with factors */ record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java index d864ec5dee8c5..61b6edc474d1f 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormat.java @@ -62,7 +62,7 @@ public ES816BinaryQuantizedVectorsFormat() { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + throw new UnsupportedOperationException(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java index 52f9f14b7bf97..1dbb4e432b188 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormat.java @@ -25,10 +25,8 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.hnsw.HnswGraph; import java.io.IOException; @@ -52,21 +50,18 @@ public class ES816HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. */ - private final int maxConn; + protected final int maxConn; /** * The number of candidate neighbors to track while searching the graph for each newly inserted * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} * for details. */ - private final int beamWidth; + protected final int beamWidth; /** The format for storing, reading, merging vectors on disk */ private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedVectorsFormat(); - private final int numMergeWorkers; - private final TaskExecutor mergeExec; - /** Constructs a format using default graph construction parameters */ public ES816HnswBinaryQuantizedVectorsFormat() { this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); @@ -109,17 +104,11 @@ public ES816HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int num if (numMergeWorkers == 1 && mergeExec != null) { throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); } - this.numMergeWorkers = numMergeWorkers; - if (mergeExec != null) { - this.mergeExec = new TaskExecutor(mergeExec); - } else { - this.mergeExec = null; - } } @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + throw new UnsupportedOperationException(); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java new file mode 100644 index 0000000000000..cc1f7b85e0f78 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/BinarizedByteVectorValues.java @@ -0,0 +1,87 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +abstract class BinarizedByteVectorValues extends ByteVectorValues { + + /** + * Retrieve the corrective terms for the given vector ordinal. For the dot-product family of + * distances, the corrective terms are, in order + * + *
    + *
  • the lower optimized interval + *
  • the upper optimized interval + *
  • the dot-product of the non-centered vector with the centroid + *
  • the sum of quantized components + *
+ * + * For euclidean: + * + *
    + *
  • the lower optimized interval + *
  • the upper optimized interval + *
  • the l2norm of the centered vector + *
  • the sum of quantized components + *
+ * + * @param vectorOrd the vector ordinal + * @return the corrective terms + * @throws IOException if an I/O error occurs + */ + public abstract OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int vectorOrd) throws IOException; + + /** + * @return the quantizer used to quantize the vectors + */ + public abstract OptimizedScalarQuantizer getQuantizer(); + + public abstract float[] getCentroid() throws IOException; + + int discretizedDimensions() { + return BQVectorUtils.discretize(dimension(), 64); + } + + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ + public abstract VectorScorer scorer(float[] query) throws IOException; + + @Override + public abstract BinarizedByteVectorValues copy() throws IOException; + + float getCentroidDP() throws IOException { + // this only gets executed on-merge + float[] centroid = getCentroid(); + return VectorUtil.dotProduct(centroid, centroid); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java new file mode 100644 index 0000000000000..7c7e470909eb3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryFlatVectorsScorer.java @@ -0,0 +1,188 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Vector scorer over binarized vector values */ +public class ES818BinaryFlatVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + public ES818BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" + ); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + // We make a copy as the quantization process mutates the input + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + if (similarityFunction == COSINE) { + VectorUtil.l2normalize(copy); + } + target = copy; + byte[] initial = new byte[target.length]; + byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8]; + OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(target, initial, (byte) 4, centroid); + BQSpaceUtils.transposeHalfByte(initial, quantized); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, queryCorrections); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors + ) { + return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + ES818BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + OptimizedScalarQuantizer.QuantizationResult correctiveTerms = queryVectors.getCorrectiveTerms(ord); + BinaryQueryVector binaryQueryVector = new BinaryQueryVector(vector, correctiveTerms); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + public record BinaryQueryVector(byte[] vector, OptimizedScalarQuantizer.QuantizationResult quantizationResult) {} + + /** Vector scorer over binarized vector values */ + public static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + public BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + OptimizedScalarQuantizer.QuantizationResult queryCorrections = queryVector.quantizationResult(); + OptimizedScalarQuantizer.QuantizationResult indexCorrections = targetVectors.getCorrectiveTerms(targetOrd); + float x1 = indexCorrections.quantizedComponentSum(); + float ax = indexCorrections.lowerInterval(); + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = indexCorrections.upperInterval() - ax; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * targetVectors.dimension() + ay * lx * x1 + ax * ly * y1 + lx * ly * qcDist; + // For euclidean, we need to invert the score and apply the additional correction, which is + // assumed to be the squared l2norm of the centroid centered vectors. + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + indexCorrections.additionalCorrection() - targetVectors.getCentroidDP(); + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..1dee9599f985f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormat.java @@ -0,0 +1,132 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + * Codec for encoding/decoding binary quantized vectors The binary quantization format used here + * is a per-vector optimized scalar quantization. Also see {@link + * org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer}. Some of key features are: + * + *
    + *
  • Estimating the distance between two vectors using their centroid normalized distance. This + * requires some additional corrective factors, but allows for centroid normalization to occur. + *
  • Optimized scalar quantization to bit level of centroid normalized vectors. + *
  • Asymmetric quantization of vectors, where query vectors are quantized to half-byte + * precision (normalized to the centroid) and then compared directly against the single bit + * quantized vectors in the index. + *
  • Transforming the half-byte quantized query vectors in such a way that the comparison with + * single bit vectors can be done with bit arithmetic. + *
+ * + * The format is stored in two files: + * + *

.veb (vector data) file

+ * + *

Stores the binary quantized vectors in a flat format. Additionally, it stores each vector's + * corrective factors. At the end of the file, additional information is stored for vector ordinal + * to centroid ordinal mapping and sparse vector information. + * + *

    + *
  • For each vector: + *
      + *
    • [byte] the binary quantized values, each byte holds 8 bits. + *
    • [float] the optimized quantiles and an additional similarity dependent corrective factor. + *
    • short the sum of the quantized components
    • + *
    + *
  • After the vectors, sparse vector information keeping track of monotonic blocks. + *
+ * + *

.vemb (vector metadata) file

+ * + *

Stores the metadata for the vectors. This includes the number of vectors, the number of + * dimensions, and file offset information. + * + *

    + *
  • int the field number + *
  • int the vector encoding ordinal + *
  • int the vector similarity ordinal + *
  • vint the vector dimensions + *
  • vlong the offset to the vector data in the .veb file + *
  • vlong the length of the vector data in the .veb file + *
  • vint the number of vectors + *
  • [float] the centroid
  • + *
  • float the centroid square magnitude
  • + *
  • The sparse vector information, if required, mapping vector ordinal to doc ID + *
+ */ +public class ES818BinaryQuantizedVectorsFormat extends FlatVectorsFormat { + + public static final String BINARIZED_VECTOR_COMPONENT = "BVEC"; + public static final String NAME = "ES818BinaryQuantizedVectorsFormat"; + + static final int VERSION_START = 0; + static final int VERSION_CURRENT = VERSION_START; + static final String META_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "ES818BinaryQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemb"; + static final String VECTOR_DATA_EXTENSION = "veb"; + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES818BinaryQuantizedVectorsFormat() { + super(NAME); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES818BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new ES818BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES818BinaryQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java new file mode 100644 index 0000000000000..8036b8314cdc1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsReader.java @@ -0,0 +1,412 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.ReadAdvice; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.SuppressForbidden; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readSimilarityFunction; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.readVectorEncoding; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +class ES818BinaryQuantizedVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(ES818BinaryQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); + private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; + private final ES818BinaryFlatVectorsScorer vectorScorer; + + ES818BinaryQuantizedVectorsReader( + SegmentReadState state, + FlatVectorsReader rawVectorsReader, + ES818BinaryFlatVectorsScorer vectorsScorer + ) throws IOException { + super(vectorsScorer); + this.vectorScorer = vectorsScorer; + this.rawVectorsReader = rawVectorsReader; + int versionMeta = -1; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.META_EXTENSION + ); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + meta, + ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_START, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + quantizedVectorData = openDataInput( + state, + versionMeta, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + // Quantized vectors are accessed randomly from their node ID stored in the HNSW + // graph. + state.context.withReadAdvice(ReadAdvice.RANDOM) + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta, info); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " + fieldEntry.dimension + ); + } + + int binaryDims = BQVectorUtils.discretize(dimension, 64) / 8; + long numQuantizedVectorBytes = Math.multiplyExact((binaryDims + (Float.BYTES * 3) + Short.BYTES), (long) fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "Binarized vector data length " + + fieldEntry.vectorDataLength + + " not matching size = " + + fieldEntry.size + + " * (binaryBytes=" + + binaryDims + + " + 14" + + ") = " + + numQuantizedVectorBytes + ); + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + return vectorScorer.getRandomVectorScorer( + fi.similarityFunction, + OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ), + target + ); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(quantizedVectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fi = fields.get(field); + if (fi == null) { + return null; + } + if (fi.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + field + "\" is encoded as: " + fi.vectorEncoding + " expected: " + VectorEncoding.FLOAT32 + ); + } + OffHeapBinarizedVectorValues bvv = OffHeapBinarizedVectorValues.load( + fi.ordToDocDISIReaderConfiguration, + fi.dimension, + fi.size, + new OptimizedScalarQuantizer(fi.similarityFunction), + fi.similarityFunction, + vectorScorer, + fi.centroid, + fi.centroidDP, + fi.vectorDataOffset, + fi.vectorDataLength, + quantizedVectorData + ); + return new BinarizedVectorValues(rawVectorsReader.getFloatVectorValues(field), bvv); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + @Override + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + } + + @Override + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + if (knnCollector.k() == 0) return; + final RandomVectorScorer scorer = getRandomVectorScorer(field, target); + if (scorer == null) return; + OrdinalTranslatedKnnCollector collector = new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + Bits acceptedOrds = scorer.getAcceptOrds(acceptDocs); + for (int i = 0; i < scorer.maxOrd(); i++) { + if (acceptedOrds == null || acceptedOrds.get(i)) { + collector.collect(i, scorer.score(i)); + collector.incVisitedCount(1); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += RamUsageEstimator.sizeOfMap(fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + public float[] getCentroid(String field) { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry != null) { + return fieldEntry.centroid; + } + return null; + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + ES818BinaryQuantizedVectorsFormat.VERSION_START, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); + } + return FieldEntry.create(input, vectorEncoding, info.getVectorSimilarityFunction()); + } + + private record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + int dimension, + int descritizedDimension, + long vectorDataOffset, + long vectorDataLength, + int size, + float[] centroid, + float centroidDP, + OrdToDocDISIReaderConfiguration ordToDocDISIReaderConfiguration + ) { + + static FieldEntry create(IndexInput input, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) + throws IOException { + int dimension = input.readVInt(); + long vectorDataOffset = input.readVLong(); + long vectorDataLength = input.readVLong(); + int size = input.readVInt(); + final float[] centroid; + float centroidDP = 0; + if (size > 0) { + centroid = new float[dimension]; + input.readFloats(centroid, 0, dimension); + centroidDP = Float.intBitsToFloat(input.readInt()); + } else { + centroid = null; + } + OrdToDocDISIReaderConfiguration conf = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + return new FieldEntry( + similarityFunction, + vectorEncoding, + dimension, + BQVectorUtils.discretize(dimension, 64), + vectorDataOffset, + vectorDataLength, + size, + centroid, + centroidDP, + conf + ); + } + } + + /** Binarized vector values holding row and quantized vector values */ + protected static final class BinarizedVectorValues extends FloatVectorValues { + private final FloatVectorValues rawVectorValues; + private final BinarizedByteVectorValues quantizedVectorValues; + + BinarizedVectorValues(FloatVectorValues rawVectorValues, BinarizedByteVectorValues quantizedVectorValues) { + this.rawVectorValues = rawVectorValues; + this.quantizedVectorValues = quantizedVectorValues; + } + + @Override + public int dimension() { + return rawVectorValues.dimension(); + } + + @Override + public int size() { + return rawVectorValues.size(); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); + } + + @Override + public BinarizedVectorValues copy() throws IOException { + return new BinarizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return rawVectorValues.getAcceptOrds(acceptDocs); + } + + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + + @Override + public DocIndexIterator iterator() { + return rawVectorValues.iterator(); + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + return quantizedVectorValues.scorer(query); + } + + BinarizedByteVectorValues getQuantizedVectorValues() throws IOException { + return quantizedVectorValues; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java new file mode 100644 index 0000000000000..02dda6a4a9da1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsWriter.java @@ -0,0 +1,944 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.FloatArrayList; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; +import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.BINARIZED_VECTOR_COMPONENT; +import static org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +@SuppressForbidden(reason = "Lucene classes") +public class ES818BinaryQuantizedVectorsWriter extends FlatVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(ES818BinaryQuantizedVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final List fields = new ArrayList<>(); + private final IndexOutput meta, binarizedVectorData; + private final FlatVectorsWriter rawVectorDelegate; + private final ES818BinaryFlatVectorsScorer vectorsScorer; + private boolean finished; + + /** + * Sole constructor + * + * @param vectorsScorer the scorer to use for scoring vectors + */ + protected ES818BinaryQuantizedVectorsWriter( + ES818BinaryFlatVectorsScorer vectorsScorer, + FlatVectorsWriter rawVectorDelegate, + SegmentWriteState state + ) throws IOException { + super(vectorsScorer); + this.vectorsScorer = vectorsScorer; + this.segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES818BinaryQuantizedVectorsFormat.META_EXTENSION + ); + + String binarizedVectorDataFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_EXTENSION + ); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + binarizedVectorData = state.directory.createOutput(binarizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + ES818BinaryQuantizedVectorsFormat.META_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + CodecUtil.writeIndexHeader( + binarizedVectorData, + ES818BinaryQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + ES818BinaryQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + FieldWriter fieldWriter = new FieldWriter(fieldInfo, (FlatFieldVectorsWriter) rawVectorDelegate); + fields.add(fieldWriter); + return fieldWriter; + } + return rawVectorDelegate; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + // after raw vectors are written, normalize vectors for clustering and quantization + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + field.normalizeVectors(); + } + final float[] clusterCenter; + int vectorCount = field.flatFieldVectorsWriter.getVectors().size(); + clusterCenter = new float[field.dimensionSums.length]; + if (vectorCount > 0) { + for (int i = 0; i < field.dimensionSums.length; i++) { + clusterCenter[i] = field.dimensionSums[i] / vectorCount; + } + if (VectorSimilarityFunction.COSINE == field.fieldInfo.getVectorSimilarityFunction()) { + VectorUtil.l2normalize(clusterCenter); + } + } + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(field.fieldInfo.getVectorSimilarityFunction()); + if (sortMap == null) { + writeField(field, clusterCenter, maxDoc, quantizer); + } else { + writeSortingField(field, clusterCenter, maxDoc, sortMap, quantizer); + } + field.finish(); + } + } + + private void writeField(FieldWriter fieldData, float[] clusterCenter, int maxDoc, OptimizedScalarQuantizer quantizer) + throws IOException { + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeBinarizedVectors(fieldData, clusterCenter, quantizer); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = fieldData.getVectors().size() > 0 ? VectorUtil.dotProduct(clusterCenter, clusterCenter) : 0; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + clusterCenter, + centroidDp, + fieldData.getDocsWithFieldSet() + ); + } + + private void writeBinarizedVectors(FieldWriter fieldData, float[] clusterCenter, OptimizedScalarQuantizer scalarQuantizer) + throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int i = 0; i < fieldData.getVectors().size(); i++) { + float[] v = fieldData.getVectors().get(i); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeSortingField( + FieldWriter fieldData, + float[] clusterCenter, + int maxDoc, + Sorter.DocMap sortMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + final int[] ordMap = new int[fieldData.getDocsWithFieldSet().cardinality()]; // new ord to old ord + + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + mapOldOrdToNewOrd(fieldData.getDocsWithFieldSet(), sortMap, null, ordMap, newDocsWithField); + + // write vector values + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + writeSortedBinarizedVectors(fieldData, clusterCenter, ordMap, scalarQuantizer); + long quantizedVectorLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + + float centroidDp = VectorUtil.dotProduct(clusterCenter, clusterCenter); + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, quantizedVectorLength, clusterCenter, centroidDp, newDocsWithField); + } + + private void writeSortedBinarizedVectors( + FieldWriter fieldData, + float[] clusterCenter, + int[] ordMap, + OptimizedScalarQuantizer scalarQuantizer + ) throws IOException { + int discreteDims = BQVectorUtils.discretize(fieldData.fieldInfo.getVectorDimension(), 64); + byte[] quantizationScratch = new byte[discreteDims]; + byte[] vector = new byte[discreteDims / 8]; + for (int ordinal : ordMap) { + float[] v = fieldData.getVectors().get(ordinal); + OptimizedScalarQuantizer.QuantizationResult corrections = scalarQuantizer.scalarQuantize( + v, + quantizationScratch, + (byte) 1, + clusterCenter + ); + BQVectorUtils.packAsBinary(quantizationScratch, vector); + binarizedVectorData.writeBytes(vector, vector.length); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) corrections.quantizedComponentSum()); + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + float[] clusterCenter, + float centroidDp, + DocsWithFieldSet docsWithField + ) throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVInt(field.getVectorDimension()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + int count = docsWithField.cardinality(); + meta.writeVInt(count); + if (count > 0) { + final ByteBuffer buffer = ByteBuffer.allocate(field.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(clusterCenter); + meta.writeBytes(buffer.array(), buffer.array().length); + meta.writeInt(Float.floatToIntBits(centroidDp)); + } + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, + meta, + binarizedVectorData, + count, + maxDoc, + docsWithField + ); + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (binarizedVectorData != null) { + CodecUtil.writeFooter(binarizedVectorData); + } + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + BinarizedFloatVectorValues binarizedVectorValues = new BinarizedFloatVectorValues( + floatVectorValues, + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()), + centroid + ); + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = writeBinarizedVectorData(binarizedVectorData, binarizedVectorValues); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + float centroidDp = docsWithField.cardinality() > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + centroidDp, + docsWithField + ); + } else { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + } + } + + static DocsWithFieldSet writeBinarizedVectorAndQueryData( + IndexOutput binarizedVectorData, + IndexOutput binarizedQueryData, + FloatVectorValues floatVectorValues, + float[] centroid, + OptimizedScalarQuantizer binaryQuantizer + ) throws IOException { + int discretizedDimension = BQVectorUtils.discretize(floatVectorValues.dimension(), 64); + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + byte[][] quantizationScratch = new byte[2][floatVectorValues.dimension()]; + byte[] toIndex = new byte[discretizedDimension / 8]; + byte[] toQuery = new byte[(discretizedDimension / 8) * BQSpaceUtils.B_QUERY]; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write index vector + OptimizedScalarQuantizer.QuantizationResult[] r = binaryQuantizer.multiScalarQuantize( + floatVectorValues.vectorValue(iterator.index()), + quantizationScratch, + new byte[] { 1, 4 }, + centroid + ); + // pack and store document bit vector + BQVectorUtils.packAsBinary(quantizationScratch[0], toIndex); + binarizedVectorData.writeBytes(toIndex, toIndex.length); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].lowerInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].upperInterval())); + binarizedVectorData.writeInt(Float.floatToIntBits(r[0].additionalCorrection())); + assert r[0].quantizedComponentSum() >= 0 && r[0].quantizedComponentSum() <= 0xffff; + binarizedVectorData.writeShort((short) r[0].quantizedComponentSum()); + docsWithField.add(docV); + + // pack and store the 4bit query vector + BQSpaceUtils.transposeHalfByte(quantizationScratch[1], toQuery); + binarizedQueryData.writeBytes(toQuery, toQuery.length); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].lowerInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].upperInterval())); + binarizedQueryData.writeInt(Float.floatToIntBits(r[1].additionalCorrection())); + assert r[1].quantizedComponentSum() >= 0 && r[1].quantizedComponentSum() <= 0xffff; + binarizedQueryData.writeShort((short) r[1].quantizedComponentSum()); + } + return docsWithField; + } + + static DocsWithFieldSet writeBinarizedVectorData(IndexOutput output, BinarizedByteVectorValues binarizedByteVectorValues) + throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIndexIterator iterator = binarizedByteVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(iterator.index()); + output.writeBytes(binaryValue, binaryValue.length); + OptimizedScalarQuantizer.QuantizationResult corrections = binarizedByteVectorValues.getCorrectiveTerms(iterator.index()); + output.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + output.writeInt(Float.floatToIntBits(corrections.upperInterval())); + output.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + output.writeShort((short) corrections.quantizedComponentSum()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final float[] centroid; + final float cDotC; + final float[] mergedCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = mergeAndRecalculateCentroids(mergeState, fieldInfo, mergedCentroid); + + // Don't need access to the random vectors, we can just use the merged + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + centroid = mergedCentroid; + cDotC = vectorCount > 0 ? VectorUtil.dotProduct(centroid, centroid) : 0; + if (segmentWriteState.infoStream.isEnabled(BINARIZED_VECTOR_COMPONENT)) { + segmentWriteState.infoStream.message(BINARIZED_VECTOR_COMPONENT, "Vectors' count:" + vectorCount); + } + return mergeOneFieldToIndex(segmentWriteState, fieldInfo, mergeState, centroid, cDotC); + } + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + private CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + SegmentWriteState segmentWriteState, + FieldInfo fieldInfo, + MergeState mergeState, + float[] centroid, + float cDotC + ) throws IOException { + long vectorDataOffset = binarizedVectorData.alignFilePointer(Float.BYTES); + final IndexOutput tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "temp", + segmentWriteState.context + ); + final IndexOutput tempScoreQuantizedVectorData = segmentWriteState.directory.createTempOutput( + binarizedVectorData.getName(), + "score_temp", + segmentWriteState.context + ); + IndexInput binarizedDataInput = null; + IndexInput binarizedScoreDataInput = null; + boolean success = false; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + try { + FloatVectorValues floatVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + } + DocsWithFieldSet docsWithField = writeBinarizedVectorAndQueryData( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + floatVectorValues, + centroid, + quantizer + ); + CodecUtil.writeFooter(tempQuantizedVectorData); + IOUtils.close(tempQuantizedVectorData); + binarizedDataInput = segmentWriteState.directory.openInput(tempQuantizedVectorData.getName(), segmentWriteState.context); + binarizedVectorData.copyBytes(binarizedDataInput, binarizedDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = binarizedVectorData.getFilePointer() - vectorDataOffset; + CodecUtil.retrieveChecksum(binarizedDataInput); + CodecUtil.writeFooter(tempScoreQuantizedVectorData); + IOUtils.close(tempScoreQuantizedVectorData); + binarizedScoreDataInput = segmentWriteState.directory.openInput( + tempScoreQuantizedVectorData.getName(), + segmentWriteState.context + ); + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + centroid, + cDotC, + docsWithField + ); + success = true; + final IndexInput finalBinarizedDataInput = binarizedDataInput; + final IndexInput finalBinarizedScoreDataInput = binarizedScoreDataInput; + OffHeapBinarizedVectorValues vectorValues = new OffHeapBinarizedVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + centroid, + cDotC, + quantizer, + fieldInfo.getVectorSimilarityFunction(), + vectorsScorer, + finalBinarizedDataInput + ); + RandomVectorScorerSupplier scorerSupplier = vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + new OffHeapBinarizedQueryVectorValues( + finalBinarizedScoreDataInput, + fieldInfo.getVectorDimension(), + docsWithField.cardinality() + ), + vectorValues + ); + return new BinarizedCloseableRandomVectorScorerSupplier(scorerSupplier, vectorValues, () -> { + IOUtils.close(finalBinarizedDataInput, finalBinarizedScoreDataInput); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName() + ); + }); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException( + tempQuantizedVectorData, + tempScoreQuantizedVectorData, + binarizedDataInput, + binarizedScoreDataInput + ); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, + tempQuantizedVectorData.getName(), + tempScoreQuantizedVectorData.getName() + ); + } + } + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, binarizedVectorData, rawVectorDelegate); + } + + static float[] getCentroid(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof ES818BinaryQuantizedVectorsReader reader) { + return reader.getCentroid(fieldName); + } + return null; + } + + static int mergeAndRecalculateCentroids(MergeState mergeState, FieldInfo fieldInfo, float[] mergedCentroid) throws IOException { + boolean recalculate = false; + int totalVectorCount = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null || knnVectorsReader.getFloatVectorValues(fieldInfo.name) == null) { + continue; + } + float[] centroid = getCentroid(knnVectorsReader, fieldInfo.name); + int vectorCount = knnVectorsReader.getFloatVectorValues(fieldInfo.name).size(); + if (vectorCount == 0) { + continue; + } + totalVectorCount += vectorCount; + // If there aren't centroids, or previously clustered with more than one cluster + // or if there are deleted docs, we must recalculate the centroid + if (centroid == null || mergeState.liveDocs[i] != null) { + recalculate = true; + break; + } + for (int j = 0; j < centroid.length; j++) { + mergedCentroid[j] += centroid[j] * vectorCount; + } + } + if (recalculate) { + return calculateCentroid(mergeState, fieldInfo, mergedCentroid); + } else { + for (int j = 0; j < mergedCentroid.length; j++) { + mergedCentroid[j] = mergedCentroid[j] / totalVectorCount; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(mergedCentroid); + } + return totalVectorCount; + } + } + + static int calculateCentroid(MergeState mergeState, FieldInfo fieldInfo, float[] centroid) throws IOException { + assert fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32); + // clear out the centroid + Arrays.fill(centroid, 0); + int count = 0; + for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { + KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; + if (knnVectorsReader == null) continue; + FloatVectorValues vectorValues = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (vectorValues == null) { + continue; + } + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + ++count; + float[] vector = vectorValues.vectorValue(iterator.index()); + // TODO Panama sum + for (int j = 0; j < vector.length; j++) { + centroid[j] += vector[j]; + } + } + } + if (count == 0) { + return count; + } + // TODO Panama div + for (int i = 0; i < centroid.length; i++) { + centroid[i] /= count; + } + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(centroid); + } + return count; + } + + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + // the field tracks the delegate field usage + total += field.ramBytesUsed(); + } + return total; + } + + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private boolean finished; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; + private final float[] dimensionSums; + private final FloatArrayList magnitudes = new FloatArrayList(); + + FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter flatFieldVectorsWriter) { + this.fieldInfo = fieldInfo; + this.flatFieldVectorsWriter = flatFieldVectorsWriter; + this.dimensionSums = new float[fieldInfo.getVectorDimension()]; + } + + @Override + public List getVectors() { + return flatFieldVectorsWriter.getVectors(); + } + + public void normalizeVectors() { + for (int i = 0; i < flatFieldVectorsWriter.getVectors().size(); i++) { + float[] vector = flatFieldVectorsWriter.getVectors().get(i); + float magnitude = magnitudes.get(i); + for (int j = 0; j < vector.length; j++) { + vector[j] /= magnitude; + } + } + } + + @Override + public DocsWithFieldSet getDocsWithFieldSet() { + return flatFieldVectorsWriter.getDocsWithFieldSet(); + } + + @Override + public void finish() throws IOException { + if (finished) { + return; + } + assert flatFieldVectorsWriter.isFinished(); + finished = true; + } + + @Override + public boolean isFinished() { + return finished && flatFieldVectorsWriter.isFinished(); + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + flatFieldVectorsWriter.addValue(docID, vectorValue); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + float dp = VectorUtil.dotProduct(vectorValue, vectorValue); + float divisor = (float) Math.sqrt(dp); + magnitudes.add(divisor); + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += (vectorValue[i] / divisor); + } + } else { + for (int i = 0; i < vectorValue.length; i++) { + dimensionSums[i] += vectorValue[i]; + } + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += flatFieldVectorsWriter.ramBytesUsed(); + size += magnitudes.ramBytesUsed(); + return size; + } + } + + // When accessing vectorValue method, targerOrd here means a row ordinal. + static class OffHeapBinarizedQueryVectorValues { + private final IndexInput slice; + private final int dimension; + private final int size; + protected final byte[] binaryValue; + protected final ByteBuffer byteBuffer; + private final int byteSize; + protected final float[] correctiveValues; + private int lastOrd = -1; + private int quantizedComponentSum; + + OffHeapBinarizedQueryVectorValues(IndexInput data, int dimension, int size) { + this.slice = data; + this.dimension = dimension; + this.size = size; + // 4x the quantized binary dimensions + int binaryDimensions = (BQVectorUtils.discretize(dimension, 64) / 8) * BQSpaceUtils.B_QUERY; + this.byteBuffer = ByteBuffer.allocate(binaryDimensions); + this.binaryValue = byteBuffer.array(); + // + 1 for the quantized sum + this.correctiveValues = new float[3]; + this.byteSize = binaryDimensions + Float.BYTES * 3 + Short.BYTES; + } + + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + vectorValue(targetOrd); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + public int size() { + return size; + } + + public int dimension() { + return dimension; + } + + public OffHeapBinarizedQueryVectorValues copy() throws IOException { + return new OffHeapBinarizedQueryVectorValues(slice.clone(), dimension, size); + } + + public IndexInput getSlice() { + return slice; + } + + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(binaryValue, 0, binaryValue.length); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + } + + static class BinarizedFloatVectorValues extends BinarizedByteVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] binarized; + private final byte[] initQuantized; + private final float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + + private int lastOrd = -1; + + BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer, float[] centroid) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[BQVectorUtils.discretize(delegate.dimension(), 64) / 8]; + this.initQuantized = new byte[delegate.dimension()]; + this.centroid = centroid; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd + ); + } + return corrections; + } + + @Override + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + throw new UnsupportedOperationException(); + } + + @Override + public float[] getCentroid() throws IOException { + return centroid; + } + + @Override + public int size() { + return values.size(); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public BinarizedByteVectorValues copy() throws IOException { + return new BinarizedFloatVectorValues(values.copy(), quantizer, centroid); + } + + private void binarize(int ord) throws IOException { + corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, (byte) 1, centroid); + BQVectorUtils.packAsBinary(initQuantized, binarized); + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + } + + static class BinarizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { + private final RandomVectorScorerSupplier supplier; + private final KnnVectorValues vectorValues; + private final Closeable onClose; + + BinarizedCloseableRandomVectorScorerSupplier(RandomVectorScorerSupplier supplier, KnnVectorValues vectorValues, Closeable onClose) { + this.supplier = supplier; + this.onClose = onClose; + this.vectorValues = vectorValues; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return supplier.scorer(ord); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return vectorValues.size(); + } + } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + + NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + return normalizedVector; + } + + @Override + public DocIndexIterator iterator() { + return values.iterator(); + } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java new file mode 100644 index 0000000000000..56942017c3cef --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormat.java @@ -0,0 +1,145 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; +import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES818HnswBinaryQuantizedVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "ES818HnswBinaryQuantizedVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = new ES818BinaryQuantizedVectorsFormat(); + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public ES818HnswBinaryQuantizedVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public ES818HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn + ); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth + ); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return MAX_DIMS_COUNT; + } + + @Override + public String toString() { + return "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java new file mode 100644 index 0000000000000..72333169b39b5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OffHeapBinarizedVectorValues.java @@ -0,0 +1,371 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.packed.DirectMonotonicReader; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.nio.ByteBuffer; + +/** Binarized vector values loaded from off-heap */ +abstract class OffHeapBinarizedVectorValues extends BinarizedByteVectorValues { + + final int dimension; + final int size; + final int numBytes; + final VectorSimilarityFunction similarityFunction; + final FlatVectorsScorer vectorsScorer; + + final IndexInput slice; + final byte[] binaryValue; + final ByteBuffer byteBuffer; + final int byteSize; + private int lastOrd = -1; + final float[] correctiveValues; + int quantizedComponentSum; + final OptimizedScalarQuantizer binaryQuantizer; + final float[] centroid; + final float centroidDp; + private final int discretizedDimensions; + + OffHeapBinarizedVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer quantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) { + this.dimension = dimension; + this.size = size; + this.similarityFunction = similarityFunction; + this.vectorsScorer = vectorsScorer; + this.slice = slice; + this.centroid = centroid; + this.centroidDp = centroidDp; + this.numBytes = BQVectorUtils.discretize(dimension, 64) / 8; + this.correctiveValues = new float[3]; + this.byteSize = numBytes + (Float.BYTES * 3) + Short.BYTES; + this.byteBuffer = ByteBuffer.allocate(numBytes); + this.binaryValue = byteBuffer.array(); + this.binaryQuantizer = quantizer; + this.discretizedDimensions = BQVectorUtils.discretize(dimension, 64); + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return size; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return binaryValue; + } + slice.seek((long) targetOrd * byteSize); + slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + lastOrd = targetOrd; + return binaryValue; + } + + @Override + public int discretizedDimensions() { + return discretizedDimensions; + } + + @Override + public float getCentroidDP() { + return centroidDp; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int targetOrd) throws IOException { + if (lastOrd == targetOrd) { + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + slice.seek(((long) targetOrd * byteSize) + numBytes); + slice.readFloats(correctiveValues, 0, 3); + quantizedComponentSum = Short.toUnsignedInt(slice.readShort()); + return new OptimizedScalarQuantizer.QuantizationResult( + correctiveValues[0], + correctiveValues[1], + correctiveValues[2], + quantizedComponentSum + ); + } + + @Override + public OptimizedScalarQuantizer getQuantizer() { + return binaryQuantizer; + } + + @Override + public float[] getCentroid() { + return centroid; + } + + @Override + public int getVectorByteLength() { + return numBytes; + } + + static OffHeapBinarizedVectorValues load( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + float[] centroid, + float centroidDp, + long quantizedVectorDataOffset, + long quantizedVectorDataLength, + IndexInput vectorData + ) throws IOException { + if (configuration.isEmpty()) { + return new EmptyOffHeapVectorValues(dimension, similarityFunction, vectorsScorer); + } + assert centroid != null; + IndexInput bytesSlice = vectorData.slice("quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); + if (configuration.isDense()) { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + bytesSlice + ); + } else { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + vectorData, + similarityFunction, + vectorsScorer, + bytesSlice + ); + } + } + + /** Dense off-heap binarized vector values */ + static class DenseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + DenseOffHeapVectorValues( + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) { + super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice); + } + + @Override + public DenseOffHeapVectorValues copy() throws IOException { + return new DenseOffHeapVectorValues( + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + } + + /** Sparse off-heap binarized vector values */ + private static class SparseOffHeapVectorValues extends OffHeapBinarizedVectorValues { + private final DirectMonotonicReader ordToDoc; + private final IndexedDISI disi; + // dataIn was used to init a new IndexedDIS for #randomAccess() + private final IndexInput dataIn; + private final OrdToDocDISIReaderConfiguration configuration; + + SparseOffHeapVectorValues( + OrdToDocDISIReaderConfiguration configuration, + int dimension, + int size, + float[] centroid, + float centroidDp, + OptimizedScalarQuantizer binaryQuantizer, + IndexInput dataIn, + VectorSimilarityFunction similarityFunction, + FlatVectorsScorer vectorsScorer, + IndexInput slice + ) throws IOException { + super(dimension, size, centroid, centroidDp, binaryQuantizer, similarityFunction, vectorsScorer, slice); + this.configuration = configuration; + this.dataIn = dataIn; + this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); + this.disi = configuration.getIndexedDISI(dataIn); + } + + @Override + public SparseOffHeapVectorValues copy() throws IOException { + return new SparseOffHeapVectorValues( + configuration, + dimension, + size, + centroid, + centroidDp, + binaryQuantizer, + dataIn, + similarityFunction, + vectorsScorer, + slice.clone() + ); + } + + @Override + public int ordToDoc(int ord) { + return (int) ordToDoc.get(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size; + } + }; + } + + @Override + public DocIndexIterator iterator() { + return IndexedDISI.asDocIndexIterator(disi); + } + + @Override + public VectorScorer scorer(float[] target) throws IOException { + SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); + RandomVectorScorer scorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); + return new VectorScorer() { + @Override + public float score() throws IOException { + return scorer.score(iterator.index()); + } + + @Override + public DocIdSetIterator iterator() { + return iterator; + } + }; + } + } + + private static class EmptyOffHeapVectorValues extends OffHeapBinarizedVectorValues { + EmptyOffHeapVectorValues(int dimension, VectorSimilarityFunction similarityFunction, FlatVectorsScorer vectorsScorer) { + super(dimension, 0, null, Float.NaN, null, similarityFunction, vectorsScorer, null); + } + + @Override + public DocIndexIterator iterator() { + return createDenseIterator(); + } + + @Override + public DenseOffHeapVectorValues copy() { + throw new UnsupportedOperationException(); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return null; + } + + @Override + public VectorScorer scorer(float[] target) { + return null; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java new file mode 100644 index 0000000000000..d5ed38cb5a0e1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizer.java @@ -0,0 +1,246 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; + +class OptimizedScalarQuantizer { + // The initial interval is set to the minimum MSE grid for each number of bits + // these starting points are derived from the optimal MSE grid for a uniform distribution + static final float[][] MINIMUM_MSE_GRID = new float[][] { + { -0.798f, 0.798f }, + { -1.493f, 1.493f }, + { -2.051f, 2.051f }, + { -2.514f, 2.514f }, + { -2.916f, 2.916f }, + { -3.278f, 3.278f }, + { -3.611f, 3.611f }, + { -3.922f, 3.922f } }; + private static final float DEFAULT_LAMBDA = 0.1f; + private static final int DEFAULT_ITERS = 5; + private final VectorSimilarityFunction similarityFunction; + private final float lambda; + private final int iters; + + OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) { + this.similarityFunction = similarityFunction; + this.lambda = lambda; + this.iters = iters; + } + + OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) { + this(similarityFunction, DEFAULT_LAMBDA, DEFAULT_ITERS); + } + + public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {} + + public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert bits.length == destinations.length; + float[] intervalScratch = new float[2]; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + QuantizationResult[] results = new QuantizationResult[bits.length]; + for (int i = 0; i < bits.length; ++i) { + assert bits[i] > 0 && bits[i] <= 8; + int points = (1 << bits[i]); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds + intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits[i] - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits[i]) - 1); + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + // Now we have the optimized intervals, quantize the vector + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destinations[i][h] = (byte) assignment; + } + results[i] = new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery + ); + } + return results; + } + + public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) { + assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector); + assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid); + assert vector.length <= destination.length; + assert bits > 0 && bits <= 8; + float[] intervalScratch = new float[2]; + int points = 1 << bits; + double vecMean = 0; + double vecVar = 0; + float norm2 = 0; + float centroidDot = 0; + float min = Float.MAX_VALUE; + float max = -Float.MAX_VALUE; + for (int i = 0; i < vector.length; ++i) { + if (similarityFunction != EUCLIDEAN) { + centroidDot += vector[i] * centroid[i]; + } + vector[i] = vector[i] - centroid[i]; + min = Math.min(min, vector[i]); + max = Math.max(max, vector[i]); + norm2 += (vector[i] * vector[i]); + double delta = vector[i] - vecMean; + vecMean += delta / (i + 1); + vecVar += delta * (vector[i] - vecMean); + } + vecVar /= vector.length; + double vecStd = Math.sqrt(vecVar); + // Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds + intervalScratch[0] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][0] + vecMean) * vecStd, min, max); + intervalScratch[1] = (float) clamp((MINIMUM_MSE_GRID[bits - 1][1] + vecMean) * vecStd, min, max); + optimizeIntervals(intervalScratch, vector, norm2, points); + float nSteps = ((1 << bits) - 1); + // Now we have the optimized intervals, quantize the vector + float a = intervalScratch[0]; + float b = intervalScratch[1]; + float step = (b - a) / nSteps; + int sumQuery = 0; + for (int h = 0; h < vector.length; h++) { + float xi = (float) clamp(vector[h], a, b); + int assignment = Math.round((xi - a) / step); + sumQuery += assignment; + destination[h] = (byte) assignment; + } + return new QuantizationResult( + intervalScratch[0], + intervalScratch[1], + similarityFunction == EUCLIDEAN ? norm2 : centroidDot, + sumQuery + ); + } + + /** + * Compute the loss of the vector given the interval. Effectively, we are computing the MSE of a dequantized vector with the raw + * vector. + * @param vector raw vector + * @param interval interval to quantize the vector + * @param points number of quantization points + * @param norm2 squared norm of the vector + * @return the loss + */ + private double loss(float[] vector, float[] interval, int points, float norm2) { + double a = interval[0]; + double b = interval[1]; + double step = ((b - a) / (points - 1.0F)); + double stepInv = 1.0 / step; + double xe = 0.0; + double e = 0.0; + for (double xi : vector) { + // this is quantizing and then dequantizing the vector + double xiq = (a + step * Math.round((clamp(xi, a, b) - a) * stepInv)); + // how much does the de-quantized value differ from the original value + xe += xi * (xi - xiq); + e += (xi - xiq) * (xi - xiq); + } + return (1.0 - lambda) * xe * xe / norm2 + lambda * e; + } + + /** + * Optimize the quantization interval for the given vector. This is done via a coordinate descent trying to minimize the quantization + * loss. Note, the loss is not always guaranteed to decrease, so we have a maximum number of iterations and will exit early if the + * loss increases. + * @param initInterval initial interval, the optimized interval will be stored here + * @param vector raw vector + * @param norm2 squared norm of the vector + * @param points number of quantization points + */ + private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) { + double initialLoss = loss(vector, initInterval, points, norm2); + final float scale = (1.0f - lambda) / norm2; + if (Float.isFinite(scale) == false) { + return; + } + for (int i = 0; i < iters; ++i) { + float a = initInterval[0]; + float b = initInterval[1]; + float stepInv = (points - 1.0f) / (b - a); + // calculate the grid points for coordinate descent + double daa = 0; + double dab = 0; + double dbb = 0; + double dax = 0; + double dbx = 0; + for (float xi : vector) { + float k = Math.round((clamp(xi, a, b) - a) * stepInv); + float s = k / (points - 1); + daa += (1.0 - s) * (1.0 - s); + dab += (1.0 - s) * s; + dbb += s * s; + dax += xi * (1.0 - s); + dbx += xi * s; + } + double m0 = scale * dax * dax + lambda * daa; + double m1 = scale * dax * dbx + lambda * dab; + double m2 = scale * dbx * dbx + lambda * dbb; + // its possible that the determinant is 0, in which case we can't update the interval + double det = m0 * m2 - m1 * m1; + if (det == 0) { + return; + } + float aOpt = (float) ((m2 * dax - m1 * dbx) / det); + float bOpt = (float) ((m0 * dbx - m1 * dax) / det); + // If there is no change in the interval, we can stop + if ((Math.abs(initInterval[0] - aOpt) < 1e-8 && Math.abs(initInterval[1] - bOpt) < 1e-8)) { + return; + } + double newLoss = loss(vector, new float[] { aOpt, bOpt }, points, norm2); + // If the new loss is worse, don't update the interval and exit + // This optimization, unlike kMeans, does not always converge to better loss + // So exit if we are getting worse + if (newLoss > initialLoss) { + return; + } + // Update the interval and go again + initInterval[0] = aOpt; + initInterval[1] = bOpt; + initialLoss = newLoss; + } + } + + private static double clamp(double x, double a, double b) { + return Math.min(Math.max(x, a), b); + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 0a6a24f727572..d780faad96f2d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -46,8 +46,8 @@ import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; import org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat; import org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat; -import org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat; -import org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; import org.elasticsearch.index.fielddata.FieldDataContext; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.mapper.ArraySourceValueFetcher; @@ -1788,7 +1788,7 @@ static class BBQHnswIndexOptions extends IndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES816HnswBinaryQuantizedVectorsFormat(m, efConstruction); + return new ES818HnswBinaryQuantizedVectorsFormat(m, efConstruction); } @Override @@ -1836,7 +1836,7 @@ static class BBQFlatIndexOptions extends IndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new ES816BinaryQuantizedVectorsFormat(); + return new ES818BinaryQuantizedVectorsFormat(); } @Override diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 794b30aa5aab2..57980321bdc3d 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -44,6 +44,7 @@ private SearchCapabilities() {} private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim_with_bugfix"; private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs"; + private static final String OPTIMIZED_SCALAR_QUANTIZATION_BBQ = "optimized_scalar_quantization_bbq"; public static final Set CAPABILITIES; static { @@ -55,6 +56,7 @@ private SearchCapabilities() {} capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER); capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT); capabilities.add(RANDOM_SAMPLER_WITH_SCORED_SUBAGGS); + capabilities.add(OPTIMIZED_SCALAR_QUANTIZATION_BBQ); if (MultiDenseVectorFieldMapper.FEATURE_FLAG.isEnabled()) { capabilities.add(MULTI_DENSE_VECTOR_FIELD_MAPPER); capabilities.add(MULTI_DENSE_VECTOR_SCRIPT_ACCESS); diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 389555e60b43b..cef8d09980814 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -5,3 +5,5 @@ org.elasticsearch.index.codec.vectors.ES815HnswBitVectorsFormat org.elasticsearch.index.codec.vectors.ES815BitFlatVectorFormat org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java index 9f9114c70b6db..270ad54e9a962 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/BQVectorUtilsTests.java @@ -38,6 +38,32 @@ public static int popcount(byte[] a, int aOffset, byte[] b, int length) { private static float DELTA = Float.MIN_VALUE; + public void testPackAsBinary() { + // 5 bits + byte[] toPack = new byte[] { 1, 1, 0, 0, 1 }; + byte[] packed = new byte[1]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001000 }, packed); + + // 8 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0 }; + packed = new byte[1]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010 }, packed); + + // 10 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1 }; + packed = new byte[2]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11000000 }, packed); + + // 16 bits + toPack = new byte[] { 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0 }; + packed = new byte[2]; + BQVectorUtils.packAsBinary(toPack, packed); + assertArrayEquals(new byte[] { (byte) 0b11001010, (byte) 0b11100110 }, packed); + } + public void testPadFloat() { assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 4), DELTA); assertArrayEquals(new float[] { 1, 2, 3, 4 }, BQVectorUtils.pad(new float[] { 1, 2, 3, 4 }, 3), DELTA); diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java new file mode 100644 index 0000000000000..0bebe16f468ce --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatRWVectorsScorer.java @@ -0,0 +1,256 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.elasticsearch.index.codec.vectors.BQSpaceUtils; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; + +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; + +/** Vector scorer over binarized vector values */ +class ES816BinaryFlatRWVectorsScorer implements FlatVectorsScorer { + private final FlatVectorsScorer nonQuantizedDelegate; + + ES816BinaryFlatRWVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) { + this.nonQuantizedDelegate = nonQuantizedDelegate; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues) { + throw new UnsupportedOperationException( + "getRandomVectorScorerSupplier(VectorSimilarityFunction,RandomAccessVectorValues) not implemented for binarized format" + ); + } + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + float[] target + ) throws IOException { + if (vectorValues instanceof BinarizedByteVectorValues binarizedVectors) { + BinaryQuantizer quantizer = binarizedVectors.getQuantizer(); + float[] centroid = binarizedVectors.getCentroid(); + // FIXME: precompute this once? + int discretizedDimensions = BQVectorUtils.discretize(target.length, 64); + if (similarityFunction == COSINE) { + float[] copy = ArrayUtil.copyOfSubArray(target, 0, target.length); + VectorUtil.l2normalize(copy); + target = copy; + } + byte[] quantized = new byte[BQSpaceUtils.B_QUERY * discretizedDimensions / 8]; + BinaryQuantizer.QueryFactors factors = quantizer.quantizeForQuery(target, quantized, centroid); + BinaryQueryVector queryVector = new BinaryQueryVector(quantized, factors); + return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction); + } + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + KnnVectorValues vectorValues, + byte[] target + ) throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, + ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues scoringVectors, + BinarizedByteVectorValues targetVectors + ) { + return new BinarizedRandomVectorScorerSupplier(scoringVectors, targetVectors, similarityFunction); + } + + @Override + public String toString() { + return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + nonQuantizedDelegate + ")"; + } + + /** Vector scorer supplier over binarized vector values */ + static class BinarizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + private final ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + BinarizedRandomVectorScorerSupplier( + ES816BinaryQuantizedVectorsWriter.OffHeapBinarizedQueryVectorValues queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + this.queryVectors = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] vector = queryVectors.vectorValue(ord); + int quantizedSum = queryVectors.sumQuantizedValues(ord); + float distanceToCentroid = queryVectors.getCentroidDistance(ord); + float lower = queryVectors.getLower(ord); + float width = queryVectors.getWidth(ord); + float normVmC = 0f; + float vDotC = 0f; + if (similarityFunction != EUCLIDEAN) { + normVmC = queryVectors.getNormVmC(ord); + vDotC = queryVectors.getVDotC(ord); + } + BinaryQueryVector binaryQueryVector = new BinaryQueryVector( + vector, + new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, lower, width, normVmC, vDotC) + ); + return new BinarizedRandomVectorScorer(binaryQueryVector, targetVectors, similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BinarizedRandomVectorScorerSupplier(queryVectors.copy(), targetVectors.copy(), similarityFunction); + } + } + + /** A binarized query representing its quantized form along with factors */ + record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {} + + /** Vector scorer over binarized vector values */ + static class BinarizedRandomVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final BinaryQueryVector queryVector; + private final BinarizedByteVectorValues targetVectors; + private final VectorSimilarityFunction similarityFunction; + + private final float sqrtDimensions; + private final float maxX1; + + BinarizedRandomVectorScorer( + BinaryQueryVector queryVectors, + BinarizedByteVectorValues targetVectors, + VectorSimilarityFunction similarityFunction + ) { + super(targetVectors); + this.queryVector = queryVectors; + this.targetVectors = targetVectors; + this.similarityFunction = similarityFunction; + // FIXME: precompute this once? + this.sqrtDimensions = targetVectors.sqrtDimensions(); + this.maxX1 = targetVectors.maxX1(); + } + + @Override + public float score(int targetOrd) throws IOException { + byte[] quantizedQuery = queryVector.vector(); + int quantizedSum = queryVector.factors().quantizedSum(); + float lower = queryVector.factors().lower(); + float width = queryVector.factors().width(); + float distanceToCentroid = queryVector.factors().distToC(); + if (similarityFunction == EUCLIDEAN) { + return euclideanScore(targetOrd, sqrtDimensions, quantizedQuery, distanceToCentroid, lower, quantizedSum, width); + } + + float vmC = queryVector.factors().normVmC(); + float vDotC = queryVector.factors().vDotC(); + float cDotC = targetVectors.getCentroidDP(); + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + float ooq = targetVectors.getOOQ(targetOrd); + float normOC = targetVectors.getNormOC(targetOrd); + float oDotC = targetVectors.getODotC(targetOrd); + + float qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + + // FIXME: pre-compute these only once for each target vector + // ... pull this out or use a similar cache mechanism as do in score + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + final float dist; + // If ||o-c|| == 0, so, it's ok to throw the rest of the equation away + // and simply use `oDotC + vDotC - cDotC` as centroid == doc vector + if (normOC == 0 || ooq == 0) { + dist = oDotC + vDotC - cDotC; + } else { + // If ||o-c|| != 0, we should assume that `ooq` is finite + assert Float.isFinite(ooq); + float estimatedDot = (2 * width / sqrtDimensions * qcDist + 2 * lower / sqrtDimensions * xbSum - width / sqrtDimensions + * quantizedSum - sqrtDimensions * lower) / ooq; + dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC; + } + assert Float.isFinite(dist); + + float ooqSqr = (float) Math.pow(ooq, 2); + float errorBound = (float) (vmC * normOC * (maxX1 * Math.sqrt((1 - ooqSqr) / ooqSqr))); + float score = Float.isFinite(errorBound) ? dist - errorBound : dist; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + + private float euclideanScore( + int targetOrd, + float sqrtDimensions, + byte[] quantizedQuery, + float distanceToCentroid, + float lower, + int quantizedSum, + float width + ) throws IOException { + byte[] binaryCode = targetVectors.vectorValue(targetOrd); + + // FIXME: pre-compute these only once for each target vector + // .. not sure how to enumerate the target ordinals but that's what we did in PoC + float targetDistToC = targetVectors.getCentroidDistance(targetOrd); + float x0 = targetVectors.getVectorMagnitude(targetOrd); + float sqrX = targetDistToC * targetDistToC; + double xX0 = targetDistToC / x0; + + // TODO maybe store? + float xbSum = (float) BQVectorUtils.popcount(binaryCode); + float factorPPC = (float) (-2.0 / sqrtDimensions * xX0 * (xbSum * 2.0 - targetVectors.dimension())); + float factorIP = (float) (-2.0 / sqrtDimensions * xX0); + + long qcDist = ESVectorUtil.ipByteBinByte(quantizedQuery, binaryCode); + float score = sqrX + distanceToCentroid + factorPPC * lower + (qcDist * 2 - quantizedSum) * factorIP * width; + float projectionDist = (float) Math.sqrt(xX0 * xX0 - targetDistToC * targetDistToC); + float error = 2.0f * maxX1 * projectionDist; + float y = (float) Math.sqrt(distanceToCentroid); + float errorBound = y * error; + if (Float.isFinite(errorBound)) { + score = score + errorBound; + } + return Math.max(1 / (1f + score), 0); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java index a75b9bc6064d1..ffe007be9799d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryFlatVectorsScorerTests.java @@ -59,7 +59,7 @@ public void testScore() throws IOException { short quantizedSum = (short) random().nextInt(0, 4097); float normVmC = random().nextFloat(-1000f, 1000f); float vDotC = random().nextFloat(-1000f, 1000f); - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); @@ -134,7 +134,7 @@ public int dimension() { } }; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction @@ -217,7 +217,7 @@ public void testScoreEuclidean() throws IOException { float vl = -57.883f; float width = 9.972266f; short quantizedSum = 795; - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, 0f, 0f) ); @@ -420,7 +420,7 @@ public int dimension() { VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction @@ -824,7 +824,7 @@ public void testScoreMIP() throws IOException { float normVmC = 9.766797f; float vDotC = 133.56123f; float cDotC = 132.20227f; - ES816BinaryFlatVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatVectorsScorer.BinaryQueryVector( + ES816BinaryFlatRWVectorsScorer.BinaryQueryVector queryVector = new ES816BinaryFlatRWVectorsScorer.BinaryQueryVector( vector, new BinaryQuantizer.QueryFactors(quantizedSum, distanceToCentroid, vl, width, normVmC, vDotC) ); @@ -1768,7 +1768,7 @@ public int dimension() { VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; - ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatVectorsScorer.BinarizedRandomVectorScorer( + ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer scorer = new ES816BinaryFlatRWVectorsScorer.BinarizedRandomVectorScorer( queryVector, targetVectors, similarityFunction diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java new file mode 100644 index 0000000000000..c54903a94b54f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedRWVectorsFormat.java @@ -0,0 +1,52 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; + +/** + * Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10 + */ +public class ES816BinaryQuantizedRWVectorsFormat extends ES816BinaryQuantizedVectorsFormat { + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final ES816BinaryFlatRWVectorsScorer scorer = new ES816BinaryFlatRWVectorsScorer( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + /** Creates a new instance with the default number of vectors per cluster. */ + public ES816BinaryQuantizedRWVectorsFormat() { + super(); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ES816BinaryQuantizedVectorsWriter(scorer, rawVectorFormat.fieldsWriter(state), state); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java index 681f615653d40..48ba566353f5d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsFormatTests.java @@ -63,7 +63,7 @@ protected Codec getCodec() { return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new ES816BinaryQuantizedVectorsFormat(); + return new ES816BinaryQuantizedRWVectorsFormat(); } }; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java similarity index 99% rename from server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java rename to server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java index 31ae977e81118..4d97235c5fae5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816BinaryQuantizedVectorsWriter.java @@ -77,7 +77,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { private final List fields = new ArrayList<>(); private final IndexOutput meta, binarizedVectorData; private final FlatVectorsWriter rawVectorDelegate; - private final ES816BinaryFlatVectorsScorer vectorsScorer; + private final ES816BinaryFlatRWVectorsScorer vectorsScorer; private boolean finished; /** @@ -86,7 +86,7 @@ class ES816BinaryQuantizedVectorsWriter extends FlatVectorsWriter { * @param vectorsScorer the scorer to use for scoring vectors */ protected ES816BinaryQuantizedVectorsWriter( - ES816BinaryFlatVectorsScorer vectorsScorer, + ES816BinaryFlatRWVectorsScorer vectorsScorer, FlatVectorsWriter rawVectorDelegate, SegmentWriteState state ) throws IOException { diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java new file mode 100644 index 0000000000000..e9bace72b591c --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedRWVectorsFormat.java @@ -0,0 +1,55 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es816; + +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.SegmentWriteState; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; + +class ES816HnswBinaryQuantizedRWVectorsFormat extends ES816HnswBinaryQuantizedVectorsFormat { + + private static final FlatVectorsFormat flatVectorsFormat = new ES816BinaryQuantizedRWVectorsFormat(); + + /** Constructs a format using default graph construction parameters */ + ES816HnswBinaryQuantizedRWVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH); + } + + ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + ES816HnswBinaryQuantizedRWVectorsFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(maxConn, beamWidth, numMergeWorkers, mergeExec); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java index a25fa2836ee34..03aa847f3a5d4 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es816/ES816HnswBinaryQuantizedVectorsFormatTests.java @@ -59,7 +59,7 @@ protected Codec getCodec() { return new Lucene100Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new ES816HnswBinaryQuantizedVectorsFormat(); + return new ES816HnswBinaryQuantizedRWVectorsFormat(); } }; } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..397cc472592b6 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818BinaryQuantizedVectorsFormatTests.java @@ -0,0 +1,181 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; + +import java.io.IOException; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES818BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected Codec getCodec() { + return new Lucene100Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES818BinaryQuantizedVectorsFormat(); + } + }; + } + + public void testSearch() throws Exception { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + IndexWriterConfig iwc = newIndexWriterConfig(); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, iwc)) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + } + w.commit(); + + try (IndexReader reader = DirectoryReader.open(w)) { + IndexSearcher searcher = new IndexSearcher(reader); + final int k = random().nextInt(5, 50); + float[] queryVector = randomVector(dims); + Query q = new KnnFloatVectorQuery(fieldName, queryVector, k); + TopDocs collectedDocs = searcher.search(q, k); + assertEquals(k, collectedDocs.totalHits.value()); + assertEquals(TotalHits.Relation.EQUAL_TO, collectedDocs.totalHits.relation()); + } + } + } + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES818BinaryQuantizedVectorsFormat(); + } + }; + String expectedPattern = "ES818BinaryQuantizedVectorsFormat(" + + "name=ES818BinaryQuantizedVectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s()))"; + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + @Override + public void testRandomWithUpdatesAndGraph() { + // graph not supported + } + + @Override + public void testSearchWithVisitedLimit() { + // visited limit is not respected, as it is brute force search + } + + public void testQuantizedVectorsWriteAndRead() throws IOException { + String fieldName = "field"; + int numVectors = random().nextInt(99, 500); + int dims = random().nextInt(4, 65); + + float[] vector = randomVector(dims); + VectorSimilarityFunction similarityFunction = randomSimilarity(); + KnnFloatVectorField knnField = new KnnFloatVectorField(fieldName, vector, similarityFunction); + try (Directory dir = newDirectory()) { + try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numVectors; i++) { + Document doc = new Document(); + knnField.setVectorValue(randomVector(dims)); + doc.add(knnField); + w.addDocument(doc); + if (i % 101 == 0) { + w.commit(); + } + } + w.commit(); + w.forceMerge(1); + + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); + assertEquals(vectorValues.size(), numVectors); + BinarizedByteVectorValues qvectorValues = ((ES818BinaryQuantizedVectorsReader.BinarizedVectorValues) vectorValues) + .getQuantizedVectorValues(); + float[] centroid = qvectorValues.getCentroid(); + assertEquals(centroid.length, dims); + + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction); + byte[] quantizedVector = new byte[dims]; + byte[] expectedVector = new byte[BQVectorUtils.discretize(dims, 64) / 8]; + if (similarityFunction == VectorSimilarityFunction.COSINE) { + vectorValues = new ES818BinaryQuantizedVectorsWriter.NormalizedFloatVectorValues(vectorValues); + } + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + OptimizedScalarQuantizer.QuantizationResult corrections = quantizer.scalarQuantize( + vectorValues.vectorValue(docIndexIterator.index()), + quantizedVector, + (byte) 1, + centroid + ); + BQVectorUtils.packAsBinary(quantizedVector, expectedVector); + assertArrayEquals(expectedVector, qvectorValues.vectorValue(docIndexIterator.index())); + assertEquals(corrections, qvectorValues.getCorrectiveTerms(docIndexIterator.index())); + } + } + } + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java new file mode 100644 index 0000000000000..b6ae3199bb896 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/ES818HnswBinaryQuantizedVectorsFormatTests.java @@ -0,0 +1,132 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * Modifications copyright (C) 2024 Elasticsearch B.V. + */ +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene100.Lucene100Codec; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.util.SameThreadExecutorService; +import org.elasticsearch.common.logging.LogConfigurator; + +import java.util.Arrays; +import java.util.Locale; + +import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ES818HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Override + protected Codec getCodec() { + return new Lucene100Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new ES818HnswBinaryQuantizedVectorsFormat(); + } + }; + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ES818HnswBinaryQuantizedVectorsFormat(10, 20, 1, null); + } + }; + String expectedPattern = + "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20," + + " flatVectorFormat=ES818BinaryQuantizedVectorsFormat(name=ES818BinaryQuantizedVectorsFormat," + + " flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=%s())))"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testSingleVectorCase() throws Exception { + float[] vector = randomVector(random().nextInt(12, 500)); + for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) { + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, similarityFunction)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + FloatVectorValues vectorValues = r.getFloatVectorValues("f"); + KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator(); + assert (vectorValues.size() == 1); + while (docIndexIterator.nextDoc() != NO_MORE_DOCS) { + assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f); + } + float[] randomVector = randomVector(vector.length); + float trueScore = similarityFunction.compare(vector, randomVector); + TopDocs td = r.searchNearestVectors("f", randomVector, 1, null, Integer.MAX_VALUE); + assertEquals(1, td.totalHits.value()); + assertTrue(td.scoreDocs[0].score >= 0); + // When it's the only vector in a segment, the score should be very close to the true score + assertEquals(trueScore, td.scoreDocs[0].score, 0.0001f); + } + } + } + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(-1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(0, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 0)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, -1)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(512 + 1, 20)); + expectThrows(IllegalArgumentException.class, () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 3201)); + expectThrows( + IllegalArgumentException.class, + () -> new ES818HnswBinaryQuantizedVectorsFormat(20, 100, 1, new SameThreadExecutorService()) + ); + } + + // Ensures that all expected vector similarity functions are translatable in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } +} diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java new file mode 100644 index 0000000000000..e3e2d6caafe0e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/es818/OptimizedScalarQuantizerTests.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.es818; + +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.index.codec.vectors.es818.OptimizedScalarQuantizer.MINIMUM_MSE_GRID; + +public class OptimizedScalarQuantizerTests extends ESTestCase { + + static final byte[] ALL_BITS = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + + public void testAbusiveEdgeCases() { + // large zero array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + continue; + } + float[] vector = new float[4096]; + float[] centroid = new float[4096]; + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][4096]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (byte[] destination : destinations) { + assertArrayEquals(new byte[4096], destination); + } + byte[] destination = new byte[4096]; + for (byte bit : ALL_BITS) { + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertArrayEquals(new byte[4096], destination); + } + } + + // single value array + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + float[] vector = new float[] { randomFloat() }; + float[] centroid = new float[] { randomFloat() }; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][1]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(vector, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + vector = new float[] { randomFloat() }; + centroid = new float[] { randomFloat() }; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(vector); + VectorUtil.l2normalize(centroid); + } + byte[] destination = new byte[1]; + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(vector, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + + } + + public void testMathematicalConsistency() { + int dims = randomIntBetween(1, 4096); + float[] vector = new float[dims]; + for (int i = 0; i < dims; ++i) { + vector[i] = randomFloat(); + } + float[] centroid = new float[dims]; + for (int i = 0; i < dims; ++i) { + centroid[i] = randomFloat(); + } + float[] copy = new float[dims]; + for (VectorSimilarityFunction vectorSimilarityFunction : VectorSimilarityFunction.values()) { + // copy the vector to avoid modifying it + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(vectorSimilarityFunction); + byte[][] destinations = new byte[MINIMUM_MSE_GRID.length][dims]; + OptimizedScalarQuantizer.QuantizationResult[] results = osq.multiScalarQuantize(copy, destinations, ALL_BITS, centroid); + assertEquals(MINIMUM_MSE_GRID.length, results.length); + assertValidResults(results); + for (int i = 0; i < ALL_BITS.length; i++) { + assertValidQuantizedRange(destinations[i], ALL_BITS[i]); + } + for (byte bit : ALL_BITS) { + byte[] destination = new byte[dims]; + System.arraycopy(vector, 0, copy, 0, dims); + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + VectorUtil.l2normalize(copy); + VectorUtil.l2normalize(centroid); + } + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(copy, destination, bit, centroid); + assertValidResults(result); + assertValidQuantizedRange(destination, bit); + } + } + } + + static void assertValidQuantizedRange(byte[] quantized, byte bits) { + for (byte b : quantized) { + if (bits < 8) { + assertTrue(b >= 0); + } + assertTrue(b < 1 << bits); + } + } + + static void assertValidResults(OptimizedScalarQuantizer.QuantizationResult... results) { + for (OptimizedScalarQuantizer.QuantizationResult result : results) { + assertTrue(Float.isFinite(result.lowerInterval())); + assertTrue(Float.isFinite(result.upperInterval())); + assertTrue(result.lowerInterval() <= result.upperInterval()); + assertTrue(Float.isFinite(result.additionalCorrection())); + assertTrue(result.quantizedComponentSum() >= 0); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java index de084cd4582e2..c043b9ffb381a 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java @@ -1970,13 +1970,13 @@ public void testKnnBBQHNSWVectorsFormat() throws IOException { assertThat(codec, instanceOf(LegacyPerFieldMapperCodec.class)); knnVectorsFormat = ((LegacyPerFieldMapperCodec) codec).getKnnVectorsFormatForField("field"); } - String expectedString = "ES816HnswBinaryQuantizedVectorsFormat(name=ES816HnswBinaryQuantizedVectorsFormat, maxConn=" + String expectedString = "ES818HnswBinaryQuantizedVectorsFormat(name=ES818HnswBinaryQuantizedVectorsFormat, maxConn=" + m + ", beamWidth=" + efConstruction - + ", flatVectorFormat=ES816BinaryQuantizedVectorsFormat(" - + "name=ES816BinaryQuantizedVectorsFormat, " - + "flatVectorScorer=ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; + + ", flatVectorFormat=ES818BinaryQuantizedVectorsFormat(" + + "name=ES818BinaryQuantizedVectorsFormat, " + + "flatVectorScorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate=DefaultFlatVectorScorer())))"; assertEquals(expectedString, knnVectorsFormat.toString()); } From 4b868b0e11f4a59d608523c57d1a62b870ac8e0e Mon Sep 17 00:00:00 2001 From: Niels Bauman <33722607+nielsbauman@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:57:40 +0100 Subject: [PATCH 6/6] Fix enrich cache size setting name (#117575) The enrich cache size setting accidentally got renamed from `enrich.cache_size` to `enrich.cache.size` in #111412. This commit updates the enrich plugin to accept both names and deprecates the wrong name. --- docs/changelog/117575.yaml | 5 ++ .../xpack/enrich/EnrichPlugin.java | 59 +++++++++++++++++- .../xpack/enrich/EnrichPluginTests.java | 62 +++++++++++++++++++ 3 files changed, 123 insertions(+), 3 deletions(-) create mode 100644 docs/changelog/117575.yaml create mode 100644 x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPluginTests.java diff --git a/docs/changelog/117575.yaml b/docs/changelog/117575.yaml new file mode 100644 index 0000000000000..781444ae97be5 --- /dev/null +++ b/docs/changelog/117575.yaml @@ -0,0 +1,5 @@ +pr: 117575 +summary: Fix enrich cache size setting name +area: Ingest Node +type: bug +issues: [] diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java index 1a68ada60b6f1..d46639d700420 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/EnrichPlugin.java @@ -14,6 +14,8 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.logging.DeprecationCategory; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Setting; @@ -23,6 +25,7 @@ import org.elasticsearch.common.unit.MemorySizeValue; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.core.UpdateForV10; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.indices.SystemIndexDescriptor; import org.elasticsearch.ingest.Processor; @@ -74,6 +77,8 @@ public class EnrichPlugin extends Plugin implements SystemIndexPlugin, IngestPlugin { + private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(EnrichPlugin.class); + static final Setting ENRICH_FETCH_SIZE_SETTING = Setting.intSetting( "enrich.fetch_size", 10000, @@ -126,9 +131,9 @@ public class EnrichPlugin extends Plugin implements SystemIndexPlugin, IngestPlu return String.valueOf(maxConcurrentRequests * maxLookupsPerRequest); }, val -> Setting.parseInt(val, 1, Integer.MAX_VALUE, QUEUE_CAPACITY_SETTING_NAME), Setting.Property.NodeScope); - public static final String CACHE_SIZE_SETTING_NAME = "enrich.cache.size"; + public static final String CACHE_SIZE_SETTING_NAME = "enrich.cache_size"; public static final Setting CACHE_SIZE = new Setting<>( - "enrich.cache.size", + CACHE_SIZE_SETTING_NAME, (String) null, (String s) -> FlatNumberOrByteSizeValue.parse( s, @@ -138,16 +143,59 @@ public class EnrichPlugin extends Plugin implements SystemIndexPlugin, IngestPlu Setting.Property.NodeScope ); + /** + * This setting solely exists because the original setting was accidentally renamed in + * https://github.com/elastic/elasticsearch/pull/111412. + */ + @UpdateForV10(owner = UpdateForV10.Owner.DATA_MANAGEMENT) + public static final String CACHE_SIZE_SETTING_BWC_NAME = "enrich.cache.size"; + public static final Setting CACHE_SIZE_BWC = new Setting<>( + CACHE_SIZE_SETTING_BWC_NAME, + (String) null, + (String s) -> FlatNumberOrByteSizeValue.parse( + s, + CACHE_SIZE_SETTING_BWC_NAME, + new FlatNumberOrByteSizeValue(ByteSizeValue.ofBytes((long) (0.01 * JvmInfo.jvmInfo().getConfiguredMaxHeapSize()))) + ), + Setting.Property.NodeScope, + Setting.Property.Deprecated + ); + private final Settings settings; private final EnrichCache enrichCache; + private final long maxCacheSize; public EnrichPlugin(final Settings settings) { this.settings = settings; - FlatNumberOrByteSizeValue maxSize = CACHE_SIZE.get(settings); + FlatNumberOrByteSizeValue maxSize; + if (settings.hasValue(CACHE_SIZE_SETTING_BWC_NAME)) { + if (settings.hasValue(CACHE_SIZE_SETTING_NAME)) { + throw new IllegalArgumentException( + Strings.format( + "Both [{}] and [{}] are set, please use [{}]", + CACHE_SIZE_SETTING_NAME, + CACHE_SIZE_SETTING_BWC_NAME, + CACHE_SIZE_SETTING_NAME + ) + ); + } + deprecationLogger.warn( + DeprecationCategory.SETTINGS, + "enrich_cache_size_name", + "The [{}] setting is deprecated and will be removed in a future version. Please use [{}] instead.", + CACHE_SIZE_SETTING_BWC_NAME, + CACHE_SIZE_SETTING_NAME + ); + maxSize = CACHE_SIZE_BWC.get(settings); + } else { + maxSize = CACHE_SIZE.get(settings); + } if (maxSize.byteSizeValue() != null) { this.enrichCache = new EnrichCache(maxSize.byteSizeValue()); + this.maxCacheSize = maxSize.byteSizeValue().getBytes(); } else { this.enrichCache = new EnrichCache(maxSize.flatNumber()); + this.maxCacheSize = maxSize.flatNumber(); } } @@ -286,6 +334,11 @@ public String getFeatureDescription() { return "Manages data related to Enrich policies"; } + // Visible for testing + long getMaxCacheSize() { + return maxCacheSize; + } + /** * A class that specifies either a flat (unit-less) number or a byte size value. */ diff --git a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPluginTests.java b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPluginTests.java new file mode 100644 index 0000000000000..07de0e0967448 --- /dev/null +++ b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPluginTests.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.enrich; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +public class EnrichPluginTests extends ESTestCase { + + public void testConstructWithByteSize() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_NAME, size + "b").build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithFlatNumber() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_NAME, size).build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithByteSizeBwc() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_BWC_NAME, size + "b").build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithFlatNumberBwc() { + final var size = randomNonNegativeInt(); + Settings settings = Settings.builder().put(EnrichPlugin.CACHE_SIZE_SETTING_BWC_NAME, size).build(); + EnrichPlugin plugin = new EnrichPlugin(settings); + assertEquals(size, plugin.getMaxCacheSize()); + } + + public void testConstructWithBothSettings() { + Settings settings = Settings.builder() + .put(EnrichPlugin.CACHE_SIZE_SETTING_NAME, randomNonNegativeInt()) + .put(EnrichPlugin.CACHE_SIZE_SETTING_BWC_NAME, randomNonNegativeInt()) + .build(); + assertThrows(IllegalArgumentException.class, () -> new EnrichPlugin(settings)); + } + + @Override + protected List filteredWarnings() { + final var warnings = super.filteredWarnings(); + warnings.add("[enrich.cache.size] setting was deprecated in Elasticsearch and will be removed in a future release."); + warnings.add( + "The [enrich.cache.size] setting is deprecated and will be removed in a future version. Please use [enrich.cache_size] instead." + ); + return warnings; + } +}