From bf2234ad19bf4e7e32af567d533abf606ad2d4f9 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 7 Oct 2024 14:43:31 -0700 Subject: [PATCH] Fix lucene codec after lucene version bumped to 9.12 Signed-off-by: Navneet Verma --- CHANGELOG.md | 1 + .../codec/KNN9120Codec/KNN9120Codec.java | 61 +++++++++++++++++++ .../NativeEngineFieldVectorsWriter.java | 29 +++++++-- .../NativeEngines990KnnVectorsWriter.java | 8 ++- .../knn/index/codec/KNNCodecVersion.java | 21 ++++++- ...KNNScalarQuantizedVectorsFormatParams.java | 5 +- .../NativeEngineFieldVectorsWriterTests.java | 28 ++++++--- ...eEngines990KnnVectorsWriterFlushTests.java | 17 ++++-- ...eEngines990KnnVectorsWriterMergeTests.java | 16 +++-- ...alarQuantizedVectorsFormatParamsTests.java | 4 +- 10 files changed, 163 insertions(+), 27 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f615a78fb9..7bc3019dfe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Documentation * Fix sed command in DEVELOPER_GUIDE.md to append a new line character '\n'. [#2181](https://github.com/opensearch-project/k-NN/pull/2181) ### Maintenance +* Fix lucene codec after lucene version bumped to 9.12. [#2195](https://github.com/opensearch-project/k-NN/pull/2195) ### Refactoring * Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java new file mode 100644 index 0000000000..a370197ecc --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNN9120Codec/KNN9120Codec.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.KNN9120Codec; + +import lombok.Builder; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.CompoundFormat; +import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.opensearch.knn.index.codec.KNNCodecVersion; +import org.opensearch.knn.index.codec.KNNFormatFacade; + +/** + * KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12 + */ +public class KNN9120Codec extends FilterCodec { + private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0; + private final KNNFormatFacade knnFormatFacade; + private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat; + + /** + * No arg constructor that uses Lucene99 as the delegate + */ + public KNN9120Codec() { + this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat()); + } + + /** + * Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec + * and a unique name to this ctor. + * + * @param delegate codec that will perform all operations this codec does not override + * @param knnVectorsFormat per field format for KnnVector + */ + @Builder + protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) { + super(VERSION.getCodecName(), delegate); + knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate); + perFieldKnnVectorsFormat = knnVectorsFormat; + } + + @Override + public DocValuesFormat docValuesFormat() { + return knnFormatFacade.docValuesFormat(); + } + + @Override + public CompoundFormat compoundFormat() { + return knnFormatFacade.compoundFormat(); + } + + @Override + public KnnVectorsFormat knnVectorsFormat() { + return perFieldKnnVectorsFormat; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java index 1abb849446..dd7ad4af58 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriter.java @@ -13,11 +13,13 @@ import lombok.Getter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.RamUsageEstimator; +import java.io.IOException; import java.util.HashMap; import java.util.Map; @@ -44,22 +46,37 @@ class NativeEngineFieldVectorsWriter extends KnnFieldVectorsWriter { @Getter private final DocsWithFieldSet docsWithField; private final InfoStream infoStream; + private final FlatFieldVectorsWriter flatFieldVectorsWriter; - static NativeEngineFieldVectorsWriter create(final FieldInfo fieldInfo, final InfoStream infoStream) { + @SuppressWarnings("unchecked") + static NativeEngineFieldVectorsWriter create( + final FieldInfo fieldInfo, + final FlatFieldVectorsWriter flatFieldVectorsWriter, + final InfoStream infoStream + ) { switch (fieldInfo.getVectorEncoding()) { case FLOAT32: - return new NativeEngineFieldVectorsWriter(fieldInfo, infoStream); + return new NativeEngineFieldVectorsWriter<>( + fieldInfo, + (FlatFieldVectorsWriter) flatFieldVectorsWriter, + infoStream + ); case BYTE: - return new NativeEngineFieldVectorsWriter(fieldInfo, infoStream); + return new NativeEngineFieldVectorsWriter<>(fieldInfo, (FlatFieldVectorsWriter) flatFieldVectorsWriter, infoStream); } throw new IllegalStateException("Unsupported Vector encoding : " + fieldInfo.getVectorEncoding()); } - private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) { + private NativeEngineFieldVectorsWriter( + final FieldInfo fieldInfo, + final FlatFieldVectorsWriter flatFieldVectorsWriter, + final InfoStream infoStream + ) { this.fieldInfo = fieldInfo; this.infoStream = infoStream; vectors = new HashMap<>(); this.docsWithField = new DocsWithFieldSet(); + this.flatFieldVectorsWriter = flatFieldVectorsWriter; } /** @@ -70,7 +87,7 @@ private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStre * @param vectorValue T */ @Override - public void addValue(int docID, T vectorValue) { + public void addValue(int docID, T vectorValue) throws IOException { if (docID == lastDocID) { throw new IllegalArgumentException( "[NativeEngineKNNVectorWriter]VectorValuesField \"" @@ -81,6 +98,8 @@ public void addValue(int docID, T vectorValue) { // TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative // graph build support on the JNI layer. assert docID > lastDocID; + // ensuring that vector is provided to flatFieldWriter. + flatFieldVectorsWriter.addValue(docID, vectorValue); vectors.put(docID, vectorValue); docsWithField.add(docID); lastDocID = docID; diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 2f22565c98..eccad41c8b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -65,9 +65,13 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla */ @Override public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOException { - final NativeEngineFieldVectorsWriter newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream); + final NativeEngineFieldVectorsWriter newField = NativeEngineFieldVectorsWriter.create( + fieldInfo, + flatVectorsWriter.addField(fieldInfo), + segmentWriteState.infoStream + ); fields.add(newField); - return flatVectorsWriter.addField(fieldInfo, newField); + return newField; } /** diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java index 419505aa20..4bbeceffb3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java @@ -12,12 +12,14 @@ import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec; import org.apache.lucene.codecs.Codec; import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec; -import org.apache.lucene.codecs.lucene99.Lucene99Codec; +import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec; +import org.apache.lucene.codecs.lucene912.Lucene912Codec; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat; import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat; import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec; +import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec; import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec; import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat; import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec; @@ -110,9 +112,24 @@ public enum KNNCodecVersion { .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) .build(), KNN990Codec::new + ), + + V_9_12_0( + "KNN990Codec", + new Lucene912Codec(), + new KNN990PerFieldKnnVectorsFormat(Optional.empty()), + (delegate) -> new KNNFormatFacade( + new KNN80DocValuesFormat(delegate.docValuesFormat()), + new KNN80CompoundFormat(delegate.compoundFormat()) + ), + (userCodec, mapperService) -> KNN9120Codec.builder() + .delegate(userCodec) + .knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService))) + .build(), + KNN9120Codec::new ); - private static final KNNCodecVersion CURRENT = V_9_9_0; + private static final KNNCodecVersion CURRENT = V_9_12_0; private final String codecName; private final Codec defaultCodecDelegate; diff --git a/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java index e2d31183b6..cdcecac0fb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java @@ -31,6 +31,7 @@ public KNNScalarQuantizedVectorsFormatParams(Map params, int def Map sqEncoderParams = encoderMethodComponentContext.getParameters(); this.initConfidenceInterval(sqEncoderParams); this.initBits(sqEncoderParams); + // compression flag should be inited after initBits as compressionFlag Depends on bits. this.initCompressFlag(); } @@ -77,6 +78,8 @@ private void initBits(final Map params) { } private void initCompressFlag() { - this.compressFlag = true; + // This check is coming from Lucene. Code ref: + // https://github.com/apache/lucene/blob/branch_9_12/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java#L113-L116 + this.compressFlag = this.bits <= 4; } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java index 29e3531cf5..daee736f83 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngineFieldVectorsWriterTests.java @@ -11,6 +11,8 @@ package org.opensearch.knn.index.codec.KNN990Codec; +import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.InfoStream; @@ -21,11 +23,14 @@ public class NativeEngineFieldVectorsWriterTests extends KNNCodecTestCase { @SuppressWarnings("unchecked") + @SneakyThrows public void testCreate_ForDifferentInputs_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); + FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.eq(1), Mockito.any()); NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); floatWriter.addValue(1, new float[] { 1.0f, 2.0f }); Mockito.verify(fieldInfo).getVectorEncoding(); @@ -33,6 +38,7 @@ public void testCreate_ForDifferentInputs_thenSuccess() { Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter.create( fieldInfo, + mockedFlatFieldVectorsWriter, InfoStream.getDefault() ); Assert.assertNotNull(byteWriter); @@ -41,11 +47,15 @@ public void testCreate_ForDifferentInputs_thenSuccess() { } @SuppressWarnings("unchecked") + @SneakyThrows public void testAddValue_ForDifferentInputs_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); + FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.eq(1), Mockito.any()); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.eq(2), Mockito.any()); final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); final float[] vec1 = new float[] { 1.0f, 2.0f }; final float[] vec2 = new float[] { 2.0f, 2.0f }; floatWriter.addValue(1, vec1); @@ -57,7 +67,7 @@ public void testAddValue_ForDifferentInputs_thenSuccess() { Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); final byte[] bvec1 = new byte[] { 1, 2 }; final byte[] bvec2 = new byte[] { 2, 2 }; byteWriter.addValue(1, bvec1); @@ -69,32 +79,36 @@ public void testAddValue_ForDifferentInputs_thenSuccess() { } @SuppressWarnings("unchecked") + @SneakyThrows public void testCopyValue_whenValidInput_thenException() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); + FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); expectThrows(UnsupportedOperationException.class, () -> floatWriter.copyValue(new float[3])); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); expectThrows(UnsupportedOperationException.class, () -> byteWriter.copyValue(new byte[3])); } @SuppressWarnings("unchecked") + @SneakyThrows public void testRamByteUsed_whenValidInput_thenSuccess() { final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2); + FlatFieldVectorsWriter mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); final NativeEngineFieldVectorsWriter floatWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. Assert.assertTrue(floatWriter.ramBytesUsed() > 0); Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE); final NativeEngineFieldVectorsWriter byteWriter = (NativeEngineFieldVectorsWriter) NativeEngineFieldVectorsWriter - .create(fieldInfo, InfoStream.getDefault()); + .create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault()); // testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too. Assert.assertTrue(byteWriter.ramBytesUsed() > 0); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 9f74b2c104..5bb6d19268 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -8,6 +8,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -16,6 +17,7 @@ import org.mockito.Mock; import org.mockito.MockedConstruction; import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; @@ -68,6 +70,8 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa @Mock private NativeIndexWriter nativeIndexWriter; + private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; + private NativeEngines990KnnVectorsWriter objectUnderTest; private final String description; @@ -78,6 +82,9 @@ public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); + Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); } @ParametersFactory @@ -139,8 +146,9 @@ public void testFlush() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); try { objectUnderTest.addField(fieldInfo); @@ -227,8 +235,9 @@ public void testFlush_WithQuantization() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); try { objectUnderTest.addField(fieldInfo); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index 41940c4d47..af18cd281f 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -9,6 +9,7 @@ import lombok.RequiredArgsConstructor; import lombok.SneakyThrows; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; @@ -19,6 +20,7 @@ import org.mockito.Mock; import org.mockito.MockedConstruction; import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; @@ -74,12 +76,16 @@ public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCa private final String description; private final Map mergedVectors; + private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter; @Override public void setUp() throws Exception { super.setUp(); MockitoAnnotations.openMocks(this); objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter); + mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class); + Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any()); + Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter); } @ParametersFactory @@ -120,8 +126,9 @@ public void testMerge() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); @@ -184,8 +191,9 @@ public void testMerge_WithQuantization() { ); NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors); - fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) - .thenReturn(field); + fieldWriterMockedStatic.when( + () -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream) + ).thenReturn(field); mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)) .thenReturn(floatVectorValues); diff --git a/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java b/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java index 573e826e0f..a2bba1b4b6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java @@ -39,7 +39,7 @@ public void testInitParams_whenCalled_thenReturnDefaultParams() { assertEquals(DEFAULT_MAX_CONNECTIONS, knnScalarQuantizedVectorsFormatParams.getMaxConnections()); assertEquals(DEFAULT_BEAM_WIDTH, knnScalarQuantizedVectorsFormatParams.getBeamWidth()); assertNull(knnScalarQuantizedVectorsFormatParams.getConfidenceInterval()); - assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); + assertFalse(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits()); } @@ -65,7 +65,7 @@ public void testInitParams_whenCalled_thenReturnParams() { assertEquals(m, knnScalarQuantizedVectorsFormatParams.getMaxConnections()); assertEquals(efConstruction, knnScalarQuantizedVectorsFormatParams.getBeamWidth()); assertEquals((float) MINIMUM_CONFIDENCE_INTERVAL, knnScalarQuantizedVectorsFormatParams.getConfidenceInterval()); - assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); + assertFalse(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits()); }