Skip to content

Commit

Permalink
Add support for Lucene Inbuilt Scalar Quantizer
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 17, 2024
1 parent b422466 commit 721d5aa
Show file tree
Hide file tree
Showing 9 changed files with 202 additions and 4 deletions.
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public class KNNConstants {
public static final String KNN_METHOD = "method";
public static final String NAME = "name";
public static final String PARAMETERS = "parameters";
public static final String MAX_CONNECTIONS = "max_connections";
public static final String BEAM_WIDTH = "beam_width";
public static final String METHOD_HNSW = "hnsw";
public static final String TYPE = "type";
public static final String TYPE_NESTED = "nested";
Expand Down Expand Up @@ -73,6 +75,14 @@ public class KNNConstants {

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
public static final String LUCENE_SQ_CONFIDENCE_INTERVAL = "confidence_interval";
public static final int DYNAMIC_CONFIDENCE_INTERVAL = 0;
public static final double MINIMUM_CONFIDENCE_INTERVAL = 0.9;
public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0;
public static final String LUCENE_SQ_BITS = "bits";
public static final Integer LUCENE_SQ_DEFAULT_BITS = 7;
public static final List<Integer> LUCENE_SQ_BITS_SUPPORTED = List.of(7);
public static final String LUCENE_SQ_COMPRESS = "compress";

// nmslib specific constants
public static final String NMSLIB_NAME = "nmslib";
Expand Down
63 changes: 63 additions & 0 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,69 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector
}
}

/**
* Double method parameter
*/
public static class DoubleParameter extends Parameter<Double> {
public DoubleParameter(String name, Double defaultValue, Predicate<Double> validator) {
super(name, defaultValue, validator);
}

public DoubleParameter(
String name,
Double defaultValue,
Predicate<Double> validator,
BiFunction<Double, VectorSpaceInfo, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

@Override
public ValidationException validate(Object value) {
ValidationException validationException = null;
if (value.equals(0)) value = 0.0;

if (!(value instanceof Double)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("Value not of type Double for Double " + "parameter \"%s\".", getName())
);
return validationException;
}

if (!validator.test((Double) value)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("Parameter validation failed for Double " + "parameter \"%s\".", getName())
);
}
return validationException;
}

@Override
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
ValidationException validationException = null;
if (!(value instanceof Double)) {
validationException = new ValidationException();
validationException.addValidationError(
String.format("value is not an instance of Double for Double parameter [%s].", getName())
);
return validationException;
}

if (validatorWithData == null) {
return null;
}

if (!validatorWithData.apply((Double) value, vectorSpaceInfo)) {
validationException = new ValidationException();
validationException.addValidationError(String.format("parameter validation failed for Double parameter [%s].", getName()));
}

return validationException;
}
}

/**
* String method parameter
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,23 @@
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.BEAM_WIDTH;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAX_CONNECTIONS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
*/
Expand All @@ -30,6 +40,7 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;
private final Function5Arity<Integer, Integer, Float, Integer, Boolean, KnnVectorsFormat> quantizedVectorsFormatSupplier;

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
Expand All @@ -50,6 +61,36 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
if (params.get(METHOD_ENCODER_PARAMETER) != null) {
MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) params.get(METHOD_ENCODER_PARAMETER);
if (ENCODER_SQ.equals(encoderMethodComponentContext.getName())) {
Map<String, Object> sqEncoderParams = encoderMethodComponentContext.getParameters();

Float confidenceInterval = getConfidenceInterval(sqEncoderParams);
int bits = getBits(sqEncoderParams);
boolean compressFlag = getCompressFlag(sqEncoderParams);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\",[{}] = \"{}\"",
field,
MAX_CONNECTIONS,
maxConnections,
BEAM_WIDTH,
beamWidth,
LUCENE_SQ_CONFIDENCE_INTERVAL,
confidenceInterval,
LUCENE_SQ_BITS,
bits,
LUCENE_SQ_COMPRESS,
compressFlag
);
return quantizedVectorsFormatSupplier.apply(maxConnections, beamWidth, confidenceInterval, bits, compressFlag);
}

}

}

log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
field,
Expand Down Expand Up @@ -81,4 +122,29 @@ private int getBeamWidth(final Map<String, Object> params) {
}
return defaultBeamWidth;
}

private Float getConfidenceInterval(final Map<String, Object> params) {

if (params != null && params.containsKey(LUCENE_SQ_CONFIDENCE_INTERVAL)) {
if (params.get("confidence_interval").equals(0)) return Float.valueOf(0);

return ((Double) params.get("confidence_interval")).floatValue();

}
return null;
}

private int getBits(final Map<String, Object> params) {
if (params != null && params.containsKey(LUCENE_SQ_BITS)) {
return (int) params.get("bits");
}
return LUCENE_SQ_DEFAULT_BITS;
}

private boolean getCompressFlag(final Map<String, Object> params) {
if (params != null && params.containsKey(LUCENE_SQ_COMPRESS)) {
return (boolean) params.get("compress");
}
return false;
}
}
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/index/codec/Function5Arity.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec;

@FunctionalInterface
public interface Function5Arity<S, T, U, V, X, R> {
R apply(S s, T t, U u, V v, X x);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)

);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene94HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public KNN950PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene95HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth)
(maxConnm, beamWidth) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
Expand All @@ -23,7 +24,16 @@ public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene99HnswVectorsFormat(maxConnm, beamWidth)
(maxConnm, beamWidth) -> new Lucene99HnswVectorsFormat(maxConnm, beamWidth),
(maxConnm, beamWidth, confidenceInterval, bits, compress) -> new Lucene99HnswScalarQuantizedVectorsFormat(
maxConnm,
beamWidth,
1,
bits,
compress,
confidenceInterval,
null
)
);
}

Expand Down
34 changes: 34 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.DYNAMIC_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS_SUPPORTED;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_COMPRESS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.MAXIMUM_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.MINIMUM_CONFIDENCE_INTERVAL;

/**
* KNN Library for Lucene
Expand All @@ -28,6 +39,25 @@ public class Lucene extends JVMLibrary {

Map<SpaceType, Function<Float, Float>> distanceTransform;

private final static Map<String, MethodComponent> HNSW_ENCODERS = ImmutableMap.of(
ENCODER_SQ,
MethodComponent.Builder.builder(ENCODER_SQ)
.addParameter(
LUCENE_SQ_CONFIDENCE_INTERVAL,
new Parameter.DoubleParameter(
LUCENE_SQ_CONFIDENCE_INTERVAL,
null,
v -> v == DYNAMIC_CONFIDENCE_INTERVAL || (v >= MINIMUM_CONFIDENCE_INTERVAL && v <= MAXIMUM_CONFIDENCE_INTERVAL)
)
)
.addParameter(
LUCENE_SQ_BITS,
new Parameter.IntegerParameter(LUCENE_SQ_BITS, LUCENE_SQ_DEFAULT_BITS, LUCENE_SQ_BITS_SUPPORTED::contains)
)
.addParameter(LUCENE_SQ_COMPRESS, new Parameter.BooleanParameter(LUCENE_SQ_COMPRESS, false, Objects::nonNull))
.build()
);

final static Map<String, KNNMethod> METHODS = ImmutableMap.of(
METHOD_HNSW,
KNNMethod.Builder.builder(
Expand All @@ -44,6 +74,10 @@ public class Lucene extends JVMLibrary {
v -> v > 0
)
)
.addParameter(
METHOD_ENCODER_PARAMETER,
new Parameter.MethodComponentContextParameter(METHOD_ENCODER_PARAMETER, null, HNSW_ENCODERS)
)
.build()
).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT).build()
);
Expand Down

0 comments on commit 721d5aa

Please sign in to comment.