diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b1755e21..7f492eff3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Support filter and nested field in faiss engine radial search [#1652](https://github.com/opensearch-project/k-NN/pull/1652) ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) -* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) * Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604) * Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613) * Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499) diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index 85f037c0f..f4caa4f20 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -5,10 +5,9 @@ package org.opensearch.knn.index; +import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; -import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; @@ -40,29 +39,10 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); - if (fieldInfo == null) { - return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); - } - - DocIdSetIterator values; - if (fieldInfo.hasVectorValues()) { - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - values = reader.getFloatVectorValues(fieldName); - break; - case BYTE: - values = reader.getByteVectorValues(fieldName); - break; - default: - throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding()); - } - } else { - values = DocValues.getBinary(reader, fieldName); - } - return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); + BinaryDocValues values = DocValues.getBinary(reader, fieldName); + return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); } catch (IOException e) { - throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); + throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index f69ad850e..349988c93 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -6,21 +6,18 @@ package org.opensearch.knn.index; import java.io.IOException; -import java.util.Objects; -import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -@RequiredArgsConstructor(access = AccessLevel.PRIVATE) -public abstract class KNNVectorScriptDocValues extends ScriptDocValues { +import java.io.IOException; + +@RequiredArgsConstructor +public final class KNNVectorScriptDocValues extends ScriptDocValues { - private final DocIdSetIterator vectorValues; + private final BinaryDocValues binaryDocValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; @@ -28,7 +25,11 @@ public abstract class KNNVectorScriptDocValues extends ScriptDocValues @Override public void setNextDocId(int docId) throws IOException { - docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId; + if (binaryDocValues.advanceExact(docId)) { + docExists = true; + return; + } + docExists = false; } public float[] getValue() { @@ -43,14 +44,12 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - return doGetValue(); + return vectorDataType.getVectorFromBytesRef(binaryDocValues.binaryValue()); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } } - protected abstract float[] doGetValue() throws IOException; - @Override public int size() { return docExists ? 1 : 0; @@ -60,89 +59,4 @@ public int size() { public float[] get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } - - /** - * Creates a KNNVectorScriptDocValues object based on the provided parameters. - * - * @param values The DocIdSetIterator representing the vector values. - * @param fieldName The name of the field. - * @param vectorDataType The data type of the vector. - * @return A KNNVectorScriptDocValues object based on the type of the values. - * @throws IllegalArgumentException If the type of values is unsupported. - */ - public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { - Objects.requireNonNull(values, "values must not be null"); - if (values instanceof ByteVectorValues) { - return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); - } else if (values instanceof FloatVectorValues) { - return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); - } else if (values instanceof BinaryDocValues) { - return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); - } else { - throw new IllegalArgumentException("Unsupported values type: " + values.getClass()); - } - } - - private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { - private final ByteVectorValues values; - - KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { - super(values, field, type); - this.values = values; - } - - @Override - protected float[] doGetValue() throws IOException { - byte[] bytes = values.vectorValue(); - float[] value = new float[bytes.length]; - for (int i = 0; i < bytes.length; i++) { - value[i] = (float) bytes[i]; - } - return value; - } - } - - private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { - private final FloatVectorValues values; - - KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { - super(values, field, type); - this.values = values; - } - - @Override - protected float[] doGetValue() throws IOException { - return values.vectorValue(); - } - } - - private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { - private final BinaryDocValues values; - - KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { - super(values, field, type); - this.values = values; - } - - @Override - protected float[] doGetValue() throws IOException { - return getVectorDataType().getVectorFromBytesRef(values.binaryValue()); - } - } - - /** - * Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type. - * - * @param fieldName The name of the field. - * @param type The data type of the vector. - * @return An empty KNNVectorScriptDocValues object. - */ - public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { - return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { - @Override - protected float[] doGetValue() throws IOException { - throw new UnsupportedOperationException("empty values"); - } - }; - } } diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index 66e2893c0..3f98a9136 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -5,15 +5,7 @@ package org.opensearch.knn.index; -import org.apache.lucene.document.Field; -import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.index.BinaryDocValues; -import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.LeafReader; -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.LeafReaderContext; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -41,39 +33,26 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); - createKNNVectorDocument(directory, valuesClass); + createKNNVectorDocument(directory); reader = DirectoryReader.open(directory); - LeafReader leafReader = reader.getContext().leaves().get(0).reader(); - DocIdSetIterator vectorValues; - if (BinaryDocValues.class.equals(valuesClass)) { - vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); - } else if (ByteVectorValues.class.equals(valuesClass)) { - vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); - } else { - vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); - } - - scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + scriptDocValues = new KNNVectorScriptDocValues( + leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), + MOCK_INDEX_FIELD_NAME, + VectorDataType.FLOAT + ); } - private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { + private void createKNNVectorDocument(Directory directory) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - Field field; - if (BinaryDocValues.class.equals(valuesClass)) { - field = new BinaryDocValuesField( + knnDocument.add( + new BinaryDocValuesField( MOCK_INDEX_FIELD_NAME, new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ); - } else if (ByteVectorValues.class.equals(valuesClass)) { - field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); - } else { - field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); - } - - knnDocument.add(field); + ) + ); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -105,18 +84,4 @@ public void testSize() throws IOException { public void testGet() throws IOException { expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); } - - public void testUnsupportedValues() throws IOException { - expectThrows( - IllegalArgumentException.class, - () -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT) - ); - } - - public void testEmptyValues() throws IOException { - KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); - assertEquals(0, values.size()); - scriptDocValues.setNextDocId(0); - assertEquals(0, values.size()); - } } diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 19270717d..4423c85d8 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return KNNVectorScriptDocValues.create( + return new KNNVectorScriptDocValues( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return KNNVectorScriptDocValues.create( + return new KNNVectorScriptDocValues( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 22110accd..8c43a4acf 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = KNNVectorScriptDocValues.create( + scriptDocValues = new KNNVectorScriptDocValues( leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName, VectorDataType.FLOAT diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 11c626ff7..901511a68 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -612,16 +612,7 @@ private List createMappers(int dimensions) throws Exception { dimensions, KNNConstants.METHOD_HNSW, KNNEngine.LUCENE.getName(), - SpaceType.DEFAULT.getValue(), - true - ), - createKnnIndexMapping( - FIELD_NAME, - dimensions, - KNNConstants.METHOD_HNSW, - KNNEngine.LUCENE.getName(), - SpaceType.DEFAULT.getValue(), - false + SpaceType.DEFAULT.getValue() ) ); } diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index 5325d1205..0315c47c5 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -563,9 +563,7 @@ public void testL2ScriptingWithLuceneBackedIndex() throws Exception { new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); properties.add( - new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2") - .knnMethodContext(knnMethodContext) - .docValues(randomBoolean()) + new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext) ); String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME);