Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lucene codec after lucene version bumped to 9.12 #2195

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Documentation
* Fix sed command in DEVELOPER_GUIDE.md to append a new line character '\n'. [#2181](https://github.com/opensearch-project/k-NN/pull/2181)
### Maintenance
* Fix lucene codec after lucene version bumped to 9.12. [#2195](https://github.com/opensearch-project/k-NN/pull/2195)
### Refactoring
* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN9120Codec;

import lombok.Builder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNNCodecVersion;
import org.opensearch.knn.index.codec.KNNFormatFacade;

/**
* KNN Codec that wraps the Lucene Codec which is part of Lucene 9.12
*/
public class KNN9120Codec extends FilterCodec {
private static final KNNCodecVersion VERSION = KNNCodecVersion.V_9_12_0;
private final KNNFormatFacade knnFormatFacade;
private final PerFieldKnnVectorsFormat perFieldKnnVectorsFormat;

/**
* No arg constructor that uses Lucene99 as the delegate
*/
public KNN9120Codec() {
this(VERSION.getDefaultCodecDelegate(), VERSION.getPerFieldKnnVectorsFormat());
}

/**
* Sole constructor. When subclassing this codec, create a no-arg ctor and pass the delegate codec
* and a unique name to this ctor.
*
* @param delegate codec that will perform all operations this codec does not override
* @param knnVectorsFormat per field format for KnnVector
*/
@Builder
protected KNN9120Codec(Codec delegate, PerFieldKnnVectorsFormat knnVectorsFormat) {
super(VERSION.getCodecName(), delegate);
knnFormatFacade = VERSION.getKnnFormatFacadeSupplier().apply(delegate);
perFieldKnnVectorsFormat = knnVectorsFormat;
}

@Override
public DocValuesFormat docValuesFormat() {
return knnFormatFacade.docValuesFormat();
}

@Override
public CompoundFormat compoundFormat() {
return knnFormatFacade.compoundFormat();
}

@Override
public KnnVectorsFormat knnVectorsFormat() {
return perFieldKnnVectorsFormat;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@

import lombok.Getter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

Expand All @@ -44,22 +46,37 @@ class NativeEngineFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
@Getter
private final DocsWithFieldSet docsWithField;
private final InfoStream infoStream;
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;

static NativeEngineFieldVectorsWriter<?> create(final FieldInfo fieldInfo, final InfoStream infoStream) {
@SuppressWarnings("unchecked")
static NativeEngineFieldVectorsWriter<?> create(
final FieldInfo fieldInfo,
final FlatFieldVectorsWriter<?> flatFieldVectorsWriter,
final InfoStream infoStream
) {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
return new NativeEngineFieldVectorsWriter<float[]>(fieldInfo, infoStream);
return new NativeEngineFieldVectorsWriter<>(
fieldInfo,
(FlatFieldVectorsWriter<float[]>) flatFieldVectorsWriter,
infoStream
);
case BYTE:
return new NativeEngineFieldVectorsWriter<byte[]>(fieldInfo, infoStream);
return new NativeEngineFieldVectorsWriter<>(fieldInfo, (FlatFieldVectorsWriter<byte[]>) flatFieldVectorsWriter, infoStream);
}
throw new IllegalStateException("Unsupported Vector encoding : " + fieldInfo.getVectorEncoding());
}

private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStream infoStream) {
private NativeEngineFieldVectorsWriter(
final FieldInfo fieldInfo,
final FlatFieldVectorsWriter<T> flatFieldVectorsWriter,
final InfoStream infoStream
) {
this.fieldInfo = fieldInfo;
this.infoStream = infoStream;
vectors = new HashMap<>();
this.docsWithField = new DocsWithFieldSet();
this.flatFieldVectorsWriter = flatFieldVectorsWriter;
}

/**
Expand All @@ -70,7 +87,7 @@ private NativeEngineFieldVectorsWriter(final FieldInfo fieldInfo, final InfoStre
* @param vectorValue T
*/
@Override
public void addValue(int docID, T vectorValue) {
public void addValue(int docID, T vectorValue) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"[NativeEngineKNNVectorWriter]VectorValuesField \""
Expand All @@ -81,6 +98,8 @@ public void addValue(int docID, T vectorValue) {
// TODO: we can build the graph here too iteratively. but right now I am skipping that as we need iterative
// graph build support on the JNI layer.
assert docID > lastDocID;
// ensuring that vector is provided to flatFieldWriter.
flatFieldVectorsWriter.addValue(docID, vectorValue);
vectors.put(docID, vectorValue);
docsWithField.add(docID);
lastDocID = docID;
Expand All @@ -105,6 +124,7 @@ public long ramBytesUsed() {
return SHALLOW_SIZE + docsWithField.ramBytesUsed() + (long) this.vectors.size() * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() * RamUsageEstimator.shallowSizeOfInstance(
Integer.class
) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter
.ramBytesUsed();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla
*/
@Override
public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOException {
final NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream);
final NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(
fieldInfo,
flatVectorsWriter.addField(fieldInfo),
segmentWriteState.infoStream
);
fields.add(newField);
return flatVectorsWriter.addField(fieldInfo, newField);
return newField;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.apache.lucene.backward_codecs.lucene94.Lucene94Codec;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.backward_codecs.lucene95.Lucene95Codec;
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.backward_codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene912.Lucene912Codec;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80CompoundFormat;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesFormat;
import org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec;
import org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec;
import org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec;
import org.opensearch.knn.index.codec.KNN920Codec.KNN920PerFieldKnnVectorsFormat;
import org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec;
Expand Down Expand Up @@ -110,9 +112,24 @@ public enum KNNCodecVersion {
.knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService)))
.build(),
KNN990Codec::new
),

V_9_12_0(
"KNN9120Codec",
new Lucene912Codec(),
new KNN990PerFieldKnnVectorsFormat(Optional.empty()),
(delegate) -> new KNNFormatFacade(
new KNN80DocValuesFormat(delegate.docValuesFormat()),
new KNN80CompoundFormat(delegate.compoundFormat())
),
(userCodec, mapperService) -> KNN9120Codec.builder()
.delegate(userCodec)
.knnVectorsFormat(new KNN990PerFieldKnnVectorsFormat(Optional.ofNullable(mapperService)))
.build(),
KNN9120Codec::new
);

private static final KNNCodecVersion CURRENT = V_9_9_0;
private static final KNNCodecVersion CURRENT = V_9_12_0;

private final String codecName;
private final Codec defaultCodecDelegate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public KNNScalarQuantizedVectorsFormatParams(Map<String, Object> params, int def
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();
this.initConfidenceInterval(sqEncoderParams);
this.initBits(sqEncoderParams);
this.initCompressFlag();
// compression flag should be set after bits has been initialised as compressionFlag depends on bits.
this.setCompressionFlag();
}

@Override
Expand Down Expand Up @@ -76,7 +77,14 @@ private void initBits(final Map<String, Object> params) {
this.bits = LUCENE_SQ_DEFAULT_BITS;
}

private void initCompressFlag() {
this.compressFlag = true;
private void setCompressionFlag() {
if (this.bits <= 0) {
throw new IllegalArgumentException(
"Either bits are set to less than 0 or they have not been initialized." + " Bit value: " + this.bits
);
}
// This check is coming from Lucene. Code ref:
// https://github.com/apache/lucene/blob/branch_9_12/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java#L113-L116
this.compressFlag = this.bits <= 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ org.opensearch.knn.index.codec.KNN920Codec.KNN920Codec
org.opensearch.knn.index.codec.KNN940Codec.KNN940Codec
org.opensearch.knn.index.codec.KNN950Codec.KNN950Codec
org.opensearch.knn.index.codec.KNN990Codec.KNN990Codec
org.opensearch.knn.index.codec.KNN9120Codec.KNN9120Codec
org.opensearch.knn.index.codec.KNN990Codec.UnitTestCodec
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.InfoStream;
Expand All @@ -21,82 +23,109 @@
public class NativeEngineFieldVectorsWriterTests extends KNNCodecTestCase {

@SuppressWarnings("unchecked")
@SneakyThrows
public void testCreate_ForDifferentInputs_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
final FlatFieldVectorsWriter<float[]> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
floatWriter.addValue(1, new float[] { 1.0f, 2.0f });
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
final float[] floatVector = new float[] { 1.0f, 2.0f };
floatWriter.addValue(1, floatVector);
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(1, floatVector);

Mockito.verify(fieldInfo).getVectorEncoding();
Mockito.verify(mockedFlatFieldVectorsWriter).addValue(1, floatVector);

final byte[] byteVector = new byte[] { 1, 2 };
final FlatFieldVectorsWriter<byte[]> mockedFlatFieldByteVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(1, byteVector);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter.create(
fieldInfo,
mockedFlatFieldByteVectorsWriter,
InfoStream.getDefault()
);
Assert.assertNotNull(byteWriter);
Mockito.verify(fieldInfo, Mockito.times(2)).getVectorEncoding();
byteWriter.addValue(1, new byte[] { 1, 2 });
byteWriter.addValue(1, byteVector);
Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(1, byteVector);
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testAddValue_ForDifferentInputs_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
final FlatFieldVectorsWriter<float[]> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
final float[] vec1 = new float[] { 1.0f, 2.0f };
final float[] vec2 = new float[] { 2.0f, 2.0f };
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(1, vec1);
Mockito.doNothing().when(mockedFlatFieldVectorsWriter).addValue(2, vec2);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
floatWriter.addValue(1, vec1);
floatWriter.addValue(2, vec2);
Mockito.verify(mockedFlatFieldVectorsWriter).addValue(1, vec1);
Mockito.verify(mockedFlatFieldVectorsWriter).addValue(2, vec2);

Assert.assertEquals(vec1, floatWriter.getVectors().get(1));
Assert.assertEquals(vec2, floatWriter.getVectors().get(2));
Mockito.verify(fieldInfo).getVectorEncoding();

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
final FlatFieldVectorsWriter<byte[]> mockedFlatFieldByteVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
final byte[] bvec1 = new byte[] { 1, 2 };
final byte[] bvec2 = new byte[] { 2, 2 };
Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(1, bvec1);
Mockito.doNothing().when(mockedFlatFieldByteVectorsWriter).addValue(2, bvec2);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, mockedFlatFieldByteVectorsWriter, InfoStream.getDefault());
byteWriter.addValue(1, bvec1);
byteWriter.addValue(2, bvec2);

Assert.assertEquals(bvec1, byteWriter.getVectors().get(1));
Assert.assertEquals(bvec2, byteWriter.getVectors().get(2));
Mockito.verify(fieldInfo, Mockito.times(2)).getVectorEncoding();
Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(1, bvec1);
Mockito.verify(mockedFlatFieldByteVectorsWriter).addValue(2, bvec2);
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testCopyValue_whenValidInput_thenException() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
expectThrows(UnsupportedOperationException.class, () -> floatWriter.copyValue(new float[3]));

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
expectThrows(UnsupportedOperationException.class, () -> byteWriter.copyValue(new byte[3]));
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testRamByteUsed_whenValidInput_thenSuccess() {
final FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32);
Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.when(mockedFlatFieldVectorsWriter.ramBytesUsed()).thenReturn(1L);
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Assert.assertTrue(floatWriter.ramBytesUsed() > 0);

Mockito.when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.BYTE);
final NativeEngineFieldVectorsWriter<byte[]> byteWriter = (NativeEngineFieldVectorsWriter<byte[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, InfoStream.getDefault());
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Assert.assertTrue(byteWriter.ramBytesUsed() > 0);
Mockito.verify(mockedFlatFieldVectorsWriter, Mockito.times(2)).ramBytesUsed();

}
}
Loading
Loading