Skip to content

Commit

Permalink
Fix model field mapper
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 14, 2024
1 parent 8ca67a0 commit 706bbaa
Showing 1 changed file with 8 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand All @@ -34,7 +31,6 @@ public class ModelFieldMapper extends KNNVectorFieldMapper {
// If the dimension has not yet been set because we do not have access to model metadata, it will be -1
public static final int UNSET_MODEL_DIMENSION_IDENTIFIER = -1;

private PerDimensionProcessor perDimensionProcessor;
private PerDimensionValidator perDimensionValidator;
private VectorValidator vectorValidator;

Expand Down Expand Up @@ -98,7 +94,6 @@ private ModelFieldMapper(
// For the model field mapper, we cannot validate the model during index creation due to
// an issue with reading cluster state during mapper creation. So, we need to validate the
// model when ingestion starts. We do this as lazily as we can
this.perDimensionProcessor = null;
this.perDimensionValidator = null;
this.vectorValidator = null;

Expand All @@ -121,44 +116,29 @@ protected PerDimensionValidator getPerDimensionValidator() {

@Override
protected PerDimensionProcessor getPerDimensionProcessor() {
initPerDimensionProcessor();
return perDimensionProcessor;
return PerDimensionProcessor.NOOP_PROCESSOR;
}

private void initVectorValidator() {
if (vectorValidator != null) {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
vectorValidator = knnLibraryIndexingContext.getVectorValidator();
vectorValidator = new SpaceVectorValidator(modelMetadata.getSpaceType());
}

private void initPerDimensionValidator() {
if (perDimensionValidator != null) {
return;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator();
}

private void initPerDimensionProcessor() {
if (perDimensionProcessor != null) {
return;
if (modelMetadata.getVectorDataType() == VectorDataType.BINARY) {
perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR;
} else if (modelMetadata.getVectorDataType() == VectorDataType.BYTE) {
perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR;
} else {
perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR;
}
ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId);
KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata);
KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata);
KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor();
}

@Override
Expand All @@ -183,17 +163,6 @@ protected void parseCreateField(ParseContext context) throws IOException {
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType());
}

private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) {
return new KNNMethodContext(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext());
}

private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata(ModelMetadata modelMetadata) {
return KNNMethodConfigContext.builder()
.vectorDataType(modelMetadata.getVectorDataType())
.dimension(modelMetadata.getDimension())
.build();
}

private static ModelMetadata getModelMetadata(ModelDao modelDao, String modelId) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
Expand Down

0 comments on commit 706bbaa

Please sign in to comment.