From 706bbaa2272cbdcc02995f1d4db1d29d322ad40e Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Wed, 14 Aug 2024 11:52:30 -0700 Subject: [PATCH] Fix model field mapper Signed-off-by: John Mazanec --- .../knn/index/mapper/ModelFieldMapper.java | 47 ++++--------------- 1 file changed, 8 insertions(+), 39 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 0b572b36eb..7d49a5de34 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -13,9 +13,6 @@ import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; -import org.opensearch.knn.index.engine.KNNMethodConfigContext; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; @@ -34,7 +31,6 @@ public class ModelFieldMapper extends KNNVectorFieldMapper { // If the dimension has not yet been set because we do not have access to model metadata, it will be -1 public static final int UNSET_MODEL_DIMENSION_IDENTIFIER = -1; - private PerDimensionProcessor perDimensionProcessor; private PerDimensionValidator perDimensionValidator; private VectorValidator vectorValidator; @@ -98,7 +94,6 @@ private ModelFieldMapper( // For the model field mapper, we cannot validate the model during index creation due to // an issue with reading cluster state during mapper creation. So, we need to validate the // model when ingestion starts. We do this as lazily as we can - this.perDimensionProcessor = null; this.perDimensionValidator = null; this.vectorValidator = null; @@ -121,8 +116,7 @@ protected PerDimensionValidator getPerDimensionValidator() { @Override protected PerDimensionProcessor getPerDimensionProcessor() { - initPerDimensionProcessor(); - return perDimensionProcessor; + return PerDimensionProcessor.NOOP_PROCESSOR; } private void initVectorValidator() { @@ -130,11 +124,7 @@ private void initVectorValidator() { return; } ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); - KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); - KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); - KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); - vectorValidator = knnLibraryIndexingContext.getVectorValidator(); + vectorValidator = new SpaceVectorValidator(modelMetadata.getSpaceType()); } private void initPerDimensionValidator() { @@ -142,23 +132,13 @@ private void initPerDimensionValidator() { return; } ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); - KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); - KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); - KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); - perDimensionValidator = knnLibraryIndexingContext.getPerDimensionValidator(); - } - - private void initPerDimensionProcessor() { - if (perDimensionProcessor != null) { - return; + if (modelMetadata.getVectorDataType() == VectorDataType.BINARY) { + perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + } else if (modelMetadata.getVectorDataType() == VectorDataType.BYTE) { + perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + } else { + perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; } - ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); - KNNMethodContext knnMethodContext = getKNNMethodContextFromModelMetadata(modelMetadata); - KNNMethodConfigContext knnMethodConfigContext = getKNNMethodConfigContextFromModelMetadata(modelMetadata); - KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine() - .getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext); - perDimensionProcessor = knnLibraryIndexingContext.getPerDimensionProcessor(); } @Override @@ -183,17 +163,6 @@ protected void parseCreateField(ParseContext context) throws IOException { parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType()); } - private static KNNMethodContext getKNNMethodContextFromModelMetadata(ModelMetadata modelMetadata) { - return new KNNMethodContext(modelMetadata.getKnnEngine(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext()); - } - - private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata(ModelMetadata modelMetadata) { - return KNNMethodConfigContext.builder() - .vectorDataType(modelMetadata.getVectorDataType()) - .dimension(modelMetadata.getDimension()) - .build(); - } - private static ModelMetadata getModelMetadata(ModelDao modelDao, String modelId) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) {