Skip to content

Commit

Permalink
Fix lucene codec after lucene version bumped to 9.12
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Oct 8, 2024
1 parent 2c170fb commit bf2234a
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
* Fix sed command in DEVELOPER_GUIDE.md to append a new line character '\n'. [#2181](https://github.com/opensearch-project/k-NN/pull/2181)
### Maintenance
* Fix lucene codec after lucene version bumped to 9.12. [#2195](https://github.com/opensearch-project/k-NN/pull/2195)
### Refactoring
* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.Builder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;

/**
* KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12
*/
public class KNN9120Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene99 as the delegate
*/
public KNN9120Codec() {
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
* Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec
* and a unique name to this ctor.
*
* @param delegate codec that will perform all operations this codec does not override
* @param knnVectorsFormat per field format for KnnVector
*/
@Builder
protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

@Override
public DocValuesFormat docValuesFormat() {
return knnFormatFacade.docValuesFormat();
}

@Override
public CompoundFormat compoundFormat() {
return knnFormatFacade.compoundFormat();
}

@Override
public KnnVectorsFormat knnVectorsFormat() {
return perFieldKnnVectorsFormat;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@

import lombok.Getter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

Expand All @@ -44,22 +46,37 @@ class NativeEngineFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
@Getter
private final DocsWithFieldSet docsWithField;
private final InfoStream infoStream;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;

static NativeEngineFieldVectorsWriter<?> create(final FieldInfo fieldInfo, final InfoStream infoStream) {
@SuppressWarnings("unchecked")
static NativeEngineFieldVectorsWriter<?> create(
final FieldInfo fieldInfo,
final FlatFieldVectorsWriter<?> flatFieldVectorsWriter,
final InfoStream infoStream
) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
return new NativeEngineFieldVectorsWriter<float[]>(fieldInfo, infoStream);
return new NativeEngineFieldVectorsWriter<>(
fieldInfo,
(FlatFieldVectorsWriter<float[]>) flatFieldVectorsWriter,
infoStream
);
case BYTE:
return new NativeEngineFieldVectorsWriter<byte[]>(fieldInfo, infoStream);
return new NativeEngineFieldVectorsWriter<>(fieldInfo, (FlatFieldVectorsWriter<byte[]>) flatFieldVectorsWriter, infoStream);
}
throw new IllegalStateException("Unsupported Vector encoding : " + fieldInfo.getVectorEncoding());
}

private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) {
private NativeEngineFieldVectorsWriter(
final FieldInfo fieldInfo,
final FlatFieldVectorsWriter<T> flatFieldVectorsWriter,
final InfoStream infoStream
) {
this.fieldInfo = fieldInfo;
this.infoStream = infoStream;
vectors = new HashMap<>();
this.docsWithField = new DocsWithFieldSet();
this.flatFieldVectorsWriter = flatFieldVectorsWriter;
}

/**
Expand All @@ -70,7 +87,7 @@ private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStre
* @param vectorValue T
*/
@Override
public void addValue(int docID, T vectorValue) {
public void addValue(int docID, T vectorValue) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"[NativeEngineKNNVectorWriter]VectorValuesField \""
Expand All @@ -81,6 +98,8 @@ public void addValue(int docID, T vectorValue) {
// TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative
// graph build support on the JNI layer.
assert docID > lastDocID;
// ensuring that vector is provided to flatFieldWriter.
flatFieldVectorsWriter.addValue(docID, vectorValue);
vectors.put(docID, vectorValue);
docsWithField.add(docID);
lastDocID = docID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla
*/
@Override
public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOException {
final NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream);
final NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(
fieldInfo,
flatVectorsWriter.addField(fieldInfo),
segmentWriteState.infoStream
);
fields.add(newField);
return flatVectorsWriter.addField(fieldInfo, newField);
return newField;
}

/**
Expand Down
21 changes: 19 additions & 2 deletions src/main/java/org/opensearch/knn/index/codec/KNNCodecVersion.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene912.Lucene912Codec;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat;
import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec;
import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec;
import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec;
import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec;
Expand Down Expand Up @@ -110,9 +112,24 @@ public enum KNNCodecVersion {
.knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService)))
.build(),
KNN990Codec::new
),

V_9_12_0(
"KNN990Codec",
new Lucene912Codec(),
new KNN990PerFieldKnnVectorsFormat(Optional.empty()),
(delegate) -> new KNNFormatFacade(
new KNN80DocValuesFormat(delegate.docValuesFormat()),
new KNN80CompoundFormat(delegate.compoundFormat())
),
(userCodec, mapperService) -> KNN9120Codec.builder()
.delegate(userCodec)
.knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService)))
.build(),
KNN9120Codec::new
);

private static final KNNCodecVersion CURRENT = V_9_9_0;
private static final KNNCodecVersion CURRENT = V_9_12_0;

private final String codecName;
private final Codec defaultCodecDelegate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public KNNScalarQuantizedVectorsFormatParams(Map<String, Object> params, int def
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();
this.initConfidenceInterval(sqEncoderParams);
this.initBits(sqEncoderParams);
// compression flag should be inited after initBits as compressionFlag Depends on bits.
this.initCompressFlag();
}

Expand Down Expand Up @@ -77,6 +78,8 @@ private void initBits(final Map<String, Object> params) {
}

private void initCompressFlag() {
this.compressFlag = true;
// This check is coming from Lucene. Code ref:
// https://github.com/apache/lucene/blob/branch_9_12/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java#L113-L116
this.compressFlag = this.bits <= 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.InfoStream;
Expand All @@ -21,18 +23,22 @@
public class NativeEngineFieldVectorsWriterTests extends KNNCodecTestCase {

@SuppressWarnings("unchecked")
@SneakyThrows
public void testCreate_ForDifferentInputs_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.eq(1), Mockito.any());
NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
floatWriter.addValue(1, new float[] { 1.0f, 2.0f });

Mockito.verify(fieldInfo).getVectorEncoding();

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter.create(
fieldInfo,
mockedFlatFieldVectorsWriter,
InfoStream.getDefault()
);
Assert.assertNotNull(byteWriter);
Expand All @@ -41,11 +47,15 @@ public void testCreate_ForDifferentInputs_thenSuccess() {
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testAddValue_ForDifferentInputs_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.eq(1), Mockito.any());
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.eq(2), Mockito.any());
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
final float[] vec1 = new float[] { 1.0f, 2.0f };
final float[] vec2 = new float[] { 2.0f, 2.0f };
floatWriter.addValue(1, vec1);
Expand All @@ -57,7 +67,7 @@ public void testAddValue_ForDifferentInputs_thenSuccess() {

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
final byte[] bvec1 = new byte[] { 1, 2 };
final byte[] bvec2 = new byte[] { 2, 2 };
byteWriter.addValue(1, bvec1);
Expand All @@ -69,32 +79,36 @@ public void testAddValue_ForDifferentInputs_thenSuccess() {
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testCopyValue_whenValidInput_thenException() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
expectThrows(UnsupportedOperationException.class, () -> floatWriter.copyValue(new float[3]));

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
expectThrows(UnsupportedOperationException.class, () -> byteWriter.copyValue(new byte[3]));
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testRamByteUsed_whenValidInput_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Assert.assertTrue(floatWriter.ramBytesUsed() > 0);

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Assert.assertTrue(byteWriter.ramBytesUsed() > 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
Expand All @@ -16,6 +17,7 @@
import org.mockito.Mock;
import org.mockito.MockedConstruction;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
Expand Down Expand Up @@ -68,6 +70,8 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa
@Mock
private NativeIndexWriter nativeIndexWriter;

private FlatFieldVectorsWriter mockedFlatFieldVectorsWriter;

private NativeEngines990KnnVectorsWriter objectUnderTest;

private final String description;
Expand All @@ -78,6 +82,9 @@ public void setUp() throws Exception {
super.setUp();
MockitoAnnotations.openMocks(this);
objectUnderTest = new NativeEngines990KnnVectorsWriter(segmentWriteState, flatVectorsWriter);
mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(Mockito.anyInt(), Mockito.any());
Mockito.when(flatVectorsWriter.addField(Mockito.any())).thenReturn(mockedFlatFieldVectorsWriter);
}

@ParametersFactory
Expand Down Expand Up @@ -139,8 +146,9 @@ public void testFlush() {
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);
fieldWriterMockedStatic.when(
() -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream)
).thenReturn(field);

try {
objectUnderTest.addField(fieldInfo);
Expand Down Expand Up @@ -227,8 +235,9 @@ public void testFlush_WithQuantization() {
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);
fieldWriterMockedStatic.when(
() -> NativeEngineFieldVectorsWriter.create(fieldInfo, mockedFlatFieldVectorsWriter, segmentWriteState.infoStream)
).thenReturn(field);

try {
objectUnderTest.addField(fieldInfo);
Expand Down
Loading

0 comments on commit bf2234a

Please sign in to comment.