From ad4ce3fb9c839454e993ef84a6f840ddce151424 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Sun, 9 Jul 2023 20:21:49 -0500 Subject: [PATCH] Add more unit-tests for Lucene Byte Vector Signed-off-by: Naveen Tatikonda --- .../index/mapper/KNNVectorFieldMapper.java | 4 +- .../index/KNNVectorDVLeafFieldDataTests.java | 26 ++++--- .../knn/index/VectorDataTypeTests.java | 58 +++++++++++++++ .../mapper/KNNVectorFieldMapperTests.java | 11 +++ .../knn/index/query/KNNQueryBuilderTests.java | 23 +++++- .../knn/index/query/KNNQueryFactoryTests.java | 20 +++++ .../script/KNNScoringSpaceUtilTests.java | 23 ++++++ .../plugin/script/KNNScoringUtilTests.java | 73 +++++++++++++++++-- 8 files changed, 219 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 346d4c238..5d2b232d7 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.mapper; +import com.google.common.annotations.VisibleForTesting; import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.opensearch.common.ValidationException; @@ -110,7 +111,8 @@ public static class Builder extends ParametrizedFieldMapper.Builder { * data_type which defines the datatype of the vector values. This is an optional parameter and * this is right now only relevant for lucene engine. The default value is float. */ - private final Parameter vectorDataType = new Parameter<>( + @VisibleForTesting + protected final Parameter vectorDataType = new Parameter<>( VECTOR_DATA_TYPE_FIELD, false, () -> DEFAULT_VECTOR_DATA_TYPE_FIELD, diff --git a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java index cbe11dd6b..c96040310 100644 --- a/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNVectorDVLeafFieldDataTests.java @@ -61,15 +61,12 @@ public void tearDown() throws Exception { directory.close(); } - public void testGetScriptValues() { - KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( - leafReaderContext.reader(), - MOCK_INDEX_FIELD_NAME, - VectorDataType.FLOAT - ); - ScriptDocValues scriptValues = leafFieldData.getScriptValues(); - assertNotNull(scriptValues); - assertTrue(scriptValues instanceof KNNVectorScriptDocValues); + public void testGetScriptValuesFloatVectorDataType() { + validateGetScriptValuesWithVectorDataType(VectorDataType.FLOAT); + } + + public void testGetScriptValuesByteVectorDataType() { + validateGetScriptValuesWithVectorDataType(VectorDataType.BYTE); } public void testGetScriptValuesWrongFieldName() { @@ -87,6 +84,17 @@ public void testGetScriptValuesWrongFieldType() { expectThrows(IllegalStateException.class, () -> leafFieldData.getScriptValues()); } + private void validateGetScriptValuesWithVectorDataType(VectorDataType vectorDataType) { + KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData( + leafReaderContext.reader(), + MOCK_INDEX_FIELD_NAME, + vectorDataType + ); + ScriptDocValues scriptValues = leafFieldData.getScriptValues(); + assertNotNull(scriptValues); + assertTrue(scriptValues instanceof KNNVectorScriptDocValues); + } + public void testRamBytesUsed() { KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT); assertEquals(0, leafFieldData.ramBytesUsed()); diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java index 4423c85d8..c6a9e2bea 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java @@ -17,8 +17,12 @@ import org.apache.lucene.tests.analysis.MockAnalyzer; import org.junit.Assert; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; import java.io.IOException; +import java.util.Locale; + +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class VectorDataTypeTests extends KNNTestCase { @@ -51,6 +55,60 @@ public void testGetDocValuesWithByteVectorDataType() { directory.close(); } + public void testFloatVectorValueValidations() { + // Validate Float Vector Value which is NaN and throws exception + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.validateFloatVectorValue(Float.NaN) + ); + assertTrue(ex.getMessage().contains("KNN vector values cannot be NaN")); + + // Validate Float Vector Value which is infinite and throws exception + IllegalArgumentException ex1 = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.validateFloatVectorValue(Float.POSITIVE_INFINITY) + ); + assertTrue(ex1.getMessage().contains("KNN vector values cannot be infinity")); + } + + public void testByteVectorValueValidations() { + // Validate Byte Vector Value which is float with decimal values and throws exception + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.validateByteVectorValue(10.54f) + ); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + ) + ); + + // Validate Byte Vector Value which is not in the byte range and throws exception + IllegalArgumentException ex1 = expectThrows( + IllegalArgumentException.class, + () -> KNNVectorFieldMapperUtil.validateByteVectorValue(200f) + ); + assertTrue( + ex1.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ) + ); + } + @SneakyThrows private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() { directory = newDirectory(); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 1f3598781..1f27144f6 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -7,8 +7,10 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.BytesRef; @@ -218,6 +220,7 @@ public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOExcep .startObject() .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) .field(DIMENSION_FIELD_NAME, dimension) + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue()) .startObject(KNN_METHOD) .field(NAME, METHOD_HNSW) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) @@ -237,6 +240,7 @@ public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOExcep builder.build(builderContext); assertEquals(METHOD_HNSW, builder.knnMethodContext.get().getMethodComponent().getName()); + assertEquals(VectorDataType.BYTE.getValue(), builder.vectorDataType.getValue().getValue()); assertEquals( efConstruction, builder.knnMethodContext.get().getMethodComponent().getParameters().get(METHOD_PARAMETER_EF_CONSTRUCTION) @@ -871,6 +875,13 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); } + public void testBuildDocValuesFieldType() { + FieldType fieldType = KNNVectorFieldMapperUtil.buildDocValuesFieldType(KNNEngine.LUCENE); + assertNotNull(fieldType); + assertEquals(KNNEngine.LUCENE.getName(), fieldType.getAttributes().get(KNN_ENGINE)); + assertEquals(DocValuesType.BINARY, fieldType.docValuesType()); + } + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( VectorDataType vectorDataType ) { diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 4e7f739a7..0ea5ce713 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.opensearch.Version; @@ -146,7 +147,7 @@ protected NamedWriteableRegistry writableRegistry() { return new NamedWriteableRegistry(entries); } - public void testDoToQuery_Normal() throws Exception { + public void testDoToQuery_Normal() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); Index dummyIndex = new Index("dummy", "dummy"); @@ -162,6 +163,26 @@ public void testDoToQuery_Normal() throws Exception { assertEquals(knnQueryBuilder.vector(), query.getQueryVector()); } + public void testDoToQuery_Normal_ByteVectorDataType() { + // Validate doToQuery with Byte vector data type + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); + Index dummyIndex = new Index("dummy", "dummy"); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + KNNMethodContext mockKNNMethodContext = mock(KNNMethodContext.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn(mockKNNMethodContext); + when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BYTE); + when(mockKNNMethodContext.getKnnEngine()).thenReturn(KNNEngine.LUCENE); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + assertNotNull(query); + assertTrue(query.getClass().isAssignableFrom(KnnByteVectorQuery.class)); + } + public void testDoToQuery_KnnQueryWithFilter() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 4dccfd087..168b40567 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import org.apache.lucene.index.Term; +import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; @@ -15,6 +16,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.util.Arrays; @@ -73,6 +75,24 @@ public void testCreateLuceneDefaultQuery() { } } + public void testCreateLuceneQueryByteVectorDataType() { + byte[] byteQueryVector = { 1, 2, 3, 4 }; + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.LUCENE) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(null) + .byteVector(byteQueryVector) + .vectorDataType(VectorDataType.BYTE) + .k(testK) + .filter(null) + .context(mockQueryShardContext) + .build(); + Query query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query.getClass().isAssignableFrom(KnnByteVectorQuery.class)); + } + public void testCreateLuceneQueryWithFilter() { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index b5bc4b95f..82e259e99 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -75,4 +75,27 @@ public void testParseKNNVectorQuery() { String invalidObject = "invalidObject"; expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT)); } + + public void testParseKNNVectorQueryByteVectorDataType() { + float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; + List arrayListQueryObject = new ArrayList<>(Arrays.asList(1, 2, 3)); + KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(fieldType.getDimension()).thenReturn(3); + // Query vector is a byte vector, so test should succeed + assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.BYTE), 0.1f); + + // Query vector is a float vector for byte vector data type, so test should throw IllegalArgumentException + List arrayListQueryObject1 = new ArrayList<>(Arrays.asList(1.1, 2.56, 3.67)); + expectThrows( + IllegalArgumentException.class, + () -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject1, 3, VectorDataType.BYTE) + ); + + // Query vector is not within the byte range for byte vector data type, so test should throw IllegalArgumentException + List arrayListQueryObject2 = new ArrayList<>(Arrays.asList(1000, 2, 3)); + expectThrows( + IllegalArgumentException.class, + () -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject2, 3, VectorDataType.BYTE) + ); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java index 4a2bb7254..9107bff50 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringUtilTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.script; +import lombok.SneakyThrows; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNVectorScriptDocValues; import org.opensearch.knn.index.VectorDataType; @@ -23,6 +24,7 @@ import java.io.IOException; import java.math.BigInteger; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; public class KNNScoringUtilTests extends KNNTestCase { @@ -174,7 +176,7 @@ public void testL2SquaredAllowlistedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.FLOAT); scriptDocValues.setNextDocId(0); Float distance = KNNScoringUtil.l2Squared(queryVector, scriptDocValues); assertEquals(27.0f, distance, 0.1f); @@ -185,16 +187,60 @@ public void testScriptDocValuesFailsL2() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.FLOAT); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.l2Squared(queryVector, scriptDocValues)); dataset.close(); } + public void testL2SquaredScoringFunctionByteVectorDataType() throws IOException { + List queryVector = getTestQueryVector(); + TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); + dataset.createKNNVectorDocument(new byte[] { 4, 4, 4 }, "test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.BYTE); + scriptDocValues.setNextDocId(0); + Float distance = KNNScoringUtil.l2Squared(queryVector, scriptDocValues); + assertEquals(27.0f, distance, 0.1f); + dataset.close(); + } + + @SneakyThrows + public void testL2SquaredScoringFunctionByteVectorDataTypeThrowsException() { + // Float query vector for byte vector data type which throws IllegalArgumentException + List queryVector = Arrays.asList(10.56, 54.0, 65.0); + TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); + dataset.createKNNVectorDocument(new byte[] { 4, 4, 4 }, "test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.BYTE); + scriptDocValues.setNextDocId(0); + expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.l2Squared(queryVector, scriptDocValues)); + dataset.close(); + + // Invalid byte query vector(outside of range [-128, 127]) for byte vector data type which throws IllegalArgumentException + List queryVector1 = Arrays.asList(1056, 540, 650); + TestKNNScriptDocValues dataset1 = new TestKNNScriptDocValues(); + dataset1.createKNNVectorDocument(new byte[] { 4, 4, 4 }, "test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues1 = dataset1.getScriptDocValues("test-index-field-name", VectorDataType.BYTE); + scriptDocValues1.setNextDocId(0); + expectThrows(IllegalArgumentException.class, () -> KNNScoringUtil.l2Squared(queryVector1, scriptDocValues1)); + dataset1.close(); + } + public void testCosineSimilarityScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.FLOAT); + scriptDocValues.setNextDocId(0); + + Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues); + assertEquals(1.0f, actualScore, 0.0001); + dataset.close(); + } + + public void testCosineSimilarityScoringFunctionByteVectorDataType() throws IOException { + List queryVector = getTestQueryVector(); + TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); + dataset.createKNNVectorDocument(new byte[] { 4, 4, 4 }, "test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.BYTE); scriptDocValues.setNextDocId(0); Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues); @@ -206,7 +252,7 @@ public void testScriptDocValuesFailsCosineSimilarity() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.FLOAT); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues)); dataset.close(); } @@ -215,7 +261,7 @@ public void testCosineSimilarityOptimizedScoringFunction() throws IOException { List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.FLOAT); scriptDocValues.setNextDocId(0); Float actualScore = KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f); assertEquals(1.0f, actualScore, 0.0001); @@ -226,7 +272,7 @@ public void testScriptDocValuesFailsCosineSimilarityOptimized() throws IOExcepti List queryVector = getTestQueryVector(); TestKNNScriptDocValues dataset = new TestKNNScriptDocValues(); dataset.createKNNVectorDocument(new float[] { 4.0f, 4.0f, 4.0f }, "test-index-field-name"); - KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name"); + KNNVectorScriptDocValues scriptDocValues = dataset.getScriptDocValues("test-index-field-name", VectorDataType.FLOAT); expectThrows(IllegalStateException.class, () -> KNNScoringUtil.cosineSimilarity(queryVector, scriptDocValues, 3.0f)); dataset.close(); } @@ -240,14 +286,14 @@ class TestKNNScriptDocValues { directory = newDirectory(); } - public KNNVectorScriptDocValues getScriptDocValues(String fieldName) throws IOException { + private KNNVectorScriptDocValues getScriptDocValues(String fieldName, VectorDataType vectorDataType) throws IOException { if (scriptDocValues == null) { reader = DirectoryReader.open(directory); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); scriptDocValues = new KNNVectorScriptDocValues( leafReaderContext.reader().getBinaryDocValues(fieldName), fieldName, - VectorDataType.FLOAT + vectorDataType ); } return scriptDocValues; @@ -268,5 +314,16 @@ public void createKNNVectorDocument(final float[] content, final String fieldNam writer.commit(); writer.close(); } + + public void createKNNVectorDocument(final byte[] content, final String fieldName) throws IOException { + IndexWriterConfig conf = newIndexWriterConfig(new MockAnalyzer(random())); + IndexWriter writer = new IndexWriter(directory, conf); + conf.setMergePolicy(NoMergePolicy.INSTANCE); // prevent merges for this test + Document knnDocument = new Document(); + knnDocument.add(new BinaryDocValuesField(fieldName, new VectorField(fieldName, content, new FieldType()).binaryValue())); + writer.addDocument(knnDocument); + writer.commit(); + writer.close(); + } } }