diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 69cdbfcd72..b489f89862 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -562,7 +562,8 @@ protected void parseCreateField(ParseContext context) throws IOException { context, fieldType().getDimension(), fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()) + getMethodComponentContext(fieldType().getKnnMethodContext()), + fieldType().getVectorDataType() ); } @@ -605,7 +606,7 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) + protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext, VectorDataType vectorDataType) throws IOException { validateIfKNNPluginEnabled(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 5548712790..2327a8c3b7 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -62,6 +62,6 @@ protected void parseCreateField(ParseContext context) throws IOException { ); } - parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext()); + parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext(), modelMetadata.getVectorDataType()); } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0bc6c5edbb..37edcd3aef 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -292,6 +292,7 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); + put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType()); MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (!methodComponentContext.getName().isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index f3a5506cdb..1498987fd6 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -26,6 +26,7 @@ import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -48,6 +49,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { final private String timestamp; final private String description; final private String trainingNodeAssignment; + final private VectorDataType vectorDataType; private MethodComponentContext methodComponentContext; private String error; @@ -81,6 +83,12 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.methodComponentContext = MethodComponentContext.EMPTY; } + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_DATA_TYPE)) { + this.vectorDataType = VectorDataType.get(in.readOptionalString()); + } else { + this.vectorDataType = VectorDataType.FLOAT; + } } /** @@ -105,7 +113,8 @@ public ModelMetadata( String description, String error, String trainingNodeAssignment, - MethodComponentContext methodComponentContext + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -128,6 +137,7 @@ public ModelMetadata( this.error = Objects.requireNonNull(error, "error must not be null"); this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); + this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null"); } /** @@ -211,6 +221,10 @@ public MethodComponentContext getMethodComponentContext() { return methodComponentContext; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * setter for model's state * @@ -241,7 +255,8 @@ public String toString() { description, error, trainingNodeAssignment, - methodComponentContext.toClusterStateString() + methodComponentContext.toClusterStateString(), + vectorDataType.getValue() ); } @@ -259,6 +274,7 @@ public boolean equals(Object obj) { equalsBuilder.append(getTimestamp(), other.getTimestamp()); equalsBuilder.append(getDescription(), other.getDescription()); equalsBuilder.append(getError(), other.getError()); + equalsBuilder.append(getVectorDataType(), other.getVectorDataType()); return equalsBuilder.isEquals(); } @@ -273,6 +289,7 @@ public int hashCode() { .append(getDescription()) .append(getError()) .append(getMethodComponentContext()) + .append(getVectorDataType()) .toHashCode(); } @@ -288,7 +305,7 @@ public static ModelMetadata fromString(String modelMetadataString) { // Training node assignment was added as a field in Version 2.12.0 // Because models can be created on older versions and the cluster can be upgraded after, // we need to accept model metadata arrays both with and without the training node assignment. - if (modelMetadataArray.length == 7) { + if (modelMetadataArray.length == 8) { log.debug( "Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context." ); @@ -299,6 +316,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String timestamp = modelMetadataArray[4]; String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]); return new ModelMetadata( knnEngine, spaceType, @@ -308,9 +326,10 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + vectorDataType ); - } else if (modelMetadataArray.length == 8) { + } else if (modelMetadataArray.length == 9) { log.debug("Model metadata contains training node assignment. Assuming empty method component context."); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); @@ -320,6 +339,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String description = modelMetadataArray[5]; String error = modelMetadataArray[6]; String trainingNodeAssignment = modelMetadataArray[7]; + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[10]); return new ModelMetadata( knnEngine, spaceType, @@ -329,9 +349,10 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + vectorDataType ); - } else if (modelMetadataArray.length == 9) { + } else if (modelMetadataArray.length == 10) { log.debug("Model metadata contains training node assignment and method context"); KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); @@ -342,6 +363,7 @@ public static ModelMetadata fromString(String modelMetadataString) { String error = modelMetadataArray[6]; String trainingNodeAssignment = modelMetadataArray[7]; MethodComponentContext methodComponentContext = MethodComponentContext.fromClusterStateString(modelMetadataArray[8]); + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]); return new ModelMetadata( knnEngine, spaceType, @@ -351,7 +373,8 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - methodComponentContext + methodComponentContext, + vectorDataType ); } else { throw new IllegalArgumentException( @@ -387,6 +410,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT); + Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; @@ -416,7 +440,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m objectToString(description), objectToString(error), objectToString(trainingNodeAssignment), - (MethodComponentContext) methodComponentContext + (MethodComponentContext) methodComponentContext, + VectorDataType.get(objectToString(vectorDataType)) ); return modelMetadata; } @@ -436,6 +461,9 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { getMethodComponentContext().writeTo(out); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_DATA_TYPE)) { + out.writeOptionalString(vectorDataType.getValue()); + } } @Override @@ -456,6 +484,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws getMethodComponentContext().toXContent(builder, params); builder.endObject(); } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_DATA_TYPE)) { + builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index fb8ccc4cec..eec2540af9 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -17,6 +17,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; @@ -30,16 +31,7 @@ import java.util.Locale; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; -import static org.opensearch.knn.common.KNNConstants.MAX_VECTOR_COUNT_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.MODELS; -import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; -import static org.opensearch.knn.common.KNNConstants.MODEL_ID; -import static org.opensearch.knn.common.KNNConstants.PREFERENCE_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.SEARCH_SIZE_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; -import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.*; /** * Rest Handler for model training api endpoint. @@ -83,6 +75,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr String trainingIndex = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String trainingField = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String description = (String) DEFAULT_NOT_SET_OBJECT_VALUE; + VectorDataType vectorDataType = (VectorDataType) DEFAULT_NOT_SET_OBJECT_VALUE; int dimension = DEFAULT_NOT_SET_INT_VALUE; int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE; @@ -110,6 +103,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { description = parser.textOrNull(); ModelUtil.blockCommasInModelDescription(description); + } else if (VECTOR_DATA_TYPE_FIELD.equals(fieldName) && ensureNotSet(fieldName, vectorDataType)) { + vectorDataType = VectorDataType.get(parser.text()); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } @@ -126,6 +121,10 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr description = ""; } + if (vectorDataType == DEFAULT_NOT_SET_OBJECT_VALUE) { + vectorDataType = VectorDataType.FLOAT; + } + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, knnMethodContext, @@ -133,7 +132,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr trainingIndex, trainingField, preferredNodeId, - description + description, + vectorDataType ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 5f3913ac53..ebb286659d 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -21,6 +21,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.training.VectorSpaceInfo; @@ -41,6 +42,7 @@ public class TrainingModelRequest extends ActionRequest { private final String trainingField; private final String preferredNodeId; private final String description; + private final VectorDataType vectorDataType; private int maximumVectorCount; private int searchSize; @@ -65,7 +67,8 @@ public TrainingModelRequest( String trainingIndex, String trainingField, String preferredNodeId, - String description + String description, + VectorDataType vectorDataType ) { super(); this.modelId = modelId; @@ -75,6 +78,7 @@ public TrainingModelRequest( this.trainingField = trainingField; this.preferredNodeId = preferredNodeId; this.description = description; + this.vectorDataType = vectorDataType; // Set these as defaults initially. If call wants to override them, they can use the setters. this.maximumVectorCount = Integer.MAX_VALUE; // By default, get all vectors in the index @@ -103,6 +107,7 @@ public TrainingModelRequest(StreamInput in) throws IOException { this.maximumVectorCount = in.readInt(); this.searchSize = in.readInt(); this.trainingDataSizeInKB = in.readInt(); + this.vectorDataType = VectorDataType.get(in.readOptionalString()); } /** @@ -213,6 +218,10 @@ public int getSearchSize() { return searchSize; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * Setter for search size. * diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 33b420e2c0..58ac41b313 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -68,7 +68,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener modelAnonymousEntryContext, request.getDimension(), request.getDescription(), - clusterService.localNode().getEphemeralId() + clusterService.localNode().getEphemeralId(), + request.getVectorDataType() ); KNNCounter.TRAINING_REQUESTS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index 0486bec6fc..7db931b9c4 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -18,6 +18,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -69,7 +70,8 @@ public TrainingJob( NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, int dimension, String description, - String nodeAssignment + String nodeAssignment, + VectorDataType vectorDataType ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); @@ -87,7 +89,8 @@ public TrainingJob( description, "", nodeAssignment, - knnMethodContext.getMethodComponentContext() + knnMethodContext.getMethodComponentContext(), + vectorDataType ), null, this.modelId