Skip to content

Commit

Permalink
Add Constructor overloading and other refactoring changes
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 22, 2024
1 parent 4016e70 commit c66a592
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 39 deletions.
4 changes: 0 additions & 4 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ public class KNNConstants {
public static final String KNN_METHOD = "method";
public static final String NAME = "name";
public static final String PARAMETERS = "parameters";
public static final String MAX_CONNECTIONS = "max_connections";
public static final String BEAM_WIDTH = "beam_width";
public static final String METHOD_HNSW = "hnsw";
public static final String TYPE = "type";
public static final String TYPE_NESTED = "nested";
Expand Down Expand Up @@ -82,8 +80,6 @@ public class KNNConstants {
public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0;
public static final String LUCENE_SQ_BITS = "bits";
public static final Integer LUCENE_SQ_DEFAULT_BITS = 7;
public static final List<Integer> LUCENE_SQ_BITS_SUPPORTED = List.of(7);
public static final String LUCENE_SQ_COMPRESS = "compress";

// nmslib specific constants
public static final String NMSLIB_NAME = "nmslib";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.BEAM_WIDTH;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.MAX_CONNECTIONS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
Expand All @@ -37,15 +35,33 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier;
private final Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier;
private Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier;
private static final String MAX_CONNECTIONS = "max_connections";
private static final String BEAM_WIDTH = "beam_width";

public BasePerFieldKnnVectorsFormat(
Optional<MapperService> mapperService,
int defaultMaxConnections,
int defaultBeamWidth,
Supplier<KnnVectorsFormat> defaultFormatSupplier,
Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier
) {
this.mapperService = mapperService;
this.defaultMaxConnections = defaultMaxConnections;
this.defaultBeamWidth = defaultBeamWidth;
this.defaultFormatSupplier = defaultFormatSupplier;
this.vectorsFormatSupplier = vectorsFormatSupplier;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isKnnVectorFieldType(field) == false) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
"Initialize KNN vector format for field [{}] with default params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
defaultMaxConnections,
BEAM_WIDTH,
defaultBeamWidth
);
return defaultFormatSupplier.get();
Expand Down Expand Up @@ -85,9 +101,11 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnVectorsFormatParams.getBeamWidth()
);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
knnVectorsFormatParams -> new Lucene92HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene92HnswVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth()
)

);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
knnVectorsFormatParams -> new Lucene94HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene94HnswVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ public KNN950PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
knnVectorsFormatParams -> new Lucene95HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene95HnswVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth()
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* Class provides per field format implementation for Lucene Knn vector type
*/
public class KNN990PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {
private static final int NUM_MERGE_WORKERS = 1;

public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
Expand All @@ -31,7 +32,7 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
1,
NUM_MERGE_WORKERS,
knnScalarQuantizedVectorsFormatParams.getBits(),
knnScalarQuantizedVectorsFormatParams.isCompressFlag(),
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.params;
Expand Down Expand Up @@ -70,7 +64,7 @@ private void initConfidenceInterval(final Map<String, Object> params) {
}

// If confidence_interval is not provided by user, then it will be set with a default value as null so that
// it will be computer later in Lucene based on the dimension of the vector as 1 - 1/(1 + d)
// it will be computed later in Lucene based on the dimension of the vector as 1 - 1/(1 + d)
this.confidenceInterval = null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.params;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static org.opensearch.knn.common.KNNConstants.DYNAMIC_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS_SUPPORTED;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAXIMUM_CONFIDENCE_INTERVAL;
Expand All @@ -36,6 +35,7 @@
public class Lucene extends JVMLibrary {

Map<SpaceType, Function<Float, Float>> distanceTransform;
private static final List<Integer> LUCENE_SQ_BITS_SUPPORTED = List.of(7);

private final static Map<String, MethodComponent> HNSW_ENCODERS = ImmutableMap.of(
ENCODER_SQ,
Expand Down

0 comments on commit c66a592

Please sign in to comment.