diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 2a3732d7e0..4490ba960c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -13,6 +13,8 @@ import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.mapper.ANNConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.util.Optional; @@ -66,16 +68,19 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { ); return defaultFormatSupplier.get(); } - var type = (KNNVectorFieldType) mapperService.orElseThrow( + KNNVectorFieldType mappedFieldType = (KNNVectorFieldType) mapperService.orElseThrow( () -> new IllegalStateException( String.format("Cannot read field type for field [%s] because mapper service is not available", field) ) ).fieldType(field); - var params = type.getKnnMethodContext().getMethodComponentContext().getParameters(); - if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE - && params != null - && params.containsKey(METHOD_ENCODER_PARAMETER)) { + ANNConfig annConfig = mappedFieldType.getAnnConfig(); + KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); + + var params = knnMethodContext.getMethodComponentContext().getParameters(); + + if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( params, defaultMaxConnections, diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 8e191ac5f3..9cddfbac32 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -21,7 +21,6 @@ import org.opensearch.knn.index.codec.transfer.VectorTransferByte; import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.Model; @@ -57,7 +56,6 @@ import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; @@ -213,35 +211,19 @@ private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNN private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) throws IOException { - Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); - - // parametersString will be null when legacy mapper is used if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); - } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); + throw new IllegalStateException("Parameter string is not set. Something is wrong"); } + Map parameters = new HashMap<>( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); // Update index description of Faiss for binary data type if (KNNEngine.FAISS == knnEngine diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java index 7885761b93..d210483e60 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java @@ -9,7 +9,6 @@ import lombok.Getter; import lombok.NonNull; import lombok.Setter; -import org.opensearch.Version; import org.opensearch.common.ValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -20,7 +19,6 @@ import org.opensearch.index.mapper.MapperParsingException; import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -29,7 +27,6 @@ import org.opensearch.knn.training.VectorSpaceInfo; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -42,21 +39,6 @@ @Getter public class KNNMethodContext implements ToXContentFragment, Writeable { - private static KNNMethodContext defaultInstance = null; - - /** - * This is used only for testing - * @return default KNNMethodContext for testing - */ - public static synchronized KNNMethodContext getDefault() { - if (defaultInstance == null) { - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - methodComponentContext.setIndexVersion(Version.CURRENT); - defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); - } - return defaultInstance; - } - @NonNull private final KNNEngine knnEngine; @NonNull diff --git a/src/main/java/org/opensearch/knn/index/mapper/ANNConfig.java b/src/main/java/org/opensearch/knn/index/mapper/ANNConfig.java new file mode 100644 index 0000000000..04619647a5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/ANNConfig.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.Getter; +import org.opensearch.knn.index.engine.KNNMethodContext; + +import java.util.Optional; + +/** + * Class holds information about how the ANN indices are created. The design of this class ensures that we do not + * accidentally configure an index that has multiple ways it can be created. This class is immutable. + */ +public final class ANNConfig { + + @Getter + private final ANNConfigType annConfigType; + private final KNNMethodContext knnMethodContext; + private final String modelId; + private final Integer dimension; + + /** + * Constructor + * + * @param annConfigType Configurational context index was built. Cannot be null + * @param knnMethodContext Method context used to create the index; null if not created from method + * @param modelId Model id used to create the index; null if not created from model + * @param dimension Dimension used to create the index; needs to be null for model-based indices + */ + public ANNConfig(ANNConfigType annConfigType, KNNMethodContext knnMethodContext, String modelId, Integer dimension) { + if (annConfigType == null) { + throw new IllegalArgumentException("ANNConfiguration cannot be null"); + } + + this.annConfigType = annConfigType; + this.knnMethodContext = knnMethodContext; + this.modelId = modelId; + this.dimension = dimension; + + if (ANNConfigType.FROM_METHOD == annConfigType) { + validateFromMethod(); + return; + } + + if (ANNConfigType.FROM_MODEL == annConfigType) { + validateFromModel(); + return; + } + + if (ANNConfigType.SKIP == annConfigType) { + validateSkip(); + } + } + + private void validateFromMethod() { + if (knnMethodContext == null) { + throw new IllegalArgumentException("knnMethodContext cannot be null when created from method"); + } + + if (modelId != null) { + throw new IllegalArgumentException("modelId cannot be specified when created from method"); + } + + if (dimension == null) { + throw new IllegalArgumentException("dimension must be specified when created from method"); + } + } + + private void validateFromModel() { + if (modelId == null) { + throw new IllegalArgumentException("modelId cannot be null when created from method"); + } + + if (knnMethodContext != null) { + throw new IllegalArgumentException("knnMethodContext cannot be specified when created from method"); + } + + if (dimension != null) { + throw new IllegalArgumentException("dimension must be null when created from model"); + } + } + + private void validateSkip() { + if (knnMethodContext != null || modelId != null) { + throw new IllegalArgumentException("knnMethodContext or modelId cannot be specified when skipping"); + } + + if (dimension == null) { + throw new IllegalArgumentException("dimension must be specified when created from model"); + } + } + + /** + * + * @return Optional containing the modelId if created from model, otherwise empty + */ + public Optional getModelId() { + return Optional.ofNullable(modelId); + } + + /** + * + * @return Optional containing the KNNMethodContext if created from method, otherwise empty + */ + public Optional getKnnMethodContext() { + return Optional.ofNullable(knnMethodContext); + } + + /** + * + * @return the dimension of the index; for model based indices, it will be null + */ + public Optional getDimension() { + return Optional.ofNullable(dimension); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/ANNConfigType.java b/src/main/java/org/opensearch/knn/index/mapper/ANNConfigType.java new file mode 100644 index 0000000000..a674bd1e13 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/ANNConfigType.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +/** + * Types of configurations to build ANN indices + */ +public enum ANNConfigType { + FROM_METHOD, + FROM_MODEL, + SKIP +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java new file mode 100644 index 0000000000..442e778cb3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.opensearch.Version; +import org.opensearch.common.Explicit; +import org.opensearch.knn.index.VectorDataType; + +/** + * Mapper used when you dont want to build an underlying KNN struct - you just want to + * store vectors as doc values + */ +public class FlatVectorFieldMapper extends KNNVectorFieldMapper { + + private final PerDimensionValidator perDimensionValidator; + + public FlatVectorFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexCreatedVersion + ) { + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); + this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + this.fieldType.freeze(); + } + + private PerDimensionValidator selectPerDimensionValidator(VectorDataType vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + return PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + } + + if (VectorDataType.BYTE == vectorDataType) { + return PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + } + + return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return VectorValidator.NOOP_VECTOR_VALIDATOR; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return PerDimensionProcessor.NOOP_PROCESSOR; + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 3b94876454..951868a99a 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -11,7 +11,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; import lombok.extern.log4j.Log4j2; @@ -32,9 +31,8 @@ import org.opensearch.index.mapper.ParametrizedFieldMapper; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KnnCircuitBreakerException; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; @@ -44,23 +42,20 @@ import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; -import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; -import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfCircuitBreakerIsNotTriggered; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfKNNPluginEnabled; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataType; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; +import static org.opensearch.knn.index.mapper.ModelFieldMapper.UNSET_MODEL_DIMENSION_IDENTIFIER; /** * Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType. @@ -76,10 +71,6 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { return (KNNVectorFieldMapper) in; } - // We store the version of the index with the mapper as different version of Opensearch has different default - // values of KNN engine Algorithms hyperparameters. - protected Version indexCreatedVersion; - /** * Builder for KNNVectorFieldMapper. This class defines the set of parameters that can be applied to the knn_vector * field type @@ -107,7 +98,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { ); } return value; - }, m -> toType(m).dimension); + }, m -> toType(m).fieldType().getAnnConfig().getDimension().orElse(UNSET_MODEL_DIMENSION_IDENTIFIER)); /** * data_type which defines the datatype of the vector values. This is an optional parameter and @@ -126,7 +117,12 @@ public static class Builder extends ParametrizedFieldMapper.Builder { * model template index. If this parameter is set, it will take precedence. This parameter is only relevant for * library indices that require training. */ - protected final Parameter modelId = Parameter.stringParam(KNNConstants.MODEL_ID, false, m -> toType(m).modelId, null); + protected final Parameter modelId = Parameter.stringParam( + KNNConstants.MODEL_ID, + false, + m -> toType(m).fieldType().getAnnConfig().getModelId().orElse(null), + null + ); /** * knnMethodContext parameter allows a user to define their k-NN library index configuration. Defaults to an L2 @@ -137,7 +133,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { false, () -> null, (n, c, o) -> KNNMethodContext.parse(o), - m -> toType(m).knnMethod + m -> toType(m).originalKNNMethodContext ).setSerializer(((b, n, v) -> { b.startObject(n); v.toXContent(b, ToXContent.EMPTY_PARAMS); @@ -164,35 +160,15 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected final Parameter> meta = Parameter.metaParam(); - protected String spaceType; - protected String m; - protected String efConstruction; - protected ModelDao modelDao; - protected Version indexCreatedVersion; + private KNNMethodContext resolvedKNNMethodContext; - public Builder(String name, ModelDao modelDao, Version indexCreatedVersion) { + public Builder(String name, ModelDao modelDao, Version indexCreatedVersion, KNNMethodContext resolvedKNNMethodContext) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; - } - - /** - * This constructor is for legacy purposes. - * Checkout ODFE PR 288 - * - * @param name field name - * @param spaceType Spacetype of field - * @param m m value of field - * @param efConstruction efConstruction value of field - */ - public Builder(String name, String spaceType, String m, String efConstruction, Version indexCreatedVersion) { - super(name); - this.spaceType = spaceType; - this.m = m; - this.efConstruction = efConstruction; - this.indexCreatedVersion = indexCreatedVersion; + this.resolvedKNNMethodContext = resolvedKNNMethodContext; } @Override @@ -212,6 +188,11 @@ protected Explicit ignoreMalformed(BuilderContext context) { @Override public KNNVectorFieldMapper build(BuilderContext context) { + final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); + final CopyTo copyToBuilder = copyTo.build(); + final Explicit ignoreMalformed = ignoreMalformed(context); + final Map metaValue = meta.getValue(); + // Originally, a user would use index settings to set the spaceType, efConstruction and m hnsw // parameters. Upon further review, it makes sense to set these parameters in the mapping of a // particular field. However, because users migrating from older versions will still use the index @@ -219,45 +200,20 @@ public KNNVectorFieldMapper build(BuilderContext context) { // handle this, we first check if the mapping is set, and, if so use it. If not, we check if the model is // set. If not, we fall back to the parameters set in the index settings. This means that if a user sets // the mappings, setting the index settings will have no impact. - - final KNNMethodContext knnMethodContext = this.knnMethodContext.getValue(); - setDefaultSpaceType(knnMethodContext, vectorDataType.getValue()); - validateSpaceType(knnMethodContext, vectorDataType.getValue()); - validateDimensions(knnMethodContext, vectorDataType.getValue()); - validateEncoder(knnMethodContext, vectorDataType.getValue()); - final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); - final CopyTo copyToBuilder = copyTo.build(); - final Explicit ignoreMalformed = ignoreMalformed(context); - final Map metaValue = meta.getValue(); - - if (knnMethodContext != null) { - validateVectorDataType(knnMethodContext, vectorDataType.getValue()); - knnMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); + String modelIdAsString = this.modelId.get(); + if (modelIdAsString != null) { + // Because model information is stored in cluster metadata, we are unable to get it here. This is + // because to get the cluster metadata, you need access to the cluster state. Because this code is + // sometimes used to initialize the cluster state/update cluster state, we cannot get the state here + // safely. So, we are unable to validate the model. The model gets validated during ingestion. final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( buildFullName(context), metaValue, - dimension.getValue(), - knnMethodContext, - vectorDataType.getValue() + vectorDataType.getValue(), + new ANNConfig(ANNConfigType.FROM_MODEL, null, modelIdAsString, null) ); - if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) { - log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); - LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = - LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() - .name(name) - .mappedFieldType(mappedFieldType) - .multiFields(multiFieldsBuilder) - .copyTo(copyToBuilder) - .ignoreMalformed(ignoreMalformed) - .stored(stored.get()) - .hasDocValues(hasDocValues.get()) - .vectorDataType(vectorDataType.getValue()) - .knnMethodContext(knnMethodContext) - .build(); - return new LuceneFieldMapper(createLuceneFieldMapperInput); - } - return new MethodFieldMapper( + return new ModelFieldMapper( name, mappedFieldType, multiFieldsBuilder, @@ -265,67 +221,113 @@ public KNNVectorFieldMapper build(BuilderContext context) { ignoreMalformed, stored.get(), hasDocValues.get(), - knnMethodContext + modelDao, + indexCreatedVersion ); } - String modelIdAsString = this.modelId.get(); - if (modelIdAsString != null) { - // Because model information is stored in cluster metadata, we are unable to get it here. This is - // because to get the cluster metadata, you need access to the cluster state. Because this code is - // sometimes used to initialize the cluster state/update cluster state, we cannot get the state here - // safely. So, we are unable to validate the model. The model gets validated during ingestion. + boolean isResolvedNull = resolvedKNNMethodContext == null; + if (isResolvedNull) { + resolvedKNNMethodContext = this.knnMethodContext.getValue(); + setDefaultSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); + validateSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); + validateDimensions(resolvedKNNMethodContext, vectorDataType.getValue()); + validateEncoder(resolvedKNNMethodContext, vectorDataType.getValue()); + } - return new ModelFieldMapper( + // If the field mapper is using the legacy context and being constructed from another field mapper, + // the settings will be empt. See https://github.com/opendistro-for-elasticsearch/k-NN/issues/288. In this + // case, the input resolvedKNNMethodContext will be null and the settings wont exist (so flat mapper should + // be used). Otherwise, we need to check the setting. + boolean isSettingPresent = KNNSettings.IS_KNN_INDEX_SETTING.exists(context.indexSettings()); + if (isResolvedNull && (!isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(context.indexSettings()))) { + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + buildFullName(context), + metaValue, + vectorDataType.getValue(), + new ANNConfig(ANNConfigType.SKIP, null, null, dimension.getValue()) + ); + return new FlatVectorFieldMapper( name, - new KNNVectorFieldType(buildFullName(context), metaValue, -1, knnMethodContext, modelIdAsString), + mappedFieldType, multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), hasDocValues.get(), - modelDao, - modelIdAsString, indexCreatedVersion ); } - // Build legacy - if (this.spaceType == null) { - this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings(), vectorDataType.getValue()); - } - - if (this.m == null) { - this.m = LegacyFieldMapper.getM(context.indexSettings()); - } - - if (this.efConstruction == null) { - this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings(), indexCreatedVersion); + // If the knnMethodContext is null at this point, that means user built the index with the legacy k-NN + // settings to specify algo params. We need to convert this here to a KNNMethodContext so that we can + // properly configure the rest of the index + if (resolvedKNNMethodContext == null) { + resolvedKNNMethodContext = new KNNMethodContext( + KNNEngine.NMSLIB, + KNNVectorFieldMapperUtil.getSpaceType(context.indexSettings(), vectorDataType.getValue()), + new MethodComponentContext( + METHOD_HNSW, + Map.of( + METHOD_PARAMETER_M, + KNNVectorFieldMapperUtil.getM(context.indexSettings()), + METHOD_PARAMETER_EF_CONSTRUCTION, + KNNVectorFieldMapperUtil.getEfConstruction(context.indexSettings(), indexCreatedVersion) + ) + ) + ); + // Validates and throws exception if index.knn is set to true in the index settings + // using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper + // and it only supports float VectorDataType + if (VectorDataType.FLOAT != vectorDataType.get()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is not supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue().getValue(), + NMSLIB_NAME + ) + ); + } } - // Validates and throws exception if index.knn is set to true in the index settings - // using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper - // and it only supports float VectorDataType - validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType); - - return new LegacyFieldMapper( + validateVectorDataType(resolvedKNNMethodContext, vectorDataType.getValue()); + resolvedKNNMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + buildFullName(context), + metaValue, + vectorDataType.getValue(), + new ANNConfig(ANNConfigType.FROM_METHOD, resolvedKNNMethodContext, null, dimension.getValue()) + ); + if (resolvedKNNMethodContext.getKnnEngine() == KNNEngine.LUCENE) { + log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); + LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = LuceneFieldMapper.CreateLuceneFieldMapperInput + .builder() + .name(name) + .mappedFieldType(mappedFieldType) + .multiFields(multiFieldsBuilder) + .copyTo(copyToBuilder) + .ignoreMalformed(ignoreMalformed) + .stored(stored.get()) + .hasDocValues(hasDocValues.get()) + .vectorDataType(vectorDataType.getValue()) + .indexVersion(indexCreatedVersion) + .originalKnnMethodContext(knnMethodContext.get()) + .build(); + return new LuceneFieldMapper(createLuceneFieldMapperInput); + } + + return new MethodFieldMapper( name, - new KNNVectorFieldType( - buildFullName(context), - metaValue, - dimension.getValue(), - vectorDataType.getValue(), - SpaceType.getSpace(spaceType) - ), + mappedFieldType, multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), hasDocValues.get(), - spaceType, - m, - efConstruction, - indexCreatedVersion + indexCreatedVersion, + knnMethodContext.get() ); } @@ -430,7 +432,7 @@ public TypeParser(Supplier modelDaoSupplier) { @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { - Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated()); + Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated(), null); builder.parse(name, parserContext, node); // All parse(String name, Map node, ParserCont } // Dimension should not be null unless modelId is used - if (builder.dimension.getValue() == -1 && builder.modelId.get() == null) { + if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER && builder.modelId.get() == null) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", name)); } @@ -452,18 +454,18 @@ public Mapper.Builder parse(String name, Map node, ParserCont } } + // We store the version of the index with the mapper as different version of Opensearch has different default + // values of KNN engine Algorithms hyperparameters. + protected Version indexCreatedVersion; protected Explicit ignoreMalformed; protected boolean stored; protected boolean hasDocValues; - protected Integer dimension; protected VectorDataType vectorDataType; protected ModelDao modelDao; - // These members map to parameters in the builder. They need to be declared in the abstract class due to the - // "toType" function used in the builder. So, when adding a parameter, it needs to be added here, but set in a - // subclass (if it is unique). - protected KNNMethodContext knnMethod; - protected String modelId; + // We need to ensure that the original KNNMethodContext as parsed is stored to initialize the + // Builder for serialization. So, we need to store it here. + protected KNNMethodContext originalKNNMethodContext; public KNNVectorFieldMapper( String simpleName, @@ -473,16 +475,17 @@ public KNNVectorFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - Version indexCreatedVersion + Version indexCreatedVersion, + KNNMethodContext originalKNNMethodContext ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; this.stored = stored; this.hasDocValues = hasDocValues; - this.dimension = mappedFieldType.getDimension(); this.vectorDataType = mappedFieldType.getVectorDataType(); updateEngineStats(); this.indexCreatedVersion = indexCreatedVersion; + this.originalKNNMethodContext = originalKNNMethodContext; } public KNNVectorFieldMapper clone() { @@ -498,20 +501,11 @@ protected String contentType() { protected void parseCreateField(ParseContext context) throws IOException { parseCreateField( context, - fieldType().getDimension(), - fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()), + fieldType().getAnnConfig().getDimension().orElseThrow(() -> new IllegalArgumentException("Dimension is not set")), fieldType().getVectorDataType() ); } - private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { - if (Objects.isNull(knnMethodContext)) { - return null; - } - return knnMethodContext.getMethodComponentContext(); - } - /** * Function returns a list of fields to be indexed when the vector is float type. * @@ -544,17 +538,37 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField( - ParseContext context, - int dimension, - SpaceType spaceType, - MethodComponentContext methodComponentContext, - VectorDataType vectorDataType - ) throws IOException { - + /** + * Validation checks before parsing of doc begins + */ + protected void validatePreparse() { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - spaceType.validateVectorDataType(vectorDataType); + } + + /** + * Getter for vector validator after vector parsing + * + * @return VectorValidator + */ + protected abstract VectorValidator getVectorValidator(); + + /** + * Getter for per dimension validator during vector parsing + * + * @return PerDimensionValidator + */ + protected abstract PerDimensionValidator getPerDimensionValidator(); + + /** + * Getter for per dimension processor during vector parsing + * + * @return PerDimensionProcessor + */ + protected abstract PerDimensionProcessor getPerDimensionProcessor(); + + protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { + validatePreparse(); if (VectorDataType.BINARY == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); @@ -563,7 +577,7 @@ protected void parseCreateField( return; } final byte[] array = bytesArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForByteVector(array, fieldType)); } else if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); @@ -572,16 +586,16 @@ protected void parseCreateField( return; } final byte[] array = bytesArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForByteVector(array, fieldType)); } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); if (floatsArrayOptional.isEmpty()) { return; } final float[] array = floatsArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForFloatVector(array, fieldType)); } else { throw new IllegalArgumentException( @@ -592,80 +606,28 @@ protected void parseCreateField( context.path().remove(); } - // Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" - protected boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { - if (Objects.isNull(methodComponentContext)) { - return false; - } - - if (methodComponentContext.getParameters().size() == 0) { - return false; - } - - Map methodComponentParams = methodComponentContext.getParameters(); - - // The method component parameters should have an encoder - if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { - return false; - } - - // Validate if the object is of type MethodComponentContext before casting it later - if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { - return false; - } - - MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); - - // returns true if encoder name is "sq" and type is "fp16" - return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) - && FAISS_SQ_ENCODER_FP16.equals( - encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) - ); - - } - - // Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index - // using "sq" encoder of type "fp16". - protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { - if (Objects.nonNull(methodComponentContext)) { - return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); - } - return false; - } - - void validateIfCircuitBreakerIsNotTriggered() { - if (KNNSettings.isCircuitBreakerTriggered()) { - throw new KnnCircuitBreakerException( - "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." - ); - } - } - - void validateIfKNNPluginEnabled() { - if (!KNNSettings.isKNNPluginEnabled()) { - throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true"); - } - } - // Returns an optional array of byte values where each value in the vector is parsed as a float and validated // if it is a finite number without any decimals and within the byte range of [-128 to 127]. Optional getBytesFromContext(ParseContext context, int dimension, VectorDataType dataType) throws IOException { context.path().add(simpleName()); + PerDimensionValidator perDimensionValidator = getPerDimensionValidator(); + PerDimensionProcessor perDimensionProcessor = getPerDimensionProcessor(); + ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { - float value = context.parser().floatValue(); - validateByteVectorValue(value, dataType); + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { - float value = context.parser().floatValue(); - validateByteVectorValue(value, dataType); + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); vector.add((byte) value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -681,21 +643,11 @@ Optional getBytesFromContext(ParseContext context, int dimension, Vector return Optional.of(array); } - Optional getFloatsFromContext(ParseContext context, int dimension, MethodComponentContext methodComponentContext) - throws IOException { + Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { context.path().add(simpleName()); - // Returns an optional array of float values where each value in the vector is parsed as a float and validated - // if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'. - // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be - // clipped to FP16 range. - boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext); - boolean clipVectorValueToFP16RangeFlag = false; - if (isFaissSQfp16Flag) { - clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) - ); - } + PerDimensionValidator perDimensionValidator = getPerDimensionValidator(); + PerDimensionProcessor perDimensionProcessor = getPerDimensionProcessor(); ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); @@ -703,31 +655,14 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { - value = context.parser().floatValue(); - if (isFaissSQfp16Flag) { - if (clipVectorValueToFP16RangeFlag) { - value = clipVectorValueToFP16Range(value); - } else { - validateFP16VectorValue(value); - } - } else { - validateFloatVectorValue(value); - } - + value = perDimensionProcessor.process(context.parser().floatValue()); + perDimensionValidator.validate(value); vector.add(value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { - value = context.parser().floatValue(); - if (isFaissSQfp16Flag) { - if (clipVectorValueToFP16RangeFlag) { - value = clipVectorValueToFP16Range(value); - } else { - validateFP16VectorValue(value); - } - } else { - validateFloatVectorValue(value); - } + value = perDimensionProcessor.process(context.parser().floatValue()); + perDimensionValidator.validate(value); vector.add(value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -746,7 +681,12 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this); + return new KNNVectorFieldMapper.Builder( + simpleName(), + modelDao, + indexCreatedVersion, + fieldType().getAnnConfig().getKnnMethodContext().orElse(null) + ).init(this); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 2adbbb6953..9093074b2d 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -13,37 +13,52 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.util.BytesRef; -import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.Version; +import org.opensearch.common.settings.Settings; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.KnnCircuitBreakerException; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.util.IndexHyperParametersUtil; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; import java.util.Arrays; import java.util.Locale; +import java.util.Map; +import java.util.Objects; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; -import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; /** * Utility class for KNNVectorFieldMapper */ +@Log4j2 @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorFieldMapperUtil { @@ -150,35 +165,6 @@ public static void validateVectorDataType(KNNMethodContext methodContext, Vector throw new IllegalArgumentException("This line should not be reached"); } - /** - * Validates and throws exception if index.knn is set to true in the index settings - * using any VectorDataType (other than float, which is default) because we are using NMSLIB engine - * for LegacyFieldMapper, and it only supports float VectorDataType - * - * @param knnIndexSetting index.knn setting in the index settings - * @param vectorDataType VectorDataType Parameter - */ - public static void validateVectorDataTypeWithKnnIndexSetting( - boolean knnIndexSetting, - ParametrizedFieldMapper.Parameter vectorDataType - ) { - - if (VectorDataType.FLOAT == vectorDataType.getValue()) { - return; - } - if (knnIndexSetting) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is not supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue().getValue(), - NMSLIB_NAME - ) - ); - } - } - /** * @param knnEngine KNNEngine * @return DocValues FieldType of type Binary @@ -237,18 +223,18 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy * @return expected vector length */ public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFieldType) { - int expectedDimensions = knnVectorFieldType.getDimension(); - if (isModelBasedIndex(expectedDimensions)) { + int expectedDimensions; + if (ANNConfigType.FROM_MODEL == knnVectorFieldType.getAnnConfig().getAnnConfigType()) { ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); expectedDimensions = modelMetadata.getDimension(); + } else { + expectedDimensions = knnVectorFieldType.getAnnConfig() + .getDimension() + .orElseThrow(() -> new IllegalArgumentException("Dimension not found")); } return VectorDataType.BINARY == knnVectorFieldType.getVectorDataType() ? expectedDimensions / 8 : expectedDimensions; } - private static boolean isModelBasedIndex(int expectedDimensions) { - return expectedDimensions == -1; - } - /** * Returns the model metadata for a specified knn vector field * @@ -256,18 +242,165 @@ private static boolean isModelBasedIndex(int expectedDimensions) { * @return the model metadata from knnVectorField */ private static ModelMetadata getModelMetadataForField(final KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); + ModelMetadata[] modelMetadata = new ModelMetadata[1]; + knnVectorField.getAnnConfig().getModelId().ifPresentOrElse(modelId -> { + modelMetadata[0] = modelDao.getMetadata(modelId); + if (!ModelUtil.isModelCreated(modelMetadata[0])) { + throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + } + }, () -> { throw new IllegalArgumentException(String.format("Field '%s' does not have model.", knnVectorField.name())); }); + return modelMetadata[0]; + } - if (modelId == null) { - throw new IllegalArgumentException( - String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) + /** + * Validate if the circuit breaker is triggered + */ + public static void validateIfCircuitBreakerIsNotTriggered() { + if (KNNSettings.isCircuitBreakerTriggered()) { + throw new KnnCircuitBreakerException( + "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." + ); + } + } + + /** + * Validate if plugin is enabled + */ + public static void validateIfKNNPluginEnabled() { + if (!KNNSettings.isKNNPluginEnabled()) { + throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true"); + } + } + + /** + * Get space type or default based on settings and data type + * + * @param indexSettings Settings + * @param vectorDataType VectorDataType + * @return SpaceType as string + */ + public static SpaceType getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { + String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); + if (spaceType == null) { + spaceType = VectorDataType.BINARY == vectorDataType + ? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY + : KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", + METHOD_PARAMETER_SPACE_TYPE, + spaceType + ) + ); + } + return SpaceType.getSpace(spaceType); + } + + /** + * Get M parameter + * + * @param indexSettings IndexSettings + * @return M as string + */ + public static int getM(Settings indexSettings) { + String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); + if (m == null) { + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", + HNSW_ALGO_M, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M + ) + ); + return KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M; + } + return Integer.parseInt(m); + } + + /** + * Get EF_CONSTRUCTION parameter + * + * @param indexSettings IndexSettings + * @return EF_CONSTRUCTION as string + */ + public static int getEfConstruction(Settings indexSettings, Version indexVersion) { + final String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); + if (efConstruction == null) { + final int defaultEFConstructionValue = IndexHyperParametersUtil.getHNSWEFConstructionValue(indexVersion); + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. " + + "Picking up default value for the index =%s", + HNSW_ALGO_EF_CONSTRUCTION, + defaultEFConstructionValue + ) + ); + return defaultEFConstructionValue; + } + return Integer.parseInt(efConstruction); + } + + /** + * Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" + * + * @param methodComponentContext MethodComponentContext + * @return true if it is a "faiss" Index using "sq" encoder of type "fp16" + */ + public static boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { + if (Objects.isNull(methodComponentContext)) { + return false; + } + + if (methodComponentContext.getParameters().size() == 0) { + return false; + } + + Map methodComponentParams = methodComponentContext.getParameters(); + + // The method component parameters should have an encoder + if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { + return false; + } + + // Validate if the object is of type MethodComponentContext before casting it later + if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { + return false; + } + + MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); + + // returns true if encoder name is "sq" and type is "fp16" + return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) + && FAISS_SQ_ENCODER_FP16.equals( + encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) ); + + } + + /** + * Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index + * using "sq" encoder of type "fp16". + * + * @param methodComponentContext MethodComponentContext + * @return boolean value of "clip" parameter + */ + public static boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { + if (Objects.nonNull(methodComponentContext)) { + return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); } + return false; + } - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (!ModelUtil.isModelCreated(modelMetadata)) { - throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + /** + * Extract MethodComponentContext from KNNMethodContext + * + * @param knnMethodContext KNNMethodContext + * @return MethodComponentContext + */ + public static MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { + if (Objects.isNull(knnMethodContext)) { + return null; } - return modelMetadata; + return knnMethodContext.getMethodComponentContext(); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 8c3815c5f9..9097f03f88 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -9,7 +9,6 @@ import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; -import org.opensearch.common.Nullable; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextSearchInfo; @@ -17,9 +16,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.index.KNNVectorIndexFieldData; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; @@ -27,7 +24,6 @@ import java.util.Map; import java.util.function.Supplier; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; /** @@ -35,49 +31,21 @@ */ @Getter public class KNNVectorFieldType extends MappedFieldType { - int dimension; - String modelId; - KNNMethodContext knnMethodContext; + ANNConfig annConfig; VectorDataType vectorDataType; - SpaceType spaceType; - public KNNVectorFieldType(String name, Map meta, int dimension, VectorDataType vectorDataType, SpaceType spaceType) { - this(name, meta, dimension, null, null, vectorDataType, spaceType); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { - this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - KNNMethodContext knnMethodContext, - VectorDataType vectorDataType - ) { - this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - @Nullable KNNMethodContext knnMethodContext, - @Nullable String modelId, - VectorDataType vectorDataType, - @Nullable SpaceType spaceType - ) { + /** + * Constructor for KNNVectorFieldType. + * + * @param name name of the field + * @param meta metadata of the field + * @param vectorDataType data type of the vector + * @param annConfig configuration context for the ANN index + */ + public KNNVectorFieldType(String name, Map meta, VectorDataType vectorDataType, ANNConfig annConfig) { super(name, false, false, true, TextSearchInfo.NONE, meta); - this.dimension = dimension; - this.modelId = modelId; - this.knnMethodContext = knnMethodContext; this.vectorDataType = vectorDataType; - this.spaceType = spaceType; + this.annConfig = annConfig; } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java deleted file mode 100644 index cf5ec933a1..0000000000 --- a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.mapper; - -import lombok.extern.log4j.Log4j2; -import org.apache.lucene.document.FieldType; -import org.opensearch.Version; -import org.opensearch.common.Explicit; -import org.opensearch.common.settings.Settings; -import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.util.IndexHyperParametersUtil; -import org.opensearch.knn.index.engine.KNNEngine; - -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; - -/** - * Field mapper for original implementation. It defaults to using nmslib as the engine and retrieves parameters from index settings. - * - * Example of this mapper output: - * - * { - * "type": "knn_vector", - * "dimension": 128 - * } - */ -@Log4j2 -public class LegacyFieldMapper extends KNNVectorFieldMapper { - - protected String spaceType; - protected String m; - protected String efConstruction; - - LegacyFieldMapper( - String simpleName, - KNNVectorFieldType mappedFieldType, - MultiFields multiFields, - CopyTo copyTo, - Explicit ignoreMalformed, - boolean stored, - boolean hasDocValues, - String spaceType, - String m, - String efConstruction, - Version indexCreatedVersion - ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); - - this.spaceType = spaceType; - this.m = m; - this.efConstruction = efConstruction; - - this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - - this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); - this.fieldType.putAttribute(SPACE_TYPE, spaceType); - this.fieldType.putAttribute(KNN_ENGINE, KNNEngine.NMSLIB.getName()); - - // These are extra just for legacy - this.fieldType.putAttribute(HNSW_ALGO_M, m); - this.fieldType.putAttribute(HNSW_ALGO_EF_CONSTRUCTION, efConstruction); - - this.fieldType.freeze(); - } - - @Override - public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction, this.indexCreatedVersion).init( - this - ); - } - - static String getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { - String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); - if (spaceType == null) { - spaceType = VectorDataType.BINARY == vectorDataType - ? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY - : KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", - METHOD_PARAMETER_SPACE_TYPE, - spaceType - ) - ); - } - return spaceType; - } - - static String getM(Settings indexSettings) { - String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); - if (m == null) { - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", - HNSW_ALGO_M, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M - ) - ); - return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M); - } - return m; - } - - static String getEfConstruction(Settings indexSettings, Version indexVersion) { - final String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); - if (efConstruction == null) { - final String defaultEFConstructionValue = String.valueOf(IndexHyperParametersUtil.getHNSWEFConstructionValue(indexVersion)); - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. " - + "Picking up default value for the index =%s", - HNSW_ALGO_EF_CONSTRUCTION, - defaultEFConstructionValue - ) - ); - return defaultEFConstructionValue; - } - return efConstruction; - } -} diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index c82afb9e72..08b83b13e3 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -16,11 +16,12 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.Version; import org.opensearch.common.Explicit; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; @@ -35,6 +36,10 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { private final FieldType vectorFieldType; private final VectorDataType vectorDataType; + private PerDimensionProcessor perDimensionProcessor; + private PerDimensionValidator perDimensionValidator; + private VectorValidator vectorValidator; + LuceneFieldMapper(final CreateLuceneFieldMapperInput input) { super( input.getName(), @@ -44,16 +49,22 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { input.getIgnoreMalformed(), input.isStored(), input.isHasDocValues(), - input.getKnnMethodContext().getMethodComponentContext().getIndexVersion() + input.getIndexVersion(), + input.getOriginalKnnMethodContext() ); - + ANNConfig annConfig = input.mappedFieldType.getAnnConfig(); + KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context is missing")); vectorDataType = input.getVectorDataType(); - this.knnMethod = input.getKnnMethodContext(); - final VectorSimilarityFunction vectorSimilarityFunction = this.knnMethod.getSpaceType() + + final VectorSimilarityFunction vectorSimilarityFunction = knnMethodContext.getSpaceType() .getKnnVectorSimilarityFunction() .getVectorSimilarityFunction(); - final int dimension = input.getMappedFieldType().getDimension(); + final int dimension = input.getMappedFieldType() + .getAnnConfig() + .getDimension() + .orElseThrow(() -> new IllegalArgumentException("Dimension is missing")); if (dimension > KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE)) { throw new IllegalArgumentException( String.format( @@ -69,10 +80,13 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { this.fieldType = vectorDataType.createKnnVectorFieldType(dimension, vectorSimilarityFunction); if (this.hasDocValues) { - this.vectorFieldType = buildDocValuesFieldType(this.knnMethod.getKnnEngine()); + this.vectorFieldType = buildDocValuesFieldType(knnMethodContext.getKnnEngine()); } else { this.vectorFieldType = null; } + + initValidatorsAndProcessors(knnMethodContext); + knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); } @Override @@ -105,6 +119,36 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fieldsToBeAdded; } + private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { + this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + if (VectorDataType.BINARY == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + return; + } + + if (VectorDataType.BYTE == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + return; + } + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; + } + @Override void updateEngineStats() { KNNEngine.LUCENE.setInitialized(true); @@ -127,7 +171,7 @@ static class CreateLuceneFieldMapperInput { boolean stored; boolean hasDocValues; VectorDataType vectorDataType; - @NonNull - KNNMethodContext knnMethodContext; + Version indexVersion; + KNNMethodContext originalKnnMethodContext; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index b15ab14894..6d53109ce7 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -6,25 +6,36 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; import java.io.IOException; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.getMethodComponentContext; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for method definition in mapping */ public class MethodFieldMapper extends KNNVectorFieldMapper { + private PerDimensionProcessor perDimensionProcessor; + private PerDimensionValidator perDimensionValidator; + private VectorValidator vectorValidator; + MethodFieldMapper( String simpleName, KNNVectorFieldType mappedFieldType, @@ -33,7 +44,8 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - KNNMethodContext knnMethodContext + Version indexVerision, + KNNMethodContext originalKNNMethodContext ) { super( @@ -44,14 +56,18 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { ignoreMalformed, stored, hasDocValues, - knnMethodContext.getMethodComponentContext().getIndexVersion() + indexVerision, + originalKNNMethodContext ); - - this.knnMethod = knnMethodContext; - + ANNConfig annConfig = mappedFieldType.getAnnConfig(); + KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); + this.fieldType.putAttribute( + DIMENSION, + String.valueOf(annConfig.getDimension().orElseThrow(() -> new IllegalArgumentException("Dimension cannot be empty"))) + ); this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); @@ -66,5 +82,57 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { } this.fieldType.freeze(); + initValidatorsAndProcessors(knnMethodContext); + knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); + } + + private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { + this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); + + if (VectorDataType.BINARY == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + if (VectorDataType.BYTE == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + MethodComponentContext methodComponentContext = getMethodComponentContext(knnMethodContext); + if (!isFaissSQfp16(methodComponentContext)) { + // Normal float and byte processor + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FP16_VALIDATOR; + + if (!isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) + )) { + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + this.perDimensionProcessor = PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index adaaef28e6..07c363e4ed 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -9,19 +9,41 @@ import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for model in mapping */ public class ModelFieldMapper extends KNNVectorFieldMapper { + // If the dimension has not yet been set because we do not have access to model metadata, it will be -1 + public static final int UNSET_MODEL_DIMENSION_IDENTIFIER = -1; + + private final AtomicReference spaceType; + private final AtomicReference methodComponentContext; + private final AtomicInteger dimension; + private final AtomicReference vectorDataType; + + private final AtomicReference perDimensionProcessor; + private final AtomicReference perDimensionValidator; + private final AtomicReference vectorValidator; + + private final String modelId; + ModelFieldMapper( String simpleName, KNNVectorFieldType mappedFieldType, @@ -31,24 +53,32 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { boolean stored, boolean hasDocValues, ModelDao modelDao, - String modelId, Version indexCreatedVersion ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); - - this.modelId = modelId; + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + ANNConfig annConfig = mappedFieldType.getAnnConfig(); + modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); this.modelDao = modelDao; + this.spaceType = new AtomicReference<>(null); + this.methodComponentContext = new AtomicReference<>(null); + this.dimension = new AtomicInteger(UNSET_MODEL_DIMENSION_IDENTIFIER); + this.vectorDataType = new AtomicReference<>(null); + this.perDimensionProcessor = new AtomicReference<>(null); + this.perDimensionValidator = new AtomicReference<>(null); + this.vectorValidator = new AtomicReference<>(null); + this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); this.fieldType.freeze(); } @Override - protected void parseCreateField(ParseContext context) throws IOException { + protected void validatePreparse() { + super.validatePreparse(); // For the model field mapper, we cannot validate the model during index creation due to // an issue with reading cluster state during mapper creation. So, we need to validate the - // model when ingestion starts. + // model when ingestion starts. We do this as lazily as we can ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { @@ -56,18 +86,76 @@ protected void parseCreateField(ParseContext context) throws IOException { String.format( "Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.", modelId, - context.mapperService().index().getName(), + simpleName(), MODEL_ID ) ); } - parseCreateField( - context, - modelMetadata.getDimension(), - modelMetadata.getSpaceType(), - modelMetadata.getMethodComponentContext(), - modelMetadata.getVectorDataType() - ); + maybeInitLazyVariables(modelMetadata); + } + + private void maybeInitLazyVariables(ModelMetadata modelMetadata) { + vectorDataType.compareAndExchange(null, modelMetadata.getVectorDataType()); + if (spaceType.get() == null) { + spaceType.compareAndExchange(null, modelMetadata.getSpaceType()); + spaceType.get().validateVectorDataType(vectorDataType.get()); + } + methodComponentContext.compareAndExchange(null, modelMetadata.getMethodComponentContext()); + dimension.compareAndExchange(UNSET_MODEL_DIMENSION_IDENTIFIER, modelMetadata.getDimension()); + maybeInitValidatorsAndProcessors(); + } + + private void maybeInitValidatorsAndProcessors() { + this.vectorValidator.compareAndExchange(null, new SpaceVectorValidator(spaceType.get())); + + if (VectorDataType.BINARY == vectorDataType.get()) { + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BIT_VALIDATOR); + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + + if (VectorDataType.BYTE == vectorDataType.get()) { + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BYTE_VALIDATOR); + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + + if (!isFaissSQfp16(methodComponentContext.get())) { + // Normal float and byte processor + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR); + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FP16_VALIDATOR); + if (!isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) methodComponentContext.get().getParameters().get(METHOD_ENCODER_PARAMETER) + )) { + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR); + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator.get(); + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator.get(); + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor.get(); + } + + @Override + protected void parseCreateField(ParseContext context) throws IOException { + validatePreparse(); + parseCreateField(context, dimension.get(), vectorDataType.get()); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java new file mode 100644 index 0000000000..3b927f2df1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; + +/** + * Process values per dimension. Good to have if we want to do some kind of cleanup on data as it is coming in. + */ +public interface PerDimensionProcessor { + + /** + * Process float value per dimension. + * + * @param value value to process + * @return processed value + */ + float process(float value); + + /** + * Process byte as float value per dimension. + * + * @param value value to process + * @return processed value + */ + float processByte(float value); + + PerDimensionProcessor NOOP_PROCESSOR = new PerDimensionProcessor() { + @Override + public float process(float value) { + return value; + } + + @Override + public float processByte(float value) { + return value; + } + }; + + // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be + // clipped to FP16 range. + PerDimensionProcessor CLIP_TO_FP16_PROCESSOR = new PerDimensionProcessor() { + @Override + public float process(float value) { + return clipVectorValueToFP16Range(value); + } + + @Override + public float processByte(float value) { + throw new IllegalStateException("CLIP_TO_FP16_PROCESSOR should not be called with byte type"); + } + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java new file mode 100644 index 0000000000..09416bc14a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.index.VectorDataType; + +import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; +import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; + +/** + * Validates per dimension fields + */ +public interface PerDimensionValidator { + /** + * Validates the given float is valid for the configuration + * + * @param value to validate + */ + void validate(float value); + + /** + * Validates the given float as a byte is valid for the configuration. + * + * @param value to validate + */ + void validateByte(float value); + + PerDimensionValidator DEFAULT_FLOAT_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + validateFloatVectorValue(value); + } + + @Override + public void validateByte(float value) { + throw new IllegalStateException("DEFAULT_FLOAT_VALIDATOR should only be used for float vectors"); + } + }; + + // Validates if it is a finite number and within the fp16 range of [-65504 to 65504]. + PerDimensionValidator DEFAULT_FP16_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + validateFP16VectorValue(value); + } + + @Override + public void validateByte(float value) { + throw new IllegalStateException("DEFAULT_FP16_VALIDATOR should only be used for float vectors"); + } + }; + + PerDimensionValidator DEFAULT_BYTE_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + throw new IllegalStateException("DEFAULT_BYTE_VALIDATOR should only be used for byte values"); + } + + @Override + public void validateByte(float value) { + validateByteVectorValue(value, VectorDataType.BYTE); + } + }; + + PerDimensionValidator DEFAULT_BIT_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + throw new IllegalStateException("DEFAULT_BIT_VALIDATOR should only be used for byte values"); + } + + @Override + public void validateByte(float value) { + validateByteVectorValue(value, VectorDataType.BINARY); + } + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java b/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java new file mode 100644 index 0000000000..6ff088604c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AllArgsConstructor; +import org.opensearch.knn.index.SpaceType; + +/** + * Confirms that a given vector is valid for the provided space type + */ +@AllArgsConstructor +public class SpaceVectorValidator implements VectorValidator { + + private final SpaceType spaceType; + + @Override + public void validateVector(byte[] vector) { + spaceType.validateVector(vector); + } + + @Override + public void validateVector(float[] vector) { + spaceType.validateVector(vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java b/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java new file mode 100644 index 0000000000..86607d3439 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +/** + * Class validates vector after it has been parsed + */ +public interface VectorValidator { + /** + * Validate if the given byte vector is supported + * + * @param vector the given vector + */ + void validateVector(byte[] vector); + + /** + * Validate if the given float vector is supported + * + * @param vector the given vector + */ + void validateVector(float[] vector); + + VectorValidator NOOP_VECTOR_VALIDATOR = new VectorValidator() { + @Override + public void validateVector(byte[] vector) {} + + @Override + public void validateVector(float[] vector) {} + }; +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 6d57cb2dd5..15e2b34a76 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -24,6 +24,8 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.engine.model.QueryContext; +import org.opensearch.knn.index.mapper.ANNConfig; +import org.opensearch.knn.index.mapper.ANNConfigType; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -341,36 +343,49 @@ protected Query doToQuery(QueryShardContext context) { if (!(mappedFieldType instanceof KNNVectorFieldType)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not knn_vector type.", this.fieldName)); } - KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) mappedFieldType; - int fieldDimension = knnVectorFieldType.getDimension(); - KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); - MethodComponentContext methodComponentContext = null; - KNNEngine knnEngine = KNNEngine.DEFAULT; - VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); - SpaceType spaceType = knnVectorFieldType.getSpaceType(); - VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); - updateQueryStats(vectorQueryType); - - if (fieldDimension == -1) { - if (spaceType != null) { - throw new IllegalStateException("Space type should be null when the field uses a model"); - } - // If dimension is not set, the field uses a model and the information needs to be retrieved from there - ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); + ANNConfig annConfig = knnVectorFieldType.getAnnConfig(); + if (ANNConfigType.SKIP == annConfig.getAnnConfigType()) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName)); + } + + int fieldDimension; + KNNEngine knnEngine; + MethodComponentContext methodComponentContext; + SpaceType spaceType; + VectorDataType vectorDataType; + if (ANNConfigType.FROM_MODEL == annConfig.getAnnConfigType()) { + String modelId = annConfig.getModelId() + .orElseThrow( + () -> new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have model.", this.fieldName)) + ); + ModelMetadata modelMetadata = getModelMetadataForField(modelId); fieldDimension = modelMetadata.getDimension(); knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); methodComponentContext = modelMetadata.getMethodComponentContext(); vectorDataType = modelMetadata.getVectorDataType(); - - } else if (knnMethodContext != null) { - // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping - knnEngine = knnMethodContext.getKnnEngine(); - spaceType = knnMethodContext.getSpaceType(); + } else { + fieldDimension = annConfig.getDimension() + .orElseThrow( + () -> new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have dimension.", this.fieldName)) + ); + vectorDataType = knnVectorFieldType.getVectorDataType(); + KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() + .orElseThrow( + () -> new IllegalArgumentException( + String.format(Locale.ROOT, "Field '%s' does not have method definition.", this.fieldName) + ) + ); methodComponentContext = knnMethodContext.getMethodComponentContext(); + spaceType = knnMethodContext.getSpaceType(); + knnEngine = knnMethodContext.getKnnEngine(); } + VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); + updateQueryStats(vectorQueryType); + + // This could be null in the case of when a model did not have serialized methodComponent information final String method = methodComponentContext != null ? methodComponentContext.getName() : null; if (StringUtils.isNotBlank(method)) { final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method); @@ -492,13 +507,7 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME)); } - private ModelMetadata getModelMetadataForField(KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); - - if (modelId == null) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have model.", this.fieldName)); - } - + private ModelMetadata getModelMetadataForField(String modelId) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 56c129546f..8b3d0c7e04 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -7,12 +7,17 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.Version; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNLibrarySearchContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.core.common.bytes.BytesReference; @@ -20,12 +25,14 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.test.OpenSearchTestCase; +import java.util.Collections; import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; /** * Base class for integration tests for KNN plugin. Contains several methods for testing KNN ES functionality. @@ -91,4 +98,18 @@ private void initKNNSettings() { public Map xContentBuilderToMap(XContentBuilder xContentBuilder) { return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(); } + + public static synchronized KNNMethodContext getDefaultKNNMethodContext() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); + methodComponentContext.setIndexVersion(Version.CURRENT); + return defaultInstance; + } + + public static synchronized KNNMethodContext getDefaultBinaryKNNMethodContext() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT_BINARY, methodComponentContext); + methodComponentContext.setIndexVersion(Version.CURRENT); + return defaultInstance; + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java index 9a867c58dd..f71fbaae0f 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java @@ -94,7 +94,7 @@ public void testGetSpaceType() { */ public void testValidate() { // Check valid default - this should not throw any exception - assertNull(KNNMethodContext.getDefault().validate()); + assertNull(getDefaultKNNMethodContext().validate()); // Check a valid nmslib method MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index e1ebc57085..e87531561a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -251,62 +251,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } - public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { - // Set information about the segment and the fields - String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); - int docsInSegment = 100; - String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); - - KNNEngine knnEngine = KNNEngine.NMSLIB; - SpaceType spaceType = SpaceType.COSINESIMIL; - int dimension = 16; - - SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() - .directory(directory) - .segmentName(segmentName) - .docsInSegment(docsInSegment) - .codec(codec) - .build(); - - FieldInfo[] fieldInfoArray = new FieldInfo[] { - KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) - .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") - .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") - .addAttribute(KNNConstants.HNSW_ALGO_M, "16") - .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) - .build() }; - - FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); - SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); - long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); - - // Add documents to the field - KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( - docsInSegment, - dimension - ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); - - // The document should be created in the correct location - String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); - assertFileInCorrectLocation(state, expectedFile); - - // The footer should be valid - assertValidFooter(state.directory, expectedFile); - - // The document should be readable by nmslib - assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); - - // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); - assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); - } - public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); int docsInSegment = 100; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index a0b9b32d0e..0c03a93941 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -14,8 +14,10 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.Version; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; @@ -25,6 +27,8 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; +import org.opensearch.knn.index.mapper.ANNConfig; +import org.opensearch.knn.index.mapper.ANNConfigType; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.query.KNNQueryFactory; import org.opensearch.knn.jni.JNIService; @@ -56,6 +60,7 @@ import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -77,6 +82,8 @@ import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; @@ -86,14 +93,28 @@ public class KNNCodecTestCase extends KNNTestCase { private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); - private static FieldType sampleFieldType; + private static final FieldType sampleFieldType; static { + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); + String parameterString; + try { + parameterString = XContentFactory.jsonBuilder() + .map(knnMethodContext.getKnnEngine().getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - sampleFieldType.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); - sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()); - sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); - sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); - sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); + sampleFieldType.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); + sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, knnMethodContext.getKnnEngine().getName()); + sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + sampleFieldType.putAttribute(KNNConstants.PARAMETERS, parameterString); sampleFieldType.freeze(); } private static final String FIELD_NAME_ONE = "test_vector_one"; @@ -309,8 +330,19 @@ public void testKnnVectorIndex( SpaceType.L2, new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256)) ); - final KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldType(FIELD_NAME_ONE, Map.of(), 3, knnMethodContext); - final KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldType(FIELD_NAME_TWO, Map.of(), 2, knnMethodContext); + + final KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 3) + ); + final KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 2) + ); when(mapperService.fieldType(eq(FIELD_NAME_ONE))).thenReturn(mappedFieldType1); when(mapperService.fieldType(eq(FIELD_NAME_TWO))).thenReturn(mappedFieldType2); diff --git a/src/test/java/org/opensearch/knn/index/mapper/ANNConfigTests.java b/src/test/java/org/opensearch/knn/index/mapper/ANNConfigTests.java new file mode 100644 index 0000000000..c29967a104 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/mapper/ANNConfigTests.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.KNNTestCase; + +public class ANNConfigTests extends KNNTestCase { + + public void testANNConfigIsSkip() { + ANNConfigType annConfigType = ANNConfigType.SKIP; + int dimension = 4; + expectThrows(IllegalArgumentException.class, () -> new ANNConfig(annConfigType, null, "test-model", null)); + expectThrows(IllegalArgumentException.class, () -> new ANNConfig(annConfigType, getDefaultKNNMethodContext(), null, dimension)); + + ANNConfig annConfig = new ANNConfig(annConfigType, null, null, dimension); + assertTrue(annConfig.getModelId().isEmpty()); + assertTrue(annConfig.getKnnMethodContext().isEmpty()); + assertEquals(dimension, annConfig.getDimension().get().intValue()); + } + + public void testANNConfigIsFromModel() { + ANNConfigType annConfigType = ANNConfigType.FROM_MODEL; + expectThrows(IllegalArgumentException.class, () -> new ANNConfig(annConfigType, getDefaultKNNMethodContext(), "test-model", 4)); + expectThrows(IllegalArgumentException.class, () -> new ANNConfig(annConfigType, getDefaultKNNMethodContext(), null, null)); + + ANNConfig annConfig = new ANNConfig(annConfigType, null, "test-model", null); + assertTrue(annConfig.getModelId().isPresent()); + assertTrue(annConfig.getKnnMethodContext().isEmpty()); + assertTrue(annConfig.getDimension().isEmpty()); + } + + public void testANNConfigIsFromMethod() { + ANNConfigType annConfigType = ANNConfigType.FROM_METHOD; + int dimension = 4; + expectThrows(IllegalArgumentException.class, () -> new ANNConfig(annConfigType, null, "test-model", null)); + expectThrows(IllegalArgumentException.class, () -> new ANNConfig(annConfigType, getDefaultKNNMethodContext(), "test-model", 4)); + expectThrows(IllegalArgumentException.class, () -> new ANNConfig(annConfigType, getDefaultKNNMethodContext(), null, null)); + + ANNConfig annConfig = new ANNConfig(annConfigType, getDefaultKNNMethodContext(), null, dimension); + assertTrue(annConfig.getModelId().isEmpty()); + assertTrue(annConfig.getKnnMethodContext().isPresent()); + assertEquals(dimension, annConfig.getDimension().get().intValue()); + } + +} diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index c95568be22..1703853849 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -103,7 +103,7 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT, null); assertEquals(7, builder.getParameters().size()); List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); @@ -114,7 +114,7 @@ public void testBuilder_getParameters() { public void testBuilder_build_fromKnnMethodContext() { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -126,6 +126,7 @@ public void testBuilder_build_fromKnnMethodContext() { .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); builder.knnMethodContext.setValue( @@ -139,19 +140,17 @@ public void testBuilder_build_fromKnnMethodContext() { ) ); - builder.modelId.setValue("Random modelId"); - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); - assertNotNull(knnVectorFieldMapper.knnMethod); - assertNull(knnVectorFieldMapper.modelId); + assertTrue(knnVectorFieldMapper.fieldType().getAnnConfig().getKnnMethodContext().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getAnnConfig().getModelId().isEmpty()); } public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -163,6 +162,7 @@ public void testBuilder_build_fromModel() { .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); String modelId = "Random modelId"; @@ -184,14 +184,14 @@ public void testBuilder_build_fromModel() { when(modelDao.getMetadata(modelId)).thenReturn(mockedModelMetadata); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof ModelFieldMapper); - assertNotNull(knnVectorFieldMapper.modelId); - assertNull(knnVectorFieldMapper.knnMethod); + assertTrue(knnVectorFieldMapper.fieldType().getAnnConfig().getModelId().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getAnnConfig().getKnnMethodContext().isEmpty()); } public void testBuilder_build_fromLegacy() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); int m = 17; int efConstruction = 17; @@ -201,37 +201,22 @@ public void testBuilder_build_fromLegacy() { .put(settings(CURRENT).build()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); - assertNull(knnVectorFieldMapper.modelId); - assertNull(knnVectorFieldMapper.knnMethod); - assertEquals(SpaceType.L2.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType); - } - - public void testBuilder_whenKnnFalseWithBinary_thenSetHammingAsDefault() { - // Check legacy is picked up if model context and method context are not set - ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); - builder.vectorDataType.setValue(VectorDataType.BINARY); - builder.dimension.setValue(8); - - // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); - - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); - KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); - assertEquals(SpaceType.HAMMING.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType); + assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper.fieldType().getAnnConfig().getKnnMethodContext().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getAnnConfig().getModelId().isEmpty()); + assertEquals(SpaceType.L2, knnVectorFieldMapper.fieldType().getAnnConfig().getKnnMethodContext().get().getSpaceType()); } public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -616,7 +601,7 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -646,12 +631,18 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep // merge with itself - should be successful KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); - assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge1.knnMethod); + assertEquals( + knnVectorFieldMapper1.fieldType().getAnnConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge1.fieldType().getAnnConfig().getKnnMethodContext().get() + ); // merge with another mapper of the same field with same context KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); - assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge2.knnMethod); + assertEquals( + knnVectorFieldMapper1.fieldType().getAnnConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge2.fieldType().getAnnConfig().getKnnMethodContext().get() + ); // merge with another mapper of the same field with different context xContentBuilder = XContentFactory.jsonBuilder() @@ -676,7 +667,7 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); String modelId = "test-id"; int dimension = 133; @@ -715,12 +706,18 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { // merge with itself - should be successful KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); - assertEquals(knnVectorFieldMapper1.modelId, knnVectorFieldMapperMerge1.modelId); + assertEquals( + knnVectorFieldMapper1.fieldType().getAnnConfig().getModelId().get(), + knnVectorFieldMapperMerge1.fieldType().getAnnConfig().getModelId().get() + ); // merge with another mapper of the same field with same context KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); - assertEquals(knnVectorFieldMapper1.modelId, knnVectorFieldMapperMerge2.modelId); + assertEquals( + knnVectorFieldMapper1.fieldType().getAnnConfig().getModelId().get(), + knnVectorFieldMapperMerge2.fieldType().getAnnConfig().getModelId().get() + ); // merge with another mapper of the same field with different context xContentBuilder = XContentFactory.jsonBuilder() @@ -755,17 +752,9 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { when(parseContext.path()).thenReturn(contentPath); LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) - .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.FLOAT - ); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validatePreparse(); + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField List fields = document.getFields(); @@ -799,18 +788,10 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { inputBuilder.hasDocValues(false); luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) - .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.FLOAT - ); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); // Document should have 1 field: one for KnnVectorField fields = document.getFields(); @@ -837,16 +818,9 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.BYTE - ); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.BYTE); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField List fields = document.getFields(); @@ -881,16 +855,9 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.BYTE - ); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.BYTE); // Document should have 1 field: one for KnnByteVectorField fields = document.getFields(); @@ -970,10 +937,10 @@ private void testBuilderWithBinaryDataType( String expectedErrMsg ) { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); builder.knnMethodContext.setValue( new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(method, Collections.emptyMap())) @@ -986,7 +953,7 @@ private void testBuilderWithBinaryDataType( KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); if (SpaceType.UNDEFINED == spaceType) { - assertEquals(SpaceType.HAMMING, knnVectorFieldMapper.fieldType().spaceType); + assertEquals(SpaceType.HAMMING, knnVectorFieldMapper.fieldType().getAnnConfig().getKnnMethodContext().get().getSpaceType()); } } else { Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); @@ -996,7 +963,7 @@ private void testBuilderWithBinaryDataType( public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); // Setup settings Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); @@ -1022,7 +989,7 @@ public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1031,13 +998,13 @@ public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); + assertTrue(knnVectorFieldMapper instanceof FlatVectorFieldMapper); } public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1061,9 +1028,8 @@ private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperIn KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( TEST_FIELD_NAME, Collections.emptyMap(), - TEST_DIMENSION, - knnMethodContext, - vectorDataType + vectorDataType, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, TEST_DIMENSION) ); return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() @@ -1074,7 +1040,7 @@ private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperIn .hasDocValues(true) .vectorDataType(vectorDataType) .ignoreMalformed(new Explicit<>(true, true)) - .knnMethodContext(knnMethodContext); + .originalKnnMethodContext(knnVectorFieldType.getAnnConfig().getKnnMethodContext().orElse(null)); } private static float[] createInitializedFloatArray(int dimension, float value) { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 31da12d669..32cf05fd54 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -61,16 +61,19 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() { public void testGetExpectedVectorLengthSuccess() { KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); - + when(knnVectorFieldType.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(8); + when(knnVectorFieldTypeBinary.getAnnConfig()).thenReturn( + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultBinaryKNNMethodContext(), null, 8) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + when(knnVectorFieldTypeModelBased.getAnnConfig()).thenReturn( + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultBinaryKNNMethodContext(), null, 8) + ); String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + when(knnVectorFieldTypeModelBased.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_MODEL, null, modelId, null)); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -87,9 +90,8 @@ public void testGetExpectedVectorLengthSuccess() { public void testGetExpectedVectorLengthFailure() { KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + when(knnVectorFieldTypeModelBased.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_MODEL, null, modelId, null)); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -103,20 +105,6 @@ public void testGetExpectedVectorLengthFailure() { () -> KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased) ); assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); - - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); - String fieldName = "test-field"; - when(methodComponentContext.getName()).thenReturn(fieldName); - when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); - when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); - - e = expectThrows( - IllegalArgumentException.class, - () -> KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased) - ); - assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); } public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { diff --git a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java index dcd2557405..5c9b374cea 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java @@ -5,33 +5,35 @@ package org.opensearch.knn.index.mapper; -import junit.framework.TestCase; +import org.opensearch.Version; import org.opensearch.index.mapper.FieldMapper; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; import java.util.Collections; -public class MethodFieldMapperTests extends TestCase { - public void testMethodFieldMapper_whenVectorDataTypeIsGiven_thenSetItInFieldType() { +public class MethodFieldMapperTests extends KNNTestCase { + public void testMethodFieldMapper_whenVectorDataTypeAndContextMismatch_thenThrow() { KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( "testField", Collections.emptyMap(), - 1, VectorDataType.BINARY, - SpaceType.HAMMING + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, 1) ); - MethodFieldMapper mappers = new MethodFieldMapper( - "simpleName", - mappedFieldType, - null, - new FieldMapper.CopyTo.Builder().build(), - KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, - true, - true, - KNNMethodContext.getDefault() + // Expect that we cannot create the mapper with an invalid field type + expectThrows( + IllegalArgumentException.class, + () -> new MethodFieldMapper( + "simpleName", + mappedFieldType, + null, + new FieldMapper.CopyTo.Builder().build(), + KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, + true, + true, + Version.CURRENT, + mappedFieldType.getAnnConfig().getKnnMethodContext().orElse(null) + ) ); - assertEquals(VectorDataType.BINARY, mappers.fieldType().vectorDataType); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 0b918bd9ed..5232655e8e 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -29,6 +29,8 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.mapper.ANNConfig; +import org.opensearch.knn.index.mapper.ANNConfigType; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -185,9 +187,8 @@ public void testDoToQuery_Normal() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); @@ -207,7 +208,6 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -215,7 +215,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); float resultSimilarity = KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2); @@ -239,7 +239,6 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -247,7 +246,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); } @@ -266,16 +265,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -298,16 +295,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -325,16 +320,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -352,16 +345,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -383,16 +374,14 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -416,16 +405,14 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -444,7 +431,6 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -452,7 +438,7 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() { ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.HAMMING, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 8)); Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); assertTrue(e.getMessage().contains("Binary data type does not support radial search")); } @@ -470,15 +456,13 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); // When @@ -504,14 +488,13 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); assertNotNull(query); @@ -531,14 +514,13 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); assertNotNull(query); @@ -553,15 +535,13 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); // When @@ -586,10 +566,12 @@ public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParamete QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.LUCENE, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of())) + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.COSINESIMIL, + new MethodComponentContext("hnsw", Map.of()) ); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -609,15 +591,13 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } @@ -632,11 +612,9 @@ public void testDoToQuery_FromModel() { when(mockQueryShardContext.index()).thenReturn(dummyIndex); // Dimension is -1. In this case, model metadata will need to provide dimension - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_MODEL, null, modelId, null)); // Mock the modelDao to return mocked modelMetadata ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -672,11 +650,9 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_MODEL, null, modelId, null)); ModelMetadata modelMetadata = mock(ModelMetadata.class); when(modelMetadata.getDimension()).thenReturn(4); @@ -709,12 +685,9 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_MODEL, null, modelId, null)); ModelMetadata modelMetadata = mock(ModelMetadata.class); when(modelMetadata.getDimension()).thenReturn(4); @@ -744,10 +717,12 @@ public void testDoToQuery_InvalidDimensions() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(400); + when(mockKNNVectorField.getAnnConfig()).thenReturn( + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, 400) + ); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); - when(mockKNNVectorField.getDimension()).thenReturn(K); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, K)); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } @@ -769,9 +744,10 @@ public void testDoToQuery_InvalidZeroFloatVector() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, @@ -790,9 +766,10 @@ public void testDoToQuery_InvalidZeroByteVector() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BYTE); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, @@ -919,9 +896,8 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -946,9 +922,8 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); @@ -972,9 +947,8 @@ public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -992,9 +966,10 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(32); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING); + when(mockKNNVectorField.getAnnConfig()).thenReturn( + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultBinaryKNNMethodContext(), null, 32) + ); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertArrayEquals(expectedQueryVector, query.getByteQueryVector()); @@ -1008,9 +983,10 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING); + when(mockKNNVectorField.getAnnConfig()).thenReturn( + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultBinaryKNNMethodContext(), null, 8) + ); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Exception ex = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); assertTrue(ex.getMessage(), ex.getMessage().contains("invalid dimension")); diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index f9a9704d0b..cc28643ba7 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -30,6 +30,8 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.mapper.ANNConfig; +import org.opensearch.knn.index.mapper.ANNConfigType; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; import org.opensearch.knn.plugin.script.KNNScoringSpace; @@ -747,9 +749,8 @@ private BiFunction getScoreFunction(SpaceType spaceType KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( FIELD_NAME, Collections.emptyMap(), - SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length, SpaceType.HAMMING == spaceType ? VectorDataType.BINARY : VectorDataType.FLOAT, - null + new ANNConfig(ANNConfigType.SKIP, null, null, queryVector.length) ); List target = new ArrayList<>(queryVector.length); for (float f : queryVector) { diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index ae9ad71062..c72f5679a0 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.junit.BeforeClass; +import org.junit.Ignore; import org.opensearch.Version; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; @@ -59,6 +60,7 @@ import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +@Ignore public class JNIServiceTests extends KNNTestCase { static final int FP16_MAX = 65504; static final int FP16_MIN = -65504; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java index 823d210803..841d31079f 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java @@ -9,6 +9,8 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.SpaceType; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.index.mapper.ANNConfig; +import org.opensearch.knn.index.mapper.ANNConfigType; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.util.List; @@ -19,9 +21,11 @@ public class KNNScoringSpaceFactoryTests extends KNNTestCase { public void testValidSpaces() { KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); + when(knnVectorFieldType.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); + when(knnVectorFieldTypeBinary.getAnnConfig()).thenReturn( + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultBinaryKNNMethodContext(), null, 24) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); NumberFieldMapper.NumberFieldType numberFieldType = new NumberFieldMapper.NumberFieldType( "field", @@ -66,9 +70,11 @@ public void testValidSpaces() { public void testInvalidSpace() { List floatQueryObject = List.of(1.0f, 1.0f, 1.0f); KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); + when(knnVectorFieldType.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); + when(knnVectorFieldTypeBinary.getAnnConfig()).thenReturn( + new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultBinaryKNNMethodContext(), null, 24) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); // Verify diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 6c557c8dd5..ba80b31bae 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -17,6 +17,8 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.index.mapper.ANNConfig; +import org.opensearch.knn.index.mapper.ANNConfigType; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.math.BigInteger; @@ -57,8 +59,13 @@ private void expectThrowsExceptionWithKNNFieldWithBinaryDataType(Class clazz) th public void testL2_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 3) + ); KNNScoringSpace.L2 l2 = new KNNScoringSpace.L2(arrayListQueryObject, fieldType); assertEquals(1F, l2.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); } @@ -73,9 +80,13 @@ public void testCosineSimilarity_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 3) + ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); assertEquals(2F, cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), 0.1F); @@ -92,8 +103,13 @@ public void testCosineSimilarity_whenValid_thenSucceed() { } public void testCosineSimilarity_whenZeroVector_thenException() { - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 3) + ); final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f); IllegalArgumentException exception1 = expectThrows( @@ -116,9 +132,14 @@ public void testInnerProd_whenValid_thenSucceed() { float[] arrayFloat_case1 = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject_case1 = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); float[] arrayFloat2_case1 = new float[] { 1.0f, 1.0f, 1.0f }; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 3) + ); KNNScoringSpace.InnerProd innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); assertEquals(7.0F, innerProd.getScoringMethod().apply(arrayFloat_case1, arrayFloat2_case1), 0.001F); @@ -183,14 +204,14 @@ public void testHammingBit_Base64() { public void testHamming_whenKNNFieldType_thenSucceed() { List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); KNNVectorFieldType fieldType = new KNNVectorFieldType( "test", Collections.emptyMap(), - 8 * arrayListQueryObject.size(), - knnMethodContext, - VectorDataType.BINARY + VectorDataType.BINARY, + new ANNConfig(ANNConfigType.FROM_METHOD, knnMethodContext, null, 8 * arrayListQueryObject.size()) ); + KNNScoringSpace.Hamming hamming = new KNNScoringSpace.Hamming(arrayListQueryObject, fieldType); float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 781ed2350b..008f8f4ac4 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -9,6 +9,8 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.index.mapper.BinaryFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.knn.index.mapper.ANNConfig; +import org.opensearch.knn.index.mapper.ANNConfigType; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.math.BigInteger; @@ -64,7 +66,7 @@ public void testParseKNNVectorQuery() { KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); - when(fieldType.getDimension()).thenReturn(3); + when(fieldType.getAnnConfig()).thenReturn(new ANNConfig(ANNConfigType.FROM_METHOD, getDefaultKNNMethodContext(), null, 3)); assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.FLOAT), 0.1f); expectThrows( diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 3515c690d8..8cff4dfa14 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -23,7 +23,6 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -303,7 +302,7 @@ public void testTrainingIndexSize() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", @@ -350,7 +349,7 @@ public void testTrainIndexSize_whenDataTypeIsBinary() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", @@ -398,7 +397,7 @@ public void testTrainIndexSize_whenDataTypeIsByte() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 53e59129e8..83d39cfdc7 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -48,7 +48,7 @@ public class TrainingModelRequestTests extends KNNTestCase { public void testStreams() throws IOException { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -105,7 +105,7 @@ public void testStreams() throws IOException { public void testGetters() { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field";