From 771c4b54a74b7c4406c71a8bf758378329cfe4d5 Mon Sep 17 00:00:00 2001 From: panguixin Date: Fri, 29 Mar 2024 03:31:47 +0800 Subject: [PATCH] Support script score when doc value is disabled (#1573) * support script score when doc value is disabled Signed-off-by: panguixin * add test Signed-off-by: panguixin * apply review comments Signed-off-by: panguixin * fix test Signed-off-by: panguixin --------- Signed-off-by: panguixin --- CHANGELOG.md | 1 + .../knn/index/KNNVectorDVLeafFieldData.java | 28 +- .../knn/index/KNNVectorScriptDocValues.java | 109 ++++- .../org/opensearch/knn/index/FaissIT.java | 6 +- .../index/KNNVectorScriptDocValuesTests.java | 62 ++- .../opensearch/knn/index/LuceneEngineIT.java | 9 +- .../org/opensearch/knn/index/NmslibIT.java | 4 +- .../opensearch/knn/index/OpenSearchIT.java | 5 +- .../knn/index/VectorDataTypeTests.java | 4 +- .../plugin/script/KNNScoringUtilTests.java | 2 +- .../knn/plugin/script/KNNScriptScoringIT.java | 383 ++++++------------ .../knn/plugin/script/PainlessScriptIT.java | 20 +- .../org/opensearch/knn/KNNRestTestCase.java | 49 ++- .../java/org/opensearch/knn/KNNResult.java | 27 +- 14 files changed, 384 insertions(+), 325 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03de217a8..e82376a6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements * Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) +* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573) ### Bug Fixes ### Infrastructure * Add micro-benchmark module in k-NN plugin for benchmark streaming vectors to JNI layer functionality. [#1583](https://github.com/opensearch-project/k-NN/pull/1583) diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java index f4caa4f20..85f037c0f 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorDVLeafFieldData.java @@ -5,9 +5,10 @@ package org.opensearch.knn.index; -import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.index.fielddata.LeafFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; @@ -39,10 +40,29 @@ public long ramBytesUsed() { @Override public ScriptDocValues getScriptValues() { try { - BinaryDocValues values = DocValues.getBinary(reader, fieldName); - return new KNNVectorScriptDocValues(values, fieldName, vectorDataType); + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(fieldName); + if (fieldInfo == null) { + return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType); + } + + DocIdSetIterator values; + if (fieldInfo.hasVectorValues()) { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + values = reader.getFloatVectorValues(fieldName); + break; + case BYTE: + values = reader.getByteVectorValues(fieldName); + break; + default: + throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding()); + } + } else { + values = DocValues.getBinary(reader, fieldName); + } + return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType); } catch (IOException e) { - throw new IllegalStateException("Cannot load doc values for knn vector field: " + fieldName, e); + throw new IllegalStateException("Cannot load values for knn vector field: " + fieldName, e); } } diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index 9f7d52205..c733c534e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -5,18 +5,22 @@ package org.opensearch.knn.index; +import java.io.IOException; +import java.util.Objects; +import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.ExceptionsHelper; import org.opensearch.index.fielddata.ScriptDocValues; -import java.io.IOException; - -@RequiredArgsConstructor -public final class KNNVectorScriptDocValues extends ScriptDocValues { +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public abstract class KNNVectorScriptDocValues extends ScriptDocValues { - private final BinaryDocValues binaryDocValues; + private final DocIdSetIterator vectorValues; private final String fieldName; @Getter private final VectorDataType vectorDataType; @@ -24,11 +28,7 @@ public final class KNNVectorScriptDocValues extends ScriptDocValues { @Override public void setNextDocId(int docId) throws IOException { - if (binaryDocValues.advanceExact(docId)) { - docExists = true; - return; - } - docExists = false; + docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId; } public float[] getValue() { @@ -43,12 +43,14 @@ public float[] getValue() { throw new IllegalStateException(errorMessage); } try { - return vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue()); + return doGetValue(); } catch (IOException e) { throw ExceptionsHelper.convertToOpenSearchException(e); } } + protected abstract float[] doGetValue() throws IOException; + @Override public int size() { return docExists ? 1 : 0; @@ -58,4 +60,89 @@ public int size() { public float[] get(int i) { throw new UnsupportedOperationException("knn vector does not support this operation"); } + + /** + * Creates a KNNVectorScriptDocValues object based on the provided parameters. + * + * @param values The DocIdSetIterator representing the vector values. + * @param fieldName The name of the field. + * @param vectorDataType The data type of the vector. + * @return A KNNVectorScriptDocValues object based on the type of the values. + * @throws IllegalArgumentException If the type of values is unsupported. + */ + public static KNNVectorScriptDocValues create(DocIdSetIterator values, String fieldName, VectorDataType vectorDataType) { + Objects.requireNonNull(values, "values must not be null"); + if (values instanceof ByteVectorValues) { + return new KNNByteVectorScriptDocValues((ByteVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof FloatVectorValues) { + return new KNNFloatVectorScriptDocValues((FloatVectorValues) values, fieldName, vectorDataType); + } else if (values instanceof BinaryDocValues) { + return new KNNNativeVectorScriptDocValues((BinaryDocValues) values, fieldName, vectorDataType); + } else { + throw new IllegalArgumentException("Unsupported values type: " + values.getClass()); + } + } + + private static final class KNNByteVectorScriptDocValues extends KNNVectorScriptDocValues { + private final ByteVectorValues values; + + KNNByteVectorScriptDocValues(ByteVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + byte[] bytes = values.vectorValue(); + float[] value = new float[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + value[i] = (float) bytes[i]; + } + return value; + } + } + + private static final class KNNFloatVectorScriptDocValues extends KNNVectorScriptDocValues { + private final FloatVectorValues values; + + KNNFloatVectorScriptDocValues(FloatVectorValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return values.vectorValue(); + } + } + + private static final class KNNNativeVectorScriptDocValues extends KNNVectorScriptDocValues { + private final BinaryDocValues values; + + KNNNativeVectorScriptDocValues(BinaryDocValues values, String field, VectorDataType type) { + super(values, field, type); + this.values = values; + } + + @Override + protected float[] doGetValue() throws IOException { + return getVectorDataType().getVectorFromDocValues(values.binaryValue()); + } + } + + /** + * Creates an empty KNNVectorScriptDocValues object based on the provided field name and vector data type. + * + * @param fieldName The name of the field. + * @param type The data type of the vector. + * @return An empty KNNVectorScriptDocValues object. + */ + public static KNNVectorScriptDocValues emptyValues(String fieldName, VectorDataType type) { + return new KNNVectorScriptDocValues(DocIdSetIterator.empty(), fieldName, type) { + @Override + protected float[] doGetValue() throws IOException { + throw new UnsupportedOperationException("empty values"); + } + }; + } } diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 3fafae9ba..0cec3810e 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -148,7 +148,7 @@ public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), @@ -258,7 +258,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), @@ -828,7 +828,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java index a0df3ce64..66e2893c0 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorScriptDocValuesTests.java @@ -5,6 +5,15 @@ package org.opensearch.knn.index; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.DocIdSetIterator; import org.opensearch.knn.KNNTestCase; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.document.BinaryDocValuesField; @@ -13,7 +22,6 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; -import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.store.Directory; import org.junit.Assert; import org.junit.Before; @@ -24,6 +32,7 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { private static final String MOCK_INDEX_FIELD_NAME = "test-index-field-name"; private static final float[] SAMPLE_VECTOR_DATA = new float[] { 1.0f, 2.0f }; + private static final byte[] SAMPLE_BYTE_VECTOR_DATA = new byte[] { 1, 2 }; private KNNVectorScriptDocValues scriptDocValues; private Directory directory; private DirectoryReader reader; @@ -32,26 +41,39 @@ public class KNNVectorScriptDocValuesTests extends KNNTestCase { public void setUp() throws Exception { super.setUp(); directory = newDirectory(); - createKNNVectorDocument(directory); + Class valuesClass = randomFrom(BinaryDocValues.class, ByteVectorValues.class, FloatVectorValues.class); + createKNNVectorDocument(directory, valuesClass); reader = DirectoryReader.open(directory); - LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( - leafReaderContext.reader().getBinaryDocValues(MOCK_INDEX_FIELD_NAME), - MOCK_INDEX_FIELD_NAME, - VectorDataType.FLOAT - ); + LeafReader leafReader = reader.getContext().leaves().get(0).reader(); + DocIdSetIterator vectorValues; + if (BinaryDocValues.class.equals(valuesClass)) { + vectorValues = DocValues.getBinary(leafReader, MOCK_INDEX_FIELD_NAME); + } else if (ByteVectorValues.class.equals(valuesClass)) { + vectorValues = leafReader.getByteVectorValues(MOCK_INDEX_FIELD_NAME); + } else { + vectorValues = leafReader.getFloatVectorValues(MOCK_INDEX_FIELD_NAME); + } + + scriptDocValues = KNNVectorScriptDocValues.create(vectorValues, MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); } - private void createKNNVectorDocument(Directory directory) throws IOException { + private void createKNNVectorDocument(Directory directory, Class valuesClass) throws IOException { IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriter writer = new IndexWriter(directory, conf); Document knnDocument = new Document(); - knnDocument.add( - new BinaryDocValuesField( + Field field; + if (BinaryDocValues.class.equals(valuesClass)) { + field = new BinaryDocValuesField( MOCK_INDEX_FIELD_NAME, new VectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA, new FieldType()).binaryValue() - ) - ); + ); + } else if (ByteVectorValues.class.equals(valuesClass)) { + field = new KnnByteVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_BYTE_VECTOR_DATA); + } else { + field = new KnnFloatVectorField(MOCK_INDEX_FIELD_NAME, SAMPLE_VECTOR_DATA); + } + + knnDocument.add(field); writer.addDocument(knnDocument); writer.commit(); writer.close(); @@ -83,4 +105,18 @@ public void testSize() throws IOException { public void testGet() throws IOException { expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0)); } + + public void testUnsupportedValues() throws IOException { + expectThrows( + IllegalArgumentException.class, + () -> KNNVectorScriptDocValues.create(DocValues.emptyNumeric(), MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT) + ); + } + + public void testEmptyValues() throws IOException { + KNNVectorScriptDocValues values = KNNVectorScriptDocValues.emptyValues(MOCK_INDEX_FIELD_NAME, VectorDataType.FLOAT); + assertEquals(0, values.size()); + scriptDocValues.setNextDocId(0); + assertEquals(0, values.size()); + } } diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 8919519d1..b17155704 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -7,7 +7,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.primitives.Floats; import lombok.SneakyThrows; import org.apache.commons.lang.math.RandomUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; @@ -307,14 +306,14 @@ public void testIndexReopening() throws Exception { final float[] searchVector = TEST_QUERY_VECTORS[0]; final int k = 1 + RandomUtils.nextInt(TEST_INDEX_VECTORS.length); - final List knnResultsBeforeIndexClosure = queryResults(searchVector, k); + final List knnResultsBeforeIndexClosure = queryResults(searchVector, k); closeIndex(INDEX_NAME); openIndex(INDEX_NAME); ensureGreen(INDEX_NAME); - final List knnResultsAfterIndexClosure = queryResults(searchVector, k); + final List knnResultsAfterIndexClosure = queryResults(searchVector, k); assertArrayEquals(knnResultsBeforeIndexClosure.toArray(), knnResultsAfterIndexClosure.toArray()); } @@ -365,7 +364,7 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); float distance = TestUtils.computeDistFromSpaceType(spaceType, primitiveArray, queryVector); float rawScore = VECTOR_SIMILARITY_TO_SCORE.get(spaceType.getVectorSimilarityFunction()).apply(distance); assertEquals(KNNEngine.LUCENE.score(rawScore, spaceType), actualScores.get(j), 0.0001); @@ -373,7 +372,7 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws Excep } } - private List queryResults(final float[] searchVector, final int k) throws Exception { + private List queryResults(final float[] searchVector, final int k) throws Exception { final String responseBody = EntityUtils.toString( searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity() ); diff --git a/src/test/java/org/opensearch/knn/index/NmslibIT.java b/src/test/java/org/opensearch/knn/index/NmslibIT.java index 8007504cf..86745ab13 100644 --- a/src/test/java/org/opensearch/knn/index/NmslibIT.java +++ b/src/test/java/org/opensearch/knn/index/NmslibIT.java @@ -30,11 +30,9 @@ import java.io.IOException; import java.net.URL; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; @@ -115,7 +113,7 @@ public void testEndToEnd() throws Exception { List actualScores = parseSearchResponseScore(responseBody, fieldName); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 2e37e26c4..d82a7f98c 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -39,7 +39,6 @@ import java.util.List; import java.util.Map; import java.util.TreeMap; -import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; @@ -143,7 +142,7 @@ public void testEndToEnd() throws Exception { List actualScores = parseSearchResponseScore(responseBody, fieldName1); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( knnEngine1.score(1 - KNNScoringUtil.cosinesimil(testData.queries[i], primitiveArray), spaceType1), actualScores.get(j), @@ -159,7 +158,7 @@ public void testEndToEnd() throws Exception { actualScores = parseSearchResponseScore(responseBody, fieldName2); for (int j = 0; j < k; j++) { - float[] primitiveArray = Floats.toArray(Arrays.stream(knnResults.get(j).getVector()).collect(Collectors.toList())); + float[] primitiveArray = knnResults.get(j).getVector(); assertEquals( knnEngine2.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType2), actualScores.get(j), diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 4423c85d8..19270717d 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -57,7 +57,7 @@ private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { createKNNFloatVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_FLOAT_INDEX_FIELD_NAME, VectorDataType.FLOAT @@ -70,7 +70,7 @@ private KNNVectorScriptDocValues getKNNByteVectorScriptDocValues() { createKNNByteVectorDocument(directory); reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - return new KNNVectorScriptDocValues( + return KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME), VectorDataTypeTests.MOCK_BYTE_INDEX_FIELD_NAME, VectorDataType.BYTE diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 8c43a4acf..22110accd 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -280,7 +280,7 @@ public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOEx if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - scriptDocValues = new KNNVectorScriptDocValues( + scriptDocValues = KNNVectorScriptDocValues.create( leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName, VectorDataType.FLOAT diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 59c4f8c0e..8d014afec 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -5,8 +5,11 @@ package org.opensearch.knn.plugin.script; +import java.util.function.BiFunction; +import java.util.function.Function; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.opensearch.client.Request; @@ -21,6 +24,9 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.script.Script; import java.util.ArrayList; @@ -37,214 +43,19 @@ public class KNNScriptScoringIT extends KNNRestTestCase { public void testKNNL2ScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 2.0f, 2.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 4.0f, 4.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [2.0, 2.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.L2.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("2", "4", "3", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("2", results.get(0).getDocId()); - assertEquals("4", results.get(1).getDocId()); - assertEquals("3", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.L2); } public void testKNNL1ScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 4.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 5.0f, 5.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [1.0, 1.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.L1); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("2", "4", "3", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("2", results.get(0).getDocId()); - assertEquals("3", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.L1); } public void testKNNLInfScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 6.0f, 6.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 4.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 3.0f, 3.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 5.0f, 5.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "vector": [1.0, 1.0] - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.LINF.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("3", "2", "4", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("3", results.get(0).getDocId()); - assertEquals("2", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.LINF); } public void testKNNCosineScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { 1.0f, -1.0f }; - addKnnDoc(INDEX_NAME, "0", FIELD_NAME, f1); - - Float[] f2 = { 1.0f, 0.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f2); - - Float[] f3 = { 1.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f3); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "query_value": [2.0, 2.0], - * "space_type": "L2" - * } - * - * - */ - float[] queryVector = { 2.0f, -2.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.COSINESIMIL.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("0", "1", "2"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(3, results.size()); - - // assert document order - assertEquals("0", results.get(0).getDocId()); - assertEquals("1", results.get(1).getDocId()); - assertEquals("2", results.get(2).getDocId()); + testKNNScriptScore(SpaceType.COSINESIMIL); } public void testKNNInvalidSourceScript() throws Exception { @@ -396,10 +207,7 @@ public void testKNNScoreforNonVectorDocument() throws Exception { responseBody ).map().get("hits")).get("hits"); - List docIds = hits.stream().map(hit -> { - String id = ((String) ((Map) hit).get("_id")); - return id; - }).collect(Collectors.toList()); + List docIds = hits.stream().map(hit -> ((String) ((Map) hit).get("_id"))).collect(Collectors.toList()); // assert document order assertEquals("1", docIds.get(0)); assertEquals("0", docIds.get(1)); @@ -633,57 +441,7 @@ public void testHammingScriptScore_Base64() throws Exception { } public void testKNNInnerProdScriptScore() throws Exception { - /* - * Create knn index and populate data - */ - createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); - Float[] f1 = { -2.0f, -2.0f }; - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); - - Float[] f2 = { 1.0f, 1.0f }; - addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); - - Float[] f3 = { 2.0f, 2.0f }; - addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); - - Float[] f4 = { 2.0f, -2.0f }; - addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); - - /** - * Construct Search Request - */ - QueryBuilder qb = new MatchAllQueryBuilder(); - Map params = new HashMap<>(); - /* - * params": { - * "field": "my_dense_vector", - * "query_value": [1.0, 1.0], - * "space_type": "innerproduct", - * } - */ - float[] queryVector = { 1.0f, 1.0f }; - params.put("field", FIELD_NAME); - params.put("query_value", queryVector); - params.put("space_type", SpaceType.INNER_PRODUCT.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); - List expectedDocids = Arrays.asList("3", "2", "4", "1"); - - List actualDocids = new ArrayList<>(); - for (KNNResult result : results) { - actualDocids.add(result.getDocId()); - } - - assertEquals(4, results.size()); - - // assert document order - assertEquals("3", results.get(0).getDocId()); - assertEquals("2", results.get(1).getDocId()); - assertEquals("4", results.get(2).getDocId()); - assertEquals("1", results.get(3).getDocId()); + testKNNScriptScore(SpaceType.INNER_PRODUCT); } public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { @@ -791,4 +549,121 @@ public void testKNNScriptScoreWithRequestCacheEnabled() throws Exception { // assert that the request cache was hit at second request assertEquals(1, secondQueryCacheMap.get("hit_count")); } + + private List createMappers(int dimensions) throws Exception { + return List.of( + createKnnIndexMapping(FIELD_NAME, dimensions), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + true + ), + createKnnIndexMapping( + FIELD_NAME, + dimensions, + KNNConstants.METHOD_HNSW, + KNNEngine.LUCENE.getName(), + SpaceType.DEFAULT.getValue(), + false + ) + ); + } + + private float[] randomVector(int dimensions) { + final float[] vector = new float[dimensions]; + for (int i = 0; i < dimensions; i++) { + vector[i] = randomFloat(); + } + return vector; + } + + private Map createDataset(Function scoreFunction, int dimensions, int numDocs) { + final Map dataset = new HashMap<>(numDocs); + for (int i = 0; i < numDocs; i++) { + final float[] vector = randomVector(dimensions); + final float score = scoreFunction.apply(vector); + dataset.put(Integer.toString(i), new KNNResult(Integer.toString(i), vector, score)); + } + return dataset; + } + + private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + FIELD_NAME, + Collections.emptyMap(), + queryVector.length, + VectorDataType.FLOAT, + null + ); + List target = new ArrayList<>(queryVector.length); + for (float f : queryVector) { + target.add(f); + } + KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create(spaceType.getValue(), target, knnVectorFieldType); + switch (spaceType) { + case L1: + return ((KNNScoringSpace.L1) knnScoringSpace).scoringMethod; + case L2: + return ((KNNScoringSpace.L2) knnScoringSpace).scoringMethod; + case LINF: + return ((KNNScoringSpace.LInf) knnScoringSpace).scoringMethod; + case COSINESIMIL: + return ((KNNScoringSpace.CosineSimilarity) knnScoringSpace).scoringMethod; + case INNER_PRODUCT: + return ((KNNScoringSpace.InnerProd) knnScoringSpace).scoringMethod; + default: + throw new IllegalArgumentException(); + } + } + + private void testKNNScriptScore(SpaceType spaceType) throws Exception { + final int dims = randomIntBetween(2, 10); + final float[] queryVector = randomVector(dims); + final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); + for (String mapper : createMappers(dims)) { + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector); + } + } + + private void createIndexAndAssertScriptScore( + String mapper, + SpaceType spaceType, + BiFunction scoreFunction, + int dimensions, + float[] queryVector + ) throws Exception { + /* + * Create knn index and populate data + */ + createKnnIndex(INDEX_NAME, mapper); + Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, randomIntBetween(4, 10)); + for (Map.Entry entry : dataset.entrySet()) { + addKnnDoc(INDEX_NAME, entry.getKey(), FIELD_NAME, entry.getValue().getVector()); + } + + /** + * Construct Search Request + */ + QueryBuilder qb = new MatchAllQueryBuilder(); + Map params = new HashMap<>(); + /* + * params": { + * "field": FIELD_NAME, + * "vector": queryVector + * } + */ + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", spaceType.getValue()); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertTrue(results.stream().allMatch(r -> dataset.get(r.getDocId()).equals(r))); + deleteKNNIndex(INDEX_NAME); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java index 5fa88b0a5..5325d1205 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/PainlessScriptIT.java @@ -53,6 +53,10 @@ protected String createMapping(List properties) throws IOExcept builder.field("dimension", property.getDimension()); } + if (property.getDocValues() != null) { + builder.field("doc_values", property.getDocValues()); + } + if (property.getKnnMethodContext() != null) { builder.startObject(KNNConstants.KNN_METHOD); property.getKnnMethodContext().toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -554,12 +558,14 @@ public void testScriptedMetricIsSupported() throws Exception { public void testL2ScriptingWithLuceneBackedIndex() throws Exception { List properties = new ArrayList<>(); KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.NMSLIB, + KNNEngine.LUCENE, SpaceType.DEFAULT, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); properties.add( - new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2").knnMethodContext(knnMethodContext) + new MappingProperty(FIELD_NAME, KNNVectorFieldMapper.CONTENT_TYPE).dimension("2") + .knnMethodContext(knnMethodContext) + .docValues(randomBoolean()) ); String source = String.format("1/(1 + l2Squared([1.0f, 1.0f], doc['%s']))", FIELD_NAME); @@ -585,6 +591,7 @@ static class MappingProperty { private String dimension; private KNNMethodContext knnMethodContext; + private Boolean docValues; MappingProperty(String name, String type) { this.name = name; @@ -601,6 +608,11 @@ MappingProperty knnMethodContext(KNNMethodContext knnMethodContext) { return this; } + MappingProperty docValues(boolean docValues) { + this.docValues = docValues; + return this; + } + KNNMethodContext getKnnMethodContext() { return knnMethodContext; } @@ -616,5 +628,9 @@ String getName() { String getType() { return type; } + + Boolean getDocValues() { + return docValues; + } } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 0b6ae3a5e..68255388b 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -245,10 +245,16 @@ protected List parseSearchResponse(String responseBody, String fieldN @SuppressWarnings("unchecked") List knnSearchResponses = hits.stream().map(hit -> { @SuppressWarnings("unchecked") - Float[] vector = Arrays.stream( - ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() - ).map(Object::toString).map(Float::valueOf).toArray(Float[]::new); - return new KNNResult((String) ((Map) hit).get("_id"), vector); + final float[] vector = Floats.toArray( + Arrays.stream( + ((ArrayList) ((Map) ((Map) hit).get("_source")).get(fieldName)).toArray() + ).map(Object::toString).map(Float::valueOf).collect(Collectors.toList()) + ); + return new KNNResult( + (String) ((Map) hit).get("_id"), + vector, + ((Double) ((Map) hit).get("_score")).floatValue() + ); }).collect(Collectors.toList()); return knnSearchResponses; @@ -329,20 +335,7 @@ protected String createKnnIndexMapping(String fieldName, Integer dimensions) thr * Utility to create a Knn Index Mapping with specific algorithm and engine */ protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine) throws IOException { - return XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimensions.toString()) - .startObject("method") - .field("name", algoName) - .field("engine", knnEngine) - .endObject() - .endObject() - .endObject() - .endObject() - .toString(); + return this.createKnnIndexMapping(fieldName, dimensions, algoName, knnEngine, SpaceType.DEFAULT.getValue()); } /** @@ -350,12 +343,27 @@ protected String createKnnIndexMapping(String fieldName, Integer dimensions, Str */ protected String createKnnIndexMapping(String fieldName, Integer dimensions, String algoName, String knnEngine, String spaceType) throws IOException { + return this.createKnnIndexMapping(fieldName, dimensions, algoName, knnEngine, spaceType, true); + } + + /** + * Utility to create a Knn Index Mapping with specific algorithm, engine, spaceType and docValues + */ + protected String createKnnIndexMapping( + String fieldName, + Integer dimensions, + String algoName, + String knnEngine, + String spaceType, + boolean docValues + ) throws IOException { return XContentFactory.jsonBuilder() .startObject() .startObject("properties") .startObject(fieldName) .field(KNNConstants.TYPE, KNNConstants.TYPE_KNN_VECTOR) .field(KNNConstants.DIMENSION, dimensions.toString()) + .field("doc_values", docValues) .startObject(KNNConstants.KNN_METHOD) .field(KNNConstants.NAME, algoName) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType) @@ -480,7 +488,7 @@ protected void forceMergeKnnIndex(String index, int maxSegments) throws Exceptio /** * Add a single KNN Doc to an index */ - protected void addKnnDoc(String index, String docId, String fieldName, Object[] vector) throws IOException { + protected void addKnnDoc(String index, String docId, String fieldName, T vector) throws IOException { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field(fieldName, vector).endObject(); @@ -1041,8 +1049,7 @@ public float[][] getIndexVectorsFromIndex(String testIndex, String testField, in int i = 0; for (KNNResult result : results) { - float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList())); - vectors[i++] = primitiveArray; + vectors[i++] = result.getVector(); } return vectors; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNResult.java b/src/testFixtures/java/org/opensearch/knn/KNNResult.java index 803c2ae72..ee2ba39f7 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNResult.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNResult.java @@ -5,20 +5,41 @@ package org.opensearch.knn; +import java.util.Arrays; +import java.util.Objects; + public class KNNResult { + private final static float delta = 1e-3f; + private String docId; - private Float[] vector; + private float[] vector; + private Float score; - public KNNResult(String docId, Float[] vector) { + public KNNResult(String docId, float[] vector, Float score) { this.docId = docId; this.vector = vector; + this.score = score; } public String getDocId() { return docId; } - public Float[] getVector() { + public float[] getVector() { return vector; } + + public Float getScore() { + return score; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KNNResult knnResult = (KNNResult) o; + return Objects.equals(docId, knnResult.docId) + && Arrays.equals(vector, knnResult.vector) + && (Float.compare(score, knnResult.score) == 0 || Math.abs(score - knnResult.score) <= delta); + } }