Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model version to model metadata and change model metadata reads to be from cluster metadata #2005

Merged
merged 13 commits into from
Sep 5, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
### Enhancements
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
* Add model version to model metadata and change model metadata reads to be from cluster metadata [#2005](https://github.com/opensearch-project/k-NN/pull/2005)
### Bug Fixes
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class KNNConstants {
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String MODEL_VERSION = "model_version";
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/index/util/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class IndexUtil {
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0;
private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION = Version.V_2_17_0;
// public so neural search can access it
public static final Map<String, Version> minimalRequiredVersionMap = initializeMinimalRequiredVersionMap();

Expand Down Expand Up @@ -405,6 +406,7 @@ private static Map<String, Version> initializeMinimalRequiredVersionMap() {
put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS);
put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE);
put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE);
put(KNNConstants.MODEL_VERSION, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION);
}
};

Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment());
put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType());
put(KNNConstants.MODEL_VERSION, modelMetadata.getModelVersion());

MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
if (!methodComponentContext.getName().isEmpty()) {
Expand Down
54 changes: 47 additions & 7 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.Version;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -51,6 +52,7 @@ public class ModelMetadata implements Writeable, ToXContentObject {
final private String trainingNodeAssignment;
final private VectorDataType vectorDataType;
private MethodComponentContext methodComponentContext;
private final Version version;
private String error;

/**
Expand All @@ -59,7 +61,6 @@ public class ModelMetadata implements Writeable, ToXContentObject {
* @param in Stream input
*/
public ModelMetadata(StreamInput in) throws IOException {
String tempTrainingNodeAssignment;
this.knnEngine = KNNEngine.getEngine(in.readString());
this.spaceType = SpaceType.getSpace(in.readString());
this.dimension = in.readInt();
Expand Down Expand Up @@ -89,6 +90,12 @@ public ModelMetadata(StreamInput in) throws IOException {
} else {
this.vectorDataType = VectorDataType.DEFAULT;
}

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VERSION)) {
this.version = Version.fromString(in.readString());
} else {
this.version = Version.V_EMPTY;
}
}

/**
Expand All @@ -115,7 +122,8 @@ public ModelMetadata(
String error,
String trainingNodeAssignment,
MethodComponentContext methodComponentContext,
VectorDataType vectorDataType
VectorDataType vectorDataType,
Version version
) {
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
Expand All @@ -139,6 +147,7 @@ public ModelMetadata(
this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null");
this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null");
this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null");
this.version = Objects.requireNonNull(version, "model version must not be null");
}

/**
Expand Down Expand Up @@ -226,6 +235,14 @@ public VectorDataType getVectorDataType() {
return vectorDataType;
}

/**
* Getter for the model version
* @return version
*/
public Version getModelVersion() {
return version;
}

/**
* setter for model's state
*
Expand Down Expand Up @@ -257,7 +274,8 @@ public String toString() {
error,
trainingNodeAssignment,
methodComponentContext.toClusterStateString(),
vectorDataType.getValue()
vectorDataType.getValue(),
version.toString()
);
}

Expand Down Expand Up @@ -291,6 +309,7 @@ public int hashCode() {
.append(getError())
.append(getMethodComponentContext())
.append(getVectorDataType())
.append(getModelVersion())
.toHashCode();
}

Expand All @@ -304,13 +323,14 @@ public static ModelMetadata fromString(String modelMetadataString) {
String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1);
int length = modelMetadataArray.length;

if (length < 7 || length > 10) {
if (length < 7 || length > 11) {
ryanbogan marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException(
"Illegal format for model metadata. Must be of the form "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>\"."
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>\" or "
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Version>\"."
);
}

Expand All @@ -326,6 +346,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
? MethodComponentContext.fromClusterStateString(modelMetadataArray[8])
: MethodComponentContext.EMPTY;
VectorDataType vectorDataType = length > 9 ? VectorDataType.get(modelMetadataArray[9]) : VectorDataType.DEFAULT;
Version version = length > 10 ? Version.fromString(modelMetadataArray[10]) : Version.V_EMPTY;

log.debug(getLogMessage(length));

Expand All @@ -339,7 +360,8 @@ public static ModelMetadata fromString(String modelMetadataString) {
error,
trainingNodeAssignment,
methodComponentContext,
vectorDataType
vectorDataType,
version
);
}

Expand All @@ -353,6 +375,8 @@ private static String getLogMessage(int length) {
return "Model metadata contains training node assignment and method context.";
case 10:
return "Model metadata contains training node assignment, method context and vector data type.";
case 11:
return "Model metadata contains training node assignment, method context, vector data type, and version";
default:
throw new IllegalArgumentException("Unexpected metadata array length: " + length);
}
Expand Down Expand Up @@ -385,6 +409,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT);
Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT);
Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD);
Object version = modelSourceMap.get(KNNConstants.MODEL_VERSION);

if (trainingNodeAssignment == null) {
trainingNodeAssignment = "";
Expand All @@ -409,6 +434,10 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
vectorDataType = VectorDataType.DEFAULT.getValue();
}

if (version == null) {
version = Version.V_EMPTY;
}

ModelMetadata modelMetadata = new ModelMetadata(
KNNEngine.getEngine(objectToString(engine)),
SpaceType.getSpace(objectToString(space)),
Expand All @@ -419,7 +448,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
objectToString(error),
objectToString(trainingNodeAssignment),
(MethodComponentContext) methodComponentContext,
VectorDataType.get(objectToString(vectorDataType))
VectorDataType.get(objectToString(vectorDataType)),
Version.fromString(version.toString())
);
return modelMetadata;
}
Expand All @@ -442,6 +472,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) {
out.writeString(vectorDataType.getValue());
}
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VERSION)) {
out.writeString(version.toString());
}
}

@Override
Expand All @@ -465,6 +498,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) {
builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue());
}
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VERSION)) {
String versionString = "unknown";
if (version != Version.V_EMPTY) {
versionString = version.toString();
}
builder.field(KNNConstants.MODEL_VERSION, versionString);
}
return builder;
}
}
4 changes: 2 additions & 2 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public static ModelMetadata getModelMetadata(final String modelId) {
if (StringUtils.isEmpty(modelId)) {
return null;
}
final Model model = ModelCache.getInstance().get(modelId);
final ModelMetadata modelMetadata = model.getModelMetadata();
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
final ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (isModelCreated(modelMetadata) == false) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ public TrainingJob(
"",
nodeAssignment,
knnMethodContext.getMethodComponentContext(),
knnMethodConfigContext.getVectorDataType()
knnMethodConfigContext.getVectorDataType(),
knnMethodConfigContext.getVersionCreated()
),
null,
this.modelId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
Expand Down Expand Up @@ -166,11 +165,11 @@ private void train(TrainingJob trainingJob) {
private void serializeModel(TrainingJob trainingJob, ActionListener<IndexResponse> listener, boolean update) throws IOException,
ExecutionException, InterruptedException {
if (update) {
Model model = modelDao.get(trainingJob.getModelId());
if (model.getModelMetadata().getState().equals(ModelState.TRAINING)) {
ModelMetadata modelMetadata = modelDao.getMetadata(trainingJob.getModelId());
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
modelDao.update(trainingJob.getModel(), listener);
} else {
logger.info("Model state is {}. Skipping serialization of trained data", model.getModelMetadata().getState());
logger.info("Model state is {}. Skipping serialization of trained data", modelMetadata.getState());
}
} else {
modelDao.put(trainingJob.getModel(), listener);
Expand Down
3 changes: 3 additions & 0 deletions src/main/resources/mappings/model-index.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
},
"method_component_context": {
"type": "keyword"
},
"model_version": {
"type": "keyword"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import org.opensearch.Version;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.admin.indices.create.CreateIndexRequestBuilder;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -65,7 +66,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException
"",
"test-node",
MethodComponentContext.EMPTY,
VectorDataType.FLOAT
VectorDataType.FLOAT,
Version.V_EMPTY
);

Model model = new Model(modelMetadata, modelBlob, modelId);
Expand Down
Loading
Loading