diff --git a/src/main/java/org/opensearch/knn/index/Parameter.java b/src/main/java/org/opensearch/knn/index/Parameter.java index dcff79fc0..50d792ebd 100644 --- a/src/main/java/org/opensearch/knn/index/Parameter.java +++ b/src/main/java/org/opensearch/knn/index/Parameter.java @@ -16,6 +16,7 @@ import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Predicate; @@ -224,35 +225,46 @@ public DoubleParameter( @Override public ValidationException validate(Object value) { - ValidationException validationException = null; + if (Objects.isNull(value)) { + String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName()); + return getValidationException(validationErrorMsg); + } if (value.equals(0)) value = 0.0; if (!(value instanceof Double)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format(Locale.ROOT, "Value not of type Double for Double " + "parameter \"%s\".", getName()) + String validationErrorMsg = String.format( + Locale.ROOT, + "Value not of type Double for Double " + "parameter \"%s\".", + getName() ); - return validationException; + return getValidationException(validationErrorMsg); } if (!validator.test((Double) value)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format(Locale.ROOT, "Parameter validation failed for Double " + "parameter \"%s\".", getName()) + String validationErrorMsg = String.format( + Locale.ROOT, + "Parameter validation failed for Double " + "parameter \"%s\".", + getName() ); + return getValidationException(validationErrorMsg); } - return validationException; + return null; } @Override public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) { - ValidationException validationException = null; + if (Objects.isNull(value)) { + String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName()); + return getValidationException(validationErrorMsg); + } + if (!(value instanceof Double)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format(Locale.ROOT, "value is not an instance of Double for Double parameter [%s].", getName()) + String validationErrorMsg = String.format( + Locale.ROOT, + "value is not an instance of Double for Double parameter [%s].", + getName() ); - return validationException; + return getValidationException(validationErrorMsg); } if (validatorWithData == null) { @@ -260,12 +272,15 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector } if (!validatorWithData.apply((Double) value, vectorSpaceInfo)) { - validationException = new ValidationException(); - validationException.addValidationError( - String.format(Locale.ROOT, "parameter validation failed for Double parameter [%s].", getName()) - ); + String validationErrorMsg = String.format(Locale.ROOT, "parameter validation failed for Double parameter [%s].", getName()); + return getValidationException(validationErrorMsg); } + return null; + } + private ValidationException getValidationException(String validationErrorMsg) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError(validationErrorMsg); return validationException; } } diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 6452f98f7..ea91b2b74 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -10,22 +10,17 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.opensearch.index.mapper.MapperService; -import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.util.KNNEngine; -import java.util.Map; import java.util.Optional; -import java.util.function.BiFunction; +import java.util.function.Function; import java.util.function.Supplier; import static org.opensearch.knn.common.KNNConstants.BEAM_WIDTH; -import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL; -import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS; import static org.opensearch.knn.common.KNNConstants.MAX_CONNECTIONS; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; @@ -40,8 +35,8 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor private final int defaultMaxConnections; private final int defaultBeamWidth; private final Supplier defaultFormatSupplier; - private final BiFunction formatSupplier; - private final Function5Arity quantizedVectorsFormatSupplier; + private final Function vectorsFormatSupplier; + private final Function scalarQuantizedVectorsFormatSupplier; @Override public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { @@ -60,70 +55,41 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { ) ).fieldType(field); var params = type.getKnnMethodContext().getMethodComponentContext().getParameters(); - int maxConnections = getMaxConnections(params); - int beamWidth = getBeamWidth(params); if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { - final KnnVectorsFormat knnVectorsFormat = validateAndApplyQuantizedVectorsFormatForLuceneEngine( - params, - field, - maxConnections, - beamWidth - ); - if (knnVectorsFormat != null) { - return knnVectorsFormat; + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(); + if (knnScalarQuantizedVectorsFormatParams.validate(params)) { + knnScalarQuantizedVectorsFormatParams.initialize(params, defaultMaxConnections, defaultBeamWidth); + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"", + field, + MAX_CONNECTIONS, + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), + LUCENE_SQ_CONFIDENCE_INTERVAL, + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), + LUCENE_SQ_BITS, + knnScalarQuantizedVectorsFormatParams.getBits(), + LUCENE_SQ_COMPRESS, + knnScalarQuantizedVectorsFormatParams.isCompressFlag() + ); + return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams); } - } - - log.debug( - "Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"", - field, - maxConnections, - beamWidth - ); - return formatSupplier.apply(maxConnections, beamWidth); - } - - private KnnVectorsFormat validateAndApplyQuantizedVectorsFormatForLuceneEngine( - final Map params, - final String field, - final int maxConnections, - final int beamWidth - ) { - if (params.get(METHOD_ENCODER_PARAMETER) == null) { - return null; } - // Validate if the object is of type MethodComponentContext before casting it later - if (!(params.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { - return null; - } - MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER); - if (!ENCODER_SQ.equals(encoderMethodComponentContext.getName())) { - return null; - } - Map sqEncoderParams = encoderMethodComponentContext.getParameters(); - Float confidenceInterval = getConfidenceInterval(sqEncoderParams); - int bits = getBits(sqEncoderParams); - boolean compressFlag = getCompressFlag(sqEncoderParams); + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(); + knnVectorsFormatParams.initialize(params, defaultMaxConnections, defaultBeamWidth); log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"", + "Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"", field, - MAX_CONNECTIONS, - maxConnections, - BEAM_WIDTH, - beamWidth, - LUCENE_SQ_CONFIDENCE_INTERVAL, - confidenceInterval, - LUCENE_SQ_BITS, - bits, - LUCENE_SQ_COMPRESS, - compressFlag + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() ); - return quantizedVectorsFormatSupplier.apply(maxConnections, beamWidth, confidenceInterval, bits, compressFlag); + return vectorsFormatSupplier.apply(knnVectorsFormatParams); } @Override @@ -134,43 +100,4 @@ public int getMaxDimensions(String fieldName) { private boolean isKnnVectorFieldType(final String field) { return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType; } - - private int getMaxConnections(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_M); - } - return defaultMaxConnections; - } - - private int getBeamWidth(final Map params) { - if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { - return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); - } - return defaultBeamWidth; - } - - private Float getConfidenceInterval(final Map params) { - - if (params != null && params.containsKey(LUCENE_SQ_CONFIDENCE_INTERVAL)) { - if (params.get("confidence_interval").equals(0)) return Float.valueOf(0); - - return ((Double) params.get("confidence_interval")).floatValue(); - - } - return null; - } - - private int getBits(final Map params) { - if (params != null && params.containsKey(LUCENE_SQ_BITS)) { - return (int) params.get("bits"); - } - return LUCENE_SQ_DEFAULT_BITS; - } - - private boolean getCompressFlag(final Map params) { - if (params != null && params.containsKey(LUCENE_SQ_COMPRESS)) { - return (boolean) params.get("compress"); - } - return false; - } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java index 714782848..3ff501214 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java @@ -22,8 +22,14 @@ public KNN920PerFieldKnnVectorsFormat(final Optional mapperServic Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH, () -> new Lucene92HnswVectorsFormat(), - (maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth), - (maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth) + knnVectorsFormatParams -> new Lucene92HnswVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ), + knnScalarQuantizedVectorsFormatParams -> new Lucene92HnswVectorsFormat( + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnScalarQuantizedVectorsFormatParams.getBeamWidth() + ) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java index 4910b8be7..649a413f0 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java @@ -22,8 +22,14 @@ public KNN940PerFieldKnnVectorsFormat(final Optional mapperServic Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH, () -> new Lucene94HnswVectorsFormat(), - (maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth), - (maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth) + knnVectorsFormatParams -> new Lucene94HnswVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ), + knnScalarQuantizedVectorsFormatParams -> new Lucene94HnswVectorsFormat( + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnScalarQuantizedVectorsFormatParams.getBeamWidth() + ) ); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java index 87762c9ab..e9bb875df 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java @@ -23,8 +23,14 @@ public KNN950PerFieldKnnVectorsFormat(final Optional mapperServic Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH, () -> new Lucene95HnswVectorsFormat(), - (maxConnm, beamWidth) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth), - (maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth) + knnVectorsFormatParams -> new Lucene95HnswVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ), + knnScalarQuantizedVectorsFormatParams -> new Lucene95HnswVectorsFormat( + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnScalarQuantizedVectorsFormatParams.getBeamWidth() + ) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java index 7b059aa43..820536b22 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java @@ -24,14 +24,17 @@ public KNN990PerFieldKnnVectorsFormat(final Optional mapperServic Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, () -> new Lucene99HnswVectorsFormat(), - (maxConnm, beamWidth) -> new Lucene99HnswVectorsFormat(maxConnm, beamWidth), - (maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene99HnswScalarQuantizedVectorsFormat( - maxConnm, - beamWidth, + knnVectorsFormatParams -> new Lucene99HnswVectorsFormat( + knnVectorsFormatParams.getMaxConnections(), + knnVectorsFormatParams.getBeamWidth() + ), + knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat( + knnScalarQuantizedVectorsFormatParams.getMaxConnections(), + knnScalarQuantizedVectorsFormatParams.getBeamWidth(), 1, - bits, - compress, - confidenceInterval, + knnScalarQuantizedVectorsFormatParams.getBits(), + knnScalarQuantizedVectorsFormatParams.isCompressFlag(), + knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), null ) ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNScalarQuantizedVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/KNNScalarQuantizedVectorsFormatParams.java new file mode 100644 index 000000000..c3dcb17e5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNScalarQuantizedVectorsFormatParams.java @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.knn.index.MethodComponentContext; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; + +/** + * Class provides params for LuceneHnswScalarQuantizedVectorsFormat + */ +@Getter +@NoArgsConstructor +public class KNNScalarQuantizedVectorsFormatParams extends KNNVectorsFormatParams { + private float confidenceInterval; + private int bits; + private boolean compressFlag; + + @Override + boolean validate(Map params) { + if (params.get(METHOD_ENCODER_PARAMETER) == null) { + return false; + } + + // Validate if the object is of type MethodComponentContext before casting it later + if (!(params.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { + return false; + } + MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER); + if (!ENCODER_SQ.equals(encoderMethodComponentContext.getName())) { + return false; + } + + return true; + } + + @Override + void initialize(Map params, int defaultMaxConnections, int defaultBeamWidth) { + super.initialize(params, defaultMaxConnections, defaultBeamWidth); + MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER); + Map sqEncoderParams = encoderMethodComponentContext.getParameters(); + this.confidenceInterval = getConfidenceInterval(sqEncoderParams); + this.bits = getBits(sqEncoderParams); + this.compressFlag = getCompressFlag(sqEncoderParams); + } + + private Float getConfidenceInterval(final Map params) { + + if (params != null && params.containsKey(LUCENE_SQ_CONFIDENCE_INTERVAL)) { + if (params.get(LUCENE_SQ_CONFIDENCE_INTERVAL).equals(0)) return Float.valueOf(0); + + return ((Double) params.get(LUCENE_SQ_CONFIDENCE_INTERVAL)).floatValue(); + + } + return null; + } + + private int getBits(final Map params) { + if (params != null && params.containsKey(LUCENE_SQ_BITS)) { + return (int) params.get(LUCENE_SQ_BITS); + } + return LUCENE_SQ_DEFAULT_BITS; + } + + private boolean getCompressFlag(final Map params) { + if (params != null && params.containsKey(LUCENE_SQ_COMPRESS)) { + return (boolean) params.get(LUCENE_SQ_COMPRESS); + } + return false; + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/KNNVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/KNNVectorsFormatParams.java new file mode 100644 index 000000000..8e2ccc1e9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/KNNVectorsFormatParams.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.knn.common.KNNConstants; + +import java.util.Map; + +/** + * Class provides params for LuceneHNSWVectorsFormat + */ +@NoArgsConstructor +@Getter +public class KNNVectorsFormatParams { + private int maxConnections; + private int beamWidth; + + boolean validate(final Map params) { + return false; + } + + void initialize(final Map params, int defaultMaxConnections, int defaultBeamWidth) { + this.maxConnections = getMaxConnections(params, defaultMaxConnections); + this.beamWidth = getBeamWidth(params, defaultBeamWidth); + } + + private int getMaxConnections(final Map params, int defaultMaxConnections) { + if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) { + return (int) params.get(KNNConstants.METHOD_PARAMETER_M); + } + return defaultMaxConnections; + } + + private int getBeamWidth(final Map params, int defaultBeamWidth) { + if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { + return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION); + } + return defaultBeamWidth; + } +}