From c66a5922493b221dad0b3f463fd5e7cd5590d8fd Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Sun, 21 Jul 2024 12:42:58 -0500 Subject: [PATCH] Add Constructor overloading and other refactoring changes Signed-off-by: Naveen Tatikonda --- .../opensearch/knn/common/KNNConstants.java | 4 --- .../codec/BasePerFieldKnnVectorsFormat.java | 28 +++++++++++++++---- .../KNN920PerFieldKnnVectorsFormat.java | 5 ---- .../KNN940PerFieldKnnVectorsFormat.java | 4 --- .../KNN950PerFieldKnnVectorsFormat.java | 4 --- .../KNN990PerFieldKnnVectorsFormat.java | 3 +- ...KNNScalarQuantizedVectorsFormatParams.java | 10 ++----- .../codec/params/KNNVectorsFormatParams.java | 8 +----- .../org/opensearch/knn/index/util/Lucene.java | 2 +- 9 files changed, 29 insertions(+), 39 deletions(-) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 32b2b4d42..3002e426f 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -16,8 +16,6 @@ public class KNNConstants { public static final String KNN_METHOD = "method"; public static final String NAME = "name"; public static final String PARAMETERS = "parameters"; - public static final String MAX_CONNECTIONS = "max_connections"; - public static final String BEAM_WIDTH = "beam_width"; public static final String METHOD_HNSW = "hnsw"; public static final String TYPE = "type"; public static final String TYPE_NESTED = "nested"; @@ -82,8 +80,6 @@ public class KNNConstants { public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0; public static final String LUCENE_SQ_BITS = "bits"; public static final Integer LUCENE_SQ_DEFAULT_BITS = 7; - public static final List LUCENE_SQ_BITS_SUPPORTED = List.of(7); - public static final String LUCENE_SQ_COMPRESS = "compress"; // nmslib specific constants public static final String NMSLIB_NAME = "nmslib"; diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index d0d5d55e4..f3738452a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -19,10 +19,8 @@ import java.util.function.Function; import java.util.function.Supplier; -import static org.opensearch.knn.common.KNNConstants.BEAM_WIDTH; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL; -import static org.opensearch.knn.common.KNNConstants.MAX_CONNECTIONS; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; /** @@ -37,15 +35,33 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor private final int defaultBeamWidth; private final Supplier defaultFormatSupplier; private final Function vectorsFormatSupplier; - private final Function scalarQuantizedVectorsFormatSupplier; + private Function scalarQuantizedVectorsFormatSupplier; + private static final String MAX_CONNECTIONS = "max_connections"; + private static final String BEAM_WIDTH = "beam_width"; + + public BasePerFieldKnnVectorsFormat( + Optional mapperService, + int defaultMaxConnections, + int defaultBeamWidth, + Supplier defaultFormatSupplier, + Function vectorsFormatSupplier + ) { + this.mapperService = mapperService; + this.defaultMaxConnections = defaultMaxConnections; + this.defaultBeamWidth = defaultBeamWidth; + this.defaultFormatSupplier = defaultFormatSupplier; + this.vectorsFormatSupplier = vectorsFormatSupplier; + } @Override public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { if (isKnnVectorFieldType(field) == false) { log.debug( - "Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"", + "Initialize KNN vector format for field [{}] with default params [{}] = \"{}\" and [{}] = \"{}\"", field, + MAX_CONNECTIONS, defaultMaxConnections, + BEAM_WIDTH, defaultBeamWidth ); return defaultFormatSupplier.get(); @@ -85,9 +101,11 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth); log.debug( - "Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"", + "Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"", field, + MAX_CONNECTIONS, knnVectorsFormatParams.getMaxConnections(), + BEAM_WIDTH, knnVectorsFormatParams.getBeamWidth() ); return vectorsFormatSupplier.apply(knnVectorsFormatParams); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java index 3ff501214..7cca04319 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN920Codec/KNN920PerFieldKnnVectorsFormat.java @@ -25,12 +25,7 @@ public KNN920PerFieldKnnVectorsFormat(final Optional mapperServic knnVectorsFormatParams -> new Lucene92HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() - ), - knnScalarQuantizedVectorsFormatParams -> new Lucene92HnswVectorsFormat( - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), - knnScalarQuantizedVectorsFormatParams.getBeamWidth() ) - ); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java index 649a413f0..1ed9c929c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN940Codec/KNN940PerFieldKnnVectorsFormat.java @@ -25,10 +25,6 @@ public KNN940PerFieldKnnVectorsFormat(final Optional mapperServic knnVectorsFormatParams -> new Lucene94HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() - ), - knnScalarQuantizedVectorsFormatParams -> new Lucene94HnswVectorsFormat( - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), - knnScalarQuantizedVectorsFormatParams.getBeamWidth() ) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java index e9bb875df..978b22003 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN950Codec/KNN950PerFieldKnnVectorsFormat.java @@ -26,10 +26,6 @@ public KNN950PerFieldKnnVectorsFormat(final Optional mapperServic knnVectorsFormatParams -> new Lucene95HnswVectorsFormat( knnVectorsFormatParams.getMaxConnections(), knnVectorsFormatParams.getBeamWidth() - ), - knnScalarQuantizedVectorsFormatParams -> new Lucene95HnswVectorsFormat( - knnScalarQuantizedVectorsFormatParams.getMaxConnections(), - knnScalarQuantizedVectorsFormatParams.getBeamWidth() ) ); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java index 820536b22..e8ecfad18 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990PerFieldKnnVectorsFormat.java @@ -17,6 +17,7 @@ * Class provides per field format implementation for Lucene Knn vector type */ public class KNN990PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat { + private static final int NUM_MERGE_WORKERS = 1; public KNN990PerFieldKnnVectorsFormat(final Optional mapperService) { super( @@ -31,7 +32,7 @@ public KNN990PerFieldKnnVectorsFormat(final Optional mapperServic knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat( knnScalarQuantizedVectorsFormatParams.getMaxConnections(), knnScalarQuantizedVectorsFormatParams.getBeamWidth(), - 1, + NUM_MERGE_WORKERS, knnScalarQuantizedVectorsFormatParams.getBits(), knnScalarQuantizedVectorsFormatParams.isCompressFlag(), knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(), diff --git a/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java index 16c70d643..79bf1cbdb 100644 --- a/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParams.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.knn.index.codec.params; @@ -70,7 +64,7 @@ private void initConfidenceInterval(final Map params) { } // If confidence_interval is not provided by user, then it will be set with a default value as null so that - // it will be computer later in Lucene based on the dimension of the vector as 1 - 1/(1 + d) + // it will be computed later in Lucene based on the dimension of the vector as 1 - 1/(1 + d) this.confidenceInterval = null; } diff --git a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java index e5ed44c9f..52134bc7e 100644 --- a/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParams.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.knn.index.codec.params; diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index 947f6e065..e68020ed9 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -20,7 +20,6 @@ import static org.opensearch.knn.common.KNNConstants.DYNAMIC_CONFIDENCE_INTERVAL; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS; -import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS_SUPPORTED; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL; import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS; import static org.opensearch.knn.common.KNNConstants.MAXIMUM_CONFIDENCE_INTERVAL; @@ -36,6 +35,7 @@ public class Lucene extends JVMLibrary { Map> distanceTransform; + private static final List LUCENE_SQ_BITS_SUPPORTED = List.of(7); private final static Map HNSW_ENCODERS = ImmutableMap.of( ENCODER_SQ,