Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Lucene inbuilt Scalar Quantizer #1848

1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.16.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Compatible with OpenSearch 2.16.0
* Add script scoring support for knn field with binary data type [#1826](https://github.com/opensearch-project/k-NN/pull/1826)
* Add painless script support for hamming with binary vector data type [#1839](https://github.com/opensearch-project/k-NN/pull/1839)
* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784)
* Add support for Lucene inbuilt Scalar Quantizer [#1848](https://github.com/opensearch-project/k-NN/pull/1848)
### Enhancements
* Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825)
### Bug Fixes
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ public class KNNConstants {

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
public static final String LUCENE_SQ_CONFIDENCE_INTERVAL = "confidence_interval";
public static final int DYNAMIC_CONFIDENCE_INTERVAL = 0;
public static final double MINIMUM_CONFIDENCE_INTERVAL = 0.9;
public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0;
public static final String LUCENE_SQ_BITS = "bits";
public static final int LUCENE_SQ_DEFAULT_BITS = 7;

// nmslib specific constants
public static final String NMSLIB_NAME = "nmslib";
Expand Down
81 changes: 81 additions & 0 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import org.opensearch.common.ValidationException;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Predicate;

Expand Down Expand Up @@ -204,6 +206,85 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector
}
}

/**
* Double method parameter
*/
public static class DoubleParameter extends Parameter<Double> {
public DoubleParameter(String name, Double defaultValue, Predicate<Double> validator) {
super(name, defaultValue, validator);
}

public DoubleParameter(
String name,
Double defaultValue,
Predicate<Double> validator,
BiFunction<Double, VectorSpaceInfo, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

@Override
public ValidationException validate(Object value) {
if (Objects.isNull(value)) {
String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName());
return getValidationException(validationErrorMsg);
}
if (value.equals(0)) value = 0.0;
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. This style is error prune.

Suggested change
if (value.equals(0)) value = 0.0;
if (value.equals(0)) {
value = 0.0;
}


if (!(value instanceof Double)) {
String validationErrorMsg = String.format(
Locale.ROOT,
"Value not of type Double for Double " + "parameter \"%s\".",
getName()
);
return getValidationException(validationErrorMsg);
}

if (!validator.test((Double) value)) {
String validationErrorMsg = String.format(
Locale.ROOT,
"Parameter validation failed for Double " + "parameter \"%s\".",
getName()
);
return getValidationException(validationErrorMsg);
}
return null;
}

@Override
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
if (Objects.isNull(value)) {
String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName());
return getValidationException(validationErrorMsg);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No conversion here?

if (value.equals(0)) {
  value = 0.0;
  }

if (!(value instanceof Double)) {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
String validationErrorMsg = String.format(
Locale.ROOT,
"value is not an instance of Double for Double parameter [%s].",
getName()
);
return getValidationException(validationErrorMsg);
}

if (validatorWithData == null) {
return null;
}

if (!validatorWithData.apply((Double) value, vectorSpaceInfo)) {
String validationErrorMsg = String.format(Locale.ROOT, "parameter validation failed for Double parameter [%s].", getName());
return getValidationException(validationErrorMsg);
}
return null;
}

private ValidationException getValidationException(String validationErrorMsg) {
ValidationException validationException = new ValidationException();
validationException.addValidationError(validationErrorMsg);
return validationException;
}
}

/**
* String method parameter
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams;
import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
*/
Expand All @@ -29,15 +34,34 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;
private final Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier;
private Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier;
private static final String MAX_CONNECTIONS = "max_connections";
private static final String BEAM_WIDTH = "beam_width";

public BasePerFieldKnnVectorsFormat(
Optional<MapperService> mapperService,
int defaultMaxConnections,
int defaultBeamWidth,
Supplier<KnnVectorsFormat> defaultFormatSupplier,
Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier
) {
this.mapperService = mapperService;
this.defaultMaxConnections = defaultMaxConnections;
this.defaultBeamWidth = defaultBeamWidth;
this.defaultFormatSupplier = defaultFormatSupplier;
this.vectorsFormatSupplier = vectorsFormatSupplier;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isKnnVectorFieldType(field) == false) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
"Initialize KNN vector format for field [{}] with default params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
defaultMaxConnections,
BEAM_WIDTH,
defaultBeamWidth
);
return defaultFormatSupplier.get();
Expand All @@ -48,15 +72,43 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);

if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE
&& params != null
&& params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth
);
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits()
);
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}

}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
maxConnections,
beamWidth
MAX_CONNECTIONS,
knnVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnVectorsFormatParams.getBeamWidth()
);
return formatSupplier.apply(maxConnections, beamWidth);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
}

@Override
Expand All @@ -67,18 +119,4 @@ public int getMaxDimensions(String fieldName) {
private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene92HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene94HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene94HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ public KNN950PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene95HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene95HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
Expand All @@ -16,14 +17,27 @@
* Class provides per field format implementation for Lucene Knn vector type
*/
public class KNN990PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {
private static final int NUM_MERGE_WORKERS = 1;

public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene99HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
NUM_MERGE_WORKERS,
knnScalarQuantizedVectorsFormatParams.getBits(),
knnScalarQuantizedVectorsFormatParams.isCompressFlag(),
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
null
)
);
}

Expand Down
Loading
Loading