Skip to content

Commit

Permalink
Address Review Comments
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 18, 2024
1 parent dcbe6ad commit a9fd88b
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,17 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.BEAM_WIDTH;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAX_CONNECTIONS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

Expand All @@ -40,8 +35,8 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;
private final Function5Arity<Integer, Integer, Float, Integer, Boolean, KnnVectorsFormat> quantizedVectorsFormatSupplier;
private final Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier;
private final Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
Expand All @@ -60,70 +55,41 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);

if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE
&& params != null
&& params.containsKey(METHOD_ENCODER_PARAMETER)) {
final KnnVectorsFormat knnVectorsFormat = validateAndApplyQuantizedVectorsFormatForLuceneEngine(
params,
field,
maxConnections,
beamWidth
);
if (knnVectorsFormat != null) {
return knnVectorsFormat;
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams();
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
knnScalarQuantizedVectorsFormatParams.initialize(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits(),
LUCENE_SQ_COMPRESS,
knnScalarQuantizedVectorsFormatParams.isCompressFlag()
);
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}
}

log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
maxConnections,
beamWidth
);
return formatSupplier.apply(maxConnections, beamWidth);
}

private KnnVectorsFormat validateAndApplyQuantizedVectorsFormatForLuceneEngine(
final Map<String, Object> params,
final String field,
final int maxConnections,
final int beamWidth
) {

if (params.get(METHOD_ENCODER_PARAMETER) == null) {
return null;
}

// Validate if the object is of type MethodComponentContext before casting it later
if (!(params.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) {
return null;
}
MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER);
if (!ENCODER_SQ.equals(encoderMethodComponentContext.getName())) {
return null;
}
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();
Float confidenceInterval = getConfidenceInterval(sqEncoderParams);
int bits = getBits(sqEncoderParams);
boolean compressFlag = getCompressFlag(sqEncoderParams);
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams();
knnVectorsFormatParams.initialize(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"",
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
MAX_CONNECTIONS,
maxConnections,
BEAM_WIDTH,
beamWidth,
LUCENE_SQ_CONFIDENCE_INTERVAL,
confidenceInterval,
LUCENE_SQ_BITS,
bits,
LUCENE_SQ_COMPRESS,
compressFlag
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
);
return quantizedVectorsFormatSupplier.apply(maxConnections, beamWidth, confidenceInterval, bits, compressFlag);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
}

@Override
Expand All @@ -134,43 +100,4 @@ public int getMaxDimensions(String fieldName) {
private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}

private Float getConfidenceInterval(final Map<String, Object> params) {

if (params != null && params.containsKey(LUCENE_SQ_CONFIDENCE_INTERVAL)) {
if (params.get("confidence_interval").equals(0)) return Float.valueOf(0);

return ((Double) params.get("confidence_interval")).floatValue();

}
return null;
}

private int getBits(final Map<String, Object> params) {
if (params != null && params.containsKey(LUCENE_SQ_BITS)) {
return (int) params.get("bits");
}
return LUCENE_SQ_DEFAULT_BITS;
}

private boolean getCompressFlag(final Map<String, Object> params) {
if (params != null && params.containsKey(LUCENE_SQ_COMPRESS)) {
return (boolean) params.get("compress");
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene92HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene92HnswVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth()
)

);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene94HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene94HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene94HnswVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth()
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ public KNN950PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene95HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene95HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene95HnswVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth()
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene99HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene99HnswScalarQuantizedVectorsFormat(
maxConnm,
beamWidth,
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
1,
bits,
compress,
confidenceInterval,
knnScalarQuantizedVectorsFormatParams.getBits(),
knnScalarQuantizedVectorsFormatParams.isCompressFlag(),
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
null
)
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec;

import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.knn.index.MethodComponentContext;

import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
* Class provides params for LuceneHnswScalarQuantizedVectorsFormat
*/
@Getter
@NoArgsConstructor
public class KNNScalarQuantizedVectorsFormatParams extends KNNVectorsFormatParams {
private float confidenceInterval;
private int bits;
private boolean compressFlag;

@Override
boolean validate(Map<String, Object> params) {
if (params.get(METHOD_ENCODER_PARAMETER) == null) {
return false;
}

// Validate if the object is of type MethodComponentContext before casting it later
if (!(params.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) {
return false;
}
MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER);
if (!ENCODER_SQ.equals(encoderMethodComponentContext.getName())) {
return false;
}

return true;
}

@Override
void initialize(Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
super.initialize(params, defaultMaxConnections, defaultBeamWidth);
MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER);
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();
this.confidenceInterval = getConfidenceInterval(sqEncoderParams);
this.bits = getBits(sqEncoderParams);
this.compressFlag = getCompressFlag(sqEncoderParams);
}

private Float getConfidenceInterval(final Map<String, Object> params) {

if (params != null && params.containsKey(LUCENE_SQ_CONFIDENCE_INTERVAL)) {
if (params.get(LUCENE_SQ_CONFIDENCE_INTERVAL).equals(0)) return Float.valueOf(0);

return ((Double) params.get(LUCENE_SQ_CONFIDENCE_INTERVAL)).floatValue();

}
return null;
}

private int getBits(final Map<String, Object> params) {
if (params != null && params.containsKey(LUCENE_SQ_BITS)) {
return (int) params.get(LUCENE_SQ_BITS);
}
return LUCENE_SQ_DEFAULT_BITS;
}

private boolean getCompressFlag(final Map<String, Object> params) {
if (params != null && params.containsKey(LUCENE_SQ_COMPRESS)) {
return (boolean) params.get(LUCENE_SQ_COMPRESS);
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;

import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.knn.common.KNNConstants;

import java.util.Map;

/**
* Class provides params for LuceneHNSWVectorsFormat
*/
@NoArgsConstructor
@Getter
public class KNNVectorsFormatParams {
private int maxConnections;
private int beamWidth;

boolean validate(final Map<String, Object> params) {
return false;
}

void initialize(final Map<String, Object> params, int defaultMaxConnections, int defaultBeamWidth) {
this.maxConnections = getMaxConnections(params, defaultMaxConnections);
this.beamWidth = getBeamWidth(params, defaultBeamWidth);
}

private int getMaxConnections(final Map<String, Object> params, int defaultMaxConnections) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params, int defaultBeamWidth) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}
}

0 comments on commit a9fd88b

Please sign in to comment.