Skip to content

Commit

Permalink
Refactoring of resolution logic
Browse files Browse the repository at this point in the history
PR changes a lot of the resolution logic and does some renaming.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Sep 2, 2024
1 parent 28e25b5 commit 51e0828
Show file tree
Hide file tree
Showing 70 changed files with 1,547 additions and 1,920 deletions.
50 changes: 40 additions & 10 deletions src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

Expand All @@ -21,6 +22,7 @@

import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;

import java.util.Locale;

Expand All @@ -47,20 +49,43 @@ public static KNNEngine extractKNNEngine(final FieldInfo field) {
}

/**
* Extracts VectorDataType from FieldInfo
* Extracts VectorDataType from FieldInfo. This VectorDataType represents what vectors will be input to the
* library layer. For the data type that is transfered to the native layer, see extractVectorDataTypeForTransfer (better comment)
*
* @param fieldInfo {@link FieldInfo}
* @return {@link VectorDataType}
*/
public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
String vectorDataTypeString = fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD);
if (StringUtils.isEmpty(vectorDataTypeString)) {
final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata != null) {
VectorDataType vectorDataType = modelMetadata.getVectorDataType();
vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue();
}
if (StringUtils.isNotEmpty(vectorDataTypeString)) {
return VectorDataType.get(vectorDataTypeString);
}

final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata == null) {
return VectorDataType.DEFAULT;
}
return modelMetadata.getVectorDataType();
}

/**
* Extracts VectorDataType for transfer from FieldInfo. This VectorDataType represents what vectors will be transfered
* to the native layer. For the data type that is input to the library layer, see extractVectorDataType (better comment)
*
* @param fieldInfo {@link FieldInfo}
* @param quantizationParams {@link QuantizationParams}
* @return {@link VectorDataType}
*/
public static VectorDataType extractVectorDataTypeForTransfer(final FieldInfo fieldInfo, QuantizationParams quantizationParams) {
if (quantizationParams != null) {
return QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo);
}
return StringUtils.isNotEmpty(vectorDataTypeString) ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT;
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY) {
return VectorDataType.BINARY;
}

return extractVectorDataType(fieldInfo);
}

/**
Expand All @@ -71,10 +96,15 @@ public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
*/
public static QuantizationConfig extractQuantizationConfig(final FieldInfo fieldInfo) {
String quantizationConfigString = fieldInfo.getAttribute(QFRAMEWORK_CONFIG);
if (StringUtils.isEmpty(quantizationConfigString)) {
if (StringUtils.isNotEmpty(quantizationConfigString)) {
return QuantizationConfigParser.fromCsv(quantizationConfigString);
}

final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
if (modelMetadata == null || modelMetadata.getKNNLibraryIndex().isEmpty()) {
return QuantizationConfig.EMPTY;
}
return QuantizationConfigParser.fromCsv(quantizationConfigString);
return modelMetadata.getKNNLibraryIndex().get().getQuantizationConfig();
}

/**
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/index/KNNIndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataTypeForTransfer;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.util.IndexUtil.getParametersAtLoading;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFilePrefix;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileSuffix;
Expand Down Expand Up @@ -182,7 +182,7 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
shardPath,
spaceType,
modelId,
VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()))
extractVectorDataTypeForTransfer(fieldInfo, null)
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,56 +77,54 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
).fieldType(field);

if (mappedFieldType.getModelId().isPresent()) {
return getFormatForModelBasedIndices();
}
if (mappedFieldType.getKNNEngine() == null) {
throw new IllegalStateException("Method config context cannot be empty");
return getNativeEngines990KnnVectorsFormat();
}
return getFormatForMethodBasedIndices(mappedFieldType.getKNNEngine(), mappedFieldType.getLibraryParameters(), field);
}

private KnnVectorsFormat getFormatForModelBasedIndices() {
return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()));
}

private KnnVectorsFormat getFormatForMethodBasedIndices(KNNEngine knnEngine, Map<String, Object> params, String field) {
if (knnEngine == KNNEngine.LUCENE) {
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth
);
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits()
);
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}
}
if (knnEngine != KNNEngine.LUCENE) {
return getNativeEngines990KnnVectorsFormat();
}

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnVectorsFormatParams.getBeamWidth()
// For Lucene, we need to properly configure the format because format initialization is when parameters are
// set
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
defaultBeamWidth
);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
LUCENE_SQ_CONFIDENCE_INTERVAL,
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
LUCENE_SQ_BITS,
knnScalarQuantizedVectorsFormatParams.getBits()
);
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
}
}

// All native engines to use NativeEngines990KnnVectorsFormat
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
log.debug(
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
field,
MAX_CONNECTIONS,
knnVectorsFormatParams.getMaxConnections(),
BEAM_WIDTH,
knnVectorsFormatParams.getBeamWidth()
);
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
}

private NativeEngines990KnnVectorsFormat getNativeEngines990KnnVectorsFormat() {
return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

import java.io.IOException;

import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;

/**
* This class writes the KNN docvalues to the segments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNIndexContext;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
Expand All @@ -47,7 +43,7 @@
import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataTypeForTransfer;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
Expand Down Expand Up @@ -161,17 +157,14 @@ private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues) throws
// TODO: Refactor this so its scalable. Possibly move it out of this class
private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException {
final Map<String, Object> parameters;
VectorDataType vectorDataType;
if (quantizationState != null) {
vectorDataType = QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo);
} else {
vectorDataType = extractVectorDataType(fieldInfo);
}
if (fieldInfo.attributes().containsKey(MODEL_ID)) {
Model model = getModel(fieldInfo);
parameters = getTemplateParameters(fieldInfo, model);
} else {
VectorDataType vectorDataType = extractVectorDataTypeForTransfer(
fieldInfo,
quantizationState == null ? null : quantizationState.getQuantizationParams()
);
if (fieldInfo.attributes().containsKey(MODEL_ID) == false) {
parameters = getParameters(fieldInfo, vectorDataType, knnEngine);
} else {
parameters = getTemplateParameters(fieldInfo, vectorDataType);
}

return BuildIndexParams.builder()
Expand Down Expand Up @@ -215,7 +208,6 @@ private Map<String, Object> getParameters(FieldInfo fieldInfo, VectorDataType ve
);
}

parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue());
// In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic.
// After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility,
// we need to ensure that if the description does not contain the prefix but the type is binary, we add the
Expand All @@ -228,60 +220,20 @@ private Map<String, Object> getParameters(FieldInfo fieldInfo, VectorDataType ve
return parameters;
}

private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
if (KNNEngine.FAISS != knnEngine) {
return;
}

if (!VectorDataType.BINARY.getValue()
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) {
return;
private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, VectorDataType vectorDataTypeForTransfer) {
Model model = ModelUtil.getModel(fieldInfo.getAttribute(MODEL_ID));
if (model == null) {
throw new IllegalStateException("Model not found for field " + fieldInfo.name);
}

parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
);
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
}

private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException {
Map<String, Object> parameters = new HashMap<>();
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID));
parameters.put(KNNConstants.MODEL_ID, model.getModelID());
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob());

// TODO: Is there any way we could avoid resolving it like this?
KNNIndexContext knnIndexContext = ModelUtil.getKnnMethodContextFromModelMetadata(model.getModelID(), model.getModelMetadata());
if (knnIndexContext != null && knnIndexContext.getLibraryParameters().containsKey(VECTOR_DATA_TYPE_FIELD)) {
IndexUtil.updateVectorDataTypeToParameters(
parameters,
VectorDataType.get((String) knnIndexContext.getLibraryParameters().get(VECTOR_DATA_TYPE_FIELD))
);
} else {
IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
}

parameters.put(VECTOR_DATA_TYPE_FIELD, vectorDataTypeForTransfer.getValue());
return parameters;
}

private Model getModel(FieldInfo fieldInfo) {
String modelId = fieldInfo.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);
if (model.getModelBlob() == null) {
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
}
return model;
}

private void startMergeStats(int numDocs, long bytesPerVector) {
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs);
Expand Down Expand Up @@ -358,4 +310,30 @@ private static NativeIndexWriter createWriter(
: DefaultIndexBuildStrategy.getInstance();
return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState);
}

private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
if (KNNEngine.FAISS != knnEngine) {
return;
}

if (!VectorDataType.BINARY.getValue()
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) {
return;
}

if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) {
return;
}

parameters.put(
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
);

parameters.put(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public class BuildIndexParams {
String fieldName;
KNNEngine knnEngine;
String indexPath;
/**
* Vector data type represents the type used to build the library index. If something like binary quantization is
* done, then this will be different from the vector data type the user provides
*/
VectorDataType vectorDataType;
Map<String, Object> parameters;
/**
Expand Down
Loading

0 comments on commit 51e0828

Please sign in to comment.