Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mode and compression based parameter support #1

Open
wants to merge 4 commits into
base: disk-staging-base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

Expand All @@ -21,6 +22,7 @@

import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;

import java.util.Locale;

Expand All @@ -47,20 +49,43 @@ public static KNNEngine extractKNNEngine(final FieldInfo field) {
}

/**
* Extracts VectorDataType from FieldInfo
* Extracts VectorDataType from FieldInfo. This VectorDataType represents what vectors will be input to the
* library layer. For the data type that is transfered to the native layer, see extractVectorDataTypeForTransfer (better comment)
*
* @param fieldInfo {@link FieldInfo}
* @return {@link VectorDataType}
*/
public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
String vectorDataTypeString = fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD);
if (StringUtils.isEmpty(vectorDataTypeString)) {
final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata != null) {
VectorDataType vectorDataType = modelMetadata.getVectorDataType();
vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue();
}
if (StringUtils.isNotEmpty(vectorDataTypeString)) {
return VectorDataType.get(vectorDataTypeString);
}

final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata == null) {
return VectorDataType.DEFAULT;
}
return modelMetadata.getVectorDataType();
}

/**
* Extracts VectorDataType for transfer from FieldInfo. This VectorDataType represents what vectors will be transfered
* to the native layer. For the data type that is input to the library layer, see extractVectorDataType (better comment)
*
* @param fieldInfo {@link FieldInfo}
* @param quantizationParams {@link QuantizationParams}
* @return {@link VectorDataType}
*/
public static VectorDataType extractVectorDataTypeForTransfer(final FieldInfo fieldInfo, QuantizationParams quantizationParams) {
if (quantizationParams != null) {
return QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo);
}
return StringUtils.isNotEmpty(vectorDataTypeString) ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT;
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY) {
return VectorDataType.BINARY;
}

return extractVectorDataType(fieldInfo);
}

/**
Expand All @@ -71,10 +96,15 @@ public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
*/
public static QuantizationConfig extractQuantizationConfig(final FieldInfo fieldInfo) {
String quantizationConfigString = fieldInfo.getAttribute(QFRAMEWORK_CONFIG);
if (StringUtils.isEmpty(quantizationConfigString)) {
if (StringUtils.isNotEmpty(quantizationConfigString)) {
return QuantizationConfigParser.fromCsv(quantizationConfigString);
}

final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata == null || modelMetadata.getKNNLibraryIndex().isEmpty()) {
return QuantizationConfig.EMPTY;
}
return QuantizationConfigParser.fromCsv(quantizationConfigString);
return modelMetadata.getKNNLibraryIndex().get().getQuantizationConfig();
}

/**
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public class KNNConstants {
public static final String MODEL = "model";
public static final String MODELS = "models";
public static final String MODEL_ID = "model_id";
public static final String MODE_PARAMETER = "mode";
public static final String COMPRESSION_PARAMETER = "compression";
public static final String MODE_IN_MEMORY_NAME = "in_memory";
public static final String MODE_ON_DISK_NAME = "on_disk";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to make it as enum?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an enum. I will move out of the constants

public static final String MODEL_BLOB_PARAMETER = "model_blob";
public static final String MODEL_INDEX_MAPPING_PATH = "mappings/model-index.json";
public static final String MODEL_INDEX_NAME = ".opensearch-knn-models";
Expand Down Expand Up @@ -72,6 +76,8 @@ public class KNNConstants {
public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD;
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "compression_and_mode_feature_flag";

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "qstate";

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataTypeForTransfer;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.util.IndexUtil.getParametersAtLoading;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFilePrefix;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileSuffix;
Expand Down Expand Up @@ -182,7 +182,7 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
shardPath,
spaceType,
modelId,
VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()))
extractVectorDataTypeForTransfer(fieldInfo, null)
)
);
}
Expand Down
13 changes: 0 additions & 13 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,6 @@
* nmslib calls the inner_product space "negdotprod". This translation should take place in the nmslib's jni layer.
*/
public enum SpaceType {
// This undefined space type is used to indicate that space type is not provided by user
// Later, we need to assign a default value based on data type
UNDEFINED("undefined") {
@Override
public float scoreTranslation(final float rawScore) {
throw new IllegalStateException("Unsupported method");
}

@Override
public void validateVectorDataType(VectorDataType vectorDataType) {
throw new IllegalStateException("Unsupported method");
}
},
L2("l2") {
@Override
public float scoreTranslation(float rawScore) {
Expand Down
13 changes: 3 additions & 10 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
Expand Down Expand Up @@ -114,15 +113,9 @@ public float[] getVectorFromBytesRef(BytesRef binaryValue) {
* throws Exception if an invalid value is provided.
*/
public static VectorDataType get(String vectorDataType) {
Objects.requireNonNull(
vectorDataType,
String.format(
Locale.ROOT,
"[%s] should not be null. Supported types are [%s]",
VECTOR_DATA_TYPE_FIELD,
SUPPORTED_VECTOR_DATA_TYPES
)
);
if (vectorDataType == null) {
return DEFAULT;
}
try {
return VectorDataType.valueOf(vectorDataType.toUpperCase(Locale.ROOT));
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams;
import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;

import java.util.Map;
Expand Down Expand Up @@ -78,50 +76,55 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
)
).fieldType(field);

KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig();
KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext()
.orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty"));
if (mappedFieldType.getModelId().isPresent()) {
return getNativeEngines990KnnVectorsFormat();
}
return getFormatForMethodBasedIndices(mappedFieldType.getKNNEngine(), mappedFieldType.getLibraryParameters(), field);
}

final KNNEngine engine = knnMethodContext.getKnnEngine();
final Map<String, Object> params = knnMethodContext.getMethodComponentContext().getParameters();
private KnnVectorsFormat getFormatForMethodBasedIndices(KNNEngine knnEngine, Map<String, Object> params, String field) {
if (knnEngine != KNNEngine.LUCENE) {
return getNativeEngines990KnnVectorsFormat();
}

if (engine == KNNEngine.LUCENE) {
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth
// For Lucene, we need to properly configure the format because format initialization is when parameters are
// set
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth
);
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits()
);
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits()
);
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnVectorsFormatParams.getBeamWidth()
);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
}

// All native engines to use NativeEngines990KnnVectorsFormat
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnVectorsFormatParams.getBeamWidth()
);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
}

private NativeEngines990KnnVectorsFormat getNativeEngines990KnnVectorsFormat() {
return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

import java.io.IOException;

import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;

/**
* This class writes the KNN docvalues to the segments
Expand Down
Loading