Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model management to support apis #95

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public class KNNConstants {
public static final String PLUGIN_NAME = "knn";
public static final String MODEL_METADATA_FIELD = "knn-models";

VijayanB marked this conversation as resolved.
Show resolved Hide resolved
public static final String MODEL_STATE = "state";
public static final String MODEL_TIMESTAMP = "timestamp";
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";

// nmslib specific constants
public static final String NMSLIB_NAME = "nmslib";
public static final String SPACE_TYPE = "spaceType"; // used as field info key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

if (model.getModelBlob() == null) {
throw new RuntimeException("Model blob cannot be null");
Copy link
Member

@vamshin vamshin Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets have user friendly message. Something like "There is no model associated to this id xyz"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will update.

}

if (model.getModelMetadata().getKnnEngine() != knnEngine) {
throw new RuntimeException("Model Engine \"" + model.getModelMetadata().getKnnEngine().getName()
+ "\" cannot be different than index engine \"" + knnEngine.getName() + "\"");
Expand Down
39 changes: 29 additions & 10 deletions src/main/java/org/opensearch/knn/indices/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,30 @@

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.common.Nullable;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

public class Model {

final private ModelMetadata modelMetadata;
final private byte[] modelBlob;
private ModelMetadata modelMetadata;
private AtomicReference<byte[]> modelBlob;

/**
* Constructor
*
* @param modelMetadata metadata about the model
* @param modelBlob binary representation of model template index
* @param modelBlob binary representation of model template index. Can be null if model is not yet in CREATED state.
*/
public Model(ModelMetadata modelMetadata, byte[] modelBlob) {
public Model(ModelMetadata modelMetadata, @Nullable byte[] modelBlob) {
this.modelMetadata = Objects.requireNonNull(modelMetadata, "modelMetadata must not be null");
this.modelBlob = Objects.requireNonNull(modelBlob, "modelBlob must not be null");

if (ModelState.CREATED.equals(this.modelMetadata.getState()) && modelBlob == null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be equalsIgnoreCase instead of equals?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CREATED is an enum not a string.

throw new IllegalArgumentException("Model blob cannot be null when model metadata says model is created");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will user understands what is model blob/ model meta data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we may have had this discussion before. Basically a model is made of a binary model and some metadata about the model. model_blob refers to the binary. I am not sure if there is a better name. Do you have any recs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as we document what this error message means / how to resolve it, it should be fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Vijay. We can put message something like "We could not find model associated to this id xyz"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will update error message.

}

this.modelBlob = new AtomicReference<>(modelBlob);
}

/**
Expand All @@ -47,7 +54,7 @@ public ModelMetadata getModelMetadata() {
* @return modelBlob
*/
public byte[] getModelBlob() {
return modelBlob;
return modelBlob.get();
}

/**
Expand All @@ -56,7 +63,19 @@ public byte[] getModelBlob() {
* @return length of model blob
*/
public int getLength() {
return modelBlob.length;
if (getModelBlob() == null) {
return 0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: Should this be -1 to make it clear to the caller that there is no model?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we use this in the cache, I think it should just be 0, indicating the model has a size of 0.

}
return getModelBlob().length;
}

/**
* Sets model blob to new value
*
* @param modelBlob updated model blob
*/
public synchronized void setModelBlob(byte[] modelBlob) {
this.modelBlob = new AtomicReference<>(Objects.requireNonNull(modelBlob, "model blob cannot be updated to null"));
}

@Override
Expand All @@ -68,14 +87,14 @@ public boolean equals(Object obj) {
Model other = (Model) obj;

EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(modelMetadata, other.modelMetadata);
equalsBuilder.append(modelBlob, other.modelBlob);
equalsBuilder.append(getModelMetadata(), other.getModelMetadata());
equalsBuilder.append(getModelBlob(), other.getModelBlob());

return equalsBuilder.isEquals();
}

@Override
public int hashCode() {
return new HashCodeBuilder().append(modelMetadata).append(modelBlob).toHashCode();
return new HashCodeBuilder().append(getModelMetadata()).append(getModelBlob()).toHashCode();
}
}
159 changes: 104 additions & 55 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
package org.opensearch.knn.indices;

import com.google.common.base.Charsets;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Resources;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -33,7 +32,9 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -44,6 +45,7 @@
import java.io.IOException;
import java.net.URL;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;

Expand Down Expand Up @@ -79,18 +81,27 @@ public interface ModelDao {
*
* @param modelId Id of model to create
* @param model Model to be indexed
* @param listener handles acknowledged response
* @param listener handles index response
*/
void put(String modelId, Model model, ActionListener<AcknowledgedResponse> listener) throws IOException;
void put(String modelId, Model model, ActionListener<IndexResponse> listener) throws IOException;

/**
* Put a model into the system index. Non-blocking. When no id is passed in, OpenSearch will generate the id
* automatically. The id can be retrieved in the IndexResponse.
*
* @param model Model to be indexed
* @param listener handles acknowledged response
* @param listener handles index response
*/
void put(Model model, ActionListener<AcknowledgedResponse> listener) throws IOException;
void put(Model model, ActionListener<IndexResponse> listener) throws IOException;

/**
* Update model of model id with new model.
*
* @param modelId model id to update
* @param model new model
* @param listener handles index response
*/
void update(String modelId, Model model, ActionListener<IndexResponse> listener) throws IOException;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we do not want to update model once created? Can some one do it using this function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required because initially we set state as training, then we train, then we update state to created and add the model. This will not be exposed to user.


/**
* Get a model from the system index. Call blocks.
Expand Down Expand Up @@ -183,75 +194,104 @@ public boolean isCreated() {
}

@Override
public void put(String modelId, Model model, ActionListener<AcknowledgedResponse> listener) throws IOException {
String base64Model = Base64.getEncoder().encodeToString(model.getModelBlob());

Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.KNN_ENGINE, model.getModelMetadata().getKnnEngine().getName(),
KNNConstants.METHOD_PARAMETER_SPACE_TYPE, model.getModelMetadata().getSpaceType().getValue(),
KNNConstants.DIMENSION, model.getModelMetadata().getDimension(),
KNNConstants.MODEL_BLOB_PARAMETER, base64Model
);
public void put(String modelId, Model model, ActionListener<IndexResponse> listener) throws IOException {
putInternal(modelId, model, listener, DocWriteRequest.OpType.CREATE);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc");
indexRequestBuilder.setId(modelId);
indexRequestBuilder.setSource(parameters);
@Override
public void put(Model model, ActionListener<IndexResponse> listener) throws IOException {
putInternal(null, model, listener, DocWriteRequest.OpType.CREATE);
}

// Fail if the id already exists. Models are not updateable
indexRequestBuilder.setOpType(DocWriteRequest.OpType.CREATE);
indexRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
@Override
public void update(String modelId, Model model, ActionListener<IndexResponse> listener)
throws IOException {
putInternal(modelId, model, listener, DocWriteRequest.OpType.INDEX);
}

// After the model is indexed, update metadata
ActionListener<IndexResponse> putMetadataListener = getUpdateModelMetadataListener(model.getModelMetadata(),
listener);
private void putInternal(@Nullable String modelId, Model model, ActionListener<IndexResponse> listener,
DocWriteRequest.OpType requestOpType) throws IOException {

if (!isCreated()) {
create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(putMetadataListener),
listener::onFailure));
return;
if (model == null) {
throw new IllegalArgumentException("Model cannot be null");
}

indexRequestBuilder.execute(putMetadataListener);
}
ModelMetadata modelMetadata = model.getModelMetadata();

@Override
public void put(Model model, ActionListener<AcknowledgedResponse> listener) throws IOException {
String base64Model = Base64.getEncoder().encodeToString(model.getModelBlob());

Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.KNN_ENGINE, model.getModelMetadata().getKnnEngine().getName(),
KNNConstants.METHOD_PARAMETER_SPACE_TYPE, model.getModelMetadata().getSpaceType().getValue(),
KNNConstants.DIMENSION, model.getModelMetadata().getDimension(),
KNNConstants.MODEL_BLOB_PARAMETER, base64Model
);
Map<String, Object> parameters = new HashMap<String, Object>() {{
put(KNNConstants.KNN_ENGINE, modelMetadata.getKnnEngine().getName());
put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue());
put(KNNConstants.DIMENSION, modelMetadata.getDimension());
put(KNNConstants.MODEL_STATE, modelMetadata.getState().getName());
put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString());
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
}};

byte[] modelBlob = model.getModelBlob();

if (modelBlob == null && ModelState.CREATED.equals(modelMetadata.getState())) {
throw new IllegalArgumentException("Model blob cannot be null when model state is CREATED");
}

// Only add model if it is not null
if (modelBlob != null) {
String base64Model = Base64.getEncoder().encodeToString(modelBlob);
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc");

// Set id for request only if modelId is present
if (modelId != null) {
indexRequestBuilder.setId(modelId);
}

indexRequestBuilder.setSource(parameters);

// Fail if the id already exists. Models are not updateable
indexRequestBuilder.setOpType(DocWriteRequest.OpType.CREATE);
indexRequestBuilder.setOpType(requestOpType); // Delegate whether this request can update based on opType
indexRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// After the model is indexed, update metadata
ActionListener<IndexResponse> putMetadataListener = getUpdateModelMetadataListener(model.getModelMetadata(),
listener);
// After metadata update finishes, remove item from cache if necessary. If no model id is
// passed then nothing needs to be removed from the cache
//TODO: Bug. Model needs to be removed from all nodes caches, not just local.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to fix this TODO in next PR? Having model updatable would result in these scenarios. We should Ideally be not updating them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a bug that needs to be fixed. The problem is more relevant in DELETE. #93

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok lets also put the issue in the comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

ActionListener<IndexResponse> onMetaListener;
if (modelId != null) {
onMetaListener = ActionListener.wrap(response -> {
ModelCache.getInstance().remove(modelId);
listener.onResponse(response);
}, listener::onFailure);
} else {
onMetaListener = listener;
}

// If the index has not been created yet, create it and then add the document
// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
if (ModelState.CREATED.equals(model.getModelMetadata().getState())) {
onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
} else {
onIndexListener = onMetaListener;
}

// Create the model index if it does not already exist
if (!isCreated()) {
create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(putMetadataListener),
listener::onFailure));
create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(onIndexListener),
onIndexListener::onFailure));
return;
}

indexRequestBuilder.execute(putMetadataListener);
indexRequestBuilder.execute(onIndexListener);
}

private ActionListener<IndexResponse> getUpdateModelMetadataListener(ModelMetadata modelMetadata,
ActionListener<AcknowledgedResponse> listener) {
ActionListener<IndexResponse> listener) {
return ActionListener.wrap(indexResponse -> client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(indexResponse.getId(), false, modelMetadata),
listener
// Here we wrap the IndexResponse listener around an AcknowledgedListener. This allows us
// to pass the indexResponse back up.
ActionListener.wrap(acknowledgedResponse -> listener.onResponse(indexResponse),
listener::onFailure)
), listener::onFailure);
}

Expand All @@ -269,15 +309,23 @@ public Model get(String modelId) throws ExecutionException, InterruptedException
Object engine = responseMap.get(KNNConstants.KNN_ENGINE);
Object space = responseMap.get(KNNConstants.METHOD_PARAMETER_SPACE_TYPE);
Object dimension = responseMap.get(KNNConstants.DIMENSION);
Object state = responseMap.get(KNNConstants.MODEL_STATE);
Object timestamp = responseMap.get(KNNConstants.MODEL_TIMESTAMP);
Object description = responseMap.get(KNNConstants.MODEL_DESCRIPTION);
Object error = responseMap.get(KNNConstants.MODEL_ERROR);
Object blob = responseMap.get(KNNConstants.MODEL_BLOB_PARAMETER);

if (blob == null) {
throw new IllegalArgumentException("No model available in \"" + MODEL_INDEX_NAME + "\" index with id \""
+ modelId + "\".");
// If byte blob is not there, it means that the state has not yet been updated to CREATED.
byte[] byteBlob = null;
if (blob != null) {
byteBlob = Base64.getDecoder().decode((String) blob);
}

ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.getEngine((String) engine),
SpaceType.getSpace((String) space), (Integer) dimension);
return new Model(modelMetadata, Base64.getDecoder().decode((String) blob));
SpaceType.getSpace((String) space), (Integer) dimension, ModelState.getModelState((String) state),
TimeValue.parseTimeValue((String) timestamp, KNNConstants.MODEL_TIMESTAMP), (String) description,
(String) error);
return new Model(modelMetadata, byteBlob);
}

@Override
Expand Down Expand Up @@ -326,6 +374,7 @@ public void delete(String modelId, ActionListener<DeleteResponse> listener) {
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model deletion from the index, remove the model from the model cache
//TODO: Bug. Model needs to be removed from all nodes caches, not just local.
ActionListener<DeleteResponse> onModelDeleteListener = ActionListener.wrap(deleteResponse -> {
ModelCache.getInstance().remove(modelId);
listener.onResponse(deleteResponse);
Expand Down
Loading