Skip to content

Commit

Permalink
Refactor FieldMapping logic
Browse files Browse the repository at this point in the history
Refactors FieldMapper logic. It removes the LegacyFieldMapper and
replaces it with a FlatFieldMapper. The FlatFieldMapper's role is to
create fields that do not build ANN indices.

Additionally, it puts dimension, model_id, and knn_method_context in a
new ANNConfig class and adds some safety checks around accessing them.
This should make calling logic easier to handle.

Lastly, it cleans up the parsing so that there isnt encoder parsing
directly in the KNNVectorFieldMapper.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 8, 2024
1 parent df7627c commit 2dda3a8
Show file tree
Hide file tree
Showing 34 changed files with 1,354 additions and 863 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams;
import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.mapper.ANNConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;

import java.util.Optional;
Expand Down Expand Up @@ -66,16 +68,19 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
);
return defaultFormatSupplier.get();
}
var type = (KNNVectorFieldType) mapperService.orElseThrow(
KNNVectorFieldType mappedFieldType = (KNNVectorFieldType) mapperService.orElseThrow(
() -> new IllegalStateException(
String.format("Cannot read field type for field [%s] because mapper service is not available", field)
)
).fieldType(field);
var params = type.getKnnMethodContext().getMethodComponentContext().getParameters();

if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE
&& params != null
&& params.containsKey(METHOD_ENCODER_PARAMETER)) {
ANNConfig annConfig = mappedFieldType.getAnnConfig();
KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext()
.orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty"));

var params = knnMethodContext.getMethodComponentContext().getParameters();

if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
defaultMaxConnections,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.knn.index.codec.transfer.VectorTransferByte;
import org.opensearch.knn.index.codec.transfer.VectorTransferFloat;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.indices.Model;
Expand Down Expand Up @@ -57,7 +56,6 @@

import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize;
import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
Expand Down Expand Up @@ -213,35 +211,19 @@ private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNN

private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath)
throws IOException {
Map<String, Object> parameters = new HashMap<>();
Map<String, String> fieldAttributes = fieldInfo.attributes();
String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS);

// parametersString will be null when legacy mapper is used
if (parametersString == null) {
parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()));

String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION);
Map<String, Object> algoParams = new HashMap<>();
if (efConstruction != null) {
algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction));
}

String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M);
if (m != null) {
algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m));
}
parameters.put(PARAMETERS, algoParams);
} else {
parameters.putAll(
XContentHelper.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(parametersString),
MediaTypeRegistry.getDefaultMediaType()
).map()
);
throw new IllegalStateException("Parameter string is not set. Something is wrong");
}
Map<String, Object> parameters = new HashMap<>(
XContentHelper.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(parametersString),
MediaTypeRegistry.getDefaultMediaType()
).map()
);

// Update index description of Faiss for binary data type
if (KNNEngine.FAISS == knnEngine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.opensearch.Version;
import org.opensearch.common.ValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand All @@ -20,7 +19,6 @@
import org.opensearch.index.mapper.MapperParsingException;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -29,7 +27,6 @@
import org.opensearch.knn.training.VectorSpaceInfo;

import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
Expand All @@ -42,21 +39,6 @@
@Getter
public class KNNMethodContext implements ToXContentFragment, Writeable {

private static KNNMethodContext defaultInstance = null;

/**
* This is used only for testing
* @return default KNNMethodContext for testing
*/
public static synchronized KNNMethodContext getDefault() {
if (defaultInstance == null) {
MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap());
methodComponentContext.setIndexVersion(Version.CURRENT);
defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext);
}
return defaultInstance;
}

@NonNull
private final KNNEngine knnEngine;
@NonNull
Expand Down
119 changes: 119 additions & 0 deletions src/main/java/org/opensearch/knn/index/mapper/ANNConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

import lombok.Getter;
import org.opensearch.knn.index.engine.KNNMethodContext;

import java.util.Optional;

/**
* Class holds information about how the ANN indices are created. The design of this class ensures that we do not
* accidentally configure an index that has multiple ways it can be created. This class is immutable.
*/
public final class ANNConfig {

@Getter
private final ANNConfigType annConfigType;
private final KNNMethodContext knnMethodContext;
private final String modelId;
private final Integer dimension;

/**
* Constructor
*
* @param annConfigType Configurational context index was built. Cannot be null
* @param knnMethodContext Method context used to create the index; null if not created from method
* @param modelId Model id used to create the index; null if not created from model
* @param dimension Dimension used to create the index; needs to be null for model-based indices
*/
public ANNConfig(ANNConfigType annConfigType, KNNMethodContext knnMethodContext, String modelId, Integer dimension) {
if (annConfigType == null) {
throw new IllegalArgumentException("ANNConfiguration cannot be null");
}

this.annConfigType = annConfigType;
this.knnMethodContext = knnMethodContext;
this.modelId = modelId;
this.dimension = dimension;

if (ANNConfigType.FROM_METHOD == annConfigType) {
validateFromMethod();
return;
}

if (ANNConfigType.FROM_MODEL == annConfigType) {
validateFromModel();
return;
}

if (ANNConfigType.SKIP == annConfigType) {
validateSkip();
}
}

private void validateFromMethod() {
if (knnMethodContext == null) {
throw new IllegalArgumentException("knnMethodContext cannot be null when created from method");
}

if (modelId != null) {
throw new IllegalArgumentException("modelId cannot be specified when created from method");
}

if (dimension == null) {
throw new IllegalArgumentException("dimension must be specified when created from method");
}
}

private void validateFromModel() {
if (modelId == null) {
throw new IllegalArgumentException("modelId cannot be null when created from method");
}

if (knnMethodContext != null) {
throw new IllegalArgumentException("knnMethodContext cannot be specified when created from method");
}

if (dimension != null) {
throw new IllegalArgumentException("dimension must be null when created from model");
}
}

private void validateSkip() {
if (knnMethodContext != null || modelId != null) {
throw new IllegalArgumentException("knnMethodContext or modelId cannot be specified when skipping");
}

if (dimension == null) {
throw new IllegalArgumentException("dimension must be specified when created from model");
}
}

/**
*
* @return Optional containing the modelId if created from model, otherwise empty
*/
public Optional<String> getModelId() {
return Optional.ofNullable(modelId);
}

/**
*
* @return Optional containing the KNNMethodContext if created from method, otherwise empty
*/
public Optional<KNNMethodContext> getKnnMethodContext() {
return Optional.ofNullable(knnMethodContext);
}

/**
*
* @return the dimension of the index; for model based indices, it will be null
*/
public Optional<Integer> getDimension() {
return Optional.ofNullable(dimension);
}
}
15 changes: 15 additions & 0 deletions src/main/java/org/opensearch/knn/index/mapper/ANNConfigType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

/**
* Types of configurations to build ANN indices
*/
public enum ANNConfigType {
FROM_METHOD,
FROM_MODEL,
SKIP
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.mapper;

import org.apache.lucene.document.FieldType;
import org.opensearch.Version;
import org.opensearch.common.Explicit;
import org.opensearch.knn.index.VectorDataType;

/**
* Mapper used when you dont want to build an underlying KNN struct - you just want to
* store vectors as doc values
*/
public class FlatVectorFieldMapper extends KNNVectorFieldMapper {

private final PerDimensionValidator perDimensionValidator;

public FlatVectorFieldMapper(
String simpleName,
KNNVectorFieldType mappedFieldType,
MultiFields multiFields,
CopyTo copyTo,
Explicit<Boolean> ignoreMalformed,
boolean stored,
boolean hasDocValues,
Version indexCreatedVersion
) {
super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null);
this.perDimensionValidator = selectPerDimensionValidator(vectorDataType);
this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE);
this.fieldType.freeze();
}

private PerDimensionValidator selectPerDimensionValidator(VectorDataType vectorDataType) {
if (VectorDataType.BINARY == vectorDataType) {
return PerDimensionValidator.DEFAULT_BIT_VALIDATOR;
}

if (VectorDataType.BYTE == vectorDataType) {
return PerDimensionValidator.DEFAULT_BYTE_VALIDATOR;
}

return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR;
}

@Override
protected VectorValidator getVectorValidator() {
return VectorValidator.NOOP_VECTOR_VALIDATOR;
}

@Override
protected PerDimensionValidator getPerDimensionValidator() {
return perDimensionValidator;
}

@Override
protected PerDimensionProcessor getPerDimensionProcessor() {
return PerDimensionProcessor.NOOP_PROCESSOR;
}
}
Loading

0 comments on commit 2dda3a8

Please sign in to comment.