Skip to content

Commit

Permalink
Add support for Lucene inbuilt Scalar Quantizer (#1848) (#1871)
Browse files Browse the repository at this point in the history
  • Loading branch information
opensearch-trigger-bot[bot] authored Jul 23, 2024
1 parent bfed576 commit f84caf8
Show file tree
Hide file tree
Showing 16 changed files with 746 additions and 31 deletions.
1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.16.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Compatible with OpenSearch 2.16.0
* Add script scoring support for knn field with binary data type [#1826](https://github.com/opensearch-project/k-NN/pull/1826)
* Add painless script support for hamming with binary vector data type [#1839](https://github.com/opensearch-project/k-NN/pull/1839)
* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784)
* Add support for Lucene inbuilt Scalar Quantizer [#1848](https://github.com/opensearch-project/k-NN/pull/1848)
### Enhancements
* Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825)
### Bug Fixes
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ public class KNNConstants {

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
public static final String LUCENE_SQ_CONFIDENCE_INTERVAL = "confidence_interval";
public static final int DYNAMIC_CONFIDENCE_INTERVAL = 0;
public static final double MINIMUM_CONFIDENCE_INTERVAL = 0.9;
public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0;
public static final String LUCENE_SQ_BITS = "bits";
public static final int LUCENE_SQ_DEFAULT_BITS = 7;

// nmslib specific constants
public static final String NMSLIB_NAME = "nmslib";
Expand Down
81 changes: 81 additions & 0 deletions src/main/java/org/opensearch/knn/index/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
import org.opensearch.common.ValidationException;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.Predicate;

Expand Down Expand Up @@ -204,6 +206,85 @@ public ValidationException validateWithData(Object value, VectorSpaceInfo vector
}
}

/**
* Double method parameter
*/
public static class DoubleParameter extends Parameter<Double> {
public DoubleParameter(String name, Double defaultValue, Predicate<Double> validator) {
super(name, defaultValue, validator);
}

public DoubleParameter(
String name,
Double defaultValue,
Predicate<Double> validator,
BiFunction<Double, VectorSpaceInfo, Boolean> validatorWithData
) {
super(name, defaultValue, validator, validatorWithData);
}

@Override
public ValidationException validate(Object value) {
if (Objects.isNull(value)) {
String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName());
return getValidationException(validationErrorMsg);
}
if (value.equals(0)) value = 0.0;

if (!(value instanceof Double)) {
String validationErrorMsg = String.format(
Locale.ROOT,
"Value not of type Double for Double " + "parameter \"%s\".",
getName()
);
return getValidationException(validationErrorMsg);
}

if (!validator.test((Double) value)) {
String validationErrorMsg = String.format(
Locale.ROOT,
"Parameter validation failed for Double " + "parameter \"%s\".",
getName()
);
return getValidationException(validationErrorMsg);
}
return null;
}

@Override
public ValidationException validateWithData(Object value, VectorSpaceInfo vectorSpaceInfo) {
if (Objects.isNull(value)) {
String validationErrorMsg = String.format(Locale.ROOT, "Null value provided for Double " + "parameter \"%s\".", getName());
return getValidationException(validationErrorMsg);
}

if (!(value instanceof Double)) {
String validationErrorMsg = String.format(
Locale.ROOT,
"value is not an instance of Double for Double parameter [%s].",
getName()
);
return getValidationException(validationErrorMsg);
}

if (validatorWithData == null) {
return null;
}

if (!validatorWithData.apply((Double) value, vectorSpaceInfo)) {
String validationErrorMsg = String.format(Locale.ROOT, "parameter validation failed for Double parameter [%s].", getName());
return getValidationException(validationErrorMsg);
}
return null;
}

private ValidationException getValidationException(String validationErrorMsg) {
ValidationException validationException = new ValidationException();
validationException.addValidationError(validationErrorMsg);
return validationException;
}
}

/**
* String method parameter
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams;
import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
* Base class for PerFieldKnnVectorsFormat, builds KnnVectorsFormat based on specific Lucene version
*/
Expand All @@ -29,15 +34,34 @@ public abstract class BasePerFieldKnnVectorsFormat extends PerFieldKnnVectorsFor
private final int defaultMaxConnections;
private final int defaultBeamWidth;
private final Supplier<KnnVectorsFormat> defaultFormatSupplier;
private final BiFunction<Integer, Integer, KnnVectorsFormat> formatSupplier;
private final Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier;
private Function<KNNScalarQuantizedVectorsFormatParams, KnnVectorsFormat> scalarQuantizedVectorsFormatSupplier;
private static final String MAX_CONNECTIONS = "max_connections";
private static final String BEAM_WIDTH = "beam_width";

public BasePerFieldKnnVectorsFormat(
Optional<MapperService> mapperService,
int defaultMaxConnections,
int defaultBeamWidth,
Supplier<KnnVectorsFormat> defaultFormatSupplier,
Function<KNNVectorsFormatParams, KnnVectorsFormat> vectorsFormatSupplier
) {
this.mapperService = mapperService;
this.defaultMaxConnections = defaultMaxConnections;
this.defaultBeamWidth = defaultBeamWidth;
this.defaultFormatSupplier = defaultFormatSupplier;
this.vectorsFormatSupplier = vectorsFormatSupplier;
}

@Override
public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
if (isKnnVectorFieldType(field) == false) {
log.debug(
"Initialize KNN vector format for field [{}] with default params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
"Initialize KNN vector format for field [{}] with default params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
defaultMaxConnections,
BEAM_WIDTH,
defaultBeamWidth
);
return defaultFormatSupplier.get();
Expand All @@ -48,15 +72,43 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();
int maxConnections = getMaxConnections(params);
int beamWidth = getBeamWidth(params);

if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE
&& params != null
&& params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth
);
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits()
);
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}

}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [max_connections] = \"{}\" and [beam_width] = \"{}\"",
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
maxConnections,
beamWidth
MAX_CONNECTIONS,
knnVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnVectorsFormatParams.getBeamWidth()
);
return formatSupplier.apply(maxConnections, beamWidth);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
}

@Override
Expand All @@ -67,18 +119,4 @@ public int getMaxDimensions(String fieldName) {
private boolean isKnnVectorFieldType(final String field) {
return mapperService.isPresent() && mapperService.get().fieldType(field) instanceof KNNVectorFieldMapper.KNNVectorFieldType;
}

private int getMaxConnections(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_M)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_M);
}
return defaultMaxConnections;
}

private int getBeamWidth(final Map<String, Object> params) {
if (params != null && params.containsKey(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) {
return (int) params.get(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION);
}
return defaultBeamWidth;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ public KNN920PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene92HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene92HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene92HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene92HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene92HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ public KNN940PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene94HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene94HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene94HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ public KNN950PerFieldKnnVectorsFormat(final Optional<MapperService> mapperServic
Lucene95HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene95HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene95HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene95HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene95HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.codec.BasePerFieldKnnVectorsFormat;
Expand All @@ -16,14 +17,27 @@
* Class provides per field format implementation for Lucene Knn vector type
*/
public class KNN990PerFieldKnnVectorsFormat extends BasePerFieldKnnVectorsFormat {
private static final int NUM_MERGE_WORKERS = 1;

public KNN990PerFieldKnnVectorsFormat(final Optional<MapperService> mapperService) {
super(
mapperService,
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
() -> new Lucene99HnswVectorsFormat(),
(maxConnm, beamWidth) -> new Lucene99HnswVectorsFormat(maxConnm, beamWidth)
knnVectorsFormatParams -> new Lucene99HnswVectorsFormat(
knnVectorsFormatParams.getMaxConnections(),
knnVectorsFormatParams.getBeamWidth()
),
knnScalarQuantizedVectorsFormatParams -> new Lucene99HnswScalarQuantizedVectorsFormat(
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
NUM_MERGE_WORKERS,
knnScalarQuantizedVectorsFormatParams.getBits(),
knnScalarQuantizedVectorsFormatParams.isCompressFlag(),
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
null
)
);
}

Expand Down
Loading

0 comments on commit f84caf8

Please sign in to comment.