Skip to content

Commit

Permalink
Support data_type field when train model
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jul 9, 2024
1 parent 9361a96 commit 2609a68
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ protected void parseCreateField(ParseContext context) throws IOException {
context,
fieldType().getDimension(),
fieldType().getSpaceType(),
getMethodComponentContext(fieldType().getKnnMethodContext())
getMethodComponentContext(fieldType().getKnnMethodContext()),
fieldType().getVectorDataType()
);
}

Expand Down Expand Up @@ -605,7 +606,7 @@ protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType
return fields;
}

protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext)
protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext, VectorDataType vectorDataType)
throws IOException {

validateIfKNNPluginEnabled();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ protected void parseCreateField(ParseContext context) throws IOException {
);
}

parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext());
parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext(), modelMetadata.getVectorDataType());
}
}
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());
put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType());

MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (!methodComponentContext.getName().isEmpty()) {
Expand Down
49 changes: 40 additions & 9 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.MethodComponentContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.io.IOException;
Expand All @@ -48,6 +49,7 @@ public class ModelMetadata implements Writeable, ToXContentObject {
final private String timestamp;
final private String description;
final private String trainingNodeAssignment;
final private VectorDataType vectorDataType;
private MethodComponentContext methodComponentContext;
private String error;

Expand Down Expand Up @@ -81,6 +83,12 @@ public ModelMetadata(StreamInput in) throws IOException {
} else {
this.methodComponentContext = MethodComponentContext.EMPTY;
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_DATA_TYPE)) {
this.vectorDataType = VectorDataType.get(in.readOptionalString());
} else {
this.vectorDataType = VectorDataType.FLOAT;
}
}

/**
Expand All @@ -105,7 +113,8 @@ public ModelMetadata(
String description,
String error,
String trainingNodeAssignment,
MethodComponentContext methodComponentContext
MethodComponentContext methodComponentContext,
VectorDataType vectorDataType
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
Expand All @@ -128,6 +137,7 @@ public ModelMetadata(
this.error = Objects.requireNonNull(error, "error must not be null");
this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null");
this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null");
this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null");
}

/**
Expand Down Expand Up @@ -211,6 +221,10 @@ public MethodComponentContext getMethodComponentContext() {
return methodComponentContext;
}

public VectorDataType getVectorDataType() {
return vectorDataType;
}

/**
* setter for model's state
*
Expand Down Expand Up @@ -241,7 +255,8 @@ public String toString() {
description,
error,
trainingNodeAssignment,
methodComponentContext.toClusterStateString()
methodComponentContext.toClusterStateString(),
vectorDataType.getValue()
);
}

Expand All @@ -259,6 +274,7 @@ public boolean equals(Object obj) {
equalsBuilder.append(getTimestamp(), other.getTimestamp());
equalsBuilder.append(getDescription(), other.getDescription());
equalsBuilder.append(getError(), other.getError());
equalsBuilder.append(getVectorDataType(), other.getVectorDataType());

return equalsBuilder.isEquals();
}
Expand All @@ -273,6 +289,7 @@ public int hashCode() {
.append(getDescription())
.append(getError())
.append(getMethodComponentContext())
.append(getVectorDataType())
.toHashCode();
}

Expand All @@ -288,7 +305,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
// Training node assignment was added as a field in Version 2.12.0
// Because models can be created on older versions and the cluster can be upgraded after,
// we need to accept model metadata arrays both with and without the training node assignment.
if (modelMetadataArray.length == 7) {
if (modelMetadataArray.length == 8) {
log.debug(
"Model metadata array does not contain training node assignment or method component context. Assuming empty string node assignment and empty method component context."
);
Expand All @@ -299,6 +316,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
String timestamp = modelMetadataArray[4];
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]);
return new ModelMetadata(
knnEngine,
spaceType,
Expand All @@ -308,9 +326,10 @@ public static ModelMetadata fromString(String modelMetadataString) {
description,
error,
"",
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
vectorDataType
);
} else if (modelMetadataArray.length == 8) {
} else if (modelMetadataArray.length == 9) {
log.debug("Model metadata contains training node assignment. Assuming empty method component context.");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
Expand All @@ -320,6 +339,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
String description = modelMetadataArray[5];
String error = modelMetadataArray[6];
String trainingNodeAssignment = modelMetadataArray[7];
VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[10]);
return new ModelMetadata(
knnEngine,
spaceType,
Expand All @@ -329,9 +349,10 @@ public static ModelMetadata fromString(String modelMetadataString) {
description,
error,
trainingNodeAssignment,
MethodComponentContext.EMPTY
MethodComponentContext.EMPTY,
vectorDataType
);
} else if (modelMetadataArray.length == 9) {
} else if (modelMetadataArray.length == 10) {
log.debug("Model metadata contains training node assignment and method context");
KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]);
SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]);
Expand All @@ -342,6 +363,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
String error = modelMetadataArray[6];
String trainingNodeAssignment = modelMetadataArray[7];
MethodComponentContext methodComponentContext = MethodComponentContext.fromClusterStateString(modelMetadataArray[8]);
VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]);
return new ModelMetadata(
knnEngine,
spaceType,
Expand All @@ -351,7 +373,8 @@ public static ModelMetadata fromString(String modelMetadataString) {
description,
error,
trainingNodeAssignment,
methodComponentContext
methodComponentContext,
vectorDataType
);
} else {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -387,6 +410,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR);
Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT);
Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT);
Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD);

if (trainingNodeAssignment == null) {
trainingNodeAssignment = "";
Expand Down Expand Up @@ -416,7 +440,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
objectToString(description),
objectToString(error),
objectToString(trainingNodeAssignment),
(MethodComponentContext) methodComponentContext
(MethodComponentContext) methodComponentContext,
VectorDataType.get(objectToString(vectorDataType))
);
return modelMetadata;
}
Expand All @@ -436,6 +461,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) {
getMethodComponentContext().writeTo(out);
}
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_DATA_TYPE)) {
out.writeOptionalString(vectorDataType.getValue());
}
}

@Override
Expand All @@ -456,6 +484,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
getMethodComponentContext().toXContent(builder, params);
builder.endObject();
}
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(IndexUtil.MODEL_DATA_TYPE)) {
builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue());
}
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.TrainingJobRouterAction;
Expand All @@ -30,16 +31,7 @@
import java.util.Locale;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.MAX_VECTOR_COUNT_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.PREFERENCE_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.SEARCH_SIZE_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.*;

/**
* Rest Handler for model training api endpoint.
Expand Down Expand Up @@ -83,6 +75,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
String trainingIndex = (String) DEFAULT_NOT_SET_OBJECT_VALUE;
String trainingField = (String) DEFAULT_NOT_SET_OBJECT_VALUE;
String description = (String) DEFAULT_NOT_SET_OBJECT_VALUE;
VectorDataType vectorDataType = (VectorDataType) DEFAULT_NOT_SET_OBJECT_VALUE;

int dimension = DEFAULT_NOT_SET_INT_VALUE;
int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE;
Expand Down Expand Up @@ -110,6 +103,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
} else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) {
description = parser.textOrNull();
ModelUtil.blockCommasInModelDescription(description);
} else if (VECTOR_DATA_TYPE_FIELD.equals(fieldName) && ensureNotSet(fieldName, vectorDataType)) {
vectorDataType = VectorDataType.get(parser.text());
} else {
throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter.");
}
Expand All @@ -126,14 +121,19 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
description = "";
}

if (vectorDataType == DEFAULT_NOT_SET_OBJECT_VALUE) {
vectorDataType = VectorDataType.FLOAT;
}

TrainingModelRequest trainingModelRequest = new TrainingModelRequest(
modelId,
knnMethodContext,
dimension,
trainingIndex,
trainingField,
preferredNodeId,
description
description,
vectorDataType
);

if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.training.VectorSpaceInfo;

Expand All @@ -41,6 +42,7 @@ public class TrainingModelRequest extends ActionRequest {
private final String trainingField;
private final String preferredNodeId;
private final String description;
private final VectorDataType vectorDataType;

private int maximumVectorCount;
private int searchSize;
Expand All @@ -65,7 +67,8 @@ public TrainingModelRequest(
String trainingIndex,
String trainingField,
String preferredNodeId,
String description
String description,
VectorDataType vectorDataType
) {
super();
this.modelId = modelId;
Expand All @@ -75,6 +78,7 @@ public TrainingModelRequest(
this.trainingField = trainingField;
this.preferredNodeId = preferredNodeId;
this.description = description;
this.vectorDataType = vectorDataType;

// Set these as defaults initially. If call wants to override them, they can use the setters.
this.maximumVectorCount = Integer.MAX_VALUE; // By default, get all vectors in the index
Expand Down Expand Up @@ -103,6 +107,7 @@ public TrainingModelRequest(StreamInput in) throws IOException {
this.maximumVectorCount = in.readInt();
this.searchSize = in.readInt();
this.trainingDataSizeInKB = in.readInt();
this.vectorDataType = VectorDataType.get(in.readOptionalString());
}

/**
Expand Down Expand Up @@ -213,6 +218,10 @@ public int getSearchSize() {
return searchSize;
}

public VectorDataType getVectorDataType() {
return vectorDataType;
}

/**
* Setter for search size.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener
modelAnonymousEntryContext,
request.getDimension(),
request.getDescription(),
clusterService.localNode().getEphemeralId()
clusterService.localNode().getEphemeralId(),
request.getVectorDataType()
);

KNNCounter.TRAINING_REQUESTS.increment();
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
Expand Down Expand Up @@ -69,7 +70,8 @@ public TrainingJob(
NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext,
int dimension,
String description,
String nodeAssignment
String nodeAssignment,
VectorDataType vectorDataType
) {
// Generate random base64 string if one is not provided
this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID();
Expand All @@ -87,7 +89,8 @@ public TrainingJob(
description,
"",
nodeAssignment,
knnMethodContext.getMethodComponentContext()
knnMethodContext.getMethodComponentContext(),
vectorDataType
),
null,
this.modelId
Expand Down

0 comments on commit 2609a68

Please sign in to comment.