diff --git a/CHANGELOG.md b/CHANGELOG.md index 44f387533a..81c90802ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,4 +31,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) -* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) \ No newline at end of file +* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) +* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 2a3732d7e0..69229036ed 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -13,6 +13,8 @@ import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.util.Optional; @@ -66,16 +68,19 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { ); return defaultFormatSupplier.get(); } - var type = (KNNVectorFieldType) mapperService.orElseThrow( + KNNVectorFieldType mappedFieldType = (KNNVectorFieldType) mapperService.orElseThrow( () -> new IllegalStateException( String.format("Cannot read field type for field [%s] because mapper service is not available", field) ) ).fieldType(field); - var params = type.getKnnMethodContext().getMethodComponentContext().getParameters(); - if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE - && params != null - && params.containsKey(METHOD_ENCODER_PARAMETER)) { + KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); + + var params = knnMethodContext.getMethodComponentContext().getParameters(); + + if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( params, defaultMaxConnections, diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 989af4063b..83237ee4ca 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -21,7 +21,6 @@ import org.opensearch.knn.index.codec.transfer.VectorTransferByte; import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.Model; @@ -57,7 +56,6 @@ import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; @@ -214,35 +212,19 @@ private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNN private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) throws IOException { - Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); - - // parametersString will be null when legacy mapper is used if (parametersString == null) { - parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); - - String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); - Map algoParams = new HashMap<>(); - if (efConstruction != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); - } - - String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); - if (m != null) { - algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); - } - parameters.put(PARAMETERS, algoParams); - } else { - parameters.putAll( - XContentHelper.createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(parametersString), - MediaTypeRegistry.getDefaultMediaType() - ).map() - ); + throw new IllegalStateException("Parameter string is not set. Something is wrong"); } + Map parameters = new HashMap<>( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); // Update index description of Faiss for binary data type if (KNNEngine.FAISS == knnEngine diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java index 7885761b93..d210483e60 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java @@ -9,7 +9,6 @@ import lombok.Getter; import lombok.NonNull; import lombok.Setter; -import org.opensearch.Version; import org.opensearch.common.ValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -20,7 +19,6 @@ import org.opensearch.index.mapper.MapperParsingException; import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -29,7 +27,6 @@ import org.opensearch.knn.training.VectorSpaceInfo; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -42,21 +39,6 @@ @Getter public class KNNMethodContext implements ToXContentFragment, Writeable { - private static KNNMethodContext defaultInstance = null; - - /** - * This is used only for testing - * @return default KNNMethodContext for testing - */ - public static synchronized KNNMethodContext getDefault() { - if (defaultInstance == null) { - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - methodComponentContext.setIndexVersion(Version.CURRENT); - defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); - } - return defaultInstance; - } - @NonNull private final KNNEngine knnEngine; @NonNull diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java new file mode 100644 index 0000000000..3ea8221323 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.opensearch.Version; +import org.opensearch.common.Explicit; +import org.opensearch.knn.index.VectorDataType; + +import java.util.Map; +import java.util.Optional; + +/** + * Mapper used when you dont want to build an underlying KNN struct - you just want to + * store vectors as doc values + */ +public class FlatVectorFieldMapper extends KNNVectorFieldMapper { + + private final PerDimensionValidator perDimensionValidator; + + public static FlatVectorFieldMapper createFieldMapper( + String fullname, + String simpleName, + Map metaValue, + VectorDataType vectorDataType, + Integer dimension, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexCreatedVersion + ) { + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { + @Override + public Optional getDimension() { + return Optional.of(dimension); + } + }); + return new FlatVectorFieldMapper( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + indexCreatedVersion + ); + } + + private FlatVectorFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexCreatedVersion + ) { + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); + this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + this.fieldType.freeze(); + } + + private PerDimensionValidator selectPerDimensionValidator(VectorDataType vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + return PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + } + + if (VectorDataType.BYTE == vectorDataType) { + return PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + } + + return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return VectorValidator.NOOP_VECTOR_VALIDATOR; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return PerDimensionProcessor.NOOP_PROCESSOR; + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java b/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java new file mode 100644 index 0000000000..34e2947dbe --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.index.engine.KNNMethodContext; + +import java.util.Optional; + +/** + * Class holds information about how the ANN indices are created. The design of this class ensures that we do not + * accidentally configure an index that has multiple ways it can be created. This class is immutable. + */ +public interface KNNMappingConfig { + /** + * + * @return Optional containing the modelId if created from model, otherwise empty + */ + default Optional getModelId() { + return Optional.empty(); + } + + /** + * + * @return Optional containing the KNNMethodContext if created from method, otherwise empty + */ + default Optional getKnnMethodContext() { + return Optional.empty(); + } + + /** + * + * @return the dimension of the index; for model based indices, it will be null + */ + default Optional getDimension() { + return Optional.empty(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 3b94876454..8b7bfb2389 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -11,7 +11,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; import lombok.extern.log4j.Log4j2; @@ -32,9 +31,8 @@ import org.opensearch.index.mapper.ParametrizedFieldMapper; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KnnCircuitBreakerException; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; @@ -44,23 +42,17 @@ import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; -import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; -import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createKNNMethodContextFromLegacy; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfCircuitBreakerIsNotTriggered; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfKNNPluginEnabled; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataType; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; +import static org.opensearch.knn.index.mapper.ModelFieldMapper.UNSET_MODEL_DIMENSION_IDENTIFIER; /** * Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType. @@ -76,10 +68,6 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { return (KNNVectorFieldMapper) in; } - // We store the version of the index with the mapper as different version of Opensearch has different default - // values of KNN engine Algorithms hyperparameters. - protected Version indexCreatedVersion; - /** * Builder for KNNVectorFieldMapper. This class defines the set of parameters that can be applied to the knn_vector * field type @@ -107,7 +95,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { ); } return value; - }, m -> toType(m).dimension); + }, m -> toType(m).fieldType().getKnnMappingConfig().getDimension().orElse(UNSET_MODEL_DIMENSION_IDENTIFIER)); /** * data_type which defines the datatype of the vector values. This is an optional parameter and @@ -126,7 +114,12 @@ public static class Builder extends ParametrizedFieldMapper.Builder { * model template index. If this parameter is set, it will take precedence. This parameter is only relevant for * library indices that require training. */ - protected final Parameter modelId = Parameter.stringParam(KNNConstants.MODEL_ID, false, m -> toType(m).modelId, null); + protected final Parameter modelId = Parameter.stringParam( + KNNConstants.MODEL_ID, + false, + m -> toType(m).fieldType().getKnnMappingConfig().getModelId().orElse(null), + null + ); /** * knnMethodContext parameter allows a user to define their k-NN library index configuration. Defaults to an L2 @@ -137,7 +130,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { false, () -> null, (n, c, o) -> KNNMethodContext.parse(o), - m -> toType(m).knnMethod + m -> toType(m).originalKNNMethodContext ).setSerializer(((b, n, v) -> { b.startObject(n); v.toXContent(b, ToXContent.EMPTY_PARAMS); @@ -164,35 +157,30 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected final Parameter> meta = Parameter.metaParam(); - protected String spaceType; - protected String m; - protected String efConstruction; - protected ModelDao modelDao; - protected Version indexCreatedVersion; - - public Builder(String name, ModelDao modelDao, Version indexCreatedVersion) { + // KNNMethodContext that allows us to properly configure a KNNVectorFieldMapper from another + // KNNVectorFieldMapper. To support our legacy field mapping, on parsing, if index.knn=true and no method is + // passed, we build a KNNMethodContext using the space type, ef_construction and m that are set in the index + // settings. However, for fieldmappers for merging, we need to be able to initialize one field mapper from + // another (see + // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L98). + // The problem is that in this case, the settings are set to empty so we cannot properly resolve the KNNMethodContext. + // (see + // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L130). + // While we could override the KNNMethodContext parameter initializer to set the knnMethodContext based on the + // constructed KNNMethodContext from the other field mapper, this can result in merge conflict/serialization + // exceptions. See + // (https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L322-L324). + // So, what we do is pass in a "resolvedKNNMethodContext" that will either be null or be set via the merge builder + // constructor. A similar approach was taken for https://github.com/opendistro-for-elasticsearch/k-NN/issues/288 + private KNNMethodContext resolvedKNNMethodContext; + + public Builder(String name, ModelDao modelDao, Version indexCreatedVersion, KNNMethodContext resolvedKNNMethodContext) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; - } - - /** - * This constructor is for legacy purposes. - * Checkout ODFE PR 288 - * - * @param name field name - * @param spaceType Spacetype of field - * @param m m value of field - * @param efConstruction efConstruction value of field - */ - public Builder(String name, String spaceType, String m, String efConstruction, Version indexCreatedVersion) { - super(name); - this.spaceType = spaceType; - this.m = m; - this.efConstruction = efConstruction; - this.indexCreatedVersion = indexCreatedVersion; + this.resolvedKNNMethodContext = resolvedKNNMethodContext; } @Override @@ -210,121 +198,116 @@ protected Explicit ignoreMalformed(BuilderContext context) { return KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED; } + private void validateFlatMapper() { + if (modelId.get() != null || knnMethodContext.get() != null) { + throw new IllegalArgumentException("Cannot set modelId or method parameters when index.knn setting is false"); + } + } + @Override public KNNVectorFieldMapper build(BuilderContext context) { - // Originally, a user would use index settings to set the spaceType, efConstruction and m hnsw - // parameters. Upon further review, it makes sense to set these parameters in the mapping of a - // particular field. However, because users migrating from older versions will still use the index - // settings to set these parameters, we will need to provide backwards compatibilty. In order to - // handle this, we first check if the mapping is set, and, if so use it. If not, we check if the model is - // set. If not, we fall back to the parameters set in the index settings. This means that if a user sets - // the mappings, setting the index settings will have no impact. - - final KNNMethodContext knnMethodContext = this.knnMethodContext.getValue(); - setDefaultSpaceType(knnMethodContext, vectorDataType.getValue()); - validateSpaceType(knnMethodContext, vectorDataType.getValue()); - validateDimensions(knnMethodContext, vectorDataType.getValue()); - validateEncoder(knnMethodContext, vectorDataType.getValue()); final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); final CopyTo copyToBuilder = copyTo.build(); final Explicit ignoreMalformed = ignoreMalformed(context); final Map metaValue = meta.getValue(); - if (knnMethodContext != null) { - validateVectorDataType(knnMethodContext, vectorDataType.getValue()); - knnMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); - final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + // Index is being created from model + String modelIdAsString = this.modelId.get(); + if (modelIdAsString != null) { + return ModelFieldMapper.createFieldMapper( buildFullName(context), - metaValue, - dimension.getValue(), - knnMethodContext, - vectorDataType.getValue() - ); - if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) { - log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); - LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = - LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() - .name(name) - .mappedFieldType(mappedFieldType) - .multiFields(multiFieldsBuilder) - .copyTo(copyToBuilder) - .ignoreMalformed(ignoreMalformed) - .stored(stored.get()) - .hasDocValues(hasDocValues.get()) - .vectorDataType(vectorDataType.getValue()) - .knnMethodContext(knnMethodContext) - .build(); - return new LuceneFieldMapper(createLuceneFieldMapperInput); - } - - return new MethodFieldMapper( name, - mappedFieldType, + metaValue, + vectorDataType.getValue(), + modelIdAsString, multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), hasDocValues.get(), - knnMethodContext + modelDao, + indexCreatedVersion ); } - String modelIdAsString = this.modelId.get(); - if (modelIdAsString != null) { - // Because model information is stored in cluster metadata, we are unable to get it here. This is - // because to get the cluster metadata, you need access to the cluster state. Because this code is - // sometimes used to initialize the cluster state/update cluster state, we cannot get the state here - // safely. So, we are unable to validate the model. The model gets validated during ingestion. - - return new ModelFieldMapper( + // If the field mapper is using the legacy context and being constructed from another field mapper, + // the settings will be empty. See https://github.com/opendistro-for-elasticsearch/k-NN/issues/288. In this + // case, the input resolvedKNNMethodContext will be null and the settings wont exist (so flat mapper should + // be used). Otherwise, we need to check the setting. + boolean isResolvedNull = resolvedKNNMethodContext == null; + boolean isSettingPresent = KNNSettings.IS_KNN_INDEX_SETTING.exists(context.indexSettings()); + if (isResolvedNull && (!isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(context.indexSettings()))) { + validateFlatMapper(); + return FlatVectorFieldMapper.createFieldMapper( + buildFullName(context), name, - new KNNVectorFieldType(buildFullName(context), metaValue, -1, knnMethodContext, modelIdAsString), + metaValue, + vectorDataType.getValue(), + dimension.getValue(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), hasDocValues.get(), - modelDao, - modelIdAsString, indexCreatedVersion ); } - // Build legacy - if (this.spaceType == null) { - this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings(), vectorDataType.getValue()); - } - - if (this.m == null) { - this.m = LegacyFieldMapper.getM(context.indexSettings()); - } - - if (this.efConstruction == null) { - this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings(), indexCreatedVersion); - } - - // Validates and throws exception if index.knn is set to true in the index settings - // using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper - // and it only supports float VectorDataType - validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType); - - return new LegacyFieldMapper( - name, - new KNNVectorFieldType( + // See resolvedKNNMethodContext definition for explanation + if (isResolvedNull) { + resolvedKNNMethodContext = this.knnMethodContext.getValue(); + setDefaultSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); + validateSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); + validateDimensions(resolvedKNNMethodContext, vectorDataType.getValue()); + validateEncoder(resolvedKNNMethodContext, vectorDataType.getValue()); + } + + // If the knnMethodContext is null at this point, that means user built the index with the legacy k-NN + // settings to specify algo params. We need to convert this here to a KNNMethodContext so that we can + // properly configure the rest of the index + if (resolvedKNNMethodContext == null) { + resolvedKNNMethodContext = createKNNMethodContextFromLegacy(context, vectorDataType.getValue(), indexCreatedVersion); + } + + validateVectorDataType(resolvedKNNMethodContext, vectorDataType.getValue()); + resolvedKNNMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); + if (resolvedKNNMethodContext.getKnnEngine() == KNNEngine.LUCENE) { + log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); + LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = LuceneFieldMapper.CreateLuceneFieldMapperInput + .builder() + .name(name) + .multiFields(multiFieldsBuilder) + .copyTo(copyToBuilder) + .ignoreMalformed(ignoreMalformed) + .stored(stored.getValue()) + .hasDocValues(hasDocValues.getValue()) + .vectorDataType(vectorDataType.getValue()) + .indexVersion(indexCreatedVersion) + .originalKnnMethodContext(knnMethodContext.get()) + .build(); + return LuceneFieldMapper.createFieldMapper( buildFullName(context), metaValue, - dimension.getValue(), vectorDataType.getValue(), - SpaceType.getSpace(spaceType) - ), + dimension.getValue(), + resolvedKNNMethodContext, + createLuceneFieldMapperInput + ); + } + + return MethodFieldMapper.createFieldMapper( + buildFullName(context), + name, + metaValue, + vectorDataType.getValue(), + dimension.getValue(), + resolvedKNNMethodContext, + knnMethodContext.get(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, - stored.get(), - hasDocValues.get(), - spaceType, - m, - efConstruction, + stored.getValue(), + hasDocValues.getValue(), indexCreatedVersion ); } @@ -430,7 +413,7 @@ public TypeParser(Supplier modelDaoSupplier) { @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { - Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated()); + Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated(), null); builder.parse(name, parserContext, node); // All parse(String name, Map node, ParserCont } // Dimension should not be null unless modelId is used - if (builder.dimension.getValue() == -1 && builder.modelId.get() == null) { + if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER && builder.modelId.get() == null) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", name)); } @@ -452,18 +435,18 @@ public Mapper.Builder parse(String name, Map node, ParserCont } } + // We store the version of the index with the mapper as different version of Opensearch has different default + // values of KNN engine Algorithms hyperparameters. + protected Version indexCreatedVersion; protected Explicit ignoreMalformed; protected boolean stored; protected boolean hasDocValues; - protected Integer dimension; protected VectorDataType vectorDataType; protected ModelDao modelDao; - // These members map to parameters in the builder. They need to be declared in the abstract class due to the - // "toType" function used in the builder. So, when adding a parameter, it needs to be added here, but set in a - // subclass (if it is unique). - protected KNNMethodContext knnMethod; - protected String modelId; + // We need to ensure that the original KNNMethodContext as parsed is stored to initialize the + // Builder for serialization. So, we need to store it here. + protected KNNMethodContext originalKNNMethodContext; public KNNVectorFieldMapper( String simpleName, @@ -473,16 +456,17 @@ public KNNVectorFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - Version indexCreatedVersion + Version indexCreatedVersion, + KNNMethodContext originalKNNMethodContext ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; this.stored = stored; this.hasDocValues = hasDocValues; - this.dimension = mappedFieldType.getDimension(); this.vectorDataType = mappedFieldType.getVectorDataType(); updateEngineStats(); this.indexCreatedVersion = indexCreatedVersion; + this.originalKNNMethodContext = originalKNNMethodContext; } public KNNVectorFieldMapper clone() { @@ -498,20 +482,11 @@ protected String contentType() { protected void parseCreateField(ParseContext context) throws IOException { parseCreateField( context, - fieldType().getDimension(), - fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()), + fieldType().getKnnMappingConfig().getDimension().orElseThrow(() -> new IllegalArgumentException("Dimension is not set")), fieldType().getVectorDataType() ); } - private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { - if (Objects.isNull(knnMethodContext)) { - return null; - } - return knnMethodContext.getMethodComponentContext(); - } - /** * Function returns a list of fields to be indexed when the vector is float type. * @@ -544,17 +519,37 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField( - ParseContext context, - int dimension, - SpaceType spaceType, - MethodComponentContext methodComponentContext, - VectorDataType vectorDataType - ) throws IOException { - + /** + * Validation checks before parsing of doc begins + */ + protected void validatePreparse() { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - spaceType.validateVectorDataType(vectorDataType); + } + + /** + * Getter for vector validator after vector parsing + * + * @return VectorValidator + */ + protected abstract VectorValidator getVectorValidator(); + + /** + * Getter for per dimension validator during vector parsing + * + * @return PerDimensionValidator + */ + protected abstract PerDimensionValidator getPerDimensionValidator(); + + /** + * Getter for per dimension processor during vector parsing + * + * @return PerDimensionProcessor + */ + protected abstract PerDimensionProcessor getPerDimensionProcessor(); + + protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { + validatePreparse(); if (VectorDataType.BINARY == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); @@ -563,7 +558,7 @@ protected void parseCreateField( return; } final byte[] array = bytesArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForByteVector(array, fieldType)); } else if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); @@ -572,16 +567,16 @@ protected void parseCreateField( return; } final byte[] array = bytesArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForByteVector(array, fieldType)); } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); if (floatsArrayOptional.isEmpty()) { return; } final float[] array = floatsArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForFloatVector(array, fieldType)); } else { throw new IllegalArgumentException( @@ -592,80 +587,28 @@ protected void parseCreateField( context.path().remove(); } - // Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" - protected boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { - if (Objects.isNull(methodComponentContext)) { - return false; - } - - if (methodComponentContext.getParameters().size() == 0) { - return false; - } - - Map methodComponentParams = methodComponentContext.getParameters(); - - // The method component parameters should have an encoder - if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { - return false; - } - - // Validate if the object is of type MethodComponentContext before casting it later - if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { - return false; - } - - MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); - - // returns true if encoder name is "sq" and type is "fp16" - return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) - && FAISS_SQ_ENCODER_FP16.equals( - encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) - ); - - } - - // Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index - // using "sq" encoder of type "fp16". - protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { - if (Objects.nonNull(methodComponentContext)) { - return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); - } - return false; - } - - void validateIfCircuitBreakerIsNotTriggered() { - if (KNNSettings.isCircuitBreakerTriggered()) { - throw new KnnCircuitBreakerException( - "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." - ); - } - } - - void validateIfKNNPluginEnabled() { - if (!KNNSettings.isKNNPluginEnabled()) { - throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true"); - } - } - // Returns an optional array of byte values where each value in the vector is parsed as a float and validated // if it is a finite number without any decimals and within the byte range of [-128 to 127]. Optional getBytesFromContext(ParseContext context, int dimension, VectorDataType dataType) throws IOException { context.path().add(simpleName()); + PerDimensionValidator perDimensionValidator = getPerDimensionValidator(); + PerDimensionProcessor perDimensionProcessor = getPerDimensionProcessor(); + ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { - float value = context.parser().floatValue(); - validateByteVectorValue(value, dataType); + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { - float value = context.parser().floatValue(); - validateByteVectorValue(value, dataType); + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); vector.add((byte) value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -681,21 +624,11 @@ Optional getBytesFromContext(ParseContext context, int dimension, Vector return Optional.of(array); } - Optional getFloatsFromContext(ParseContext context, int dimension, MethodComponentContext methodComponentContext) - throws IOException { + Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { context.path().add(simpleName()); - // Returns an optional array of float values where each value in the vector is parsed as a float and validated - // if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'. - // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be - // clipped to FP16 range. - boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext); - boolean clipVectorValueToFP16RangeFlag = false; - if (isFaissSQfp16Flag) { - clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) - ); - } + PerDimensionValidator perDimensionValidator = getPerDimensionValidator(); + PerDimensionProcessor perDimensionProcessor = getPerDimensionProcessor(); ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); @@ -703,31 +636,14 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { - value = context.parser().floatValue(); - if (isFaissSQfp16Flag) { - if (clipVectorValueToFP16RangeFlag) { - value = clipVectorValueToFP16Range(value); - } else { - validateFP16VectorValue(value); - } - } else { - validateFloatVectorValue(value); - } - + value = perDimensionProcessor.process(context.parser().floatValue()); + perDimensionValidator.validate(value); vector.add(value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { - value = context.parser().floatValue(); - if (isFaissSQfp16Flag) { - if (clipVectorValueToFP16RangeFlag) { - value = clipVectorValueToFP16Range(value); - } else { - validateFP16VectorValue(value); - } - } else { - validateFloatVectorValue(value); - } + value = perDimensionProcessor.process(context.parser().floatValue()); + perDimensionValidator.validate(value); vector.add(value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -746,7 +662,12 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this); + return new KNNVectorFieldMapper.Builder( + simpleName(), + modelDao, + indexCreatedVersion, + fieldType().getKnnMappingConfig().getKnnMethodContext().orElse(null) + ).init(this); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 2adbbb6953..06de84b536 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -13,30 +13,48 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.util.BytesRef; -import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.Version; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.mapper.Mapper; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.KnnCircuitBreakerException; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.util.IndexHyperParametersUtil; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; import java.util.Arrays; import java.util.Locale; +import java.util.Map; +import java.util.Objects; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; @@ -44,6 +62,7 @@ /** * Utility class for KNNVectorFieldMapper */ +@Log4j2 @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorFieldMapperUtil { @@ -150,35 +169,6 @@ public static void validateVectorDataType(KNNMethodContext methodContext, Vector throw new IllegalArgumentException("This line should not be reached"); } - /** - * Validates and throws exception if index.knn is set to true in the index settings - * using any VectorDataType (other than float, which is default) because we are using NMSLIB engine - * for LegacyFieldMapper, and it only supports float VectorDataType - * - * @param knnIndexSetting index.knn setting in the index settings - * @param vectorDataType VectorDataType Parameter - */ - public static void validateVectorDataTypeWithKnnIndexSetting( - boolean knnIndexSetting, - ParametrizedFieldMapper.Parameter vectorDataType - ) { - - if (VectorDataType.FLOAT == vectorDataType.getValue()) { - return; - } - if (knnIndexSetting) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is not supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue().getValue(), - NMSLIB_NAME - ) - ); - } - } - /** * @param knnEngine KNNEngine * @return DocValues FieldType of type Binary @@ -237,37 +227,198 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy * @return expected vector length */ public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFieldType) { - int expectedDimensions = knnVectorFieldType.getDimension(); - if (isModelBasedIndex(expectedDimensions)) { - ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); - expectedDimensions = modelMetadata.getDimension(); - } + int expectedDimensions = knnVectorFieldType.getKnnMappingConfig() + .getDimension() + .orElseGet( + () -> getDimensionFromModelId( + knnVectorFieldType.getKnnMappingConfig() + .getModelId() + .orElseThrow( + () -> new IllegalStateException( + "Unable to look up dimension because its not accessible from the KNNVectorFieldType" + ) + ) + ) + ); return VectorDataType.BINARY == knnVectorFieldType.getVectorDataType() ? expectedDimensions / 8 : expectedDimensions; } - private static boolean isModelBasedIndex(int expectedDimensions) { - return expectedDimensions == -1; - } - /** * Returns the model metadata for a specified knn vector field * - * @param knnVectorField knn vector field + * @param modelId ID of model * @return the model metadata from knnVectorField */ - private static ModelMetadata getModelMetadataForField(final KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); + private static int getDimensionFromModelId(String modelId) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + } + return modelMetadata.getDimension(); + } + + /** + * Validate if the circuit breaker is triggered + */ + static void validateIfCircuitBreakerIsNotTriggered() { + if (KNNSettings.isCircuitBreakerTriggered()) { + throw new KnnCircuitBreakerException( + "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." + ); + } + } - if (modelId == null) { - throw new IllegalArgumentException( - String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) + /** + * Validate if plugin is enabled + */ + static void validateIfKNNPluginEnabled() { + if (!KNNSettings.isKNNPluginEnabled()) { + throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true"); + } + } + + private static SpaceType getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { + String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); + if (spaceType == null) { + spaceType = VectorDataType.BINARY == vectorDataType + ? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY + : KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", + METHOD_PARAMETER_SPACE_TYPE, + spaceType + ) ); } + return SpaceType.getSpace(spaceType); + } - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (!ModelUtil.isModelCreated(modelMetadata)) { - throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + private static int getM(Settings indexSettings) { + String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); + if (m == null) { + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", + HNSW_ALGO_M, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M + ) + ); + return KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M; } - return modelMetadata; + return Integer.parseInt(m); + } + + private static int getEfConstruction(Settings indexSettings, Version indexVersion) { + final String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); + if (efConstruction == null) { + final int defaultEFConstructionValue = IndexHyperParametersUtil.getHNSWEFConstructionValue(indexVersion); + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. " + + "Picking up default value for the index =%s", + HNSW_ALGO_EF_CONSTRUCTION, + defaultEFConstructionValue + ) + ); + return defaultEFConstructionValue; + } + return Integer.parseInt(efConstruction); + } + + /** + * Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" + * + * @param methodComponentContext MethodComponentContext + * @return true if it is a "faiss" Index using "sq" encoder of type "fp16" + */ + static boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { + if (Objects.isNull(methodComponentContext)) { + return false; + } + + if (methodComponentContext.getParameters().size() == 0) { + return false; + } + + Map methodComponentParams = methodComponentContext.getParameters(); + + // The method component parameters should have an encoder + if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { + return false; + } + + // Validate if the object is of type MethodComponentContext before casting it later + if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { + return false; + } + + MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); + + // returns true if encoder name is "sq" and type is "fp16" + return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) + && FAISS_SQ_ENCODER_FP16.equals( + encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + ); + + } + + /** + * Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index + * using "sq" encoder of type "fp16". + * + * @param methodComponentContext MethodComponentContext + * @return boolean value of "clip" parameter + */ + static boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { + if (Objects.nonNull(methodComponentContext)) { + return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); + } + return false; + } + + /** + * Extract MethodComponentContext from KNNMethodContext + * + * @param knnMethodContext KNNMethodContext + * @return MethodComponentContext + */ + static MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { + if (Objects.isNull(knnMethodContext)) { + return null; + } + return knnMethodContext.getMethodComponentContext(); + } + + static KNNMethodContext createKNNMethodContextFromLegacy( + Mapper.BuilderContext context, + VectorDataType vectorDataType, + Version indexCreatedVersion + ) { + if (VectorDataType.FLOAT != vectorDataType) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is not supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue(), + NMSLIB_NAME + ) + ); + } + + return new KNNMethodContext( + KNNEngine.NMSLIB, + KNNVectorFieldMapperUtil.getSpaceType(context.indexSettings(), vectorDataType), + new MethodComponentContext( + METHOD_HNSW, + Map.of( + METHOD_PARAMETER_M, + KNNVectorFieldMapperUtil.getM(context.indexSettings()), + METHOD_PARAMETER_EF_CONSTRUCTION, + KNNVectorFieldMapperUtil.getEfConstruction(context.indexSettings(), indexCreatedVersion) + ) + ) + ); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 8c3815c5f9..0fbc569f77 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -9,7 +9,6 @@ import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; -import org.opensearch.common.Nullable; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextSearchInfo; @@ -17,9 +16,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.index.KNNVectorIndexFieldData; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; @@ -27,7 +24,6 @@ import java.util.Map; import java.util.function.Supplier; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; /** @@ -35,49 +31,21 @@ */ @Getter public class KNNVectorFieldType extends MappedFieldType { - int dimension; - String modelId; - KNNMethodContext knnMethodContext; + KNNMappingConfig knnMappingConfig; VectorDataType vectorDataType; - SpaceType spaceType; - public KNNVectorFieldType(String name, Map meta, int dimension, VectorDataType vectorDataType, SpaceType spaceType) { - this(name, meta, dimension, null, null, vectorDataType, spaceType); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { - this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - KNNMethodContext knnMethodContext, - VectorDataType vectorDataType - ) { - this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - @Nullable KNNMethodContext knnMethodContext, - @Nullable String modelId, - VectorDataType vectorDataType, - @Nullable SpaceType spaceType - ) { - super(name, false, false, true, TextSearchInfo.NONE, meta); - this.dimension = dimension; - this.modelId = modelId; - this.knnMethodContext = knnMethodContext; + /** + * Constructor for KNNVectorFieldType. + * + * @param name name of the field + * @param metadata metadata of the field + * @param vectorDataType data type of the vector + * @param annConfig configuration context for the ANN index + */ + public KNNVectorFieldType(String name, Map metadata, VectorDataType vectorDataType, KNNMappingConfig annConfig) { + super(name, false, false, true, TextSearchInfo.NONE, metadata); this.vectorDataType = vectorDataType; - this.spaceType = spaceType; + this.knnMappingConfig = annConfig; } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java deleted file mode 100644 index cf5ec933a1..0000000000 --- a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.mapper; - -import lombok.extern.log4j.Log4j2; -import org.apache.lucene.document.FieldType; -import org.opensearch.Version; -import org.opensearch.common.Explicit; -import org.opensearch.common.settings.Settings; -import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.util.IndexHyperParametersUtil; -import org.opensearch.knn.index.engine.KNNEngine; - -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; - -/** - * Field mapper for original implementation. It defaults to using nmslib as the engine and retrieves parameters from index settings. - * - * Example of this mapper output: - * - * { - * "type": "knn_vector", - * "dimension": 128 - * } - */ -@Log4j2 -public class LegacyFieldMapper extends KNNVectorFieldMapper { - - protected String spaceType; - protected String m; - protected String efConstruction; - - LegacyFieldMapper( - String simpleName, - KNNVectorFieldType mappedFieldType, - MultiFields multiFields, - CopyTo copyTo, - Explicit ignoreMalformed, - boolean stored, - boolean hasDocValues, - String spaceType, - String m, - String efConstruction, - Version indexCreatedVersion - ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); - - this.spaceType = spaceType; - this.m = m; - this.efConstruction = efConstruction; - - this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - - this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); - this.fieldType.putAttribute(SPACE_TYPE, spaceType); - this.fieldType.putAttribute(KNN_ENGINE, KNNEngine.NMSLIB.getName()); - - // These are extra just for legacy - this.fieldType.putAttribute(HNSW_ALGO_M, m); - this.fieldType.putAttribute(HNSW_ALGO_EF_CONSTRUCTION, efConstruction); - - this.fieldType.freeze(); - } - - @Override - public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction, this.indexCreatedVersion).init( - this - ); - } - - static String getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { - String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); - if (spaceType == null) { - spaceType = VectorDataType.BINARY == vectorDataType - ? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY - : KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", - METHOD_PARAMETER_SPACE_TYPE, - spaceType - ) - ); - } - return spaceType; - } - - static String getM(Settings indexSettings) { - String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); - if (m == null) { - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", - HNSW_ALGO_M, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M - ) - ); - return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M); - } - return m; - } - - static String getEfConstruction(Settings indexSettings, Version indexVersion) { - final String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); - if (efConstruction == null) { - final String defaultEFConstructionValue = String.valueOf(IndexHyperParametersUtil.getHNSWEFConstructionValue(indexVersion)); - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. " - + "Picking up default value for the index =%s", - HNSW_ALGO_EF_CONSTRUCTION, - defaultEFConstructionValue - ) - ); - return defaultEFConstructionValue; - } - return efConstruction; - } -} diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index c82afb9e72..c375e294aa 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -8,6 +8,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.Optional; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; @@ -16,11 +19,12 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.Version; import org.opensearch.common.Explicit; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; @@ -35,25 +39,55 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { private final FieldType vectorFieldType; private final VectorDataType vectorDataType; - LuceneFieldMapper(final CreateLuceneFieldMapperInput input) { + private PerDimensionProcessor perDimensionProcessor; + private PerDimensionValidator perDimensionValidator; + private VectorValidator vectorValidator; + + static LuceneFieldMapper createFieldMapper( + String fullname, + Map metaValue, + VectorDataType vectorDataType, + Integer dimension, + KNNMethodContext knnMethodContext, + CreateLuceneFieldMapperInput createLuceneFieldMapperInput + ) { + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } + + @Override + public Optional getDimension() { + return Optional.of(dimension); + } + }); + + return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput); + } + + private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final CreateLuceneFieldMapperInput input) { super( input.getName(), - input.getMappedFieldType(), + mappedFieldType, input.getMultiFields(), input.getCopyTo(), input.getIgnoreMalformed(), input.isStored(), input.isHasDocValues(), - input.getKnnMethodContext().getMethodComponentContext().getIndexVersion() + input.getIndexVersion(), + mappedFieldType.knnMappingConfig.getKnnMethodContext().orElse(null) ); - + KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context is missing")); vectorDataType = input.getVectorDataType(); - this.knnMethod = input.getKnnMethodContext(); - final VectorSimilarityFunction vectorSimilarityFunction = this.knnMethod.getSpaceType() + + final VectorSimilarityFunction vectorSimilarityFunction = knnMethodContext.getSpaceType() .getKnnVectorSimilarityFunction() .getVectorSimilarityFunction(); - final int dimension = input.getMappedFieldType().getDimension(); + final int dimension = knnMappingConfig.getDimension().orElseThrow(() -> new IllegalArgumentException("Dimension is missing")); if (dimension > KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE)) { throw new IllegalArgumentException( String.format( @@ -69,10 +103,13 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { this.fieldType = vectorDataType.createKnnVectorFieldType(dimension, vectorSimilarityFunction); if (this.hasDocValues) { - this.vectorFieldType = buildDocValuesFieldType(this.knnMethod.getKnnEngine()); + this.vectorFieldType = buildDocValuesFieldType(knnMethodContext.getKnnEngine()); } else { this.vectorFieldType = null; } + + initValidatorsAndProcessors(knnMethodContext); + knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); } @Override @@ -105,6 +142,36 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fieldsToBeAdded; } + private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { + this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + if (VectorDataType.BINARY == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + return; + } + + if (VectorDataType.BYTE == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + return; + } + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; + } + @Override void updateEngineStats() { KNNEngine.LUCENE.setInitialized(true); @@ -117,8 +184,6 @@ static class CreateLuceneFieldMapperInput { @NonNull String name; @NonNull - KNNVectorFieldType mappedFieldType; - @NonNull MultiFields multiFields; @NonNull CopyTo copyTo; @@ -127,7 +192,7 @@ static class CreateLuceneFieldMapperInput { boolean stored; boolean hasDocValues; VectorDataType vectorDataType; - @NonNull - KNNMethodContext knnMethodContext; + Version indexVersion; + KNNMethodContext originalKnnMethodContext; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index b15ab14894..b38fc70206 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -6,37 +6,64 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; import java.io.IOException; import java.util.Map; +import java.util.Optional; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.getMethodComponentContext; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for method definition in mapping */ public class MethodFieldMapper extends KNNVectorFieldMapper { - MethodFieldMapper( + private PerDimensionProcessor perDimensionProcessor; + private PerDimensionValidator perDimensionValidator; + private VectorValidator vectorValidator; + + public static MethodFieldMapper createFieldMapper( + String fullname, String simpleName, - KNNVectorFieldType mappedFieldType, + Map metaValue, + VectorDataType vectorDataType, + Integer dimension, + KNNMethodContext knnMethodContext, + KNNMethodContext originalKNNMethodContext, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - KNNMethodContext knnMethodContext + Version indexCreatedVersion ) { + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } - super( + @Override + public Optional getDimension() { + return Optional.of(dimension); + } + }); + return new MethodFieldMapper( simpleName, mappedFieldType, multiFields, @@ -44,14 +71,43 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { ignoreMalformed, stored, hasDocValues, - knnMethodContext.getMethodComponentContext().getIndexVersion() + indexCreatedVersion, + originalKNNMethodContext ); + } - this.knnMethod = knnMethodContext; + private MethodFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexVerision, + KNNMethodContext originalKNNMethodContext + ) { + super( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + indexVerision, + originalKNNMethodContext + ); + KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); + KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); + this.fieldType.putAttribute( + DIMENSION, + String.valueOf(annConfig.getDimension().orElseThrow(() -> new IllegalArgumentException("Dimension cannot be empty"))) + ); this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); @@ -66,5 +122,57 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { } this.fieldType.freeze(); + initValidatorsAndProcessors(knnMethodContext); + knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); + } + + private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { + this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); + + if (VectorDataType.BINARY == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + if (VectorDataType.BYTE == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + MethodComponentContext methodComponentContext = getMethodComponentContext(knnMethodContext); + if (!isFaissSQfp16(methodComponentContext)) { + // Normal float and byte processor + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FP16_VALIDATOR; + + if (!isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) + )) { + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + this.perDimensionProcessor = PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index adaaef28e6..1fdd4fb3b6 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -9,46 +9,112 @@ import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for model in mapping */ public class ModelFieldMapper extends KNNVectorFieldMapper { - ModelFieldMapper( + // If the dimension has not yet been set because we do not have access to model metadata, it will be -1 + public static final int UNSET_MODEL_DIMENSION_IDENTIFIER = -1; + + private final AtomicReference spaceType; + private final AtomicReference methodComponentContext; + private final AtomicInteger dimension; + private final AtomicReference vectorDataType; + + private final AtomicReference perDimensionProcessor; + private final AtomicReference perDimensionValidator; + private final AtomicReference vectorValidator; + + private final String modelId; + + public static ModelFieldMapper createFieldMapper( + String fullname, String simpleName, - KNNVectorFieldType mappedFieldType, + Map metaValue, + VectorDataType vectorDataType, + String modelId, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, boolean hasDocValues, ModelDao modelDao, - String modelId, Version indexCreatedVersion ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); - this.modelId = modelId; + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { + @Override + public Optional getModelId() { + return Optional.of(modelId); + } + }); + return new ModelFieldMapper( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + modelDao, + indexCreatedVersion + ); + } + + private ModelFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + ModelDao modelDao, + Version indexCreatedVersion + ) { + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); + modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); this.modelDao = modelDao; + this.spaceType = new AtomicReference<>(null); + this.methodComponentContext = new AtomicReference<>(null); + this.dimension = new AtomicInteger(UNSET_MODEL_DIMENSION_IDENTIFIER); + this.vectorDataType = new AtomicReference<>(null); + this.perDimensionProcessor = new AtomicReference<>(null); + this.perDimensionValidator = new AtomicReference<>(null); + this.vectorValidator = new AtomicReference<>(null); + this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); this.fieldType.freeze(); } @Override - protected void parseCreateField(ParseContext context) throws IOException { + protected void validatePreparse() { + super.validatePreparse(); // For the model field mapper, we cannot validate the model during index creation due to // an issue with reading cluster state during mapper creation. So, we need to validate the - // model when ingestion starts. + // model when ingestion starts. We do this as lazily as we can ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { @@ -56,18 +122,76 @@ protected void parseCreateField(ParseContext context) throws IOException { String.format( "Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.", modelId, - context.mapperService().index().getName(), + simpleName(), MODEL_ID ) ); } - parseCreateField( - context, - modelMetadata.getDimension(), - modelMetadata.getSpaceType(), - modelMetadata.getMethodComponentContext(), - modelMetadata.getVectorDataType() - ); + maybeInitLazyVariables(modelMetadata); + } + + private void maybeInitLazyVariables(ModelMetadata modelMetadata) { + vectorDataType.compareAndExchange(null, modelMetadata.getVectorDataType()); + if (spaceType.get() == null) { + spaceType.compareAndExchange(null, modelMetadata.getSpaceType()); + spaceType.get().validateVectorDataType(vectorDataType.get()); + } + methodComponentContext.compareAndExchange(null, modelMetadata.getMethodComponentContext()); + dimension.compareAndExchange(UNSET_MODEL_DIMENSION_IDENTIFIER, modelMetadata.getDimension()); + maybeInitValidatorsAndProcessors(); + } + + private void maybeInitValidatorsAndProcessors() { + this.vectorValidator.compareAndExchange(null, new SpaceVectorValidator(spaceType.get())); + + if (VectorDataType.BINARY == vectorDataType.get()) { + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BIT_VALIDATOR); + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + + if (VectorDataType.BYTE == vectorDataType.get()) { + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_BYTE_VALIDATOR); + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + + if (!isFaissSQfp16(methodComponentContext.get())) { + // Normal float and byte processor + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR); + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + + this.perDimensionValidator.compareAndExchange(null, PerDimensionValidator.DEFAULT_FP16_VALIDATOR); + if (!isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) methodComponentContext.get().getParameters().get(METHOD_ENCODER_PARAMETER) + )) { + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.NOOP_PROCESSOR); + return; + } + this.perDimensionProcessor.compareAndExchange(null, PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR); + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator.get(); + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator.get(); + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor.get(); + } + + @Override + protected void parseCreateField(ParseContext context) throws IOException { + validatePreparse(); + parseCreateField(context, dimension.get(), vectorDataType.get()); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java new file mode 100644 index 0000000000..21139f2ad4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; + +/** + * Process values per dimension. Good to have if we want to do some kind of cleanup on data as it is coming in. + */ +public interface PerDimensionProcessor { + + /** + * Process float value per dimension. + * + * @param value value to process + * @return processed value + */ + default float process(float value) { + return value; + } + + /** + * Process byte as float value per dimension. + * + * @param value value to process + * @return processed value + */ + default float processByte(float value) { + return value; + } + + PerDimensionProcessor NOOP_PROCESSOR = new PerDimensionProcessor() { + }; + + // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be + // clipped to FP16 range. + PerDimensionProcessor CLIP_TO_FP16_PROCESSOR = new PerDimensionProcessor() { + @Override + public float process(float value) { + return clipVectorValueToFP16Range(value); + } + + @Override + public float processByte(float value) { + throw new IllegalStateException("CLIP_TO_FP16_PROCESSOR should not be called with byte type"); + } + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java new file mode 100644 index 0000000000..2ca0761c02 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.index.VectorDataType; + +import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; +import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; + +/** + * Validates per dimension fields + */ +public interface PerDimensionValidator { + /** + * Validates the given float is valid for the configuration + * + * @param value to validate + */ + default void validate(float value) {} + + /** + * Validates the given float as a byte is valid for the configuration. + * + * @param value to validate + */ + default void validateByte(float value) {} + + PerDimensionValidator DEFAULT_FLOAT_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + validateFloatVectorValue(value); + } + + @Override + public void validateByte(float value) { + throw new IllegalStateException("DEFAULT_FLOAT_VALIDATOR should only be used for float vectors"); + } + }; + + // Validates if it is a finite number and within the fp16 range of [-65504 to 65504]. + PerDimensionValidator DEFAULT_FP16_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + validateFP16VectorValue(value); + } + + @Override + public void validateByte(float value) { + throw new IllegalStateException("DEFAULT_FP16_VALIDATOR should only be used for float vectors"); + } + }; + + PerDimensionValidator DEFAULT_BYTE_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + throw new IllegalStateException("DEFAULT_BYTE_VALIDATOR should only be used for byte values"); + } + + @Override + public void validateByte(float value) { + validateByteVectorValue(value, VectorDataType.BYTE); + } + }; + + PerDimensionValidator DEFAULT_BIT_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + throw new IllegalStateException("DEFAULT_BIT_VALIDATOR should only be used for byte values"); + } + + @Override + public void validateByte(float value) { + validateByteVectorValue(value, VectorDataType.BINARY); + } + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java b/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java new file mode 100644 index 0000000000..6ff088604c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AllArgsConstructor; +import org.opensearch.knn.index.SpaceType; + +/** + * Confirms that a given vector is valid for the provided space type + */ +@AllArgsConstructor +public class SpaceVectorValidator implements VectorValidator { + + private final SpaceType spaceType; + + @Override + public void validateVector(byte[] vector) { + spaceType.validateVector(vector); + } + + @Override + public void validateVector(float[] vector) { + spaceType.validateVector(vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java b/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java new file mode 100644 index 0000000000..f4253ae373 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +/** + * Class validates vector after it has been parsed + */ +public interface VectorValidator { + /** + * Validate if the given byte vector is supported + * + * @param vector the given vector + */ + default void validateVector(byte[] vector) {} + + /** + * Validate if the given float vector is supported + * + * @param vector the given vector + */ + default void validateVector(float[] vector) {} + + VectorValidator NOOP_VECTOR_VALIDATOR = new VectorValidator() { + }; +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 6d57cb2dd5..744800a09c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -24,9 +24,9 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.engine.model.QueryContext; +import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -43,6 +43,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; @@ -341,36 +342,54 @@ protected Query doToQuery(QueryShardContext context) { if (!(mappedFieldType instanceof KNNVectorFieldType)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not knn_vector type.", this.fieldName)); } - KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) mappedFieldType; - int fieldDimension = knnVectorFieldType.getDimension(); - KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); - MethodComponentContext methodComponentContext = null; - KNNEngine knnEngine = KNNEngine.DEFAULT; - VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); - SpaceType spaceType = knnVectorFieldType.getSpaceType(); + KNNMappingConfig knnMappingConfig = knnVectorFieldType.getKnnMappingConfig(); + final AtomicReference queryConfigFromMapping = new AtomicReference<>(); + knnMappingConfig.getKnnMethodContext() + .ifPresentOrElse( + knnMethodContext -> queryConfigFromMapping.set( + new QueryConfigFromMapping( + knnMappingConfig.getDimension() + .orElseThrow( + () -> new IllegalStateException( + String.format(Locale.ROOT, "Field '%s' does not have dimension set when it should.", this.fieldName) + ) + ), + knnMethodContext.getKnnEngine(), + knnMethodContext.getMethodComponentContext(), + knnMethodContext.getSpaceType(), + knnVectorFieldType.getVectorDataType() + ) + ), + () -> knnMappingConfig.getModelId().ifPresentOrElse(modelId -> { + ModelMetadata modelMetadata = getModelMetadataForField(modelId); + queryConfigFromMapping.set( + new QueryConfigFromMapping( + modelMetadata.getDimension(), + modelMetadata.getKnnEngine(), + modelMetadata.getMethodComponentContext(), + modelMetadata.getSpaceType(), + modelMetadata.getVectorDataType() + ) + ); + }, + () -> { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName) + ); + } + ) + ); + int fieldDimension = queryConfigFromMapping.get().getFieldDimension(); + KNNEngine knnEngine = queryConfigFromMapping.get().getKnnEngine(); + MethodComponentContext methodComponentContext = queryConfigFromMapping.get().getMethodComponentContext(); + SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); + VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); + VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); updateQueryStats(vectorQueryType); - if (fieldDimension == -1) { - if (spaceType != null) { - throw new IllegalStateException("Space type should be null when the field uses a model"); - } - // If dimension is not set, the field uses a model and the information needs to be retrieved from there - ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); - fieldDimension = modelMetadata.getDimension(); - knnEngine = modelMetadata.getKnnEngine(); - spaceType = modelMetadata.getSpaceType(); - methodComponentContext = modelMetadata.getMethodComponentContext(); - vectorDataType = modelMetadata.getVectorDataType(); - - } else if (knnMethodContext != null) { - // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping - knnEngine = knnMethodContext.getKnnEngine(); - spaceType = knnMethodContext.getSpaceType(); - methodComponentContext = knnMethodContext.getMethodComponentContext(); - } - + // This could be null in the case of when a model did not have serialized methodComponent information final String method = methodComponentContext != null ? methodComponentContext.getName() : null; if (StringUtils.isNotBlank(method)) { final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method); @@ -492,13 +511,7 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME)); } - private ModelMetadata getModelMetadataForField(KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); - - if (modelId == null) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have model.", this.fieldName)); - } - + private ModelMetadata getModelMetadataForField(String modelId) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); @@ -568,4 +581,14 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I } return super.doRewrite(queryShardContext); } + + @Getter + @AllArgsConstructor + private static class QueryConfigFromMapping { + private final int fieldDimension; + private final KNNEngine knnEngine; + private final MethodComponentContext methodComponentContext; + private final SpaceType spaceType; + private final VectorDataType vectorDataType; + } } diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 56c129546f..f5481e9b6a 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -7,12 +7,18 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.Version; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNLibrarySearchContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.core.common.bytes.BytesReference; @@ -20,12 +26,15 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.test.OpenSearchTestCase; +import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; /** * Base class for integration tests for KNN plugin. Contains several methods for testing KNN ES functionality. @@ -91,4 +100,51 @@ private void initKNNSettings() { public Map xContentBuilderToMap(XContentBuilder xContentBuilder) { return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(); } + + public static KNNMethodContext getDefaultKNNMethodContext() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); + methodComponentContext.setIndexVersion(Version.CURRENT); + return defaultInstance; + } + + public static KNNMethodContext getDefaultBinaryKNNMethodContext() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT_BINARY, methodComponentContext); + methodComponentContext.setIndexVersion(Version.CURRENT); + return defaultInstance; + } + + public static KNNMappingConfig getMappingConfigForMethodMapping(KNNMethodContext knnMethodContext, int dimension) { + return new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } + + @Override + public Optional getDimension() { + return Optional.of(dimension); + } + }; + } + + public static KNNMappingConfig getMappingConfigForFlatMapping(int dimension) { + + return new KNNMappingConfig() { + @Override + public Optional getDimension() { + return Optional.of(dimension); + } + }; + } + + public static KNNMappingConfig getMappingConfigForModelMapping(String modelId) { + return new KNNMappingConfig() { + @Override + public Optional getModelId() { + return Optional.of(modelId); + } + }; + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java index 9a867c58dd..f71fbaae0f 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java @@ -94,7 +94,7 @@ public void testGetSpaceType() { */ public void testValidate() { // Check valid default - this should not throw any exception - assertNull(KNNMethodContext.getDefault().validate()); + assertNull(getDefaultKNNMethodContext().validate()); // Check a valid nmslib method MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index e1ebc57085..e87531561a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -251,62 +251,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } - public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { - // Set information about the segment and the fields - String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); - int docsInSegment = 100; - String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); - - KNNEngine knnEngine = KNNEngine.NMSLIB; - SpaceType spaceType = SpaceType.COSINESIMIL; - int dimension = 16; - - SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() - .directory(directory) - .segmentName(segmentName) - .docsInSegment(docsInSegment) - .codec(codec) - .build(); - - FieldInfo[] fieldInfoArray = new FieldInfo[] { - KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) - .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") - .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") - .addAttribute(KNNConstants.HNSW_ALGO_M, "16") - .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) - .build() }; - - FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); - SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); - long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); - - // Add documents to the field - KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( - docsInSegment, - dimension - ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); - - // The document should be created in the correct location - String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); - assertFileInCorrectLocation(state, expectedFile); - - // The footer should be valid - assertValidFooter(state.directory, expectedFile); - - // The document should be readable by nmslib - assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); - - // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); - assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); - } - public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); int docsInSegment = 100; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index a0b9b32d0e..00cc2b167c 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -14,8 +14,10 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.Version; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; @@ -56,6 +58,7 @@ import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -77,6 +80,8 @@ import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; @@ -86,14 +91,28 @@ public class KNNCodecTestCase extends KNNTestCase { private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); - private static FieldType sampleFieldType; + private static final FieldType sampleFieldType; static { + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); + String parameterString; + try { + parameterString = XContentFactory.jsonBuilder() + .map(knnMethodContext.getKnnEngine().getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - sampleFieldType.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); - sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()); - sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); - sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); - sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); + sampleFieldType.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); + sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, knnMethodContext.getKnnEngine().getName()); + sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + sampleFieldType.putAttribute(KNNConstants.PARAMETERS, parameterString); sampleFieldType.freeze(); } private static final String FIELD_NAME_ONE = "test_vector_one"; @@ -309,8 +328,19 @@ public void testKnnVectorIndex( SpaceType.L2, new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256)) ); - final KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldType(FIELD_NAME_ONE, Map.of(), 3, knnMethodContext); - final KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldType(FIELD_NAME_TWO, Map.of(), 2, knnMethodContext); + + final KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); + final KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 2) + ); when(mapperService.fieldType(eq(FIELD_NAME_ONE))).thenReturn(mappedFieldType1); when(mapperService.fieldType(eq(FIELD_NAME_TWO))).thenReturn(mappedFieldType2); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index c95568be22..f06ff79353 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -103,7 +103,7 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT, null); assertEquals(7, builder.getParameters().size()); List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); @@ -114,7 +114,7 @@ public void testBuilder_getParameters() { public void testBuilder_build_fromKnnMethodContext() { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -126,6 +126,7 @@ public void testBuilder_build_fromKnnMethodContext() { .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); builder.knnMethodContext.setValue( @@ -139,19 +140,17 @@ public void testBuilder_build_fromKnnMethodContext() { ) ); - builder.modelId.setValue("Random modelId"); - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); - assertNotNull(knnVectorFieldMapper.knnMethod); - assertNull(knnVectorFieldMapper.modelId); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); } public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -163,6 +162,7 @@ public void testBuilder_build_fromModel() { .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); String modelId = "Random modelId"; @@ -184,14 +184,14 @@ public void testBuilder_build_fromModel() { when(modelDao.getMetadata(modelId)).thenReturn(mockedModelMetadata); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof ModelFieldMapper); - assertNotNull(knnVectorFieldMapper.modelId); - assertNull(knnVectorFieldMapper.knnMethod); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isEmpty()); } public void testBuilder_build_fromLegacy() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); int m = 17; int efConstruction = 17; @@ -201,37 +201,22 @@ public void testBuilder_build_fromLegacy() { .put(settings(CURRENT).build()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); - assertNull(knnVectorFieldMapper.modelId); - assertNull(knnVectorFieldMapper.knnMethod); - assertEquals(SpaceType.L2.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType); - } - - public void testBuilder_whenKnnFalseWithBinary_thenSetHammingAsDefault() { - // Check legacy is picked up if model context and method context are not set - ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); - builder.vectorDataType.setValue(VectorDataType.BINARY); - builder.dimension.setValue(8); - - // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); - - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); - KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); - assertEquals(SpaceType.HAMMING.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType); + assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); + assertEquals(SpaceType.L2, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); } public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -317,7 +302,7 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).put(KNN_INDEX, true).build()).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -616,7 +601,7 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -646,12 +631,18 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep // merge with itself - should be successful KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); - assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge1.knnMethod); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); // merge with another mapper of the same field with same context KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); - assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge2.knnMethod); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); // merge with another mapper of the same field with different context xContentBuilder = XContentFactory.jsonBuilder() @@ -676,7 +667,7 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); String modelId = "test-id"; int dimension = 133; @@ -715,12 +706,18 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { // merge with itself - should be successful KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); - assertEquals(knnVectorFieldMapper1.modelId, knnVectorFieldMapperMerge1.modelId); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getModelId().get(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getModelId().get() + ); // merge with another mapper of the same field with same context KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); - assertEquals(knnVectorFieldMapper1.modelId, knnVectorFieldMapperMerge2.modelId); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getModelId().get(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getModelId().get() + ); // merge with another mapper of the same field with different context xContentBuilder = XContentFactory.jsonBuilder() @@ -754,18 +751,19 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { when(parseContext.doc()).thenReturn(document); when(parseContext.path()).thenReturn(contentPath); - LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) - .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.FLOAT + LuceneFieldMapper luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.FLOAT, + TEST_DIMENSION, + getDefaultKNNMethodContext(), + inputBuilder.build() + ) ); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validatePreparse(); + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField List fields = document.getFields(); @@ -798,19 +796,26 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { when(parseContext.path()).thenReturn(contentPath); inputBuilder.hasDocValues(false); - luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) - .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.FLOAT + + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); + luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.FLOAT, + TEST_DIMENSION, + knnMethodContext, + inputBuilder.build() + ) + ); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); // Document should have 1 field: one for KnnVectorField fields = document.getFields(); @@ -834,19 +839,21 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.doc()).thenReturn(document); when(parseContext.path()).thenReturn(contentPath); - LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + LuceneFieldMapper luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.BYTE, + TEST_DIMENSION, + getDefaultKNNMethodContext(), + inputBuilder.build() + ) + ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.BYTE - ); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.BYTE); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField List fields = document.getFields(); @@ -878,19 +885,21 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.path()).thenReturn(contentPath); inputBuilder.hasDocValues(false); - luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.BYTE, + TEST_DIMENSION, + getDefaultKNNMethodContext(), + inputBuilder.build() + ) + ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.BYTE - ); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.BYTE); // Document should have 1 field: one for KnnByteVectorField fields = document.getFields(); @@ -970,10 +979,10 @@ private void testBuilderWithBinaryDataType( String expectedErrMsg ) { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); builder.knnMethodContext.setValue( new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(method, Collections.emptyMap())) @@ -986,7 +995,10 @@ private void testBuilderWithBinaryDataType( KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); if (SpaceType.UNDEFINED == spaceType) { - assertEquals(SpaceType.HAMMING, knnVectorFieldMapper.fieldType().spaceType); + assertEquals( + SpaceType.HAMMING, + knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType() + ); } } else { Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); @@ -996,10 +1008,10 @@ private void testBuilderWithBinaryDataType( public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); builder.knnMethodContext.setValue( new KNNMethodContext( @@ -1022,7 +1034,7 @@ public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1031,13 +1043,13 @@ public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); + assertTrue(knnVectorFieldMapper instanceof FlatVectorFieldMapper); } public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1052,29 +1064,14 @@ public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( VectorDataType vectorDataType ) { - KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.LUCENE, - SpaceType.DEFAULT, - new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) - ); - - KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( - TEST_FIELD_NAME, - Collections.emptyMap(), - TEST_DIMENSION, - knnMethodContext, - vectorDataType - ); - return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() .name(TEST_FIELD_NAME) - .mappedFieldType(knnVectorFieldType) .multiFields(FieldMapper.MultiFields.empty()) .copyTo(FieldMapper.CopyTo.empty()) .hasDocValues(true) .vectorDataType(vectorDataType) .ignoreMalformed(new Explicit<>(true, true)) - .knnMethodContext(knnMethodContext); + .originalKnnMethodContext(getDefaultKNNMethodContext()); } private static float[] createInitializedFloatArray(int dimension, float value) { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 31da12d669..58a30472c3 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -61,16 +61,19 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() { public void testGetExpectedVectorLengthSuccess() { KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); - + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(8); + when(knnVectorFieldTypeBinary.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 8) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + when(knnVectorFieldTypeModelBased.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 8) + ); String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + when(knnVectorFieldTypeModelBased.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId)); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -87,9 +90,8 @@ public void testGetExpectedVectorLengthSuccess() { public void testGetExpectedVectorLengthFailure() { KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); + when(knnVectorFieldTypeModelBased.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId)); ModelDao modelDao = mock(ModelDao.class); ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -103,20 +105,6 @@ public void testGetExpectedVectorLengthFailure() { () -> KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased) ); assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); - - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); - String fieldName = "test-field"; - when(methodComponentContext.getName()).thenReturn(fieldName); - when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); - when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); - - e = expectThrows( - IllegalArgumentException.class, - () -> KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased) - ); - assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); } public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { diff --git a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java index dcd2557405..faae3e35d2 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java @@ -5,33 +5,35 @@ package org.opensearch.knn.index.mapper; -import junit.framework.TestCase; +import org.opensearch.Version; import org.opensearch.index.mapper.FieldMapper; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; import java.util.Collections; -public class MethodFieldMapperTests extends TestCase { - public void testMethodFieldMapper_whenVectorDataTypeIsGiven_thenSetItInFieldType() { - KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( - "testField", - Collections.emptyMap(), - 1, - VectorDataType.BINARY, - SpaceType.HAMMING - ); - MethodFieldMapper mappers = new MethodFieldMapper( - "simpleName", - mappedFieldType, - null, - new FieldMapper.CopyTo.Builder().build(), - KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, - true, - true, - KNNMethodContext.getDefault() +public class MethodFieldMapperTests extends KNNTestCase { + public void testMethodFieldMapper_whenVectorDataTypeAndContextMismatch_thenThrow() { + // Expect that we cannot create the mapper with an invalid field type + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + expectThrows( + IllegalArgumentException.class, + () -> MethodFieldMapper.createFieldMapper( + "testField", + "simpleName", + Collections.emptyMap(), + VectorDataType.BINARY, + 1, + knnMethodContext, + knnMethodContext, + null, + new FieldMapper.CopyTo.Builder().build(), + KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, + true, + true, + Version.CURRENT + ) ); - assertEquals(VectorDataType.BINARY, mappers.fieldType().vectorDataType); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 0b918bd9ed..75816eb825 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -185,9 +185,8 @@ public void testDoToQuery_Normal() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); @@ -207,7 +206,6 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -215,7 +213,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); float resultSimilarity = KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2); @@ -239,7 +237,6 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -247,7 +244,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); } @@ -266,16 +263,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -298,16 +293,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -325,16 +318,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -352,16 +343,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -383,16 +372,14 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -416,16 +403,14 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -444,7 +429,6 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -452,7 +436,7 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() { ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.HAMMING, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 8)); Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); assertTrue(e.getMessage().contains("Binary data type does not support radial search")); } @@ -470,15 +454,13 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); // When @@ -504,14 +486,13 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); assertNotNull(query); @@ -531,14 +512,13 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); assertNotNull(query); @@ -553,15 +533,13 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); // When @@ -586,10 +564,12 @@ public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParamete QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.LUCENE, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of())) + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.COSINESIMIL, + new MethodComponentContext("hnsw", Map.of()) ); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -609,15 +589,13 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } @@ -632,11 +610,9 @@ public void testDoToQuery_FromModel() { when(mockQueryShardContext.index()).thenReturn(dummyIndex); // Dimension is -1. In this case, model metadata will need to provide dimension - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId)); // Mock the modelDao to return mocked modelMetadata ModelMetadata modelMetadata = mock(ModelMetadata.class); @@ -672,11 +648,9 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId)); ModelMetadata modelMetadata = mock(ModelMetadata.class); when(modelMetadata.getDimension()).thenReturn(4); @@ -709,12 +683,9 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId)); ModelMetadata modelMetadata = mock(ModelMetadata.class); when(modelMetadata.getDimension()).thenReturn(4); @@ -744,10 +715,10 @@ public void testDoToQuery_InvalidDimensions() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(400); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 400)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); - when(mockKNNVectorField.getDimension()).thenReturn(K); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), K)); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } @@ -769,9 +740,10 @@ public void testDoToQuery_InvalidZeroFloatVector() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, @@ -790,9 +762,10 @@ public void testDoToQuery_InvalidZeroByteVector() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BYTE); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, @@ -919,9 +892,8 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -946,9 +918,8 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); @@ -972,9 +943,8 @@ public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -992,9 +962,8 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(32); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 32)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertArrayEquals(expectedQueryVector, query.getByteQueryVector()); @@ -1008,9 +977,8 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 8)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Exception ex = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); assertTrue(ex.getMessage(), ex.getMessage().contains("invalid dimension")); diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index f9a9704d0b..d1288c5f34 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -46,6 +46,7 @@ import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.knn.KNNTestCase.getMappingConfigForFlatMapping; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; @@ -744,18 +745,20 @@ private Map createDataset( } private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { - KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( - FIELD_NAME, - Collections.emptyMap(), - SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length, - SpaceType.HAMMING == spaceType ? VectorDataType.BINARY : VectorDataType.FLOAT, - null - ); List target = new ArrayList<>(queryVector.length); for (float f : queryVector) { target.add(f); } - KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create(spaceType.getValue(), target, knnVectorFieldType); + KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create( + spaceType.getValue(), + target, + new KNNVectorFieldType( + FIELD_NAME, + Collections.emptyMap(), + SpaceType.HAMMING == spaceType ? VectorDataType.BINARY : VectorDataType.FLOAT, + getMappingConfigForFlatMapping(SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length) + ) + ); switch (spaceType) { case L1: case L2: diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index ae9ad71062..c72f5679a0 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; import org.junit.BeforeClass; +import org.junit.Ignore; import org.opensearch.Version; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; @@ -59,6 +60,7 @@ import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +@Ignore public class JNIServiceTests extends KNNTestCase { static final int FP16_MAX = 65504; static final int FP16_MIN = -65504; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java index 823d210803..c41e9763b5 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java @@ -19,9 +19,11 @@ public class KNNScoringSpaceFactoryTests extends KNNTestCase { public void testValidSpaces() { KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); + when(knnVectorFieldTypeBinary.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 24) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); NumberFieldMapper.NumberFieldType numberFieldType = new NumberFieldMapper.NumberFieldType( "field", @@ -66,9 +68,11 @@ public void testValidSpaces() { public void testInvalidSpace() { List floatQueryObject = List.of(1.0f, 1.0f, 1.0f); KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); + when(knnVectorFieldTypeBinary.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 24) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); // Verify diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 6c557c8dd5..4fc549d6bc 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -57,8 +57,13 @@ private void expectThrowsExceptionWithKNNFieldWithBinaryDataType(Class clazz) th public void testL2_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); KNNScoringSpace.L2 l2 = new KNNScoringSpace.L2(arrayListQueryObject, fieldType); assertEquals(1F, l2.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); } @@ -73,9 +78,13 @@ public void testCosineSimilarity_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); assertEquals(2F, cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), 0.1F); @@ -92,8 +101,13 @@ public void testCosineSimilarity_whenValid_thenSucceed() { } public void testCosineSimilarity_whenZeroVector_thenException() { - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f); IllegalArgumentException exception1 = expectThrows( @@ -116,9 +130,14 @@ public void testInnerProd_whenValid_thenSucceed() { float[] arrayFloat_case1 = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject_case1 = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); float[] arrayFloat2_case1 = new float[] { 1.0f, 1.0f, 1.0f }; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); KNNScoringSpace.InnerProd innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); assertEquals(7.0F, innerProd.getScoringMethod().apply(arrayFloat_case1, arrayFloat2_case1), 0.001F); @@ -183,14 +202,14 @@ public void testHammingBit_Base64() { public void testHamming_whenKNNFieldType_thenSucceed() { List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); KNNVectorFieldType fieldType = new KNNVectorFieldType( "test", Collections.emptyMap(), - 8 * arrayListQueryObject.size(), - knnMethodContext, - VectorDataType.BINARY + VectorDataType.BINARY, + getMappingConfigForMethodMapping(knnMethodContext, 8 * arrayListQueryObject.size()) ); + KNNScoringSpace.Hamming hamming = new KNNScoringSpace.Hamming(arrayListQueryObject, fieldType); float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 781ed2350b..2374e4f7bb 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -64,7 +64,7 @@ public void testParseKNNVectorQuery() { KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); - when(fieldType.getDimension()).thenReturn(3); + when(fieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.FLOAT), 0.1f); expectThrows( diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 3515c690d8..8cff4dfa14 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -23,7 +23,6 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -303,7 +302,7 @@ public void testTrainingIndexSize() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", @@ -350,7 +349,7 @@ public void testTrainIndexSize_whenDataTypeIsBinary() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", @@ -398,7 +397,7 @@ public void testTrainIndexSize_whenDataTypeIsByte() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 53e59129e8..83d39cfdc7 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -48,7 +48,7 @@ public class TrainingModelRequestTests extends KNNTestCase { public void testStreams() throws IOException { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -105,7 +105,7 @@ public void testStreams() throws IOException { public void testGetters() { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field";