From cfdfa8c50e262f4773f0e460dc115ea189bdfc06 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Wed, 17 Jul 2024 16:28:01 -0500 Subject: [PATCH] Refactor code Signed-off-by: Naveen Tatikonda --- .../codec/BasePerFieldKnnVectorsFormat.java | 78 ++++++++++++------- .../index/mapper/KNNVectorFieldMapper.java | 4 +- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index dd0bd253bd..6452f98f79 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -13,6 +13,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.util.KNNEngine; import java.util.Map; import java.util.Optional; @@ -61,34 +62,19 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { var params = type.getKnnMethodContext().getMethodComponentContext().getParameters(); int maxConnections = getMaxConnections(params); int beamWidth = getBeamWidth(params); - if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { - if (params.get(METHOD_ENCODER_PARAMETER) != null) { - MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER); - if (ENCODER_SQ.equals(encoderMethodComponentContext.getName())) { - Map sqEncoderParams = encoderMethodComponentContext.getParameters(); - - Float confidenceInterval = getConfidenceInterval(sqEncoderParams); - int bits = getBits(sqEncoderParams); - boolean compressFlag = getCompressFlag(sqEncoderParams); - log.debug( - "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"", - field, - MAX_CONNECTIONS, - maxConnections, - BEAM_WIDTH, - beamWidth, - LUCENE_SQ_CONFIDENCE_INTERVAL, - confidenceInterval, - LUCENE_SQ_BITS, - bits, - LUCENE_SQ_COMPRESS, - compressFlag - ); - return quantizedVectorsFormatSupplier.apply(maxConnections, beamWidth, confidenceInterval, bits, compressFlag); - } + if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE + && params != null + && params.containsKey(METHOD_ENCODER_PARAMETER)) { + final KnnVectorsFormat knnVectorsFormat = validateAndApplyQuantizedVectorsFormatForLuceneEngine( + params, + field, + maxConnections, + beamWidth + ); + if (knnVectorsFormat != null) { + return knnVectorsFormat; } - } log.debug( @@ -100,6 +86,46 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { return formatSupplier.apply(maxConnections, beamWidth); } + private KnnVectorsFormat validateAndApplyQuantizedVectorsFormatForLuceneEngine( + final Map params, + final String field, + final int maxConnections, + final int beamWidth + ) { + + if (params.get(METHOD_ENCODER_PARAMETER) == null) { + return null; + } + + // Validate if the object is of type MethodComponentContext before casting it later + if (!(params.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { + return null; + } + MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER); + if (!ENCODER_SQ.equals(encoderMethodComponentContext.getName())) { + return null; + } + Map sqEncoderParams = encoderMethodComponentContext.getParameters(); + Float confidenceInterval = getConfidenceInterval(sqEncoderParams); + int bits = getBits(sqEncoderParams); + boolean compressFlag = getCompressFlag(sqEncoderParams); + log.debug( + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"", + field, + MAX_CONNECTIONS, + maxConnections, + BEAM_WIDTH, + beamWidth, + LUCENE_SQ_CONFIDENCE_INTERVAL, + confidenceInterval, + LUCENE_SQ_BITS, + bits, + LUCENE_SQ_COMPRESS, + compressFlag + ); + return quantizedVectorsFormatSupplier.apply(maxConnections, beamWidth, confidenceInterval, bits, compressFlag); + } + @Override public int getMaxDimensions(String fieldName) { return getKnnVectorsFormatForField(fieldName).getMaxDimensions(fieldName); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 7f7c83f3ee..d5ca5bc15f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -349,7 +349,7 @@ private void validateEncoder(final KNNMethodContext knnMethodContext, final Vect return; } - if (VectorDataType.BINARY != vectorDataType) { + if (VectorDataType.FLOAT == vectorDataType) { return; } @@ -380,7 +380,7 @@ private void validateEncoder(final KNNMethodContext knnMethodContext, final Vect String.format( Locale.ROOT, "%s data type does not support %s encoder", - VectorDataType.BINARY.getValue(), + vectorDataType.getValue(), encoderMethodComponentContext.getName() ) );