From 4227c96df32204bb6c1b827db2d5ce398b454afb Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 15 Jun 2023 16:14:38 -0500 Subject: [PATCH 1/4] Add Indexing Support for Lucene Byte Sized Vector Signed-off-by: Naveen Tatikonda --- .../opensearch/knn/common/KNNConstants.java | 5 + .../opensearch/knn/index/VectorDataType.java | 293 ++++++++++++++++++ .../org/opensearch/knn/index/VectorField.java | 15 + .../index/mapper/KNNVectorFieldMapper.java | 119 +++++-- .../knn/index/mapper/LuceneFieldMapper.java | 59 ++-- 5 files changed, 440 insertions(+), 51 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/VectorDataType.java diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 87d7a1c21..47ce0c957 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -5,6 +5,8 @@ package org.opensearch.knn.common; +import org.opensearch.knn.index.VectorDataType; + public class KNNConstants { // shared across library constants public static final String DIMENSION = "dimension"; @@ -50,6 +52,9 @@ public class KNNConstants { public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count"; public static final String SEARCH_SIZE_PARAMETER = "search_size"; + public static final String VECTOR_DATA_TYPE = "data_type"; + public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE = VectorDataType.FLOAT; + // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java new file mode 100644 index 000000000..70db606c0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -0,0 +1,293 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexableFieldType; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.knn.index.util.KNNEngine; + +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; + +/** + * Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin. + * We have two vector data_types, one is float (default) and the other one is byte. + */ +public enum VectorDataType { + BYTE("byte") { + /** + * @param dimension Dimension of the vector + * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType + * @return FieldType of type KnnByteVectorField + */ + @Override + public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { + return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); + } + + /** + * @param knnEngine KNNEngine + * @return DocValues FieldType of type Binary and with BYTE VectorEncoding + */ + @Override + public FieldType buildDocValuesFieldType(KNNEngine knnEngine) { + IndexableFieldType indexableFieldType = new IndexableFieldType() { + @Override + public boolean stored() { + return false; + } + + @Override + public boolean tokenized() { + return true; + } + + @Override + public boolean storeTermVectors() { + return false; + } + + @Override + public boolean storeTermVectorOffsets() { + return false; + } + + @Override + public boolean storeTermVectorPositions() { + return false; + } + + @Override + public boolean storeTermVectorPayloads() { + return false; + } + + @Override + public boolean omitNorms() { + return false; + } + + @Override + public IndexOptions indexOptions() { + return IndexOptions.NONE; + } + + @Override + public DocValuesType docValuesType() { + return DocValuesType.NONE; + } + + @Override + public int pointDimensionCount() { + return 0; + } + + @Override + public int pointIndexDimensionCount() { + return 0; + } + + @Override + public int pointNumBytes() { + return 0; + } + + @Override + public int vectorDimension() { + return 0; + } + + @Override + public VectorEncoding vectorEncoding() { + return VectorEncoding.BYTE; + } + + @Override + public VectorSimilarityFunction vectorSimilarityFunction() { + return VectorSimilarityFunction.EUCLIDEAN; + } + + @Override + public Map getAttributes() { + return null; + } + }; + FieldType field = new FieldType(indexableFieldType); + field.putAttribute(KNN_ENGINE, knnEngine.getName()); + field.setDocValuesType(DocValuesType.BINARY); + field.freeze(); + return field; + } + }, + FLOAT("float") { + /** + * @param dimension Dimension of the vector + * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType + * @return FieldType of type KnnFloatVectorField + */ + @Override + public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { + return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); + } + + /** + * @param knnEngine KNNEngine + * @return DocValues FieldType of type Binary and with FLOAT32 VectorEncoding + */ + @Override + public FieldType buildDocValuesFieldType(KNNEngine knnEngine) { + FieldType field = new FieldType(); + field.putAttribute(KNN_ENGINE, knnEngine.getName()); + field.setDocValuesType(DocValuesType.BINARY); + field.freeze(); + return field; + } + + }; + + private final String value; + + VectorDataType(String value) { + this.value = value; + } + + /** + * Get VectorDataType name + * + * @return name + */ + public String getValue() { + return value; + } + + public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); + + public abstract FieldType buildDocValuesFieldType(KNNEngine knnEngine); + + /** + * @return Set of names of all the supporting VectorDataTypes + */ + public static Set getValues() { + Set values = new HashSet<>(); + + for (VectorDataType dataType : VectorDataType.values()) { + values.add(dataType.getValue()); + } + return values; + } + + /** + * Validates if given VectorDataType is in the list of supported data types. + * @param vectorDataType VectorDataType + * @return the same VectorDataType if it is in the supported values else throw exception. + */ + public static VectorDataType get(String vectorDataType) { + String supportedTypes = String.join(",", getValues()); + Objects.requireNonNull( + vectorDataType, + String.format("[{}] should not be null. Supported types are [{}]", VECTOR_DATA_TYPE, supportedTypes) + ); + for (VectorDataType currentDataType : VectorDataType.values()) { + if (currentDataType.getValue().equalsIgnoreCase(vectorDataType)) { + return currentDataType; + } + } + throw new IllegalArgumentException( + String.format( + "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", + VECTOR_DATA_TYPE, + vectorDataType, + supportedTypes + ) + ); + } + + /** + * Validate the float vector values if it is a number and in the finite range. + * + * @param value float vector value + */ + public static void validateFloatVectorValues(float value) { + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + } + + /** + * Validate the float vector value in the byte range if it is a finite number, + * with no decimal values and in the byte range of [-128 to 127]. + * + * @param value float value in byte range + */ + public static void validateByteVectorValues(float value) { + validateFloatVectorValues(value); + if (value % 1 != 0) { + throw new IllegalArgumentException( + "[data_type] field was set as [byte] in index mapping. But, KNN vector values are floats instead of byte integers" + ); + } + if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { + throw new IllegalArgumentException( + String.format( + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [{}, {}]", + VECTOR_DATA_TYPE, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ); + } + } + + /** + * Validate if the given vector size matches with the dimension provided in mapping. + * + * @param dimension dimension of vector + * @param vectorSize size of the vector + */ + public static void validateVectorDimension(int dimension, int vectorSize) { + if (dimension != vectorSize) { + String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); + throw new IllegalArgumentException(errorMessage); + } + + } + + /** + * Validates and throws exception if data_type field is set in the index mapping + * using any VectorDataType (other than float, which is default) with any engine (except lucene). + * + * @param knnMethodContext KNNMethodContext Parameter + * @param vectorDataType VectorDataType Parameter + */ + public static void validateVectorDataType_Engine( + ParametrizedFieldMapper.Parameter knnMethodContext, + ParametrizedFieldMapper.Parameter vectorDataType + ) { + if (vectorDataType.getValue() != DEFAULT_VECTOR_DATA_TYPE + && (knnMethodContext.get() == null || knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE)) { + throw new IllegalArgumentException(String.format("[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE, LUCENE_NAME)); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/VectorField.java b/src/main/java/org/opensearch/knn/index/VectorField.java index c88630f6c..2c346992d 100644 --- a/src/main/java/org/opensearch/knn/index/VectorField.java +++ b/src/main/java/org/opensearch/knn/index/VectorField.java @@ -23,4 +23,19 @@ public VectorField(String name, float[] value, IndexableFieldType type) { throw new RuntimeException(e); } } + + /** + * @param name FieldType name + * @param value an array of byte vector values + * @param type FieldType to build DocValues + */ + public VectorField(String name, byte[] value, IndexableFieldType type) { + super(name, new BytesRef(), type); + try { + this.setBytesValue(value); + } catch (Exception e) { + throw new IllegalArgumentException(e); + } + + } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index ab45c384f..27416a7fb 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -35,6 +35,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.KNNVectorIndexFieldData; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -49,7 +50,13 @@ import java.util.Optional; import java.util.function.Supplier; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; +import static org.opensearch.knn.index.VectorDataType.validateByteVectorValues; +import static org.opensearch.knn.index.VectorDataType.validateFloatVectorValues; +import static org.opensearch.knn.index.VectorDataType.validateVectorDataType_Engine; +import static org.opensearch.knn.index.VectorDataType.validateVectorDimension; /** * Field Mapper for KNN vector type. @@ -96,6 +103,18 @@ public static class Builder extends ParametrizedFieldMapper.Builder { return value; }, m -> toType(m).dimension); + /** + * data_type which defines the datatype of the vector values. This is an optional parameter and + * this is right now only relevant for lucene engine. The default value is float. + */ + protected final Parameter vectorDataType = new Parameter<>( + VECTOR_DATA_TYPE, + false, + () -> DEFAULT_VECTOR_DATA_TYPE, + (n, c, o) -> VectorDataType.get((String) o), + m -> toType(m).vectorDataType + ); + /** * modelId provides a way for a user to generate the underlying library indices from an already serialized * model template index. If this parameter is set, it will take precedence. This parameter is only relevant for @@ -168,7 +187,7 @@ public Builder(String name, String spaceType, String m, String efConstruction) { @Override protected List> getParameters() { - return Arrays.asList(stored, hasDocValues, dimension, meta, knnMethodContext, modelId); + return Arrays.asList(stored, hasDocValues, dimension, vectorDataType, meta, knnMethodContext, modelId); } protected Explicit ignoreMalformed(BuilderContext context) { @@ -203,7 +222,8 @@ public KNNVectorFieldMapper build(BuilderContext context) { buildFullName(context), metaValue, dimension.getValue(), - knnMethodContext + knnMethodContext, + vectorDataType.getValue() ); if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) { log.debug(String.format("Use [LuceneFieldMapper] mapper for field [%s]", name)); @@ -216,6 +236,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { .ignoreMalformed(ignoreMalformed) .stored(stored.get()) .hasDocValues(hasDocValues.get()) + .vectorDataType(vectorDataType.getValue()) .knnMethodContext(knnMethodContext) .build(); return new LuceneFieldMapper(createLuceneFieldMapperInput); @@ -327,6 +348,10 @@ public Mapper.Builder parse(String name, Map node, ParserCont throw new IllegalArgumentException(String.format("Dimension value missing for vector: %s", name)); } + // Validates and throws exception if data_type field is set in the index mapping + // using any VectorDataType (other than float, which is default) with any engine (except lucene). + validateVectorDataType_Engine(builder.knnMethodContext, builder.vectorDataType); + return builder; } } @@ -336,20 +361,43 @@ public static class KNNVectorFieldType extends MappedFieldType { int dimension; String modelId; KNNMethodContext knnMethodContext; + VectorDataType vectorDataType; public KNNVectorFieldType(String name, Map meta, int dimension) { - this(name, meta, dimension, null, null); + this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null); + this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { + this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE); + } + + public KNNVectorFieldType( + String name, + Map meta, + int dimension, + KNNMethodContext knnMethodContext, + VectorDataType vectorDataType + ) { + this(name, meta, dimension, knnMethodContext, null, vectorDataType); + } + + public KNNVectorFieldType( + String name, + Map meta, + int dimension, + KNNMethodContext knnMethodContext, + String modelId, + VectorDataType vectorDataType + ) { super(name, false, false, true, TextSearchInfo.NONE, meta); this.dimension = dimension; this.modelId = modelId; this.knnMethodContext = knnMethodContext; + this.vectorDataType = vectorDataType; } @Override @@ -386,6 +434,7 @@ public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, S protected boolean stored; protected boolean hasDocValues; protected Integer dimension; + protected VectorDataType vectorDataType; protected ModelDao modelDao; // These members map to parameters in the builder. They need to be declared in the abstract class due to the @@ -408,6 +457,7 @@ public KNNVectorFieldMapper( this.stored = stored; this.hasDocValues = hasDocValues; this.dimension = mappedFieldType.getDimension(); + this.vectorDataType = mappedFieldType.getVectorDataType(); updateEngineStats(); } @@ -459,50 +509,65 @@ void validateIfKNNPluginEnabled() { } } - Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { + // Returns an optional array of byte values where each value in the vector is parsed as a float and validated + // if it is a finite number without any decimals and within the byte range of [-128 to 127]. + Optional getBytesFromContext(ParseContext context, int dimension) throws IOException { context.path().add(simpleName()); - ArrayList vector = new ArrayList<>(); + ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); float value; + if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { value = context.parser().floatValue(); - - if (Float.isNaN(value)) { - throw new IllegalArgumentException("KNN vector values cannot be NaN"); - } - - if (Float.isInfinite(value)) { - throw new IllegalArgumentException("KNN vector values cannot be infinity"); - } - - vector.add(value); + validateByteVectorValues(value); + vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { value = context.parser().floatValue(); + validateByteVectorValues(value); + vector.add((byte) value); + context.parser().nextToken(); + } else if (token == XContentParser.Token.VALUE_NULL) { + context.path().remove(); + return Optional.empty(); + } + validateVectorDimension(dimension, vector.size()); + byte[] array = new byte[vector.size()]; + int i = 0; + for (Byte f : vector) { + array[i++] = f; + } + return Optional.of(array); + } - if (Float.isNaN(value)) { - throw new IllegalArgumentException("KNN vector values cannot be NaN"); - } + Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { + context.path().add(simpleName()); - if (Float.isInfinite(value)) { - throw new IllegalArgumentException("KNN vector values cannot be infinity"); + ArrayList vector = new ArrayList<>(); + XContentParser.Token token = context.parser().currentToken(); + float value; + if (token == XContentParser.Token.START_ARRAY) { + token = context.parser().nextToken(); + while (token != XContentParser.Token.END_ARRAY) { + value = context.parser().floatValue(); + validateFloatVectorValues(value); + vector.add(value); + token = context.parser().nextToken(); } - + } else if (token == XContentParser.Token.VALUE_NUMBER) { + value = context.parser().floatValue(); + validateFloatVectorValues(value); vector.add(value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { context.path().remove(); return Optional.empty(); } - - if (dimension != vector.size()) { - String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vector.size()); - throw new IllegalArgumentException(errorMessage); - } + validateVectorDimension(dimension, vector.size()); float[] array = new float[vector.size()]; int i = 0; diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 5dcb09318..b571a9d2f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -9,13 +9,14 @@ import lombok.Getter; import lombok.NonNull; import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.StoredField; -import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.util.KNNEngine; @@ -23,7 +24,6 @@ import java.util.Optional; import static org.apache.lucene.index.VectorValues.MAX_DIMENSIONS; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; /** * Field mapper for case when Lucene has been set as an engine. @@ -34,6 +34,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { /** FieldType used for initializing VectorField, which is used for creating binary doc values. **/ private final FieldType vectorFieldType; + private final VectorDataType vectorDataType; LuceneFieldMapper(final CreateLuceneFieldMapperInput input) { super( @@ -46,6 +47,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { input.isHasDocValues() ); + vectorDataType = input.getVectorDataType(); this.knnMethod = input.getKnnMethodContext(); final VectorSimilarityFunction vectorSimilarityFunction = this.knnMethod.getSpaceType().getVectorSimilarityFunction(); @@ -61,45 +63,53 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { ); } - this.fieldType = KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); + this.fieldType = vectorDataType.createKnnVectorFieldType(dimension, vectorSimilarityFunction); if (this.hasDocValues) { - this.vectorFieldType = buildDocValuesFieldType(this.knnMethod.getKnnEngine()); + this.vectorFieldType = vectorDataType.buildDocValuesFieldType(this.knnMethod.getKnnEngine()); } else { this.vectorFieldType = null; } } - private static FieldType buildDocValuesFieldType(KNNEngine knnEngine) { - FieldType field = new FieldType(); - field.putAttribute(KNN_ENGINE, knnEngine.getName()); - field.setDocValuesType(DocValuesType.BINARY); - field.freeze(); - return field; - } - @Override protected void parseCreateField(ParseContext context, int dimension) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - Optional arrayOptional = getFloatsFromContext(context, dimension); + if (vectorDataType.equals(VectorDataType.BYTE)) { + Optional arrayOptional = getBytesFromContext(context, dimension); + if (arrayOptional.isEmpty()) { + return; + } + final byte[] array = arrayOptional.get(); + KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); + context.doc().add(point); + if (fieldType.stored()) { + context.doc().add(new StoredField(name(), point.toString())); + } + if (hasDocValues && vectorFieldType != null) { + context.doc().add(new VectorField(name(), array, vectorFieldType)); + } + } else { + Optional arrayOptional = getFloatsFromContext(context, dimension); - if (arrayOptional.isEmpty()) { - return; - } - final float[] array = arrayOptional.get(); + if (arrayOptional.isEmpty()) { + return; + } + final float[] array = arrayOptional.get(); - KnnVectorField point = new KnnVectorField(name(), array, fieldType); + KnnVectorField point = new KnnVectorField(name(), array, fieldType); - context.doc().add(point); - if (fieldType.stored()) { - context.doc().add(new StoredField(name(), point.toString())); - } + context.doc().add(point); + if (fieldType.stored()) { + context.doc().add(new StoredField(name(), point.toString())); + } - if (hasDocValues && vectorFieldType != null) { - context.doc().add(new VectorField(name(), array, vectorFieldType)); + if (hasDocValues && vectorFieldType != null) { + context.doc().add(new VectorField(name(), array, vectorFieldType)); + } } context.path().remove(); @@ -126,6 +136,7 @@ static class CreateLuceneFieldMapperInput { Explicit ignoreMalformed; boolean stored; boolean hasDocValues; + VectorDataType vectorDataType; @NonNull KNNMethodContext knnMethodContext; } From ab7a3b84a8c2177afad29e959fc08e0675b65f57 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 15 Jun 2023 16:15:15 -0500 Subject: [PATCH 2/4] Add tests for Indexing Signed-off-by: Naveen Tatikonda --- .../knn/index/VectorDataTypeIT.java | 188 ++++++++++++++++++ .../mapper/KNNVectorFieldMapperTests.java | 184 ++++++++++++++--- 2 files changed, 349 insertions(+), 23 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java new file mode 100644 index 000000000..865402225 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -0,0 +1,188 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index; + +import org.junit.After; +import org.opensearch.client.ResponseException; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.util.KNNEngine; + +import java.io.IOException; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; +import static org.opensearch.knn.index.VectorDataType.getValues; + +public class VectorDataTypeIT extends KNNRestTestCase { + private static final String INDEX_NAME = "test-index-vec-dt"; + private static final String FIELD_NAME = "test-field-vec-dt"; + private static final String PROPERTIES_FIELD = "properties"; + private static final String DOC_ID = "doc1"; + private static final String TYPE_FIELD_NAME = "type"; + private static final String KNN_VECTOR_TYPE = "knn_vector"; + private static final int EF_CONSTRUCTION = 128; + private static final int M = 16; + + @After + public final void cleanUp() throws IOException { + deleteKNNIndex(INDEX_NAME); + } + + // Validate if we are able to create an index by setting data_type field as byte and add a doc to it + public void testAddDocWithByteVector() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { 6, 6 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + refreshAllIndices(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + // Validate by creating an index by setting data_type field as byte, add a doc to it and update it later. + public void testUpdateDocWithByteVector() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { -36, 78 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + Byte[] updatedVector = { 89, -8 }; + updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector); + + refreshAllIndices(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + // Validate by creating an index by setting data_type field as byte, add a doc to it and delete it later. + public void testDeleteDocWithByteVector() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Byte[] vector = { 35, -46 }; + addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); + + deleteKnnDoc(INDEX_NAME, DOC_ID); + refreshAllIndices(); + + assertEquals(0, getDocCount(INDEX_NAME)); + } + + // Set an invalid value for data_type field while creating the index which should throw an exception + public void testInvalidVectorDataType() { + String vectorDataType = "invalidVectorType"; + String supportedTypes = String.join(",", getValues()); + ResponseException ex = expectThrows( + ResponseException.class, + () -> createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, vectorDataType) + ); + assertTrue( + ex.getMessage() + .contains( + String.format( + "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", + VECTOR_DATA_TYPE, + vectorDataType, + supportedTypes + ) + ) + ); + } + + // Set null value for data_type field while creating the index which should throw an exception + public void testVectorDataTypeAsNull() { + ResponseException ex = expectThrows(ResponseException.class, () -> createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, null)); + assertTrue( + ex.getMessage() + .contains( + String.format( + "[%s] on mapper [%s] of type [%s] must not have a [null] value", + VECTOR_DATA_TYPE, + FIELD_NAME, + KNN_VECTOR_TYPE + ) + ) + ); + } + + // Create an index with byte vector data_type and add a doc with decimal values which should throw exception + public void testInvalidVectorData() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Float[] vector = { -10.76f, 15.89f }; + + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector)); + assertTrue( + ex.getMessage() + .contains( + "[data_type] field was set as [byte] in index mapping. But, KNN vector values are floats instead of byte integers" + ) + ); + } + + // Create an index with byte vector data_type and add a doc with values out of byte range which should throw exception + public void testInvalidByteVectorRange() throws Exception { + createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); + Float[] vector = { -1000f, 155f }; + + ResponseException ex = expectThrows(ResponseException.class, () -> addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector)); + assertTrue( + ex.getMessage() + .contains( + String.format( + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [{}, {}]", + VECTOR_DATA_TYPE, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ) + ); + } + + // Create an index with byte vector data_type using nmslib engine which should throw an exception + public void testByteVectorDataTypeWithNmslibEngine() { + ResponseException ex = expectThrows( + ResponseException.class, + () -> createKnnIndexMappingWithNmslibEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()) + ); + assertTrue(ex.getMessage().contains(String.format("[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE, LUCENE_NAME))); + } + + private void createKnnIndexMappingWithNmslibEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { + createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.NMSLIB.getName()); + } + + private void createKnnIndexMappingWithLuceneEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { + createKnnIndexMappingWithCustomEngine(dimension, spaceType, vectorDataType, KNNEngine.LUCENE.getName()); + } + + private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spaceType, String vectorDataType, String engine) + throws Exception { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(FIELD_NAME) + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, dimension) + .field(VECTOR_DATA_TYPE, vectorDataType) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, METHOD_HNSW) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, engine) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, M) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, EF_CONSTRUCTION) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + String mapping = Strings.toString(builder); + createKnnIndex(INDEX_NAME, mapping); + } +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index d4a5b5aea..d578729bb 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -6,8 +6,10 @@ package org.opensearch.knn.index.mapper; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.BytesRef; import org.mockito.Mockito; import org.opensearch.common.Explicit; @@ -28,6 +30,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.util.KNNEngine; @@ -46,6 +49,7 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; @@ -60,6 +64,8 @@ import static org.opensearch.Version.CURRENT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; +import static org.opensearch.knn.index.VectorDataType.getValues; public class KNNVectorFieldMapperTests extends KNNTestCase { @@ -71,9 +77,13 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { private final static float[] TEST_VECTOR = createInitializedFloatArray(TEST_DIMENSION, TEST_VECTOR_VALUE); + private final static byte TEST_BYTE_VECTOR_VALUE = 10; + private final static byte[] TEST_BYTE_VECTOR = createInitializedByteArray(TEST_DIMENSION, TEST_BYTE_VECTOR_VALUE); + private final static BytesRef TEST_VECTOR_BYTES_REF = new BytesRef( KNNVectorSerializerFactory.getDefaultSerializer().floatToByteArray(TEST_VECTOR) ); + private final static BytesRef TEST_BYTE_VECTOR_BYTES_REF = new BytesRef(TEST_BYTE_VECTOR); private static final String DIMENSION_FIELD_NAME = "dimension"; private static final String KNN_VECTOR_TYPE = "knn_vector"; private static final String TYPE_FIELD_NAME = "type"; @@ -82,7 +92,8 @@ public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao); - assertEquals(6, builder.getParameters().size()); + + assertEquals(7, builder.getParameters().size()); } public void testBuilder_build_fromKnnMethodContext() { @@ -334,6 +345,52 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws ); } + // Validate TypeParser parsing invalid vector data_type which throws exception + public void testTypeParser_parse_invalidVectorDataType() throws IOException { + String fieldName = "test-field-name-vec"; + String indexName = "test-index-name-vec"; + String vectorDataType = "invalid"; + String supportedTypes = String.join(",", getValues()); + + Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + + ModelDao modelDao = mock(ModelDao.class); + KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); + + XContentBuilder xContentBuilderOverInvalidVectorType = XContentFactory.jsonBuilder() + .startObject() + .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) + .field(DIMENSION, 10) + .field(VECTOR_DATA_TYPE, vectorDataType) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) + .field(KNN_ENGINE, LUCENE_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, 128) + .endObject() + .endObject() + .endObject(); + + IllegalArgumentException ex = expectThrows( + IllegalArgumentException.class, + () -> typeParser.parse( + fieldName, + xContentBuilderToMap(xContentBuilderOverInvalidVectorType), + buildParserContext(indexName, settings) + ) + ); + assertEquals( + String.format( + "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", + VECTOR_DATA_TYPE, + vectorDataType, + supportedTypes + ), + ex.getMessage() + ); + } + public void testTypeParser_parse_fromKnnMethodContext_invalidSpaceType() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; @@ -673,30 +730,10 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } - public void testLuceneFieldMapper_parseCreateField_docValues() throws IOException { + public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() throws IOException { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField - KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.LUCENE, - SpaceType.DEFAULT, - new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) - ); - - KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( - TEST_FIELD_NAME, - Collections.emptyMap(), - TEST_DIMENSION, - knnMethodContext - ); - LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = - LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() - .name(TEST_FIELD_NAME) - .mappedFieldType(knnVectorFieldType) - .multiFields(FieldMapper.MultiFields.empty()) - .copyTo(FieldMapper.CopyTo.empty()) - .hasDocValues(true) - .ignoreMalformed(new Explicit<>(true, true)) - .knnMethodContext(knnMethodContext); + createLuceneFieldMapperInputBuilder(VectorDataType.FLOAT); ParseContext.Document document = new ParseContext.Document(); ContentPath contentPath = new ContentPath(); @@ -731,6 +768,7 @@ public void testLuceneFieldMapper_parseCreateField_docValues() throws IOExceptio } assertEquals(TEST_VECTOR_BYTES_REF, vectorField.binaryValue()); + assertEquals(VectorEncoding.FLOAT32, vectorField.fieldType().vectorEncoding()); assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); // Test when doc values are disabled @@ -757,12 +795,112 @@ public void testLuceneFieldMapper_parseCreateField_docValues() throws IOExceptio assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); } + public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() throws IOException { + // Create a lucene field mapper that creates a binary doc values field as well as KnnByteVectorField + + LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = + createLuceneFieldMapperInputBuilder(VectorDataType.BYTE); + + ParseContext.Document document = new ParseContext.Document(); + ContentPath contentPath = new ContentPath(); + ParseContext parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); + doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + + // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField + List fields = document.getFields(); + assertEquals(2, fields.size()); + IndexableField field1 = fields.get(0); + IndexableField field2 = fields.get(1); + + VectorField vectorField; + KnnByteVectorField knnByteVectorField; + if (field1 instanceof VectorField) { + assertTrue(field2 instanceof KnnByteVectorField); + vectorField = (VectorField) field1; + knnByteVectorField = (KnnByteVectorField) field2; + } else { + assertTrue(field1 instanceof KnnByteVectorField); + assertTrue(field2 instanceof VectorField); + knnByteVectorField = (KnnByteVectorField) field1; + vectorField = (VectorField) field2; + } + + assertEquals(TEST_BYTE_VECTOR_BYTES_REF, vectorField.binaryValue()); + assertEquals(VectorEncoding.BYTE, vectorField.fieldType().vectorEncoding()); + assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); + + // Test when doc values are disabled + document = new ParseContext.Document(); + contentPath = new ContentPath(); + parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + inputBuilder.hasDocValues(false); + luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); + doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION); + + // Document should have 1 field: one for KnnByteVectorField + fields = document.getFields(); + assertEquals(1, fields.size()); + IndexableField field = fields.get(0); + assertTrue(field instanceof KnnByteVectorField); + knnByteVectorField = (KnnByteVectorField) field; + assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); + } + + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( + VectorDataType vectorDataType + ) { + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) + ); + + KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldMapper.KNNVectorFieldType( + TEST_FIELD_NAME, + Collections.emptyMap(), + TEST_DIMENSION, + knnMethodContext, + vectorDataType + ); + + return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() + .name(TEST_FIELD_NAME) + .mappedFieldType(knnVectorFieldType) + .multiFields(FieldMapper.MultiFields.empty()) + .copyTo(FieldMapper.CopyTo.empty()) + .hasDocValues(true) + .vectorDataType(vectorDataType) + .ignoreMalformed(new Explicit<>(true, true)) + .knnMethodContext(knnMethodContext); + } + public static float[] createInitializedFloatArray(int dimension, float value) { float[] array = new float[dimension]; Arrays.fill(array, value); return array; } + public static byte[] createInitializedByteArray(int dimension, byte value) { + byte[] array = new byte[dimension]; + Arrays.fill(array, value); + return array; + } + public IndexMetadata buildIndexMetaData(String indexName, Settings settings) { return IndexMetadata.builder(indexName) .settings(settings) From a5889735b9a987267fcbb8f516c04b715f97427b Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 15 Jun 2023 16:33:08 -0500 Subject: [PATCH 3/4] Add CHANGELOG Signed-off-by: Naveen Tatikonda --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e57e634be..81a407b95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.8...2.x) ### Features * Added efficient filtering support for Faiss Engine ([#936](https://github.com/opensearch-project/k-NN/pull/936)) +* Add Indexing Support for Lucene Byte Sized Vector ([#937](https://github.com/opensearch-project/k-NN/pull/937)) ### Enhancements ### Bug Fixes ### Infrastructure From 5c73aa2fb271181f0706257b0bfec08cc27217fe Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Fri, 30 Jun 2023 11:43:26 -0500 Subject: [PATCH 4/4] Address Review Comments Signed-off-by: Naveen Tatikonda --- .../opensearch/knn/common/KNNConstants.java | 4 +- .../opensearch/knn/index/VectorDataType.java | 266 ++---------------- .../org/opensearch/knn/index/VectorField.java | 2 +- .../index/mapper/KNNVectorFieldMapper.java | 40 ++- .../mapper/KNNVectorFieldMapperUtil.java | 143 ++++++++++ .../knn/index/mapper/LuceneFieldMapper.java | 38 +-- .../knn/index/VectorDataTypeIT.java | 64 +++-- .../mapper/KNNVectorFieldMapperTests.java | 36 ++- 8 files changed, 285 insertions(+), 308 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 47ce0c957..6d387eec4 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -52,8 +52,8 @@ public class KNNConstants { public static final String MAX_VECTOR_COUNT_PARAMETER = "max_training_vector_count"; public static final String SEARCH_SIZE_PARAMETER = "search_size"; - public static final String VECTOR_DATA_TYPE = "data_type"; - public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE = VectorDataType.FLOAT; + public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; + public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; // Lucene specific constants public static final String LUCENE_NAME = "lucene"; diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 70db606c0..be4d7110c 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -5,289 +5,85 @@ package org.opensearch.knn.index; +import lombok.AllArgsConstructor; +import lombok.Getter; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; -import org.apache.lucene.index.DocValuesType; -import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexableFieldType; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.knn.index.util.KNNEngine; -import java.util.HashSet; -import java.util.Map; +import java.util.Arrays; +import java.util.Locale; import java.util.Objects; -import java.util.Set; +import java.util.stream.Collectors; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** * Enum contains data_type of vectors and right now only supported for lucene engine in k-NN plugin. * We have two vector data_types, one is float (default) and the other one is byte. */ +@AllArgsConstructor public enum VectorDataType { BYTE("byte") { - /** - * @param dimension Dimension of the vector - * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType - * @return FieldType of type KnnByteVectorField - */ + @Override public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { return KnnByteVectorField.createFieldType(dimension, vectorSimilarityFunction); } - - /** - * @param knnEngine KNNEngine - * @return DocValues FieldType of type Binary and with BYTE VectorEncoding - */ - @Override - public FieldType buildDocValuesFieldType(KNNEngine knnEngine) { - IndexableFieldType indexableFieldType = new IndexableFieldType() { - @Override - public boolean stored() { - return false; - } - - @Override - public boolean tokenized() { - return true; - } - - @Override - public boolean storeTermVectors() { - return false; - } - - @Override - public boolean storeTermVectorOffsets() { - return false; - } - - @Override - public boolean storeTermVectorPositions() { - return false; - } - - @Override - public boolean storeTermVectorPayloads() { - return false; - } - - @Override - public boolean omitNorms() { - return false; - } - - @Override - public IndexOptions indexOptions() { - return IndexOptions.NONE; - } - - @Override - public DocValuesType docValuesType() { - return DocValuesType.NONE; - } - - @Override - public int pointDimensionCount() { - return 0; - } - - @Override - public int pointIndexDimensionCount() { - return 0; - } - - @Override - public int pointNumBytes() { - return 0; - } - - @Override - public int vectorDimension() { - return 0; - } - - @Override - public VectorEncoding vectorEncoding() { - return VectorEncoding.BYTE; - } - - @Override - public VectorSimilarityFunction vectorSimilarityFunction() { - return VectorSimilarityFunction.EUCLIDEAN; - } - - @Override - public Map getAttributes() { - return null; - } - }; - FieldType field = new FieldType(indexableFieldType); - field.putAttribute(KNN_ENGINE, knnEngine.getName()); - field.setDocValuesType(DocValuesType.BINARY); - field.freeze(); - return field; - } }, FLOAT("float") { - /** - * @param dimension Dimension of the vector - * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType - * @return FieldType of type KnnFloatVectorField - */ + @Override public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { return KnnVectorField.createFieldType(dimension, vectorSimilarityFunction); } - /** - * @param knnEngine KNNEngine - * @return DocValues FieldType of type Binary and with FLOAT32 VectorEncoding - */ - @Override - public FieldType buildDocValuesFieldType(KNNEngine knnEngine) { - FieldType field = new FieldType(); - field.putAttribute(KNN_ENGINE, knnEngine.getName()); - field.setDocValuesType(DocValuesType.BINARY); - field.freeze(); - return field; - } - }; + public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values()) + .map(VectorDataType::getValue) + .collect(Collectors.joining(",")); + @Getter private final String value; - VectorDataType(String value) { - this.value = value; - } - /** - * Get VectorDataType name + * Creates a KnnVectorFieldType based on the VectorDataType using the provided dimension and + * VectorSimilarityFunction. * - * @return name + * @param dimension Dimension of the vector + * @param vectorSimilarityFunction VectorSimilarityFunction for a given spaceType + * @return FieldType */ - public String getValue() { - return value; - } - public abstract FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction); - public abstract FieldType buildDocValuesFieldType(KNNEngine knnEngine); - - /** - * @return Set of names of all the supporting VectorDataTypes - */ - public static Set getValues() { - Set values = new HashSet<>(); - - for (VectorDataType dataType : VectorDataType.values()) { - values.add(dataType.getValue()); - } - return values; - } - /** * Validates if given VectorDataType is in the list of supported data types. * @param vectorDataType VectorDataType - * @return the same VectorDataType if it is in the supported values else throw exception. + * @return the same VectorDataType if it is in the supported values + * throws Exception if an invalid value is provided. */ public static VectorDataType get(String vectorDataType) { - String supportedTypes = String.join(",", getValues()); Objects.requireNonNull( vectorDataType, - String.format("[{}] should not be null. Supported types are [{}]", VECTOR_DATA_TYPE, supportedTypes) - ); - for (VectorDataType currentDataType : VectorDataType.values()) { - if (currentDataType.getValue().equalsIgnoreCase(vectorDataType)) { - return currentDataType; - } - } - throw new IllegalArgumentException( String.format( - "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", - VECTOR_DATA_TYPE, - vectorDataType, - supportedTypes + Locale.ROOT, + "[%s] should not be null. Supported types are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES ) ); - } - - /** - * Validate the float vector values if it is a number and in the finite range. - * - * @param value float vector value - */ - public static void validateFloatVectorValues(float value) { - if (Float.isNaN(value)) { - throw new IllegalArgumentException("KNN vector values cannot be NaN"); - } - - if (Float.isInfinite(value)) { - throw new IllegalArgumentException("KNN vector values cannot be infinity"); - } - } - - /** - * Validate the float vector value in the byte range if it is a finite number, - * with no decimal values and in the byte range of [-128 to 127]. - * - * @param value float value in byte range - */ - public static void validateByteVectorValues(float value) { - validateFloatVectorValues(value); - if (value % 1 != 0) { - throw new IllegalArgumentException( - "[data_type] field was set as [byte] in index mapping. But, KNN vector values are floats instead of byte integers" - ); - } - if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { + try { + return VectorDataType.valueOf(vectorDataType.toUpperCase(Locale.ROOT)); + } catch (Exception e) { throw new IllegalArgumentException( String.format( - "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [{}, {}]", - VECTOR_DATA_TYPE, - VectorDataType.BYTE.getValue(), - Byte.MIN_VALUE, - Byte.MAX_VALUE + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES ) ); } } - - /** - * Validate if the given vector size matches with the dimension provided in mapping. - * - * @param dimension dimension of vector - * @param vectorSize size of the vector - */ - public static void validateVectorDimension(int dimension, int vectorSize) { - if (dimension != vectorSize) { - String errorMessage = String.format("Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); - throw new IllegalArgumentException(errorMessage); - } - - } - - /** - * Validates and throws exception if data_type field is set in the index mapping - * using any VectorDataType (other than float, which is default) with any engine (except lucene). - * - * @param knnMethodContext KNNMethodContext Parameter - * @param vectorDataType VectorDataType Parameter - */ - public static void validateVectorDataType_Engine( - ParametrizedFieldMapper.Parameter knnMethodContext, - ParametrizedFieldMapper.Parameter vectorDataType - ) { - if (vectorDataType.getValue() != DEFAULT_VECTOR_DATA_TYPE - && (knnMethodContext.get() == null || knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE)) { - throw new IllegalArgumentException(String.format("[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE, LUCENE_NAME)); - } - } } diff --git a/src/main/java/org/opensearch/knn/index/VectorField.java b/src/main/java/org/opensearch/knn/index/VectorField.java index 2c346992d..f28ef6238 100644 --- a/src/main/java/org/opensearch/knn/index/VectorField.java +++ b/src/main/java/org/opensearch/knn/index/VectorField.java @@ -34,7 +34,7 @@ public VectorField(String name, byte[] value, IndexableFieldType type) { try { this.setBytesValue(value); } catch (Exception e) { - throw new IllegalArgumentException(e); + throw new RuntimeException(e); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 27416a7fb..4b9980e27 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -11,7 +11,6 @@ import org.opensearch.knn.common.KNNConstants; import org.apache.lucene.document.FieldType; -import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.search.DocValuesFieldExistsQuery; @@ -50,13 +49,14 @@ import java.util.Optional; import java.util.function.Supplier; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE; +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; -import static org.opensearch.knn.index.VectorDataType.validateByteVectorValues; -import static org.opensearch.knn.index.VectorDataType.validateFloatVectorValues; -import static org.opensearch.knn.index.VectorDataType.validateVectorDataType_Engine; -import static org.opensearch.knn.index.VectorDataType.validateVectorDimension; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateByteVectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFloatVectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithEngine; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDimension; /** * Field Mapper for KNN vector type. @@ -107,10 +107,10 @@ public static class Builder extends ParametrizedFieldMapper.Builder { * data_type which defines the datatype of the vector values. This is an optional parameter and * this is right now only relevant for lucene engine. The default value is float. */ - protected final Parameter vectorDataType = new Parameter<>( - VECTOR_DATA_TYPE, + private final Parameter vectorDataType = new Parameter<>( + VECTOR_DATA_TYPE_FIELD, false, - () -> DEFAULT_VECTOR_DATA_TYPE, + () -> DEFAULT_VECTOR_DATA_TYPE_FIELD, (n, c, o) -> VectorDataType.get((String) o), m -> toType(m).vectorDataType ); @@ -350,7 +350,7 @@ public Mapper.Builder parse(String name, Map node, ParserCont // Validates and throws exception if data_type field is set in the index mapping // using any VectorDataType (other than float, which is default) with any engine (except lucene). - validateVectorDataType_Engine(builder.knnMethodContext, builder.vectorDataType); + validateVectorDataTypeWithEngine(builder.knnMethodContext, builder.vectorDataType); return builder; } @@ -364,15 +364,15 @@ public static class KNNVectorFieldType extends MappedFieldType { VectorDataType vectorDataType; public KNNVectorFieldType(String name, Map meta, int dimension) { - this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE); + this(name, meta, dimension, null, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE); + this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD); } public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { - this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE); + this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD); } public KNNVectorFieldType( @@ -489,9 +489,7 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx VectorField point = new VectorField(name(), array, fieldType); context.doc().add(point); - if (fieldType.stored()) { - context.doc().add(new StoredField(name(), point.toString())); - } + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); context.path().remove(); } @@ -522,13 +520,13 @@ Optional getBytesFromContext(ParseContext context, int dimension) throws token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { value = context.parser().floatValue(); - validateByteVectorValues(value); + validateByteVectorValue(value); vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { value = context.parser().floatValue(); - validateByteVectorValues(value); + validateByteVectorValue(value); vector.add((byte) value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -554,13 +552,13 @@ Optional getFloatsFromContext(ParseContext context, int dimension) thro token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { value = context.parser().floatValue(); - validateFloatVectorValues(value); + validateFloatVectorValue(value); vector.add(value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { value = context.parser().floatValue(); - validateFloatVectorValues(value); + validateFloatVectorValue(value); vector.add(value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java new file mode 100644 index 000000000..2784d2a33 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -0,0 +1,143 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.index.DocValuesType; +import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.util.KNNEngine; + +import java.util.Locale; + +import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + +public class KNNVectorFieldMapperUtil { + /** + * Validate the float vector value and throw exception if it is not a number or not in the finite range. + * + * @param value float vector value + */ + public static void validateFloatVectorValue(float value) { + if (Float.isNaN(value)) { + throw new IllegalArgumentException("KNN vector values cannot be NaN"); + } + + if (Float.isInfinite(value)) { + throw new IllegalArgumentException("KNN vector values cannot be infinity"); + } + } + + /** + * Validate the float vector value in the byte range if it is a finite number, + * with no decimal values and in the byte range of [-128 to 127]. If not throw IllegalArgumentException. + * + * @param value float value in byte range + */ + public static void validateByteVectorValue(float value) { + validateFloatVectorValue(value); + if (value % 1 != 0) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) + + ); + } + if ((int) value < Byte.MIN_VALUE || (int) value > Byte.MAX_VALUE) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + Byte.MIN_VALUE, + Byte.MAX_VALUE + ) + ); + } + } + + /** + * Validate if the given vector size matches with the dimension provided in mapping. + * + * @param dimension dimension of vector + * @param vectorSize size of the vector + */ + public static void validateVectorDimension(int dimension, int vectorSize) { + if (dimension != vectorSize) { + String errorMessage = String.format(Locale.ROOT, "Vector dimension mismatch. Expected: %d, Given: %d", dimension, vectorSize); + throw new IllegalArgumentException(errorMessage); + } + + } + + /** + * Validates and throws exception if data_type field is set in the index mapping + * using any VectorDataType (other than float, which is default) with any engine (except lucene). + * + * @param knnMethodContext KNNMethodContext Parameter + * @param vectorDataType VectorDataType Parameter + */ + public static void validateVectorDataTypeWithEngine( + ParametrizedFieldMapper.Parameter knnMethodContext, + ParametrizedFieldMapper.Parameter vectorDataType + ) { + if (vectorDataType.getValue() == DEFAULT_VECTOR_DATA_TYPE_FIELD) { + return; + } + if ((knnMethodContext.getValue() == null && KNNEngine.DEFAULT != KNNEngine.LUCENE) + || knnMethodContext.getValue().getKnnEngine() != KNNEngine.LUCENE) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue().getValue(), + LUCENE_NAME + ) + ); + } + } + + /** + * @param knnEngine KNNEngine + * @return DocValues FieldType of type Binary + */ + public static FieldType buildDocValuesFieldType(KNNEngine knnEngine) { + FieldType field = new FieldType(); + field.putAttribute(KNN_ENGINE, knnEngine.getName()); + field.setDocValuesType(DocValuesType.BINARY); + field.freeze(); + return field; + } + + public static void addStoredFieldForVectorField( + ParseContext context, + FieldType fieldType, + String mapperName, + String vectorFieldAsString + ) { + if (fieldType.stored()) { + context.doc().add(new StoredField(mapperName, vectorFieldAsString)); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index b571a9d2f..4b5a73d9a 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -11,7 +11,6 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; -import org.apache.lucene.document.StoredField; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; @@ -21,9 +20,13 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Locale; import java.util.Optional; import static org.apache.lucene.index.VectorValues.MAX_DIMENSIONS; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.addStoredFieldForVectorField; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.buildDocValuesFieldType; /** * Field mapper for case when Lucene has been set as an engine. @@ -55,6 +58,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { if (dimension > LUCENE_MAX_DIMENSION) { throw new IllegalArgumentException( String.format( + Locale.ROOT, "Dimension value cannot be greater than [%s] but got [%s] for vector [%s]", LUCENE_MAX_DIMENSION, dimension, @@ -66,7 +70,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { this.fieldType = vectorDataType.createKnnVectorFieldType(dimension, vectorSimilarityFunction); if (this.hasDocValues) { - this.vectorFieldType = vectorDataType.buildDocValuesFieldType(this.knnMethod.getKnnEngine()); + this.vectorFieldType = buildDocValuesFieldType(this.knnMethod.getKnnEngine()); } else { this.vectorFieldType = null; } @@ -78,38 +82,40 @@ protected void parseCreateField(ParseContext context, int dimension) throws IOEx validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - if (vectorDataType.equals(VectorDataType.BYTE)) { - Optional arrayOptional = getBytesFromContext(context, dimension); - if (arrayOptional.isEmpty()) { + if (VectorDataType.BYTE.equals(vectorDataType)) { + Optional bytesArrayOptional = getBytesFromContext(context, dimension); + if (bytesArrayOptional.isEmpty()) { return; } - final byte[] array = arrayOptional.get(); + final byte[] array = bytesArrayOptional.get(); KnnByteVectorField point = new KnnByteVectorField(name(), array, fieldType); + context.doc().add(point); - if (fieldType.stored()) { - context.doc().add(new StoredField(name(), point.toString())); - } + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); + if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); } - } else { - Optional arrayOptional = getFloatsFromContext(context, dimension); + } else if (VectorDataType.FLOAT.equals(vectorDataType)) { + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); - if (arrayOptional.isEmpty()) { + if (floatsArrayOptional.isEmpty()) { return; } - final float[] array = arrayOptional.get(); + final float[] array = floatsArrayOptional.get(); KnnVectorField point = new KnnVectorField(name(), array, fieldType); context.doc().add(point); - if (fieldType.stored()) { - context.doc().add(new StoredField(name(), point.toString())); - } + addStoredFieldForVectorField(context, fieldType, name(), point.toString()); if (hasDocValues && vectorFieldType != null) { context.doc().add(new VectorField(name(), array, vectorFieldType)); } + } else { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) + ); } context.path().remove(); diff --git a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java index 865402225..80ec9164f 100644 --- a/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java +++ b/src/test/java/org/opensearch/knn/index/VectorDataTypeIT.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index; +import lombok.SneakyThrows; import org.junit.After; import org.opensearch.client.ResponseException; import org.opensearch.common.Strings; @@ -14,13 +15,13 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.util.KNNEngine; -import java.io.IOException; +import java.util.Locale; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; -import static org.opensearch.knn.index.VectorDataType.getValues; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; public class VectorDataTypeIT extends KNNRestTestCase { private static final String INDEX_NAME = "test-index-vec-dt"; @@ -33,12 +34,14 @@ public class VectorDataTypeIT extends KNNRestTestCase { private static final int M = 16; @After - public final void cleanUp() throws IOException { + @SneakyThrows + public final void cleanUp() { deleteKNNIndex(INDEX_NAME); } // Validate if we are able to create an index by setting data_type field as byte and add a doc to it - public void testAddDocWithByteVector() throws Exception { + @SneakyThrows + public void testAddDocWithByteVector() { createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); Byte[] vector = { 6, 6 }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); @@ -48,7 +51,8 @@ public void testAddDocWithByteVector() throws Exception { } // Validate by creating an index by setting data_type field as byte, add a doc to it and update it later. - public void testUpdateDocWithByteVector() throws Exception { + @SneakyThrows + public void testUpdateDocWithByteVector() { createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); Byte[] vector = { -36, 78 }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); @@ -61,7 +65,8 @@ public void testUpdateDocWithByteVector() throws Exception { } // Validate by creating an index by setting data_type field as byte, add a doc to it and delete it later. - public void testDeleteDocWithByteVector() throws Exception { + @SneakyThrows + public void testDeleteDocWithByteVector() { createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); Byte[] vector = { 35, -46 }; addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector); @@ -75,7 +80,6 @@ public void testDeleteDocWithByteVector() throws Exception { // Set an invalid value for data_type field while creating the index which should throw an exception public void testInvalidVectorDataType() { String vectorDataType = "invalidVectorType"; - String supportedTypes = String.join(",", getValues()); ResponseException ex = expectThrows( ResponseException.class, () -> createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, vectorDataType) @@ -84,10 +88,10 @@ public void testInvalidVectorDataType() { ex.getMessage() .contains( String.format( - "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", - VECTOR_DATA_TYPE, - vectorDataType, - supportedTypes + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, + SUPPORTED_VECTOR_DATA_TYPES ) ) ); @@ -100,8 +104,9 @@ public void testVectorDataTypeAsNull() { ex.getMessage() .contains( String.format( + Locale.ROOT, "[%s] on mapper [%s] of type [%s] must not have a [null] value", - VECTOR_DATA_TYPE, + VECTOR_DATA_TYPE_FIELD, FIELD_NAME, KNN_VECTOR_TYPE ) @@ -110,7 +115,8 @@ public void testVectorDataTypeAsNull() { } // Create an index with byte vector data_type and add a doc with decimal values which should throw exception - public void testInvalidVectorData() throws Exception { + @SneakyThrows + public void testInvalidVectorData() { createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); Float[] vector = { -10.76f, 15.89f }; @@ -118,13 +124,19 @@ public void testInvalidVectorData() throws Exception { assertTrue( ex.getMessage() .contains( - "[data_type] field was set as [byte] in index mapping. But, KNN vector values are floats instead of byte integers" + String.format( + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue() + ) ) ); } // Create an index with byte vector data_type and add a doc with values out of byte range which should throw exception - public void testInvalidByteVectorRange() throws Exception { + @SneakyThrows + public void testInvalidByteVectorRange() { createKnnIndexMappingWithLuceneEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()); Float[] vector = { -1000f, 155f }; @@ -133,8 +145,9 @@ public void testInvalidByteVectorRange() throws Exception { ex.getMessage() .contains( String.format( - "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [{}, {}]", - VECTOR_DATA_TYPE, + Locale.ROOT, + "[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]", + VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue(), Byte.MIN_VALUE, Byte.MAX_VALUE @@ -149,7 +162,18 @@ public void testByteVectorDataTypeWithNmslibEngine() { ResponseException.class, () -> createKnnIndexMappingWithNmslibEngine(2, SpaceType.L2, VectorDataType.BYTE.getValue()) ); - assertTrue(ex.getMessage().contains(String.format("[%s] is only supported for [%s] engine", VECTOR_DATA_TYPE, LUCENE_NAME))); + assertTrue( + ex.getMessage() + .contains( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is only supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + VectorDataType.BYTE.getValue(), + LUCENE_NAME + ) + ) + ); } private void createKnnIndexMappingWithNmslibEngine(int dimension, SpaceType spaceType, String vectorDataType) throws Exception { @@ -168,7 +192,7 @@ private void createKnnIndexMappingWithCustomEngine(int dimension, SpaceType spac .startObject(FIELD_NAME) .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) .field(DIMENSION, dimension) - .field(VECTOR_DATA_TYPE, vectorDataType) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) .startObject(KNNConstants.KNN_METHOD) .field(KNNConstants.NAME, METHOD_HNSW) .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index d578729bb..1f3598781 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.mapper; import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.IndexableField; @@ -45,7 +46,9 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Optional; +import java.util.stream.Collectors; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; @@ -64,8 +67,7 @@ import static org.opensearch.Version.CURRENT; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE; -import static org.opensearch.knn.index.VectorDataType.getValues; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class KNNVectorFieldMapperTests extends KNNTestCase { @@ -94,6 +96,9 @@ public void testBuilder_getParameters() { KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao); assertEquals(7, builder.getParameters().size()); + List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); + List expectedParams = Arrays.asList("store", "doc_values", DIMENSION, VECTOR_DATA_TYPE_FIELD, "meta", KNN_METHOD, MODEL_ID); + assertEquals(expectedParams, actualParams); } public void testBuilder_build_fromKnnMethodContext() { @@ -346,11 +351,15 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws } // Validate TypeParser parsing invalid vector data_type which throws exception - public void testTypeParser_parse_invalidVectorDataType() throws IOException { + @SneakyThrows + public void testTypeParser_parse_invalidVectorDataType() { String fieldName = "test-field-name-vec"; String indexName = "test-index-name-vec"; String vectorDataType = "invalid"; - String supportedTypes = String.join(",", getValues()); + String supportedTypes = String.join( + ",", + Arrays.stream((VectorDataType.values())).map(VectorDataType::getValue).collect(Collectors.toCollection(HashSet::new)) + ); Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); @@ -361,7 +370,7 @@ public void testTypeParser_parse_invalidVectorDataType() throws IOException { .startObject() .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) .field(DIMENSION, 10) - .field(VECTOR_DATA_TYPE, vectorDataType) + .field(VECTOR_DATA_TYPE_FIELD, vectorDataType) .startObject(KNN_METHOD) .field(NAME, METHOD_HNSW) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) @@ -382,9 +391,9 @@ public void testTypeParser_parse_invalidVectorDataType() throws IOException { ); assertEquals( String.format( - "[%s] field was set as [%s] in index mapping. But, supported values are [%s]", - VECTOR_DATA_TYPE, - vectorDataType, + Locale.ROOT, + "Invalid value provided for [%s] field. Supported values are [%s]", + VECTOR_DATA_TYPE_FIELD, supportedTypes ), ex.getMessage() @@ -730,7 +739,8 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } - public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() throws IOException { + @SneakyThrows + public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = createLuceneFieldMapperInputBuilder(VectorDataType.FLOAT); @@ -795,7 +805,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() throws assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); } - public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() throws IOException { + @SneakyThrows + public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { // Create a lucene field mapper that creates a binary doc values field as well as KnnByteVectorField LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder inputBuilder = @@ -834,7 +845,6 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() throws } assertEquals(TEST_BYTE_VECTOR_BYTES_REF, vectorField.binaryValue()); - assertEquals(VectorEncoding.BYTE, vectorField.fieldType().vectorEncoding()); assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue()); // Test when doc values are disabled @@ -889,13 +899,13 @@ private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperIn .knnMethodContext(knnMethodContext); } - public static float[] createInitializedFloatArray(int dimension, float value) { + private static float[] createInitializedFloatArray(int dimension, float value) { float[] array = new float[dimension]; Arrays.fill(array, value); return array; } - public static byte[] createInitializedByteArray(int dimension, byte value) { + private static byte[] createInitializedByteArray(int dimension, byte value) { byte[] array = new byte[dimension]; Arrays.fill(array, value); return array;