Skip to content

Commit

Permalink
Refactor model management to support apis (opensearch-project#95)
Browse files Browse the repository at this point in the history
Signed-off-by: Jack Mazanec <[email protected]>
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
jmazanec15 authored and martin-gaievski committed Mar 7, 2022
1 parent cc5e35f commit 24016bb
Show file tree
Hide file tree
Showing 17 changed files with 837 additions and 158 deletions.
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public class KNNConstants {
public static final String PLUGIN_NAME = "knn";
public static final String MODEL_METADATA_FIELD = "knn-models";

public static final String MODEL_STATE = "state";
public static final String MODEL_TIMESTAMP = "timestamp";
public static final String MODEL_DESCRIPTION = "description";
public static final String MODEL_ERROR = "error";

// nmslib specific constants
public static final String NMSLIB_NAME = "nmslib";
public static final String SPACE_TYPE = "spaceType"; // used as field info key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer)
String modelId = field.attributes().get(MODEL_ID);
Model model = ModelCache.getInstance().get(modelId);

if (model.getModelBlob() == null) {
throw new RuntimeException("There is no model with id \"" + modelId + "\"");
}

if (model.getModelMetadata().getKnnEngine() != knnEngine) {
throw new RuntimeException("Model Engine \"" + model.getModelMetadata().getKnnEngine().getName()
+ "\" cannot be different than index engine \"" + knnEngine.getName() + "\"");
Expand Down
40 changes: 30 additions & 10 deletions src/main/java/org/opensearch/knn/indices/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,31 @@

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.common.Nullable;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

public class Model {

final private ModelMetadata modelMetadata;
final private byte[] modelBlob;
private ModelMetadata modelMetadata;
private AtomicReference<byte[]> modelBlob;

/**
* Constructor
*
* @param modelMetadata metadata about the model
* @param modelBlob binary representation of model template index
* @param modelBlob binary representation of model template index. Can be null if model is not yet in CREATED state.
*/
public Model(ModelMetadata modelMetadata, byte[] modelBlob) {
public Model(ModelMetadata modelMetadata, @Nullable byte[] modelBlob) {
this.modelMetadata = Objects.requireNonNull(modelMetadata, "modelMetadata must not be null");
this.modelBlob = Objects.requireNonNull(modelBlob, "modelBlob must not be null");

if (ModelState.CREATED.equals(this.modelMetadata.getState()) && modelBlob == null) {
throw new IllegalArgumentException("Cannot construct model in state CREATED when model binary is null. " +
"State must be either TRAINING or FAILED");
}

this.modelBlob = new AtomicReference<>(modelBlob);
}

/**
Expand All @@ -47,7 +55,7 @@ public ModelMetadata getModelMetadata() {
* @return modelBlob
*/
public byte[] getModelBlob() {
return modelBlob;
return modelBlob.get();
}

/**
Expand All @@ -56,7 +64,19 @@ public byte[] getModelBlob() {
* @return length of model blob
*/
public int getLength() {
return modelBlob.length;
if (getModelBlob() == null) {
return 0;
}
return getModelBlob().length;
}

/**
* Sets model blob to new value
*
* @param modelBlob updated model blob
*/
public synchronized void setModelBlob(byte[] modelBlob) {
this.modelBlob = new AtomicReference<>(Objects.requireNonNull(modelBlob, "model blob cannot be updated to null"));
}

@Override
Expand All @@ -68,14 +88,14 @@ public boolean equals(Object obj) {
Model other = (Model) obj;

EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(modelMetadata, other.modelMetadata);
equalsBuilder.append(modelBlob, other.modelBlob);
equalsBuilder.append(getModelMetadata(), other.getModelMetadata());
equalsBuilder.append(getModelBlob(), other.getModelBlob());

return equalsBuilder.isEquals();
}

@Override
public int hashCode() {
return new HashCodeBuilder().append(modelMetadata).append(modelBlob).toHashCode();
return new HashCodeBuilder().append(getModelMetadata()).append(getModelBlob()).toHashCode();
}
}
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/indices/ModelCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public Model get(String modelId) {
try {
return cache.get(modelId, () -> modelDao.get(modelId));
} catch (ExecutionException ee) {
throw new IllegalStateException("Unable to retrieve model blob for \"" + modelId + "\": " + ee);
throw new IllegalStateException("Unable to retrieve model binary for \"" + modelId + "\": " + ee);
}
}

Expand Down
161 changes: 106 additions & 55 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
package org.opensearch.knn.indices;

import com.google.common.base.Charsets;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Resources;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -33,7 +32,9 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -44,6 +45,7 @@
import java.io.IOException;
import java.net.URL;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutionException;

Expand Down Expand Up @@ -79,18 +81,27 @@ public interface ModelDao {
*
* @param modelId Id of model to create
* @param model Model to be indexed
* @param listener handles acknowledged response
* @param listener handles index response
*/
void put(String modelId, Model model, ActionListener<AcknowledgedResponse> listener) throws IOException;
void put(String modelId, Model model, ActionListener<IndexResponse> listener) throws IOException;

/**
* Put a model into the system index. Non-blocking. When no id is passed in, OpenSearch will generate the id
* automatically. The id can be retrieved in the IndexResponse.
*
* @param model Model to be indexed
* @param listener handles acknowledged response
* @param listener handles index response
*/
void put(Model model, ActionListener<AcknowledgedResponse> listener) throws IOException;
void put(Model model, ActionListener<IndexResponse> listener) throws IOException;

/**
* Update model of model id with new model.
*
* @param modelId model id to update
* @param model new model
* @param listener handles index response
*/
void update(String modelId, Model model, ActionListener<IndexResponse> listener) throws IOException;

/**
* Get a model from the system index. Call blocks.
Expand Down Expand Up @@ -183,75 +194,105 @@ public boolean isCreated() {
}

@Override
public void put(String modelId, Model model, ActionListener<AcknowledgedResponse> listener) throws IOException {
String base64Model = Base64.getEncoder().encodeToString(model.getModelBlob());

Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.KNN_ENGINE, model.getModelMetadata().getKnnEngine().getName(),
KNNConstants.METHOD_PARAMETER_SPACE_TYPE, model.getModelMetadata().getSpaceType().getValue(),
KNNConstants.DIMENSION, model.getModelMetadata().getDimension(),
KNNConstants.MODEL_BLOB_PARAMETER, base64Model
);
public void put(String modelId, Model model, ActionListener<IndexResponse> listener) throws IOException {
putInternal(modelId, model, listener, DocWriteRequest.OpType.CREATE);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc");
indexRequestBuilder.setId(modelId);
indexRequestBuilder.setSource(parameters);
@Override
public void put(Model model, ActionListener<IndexResponse> listener) throws IOException {
putInternal(null, model, listener, DocWriteRequest.OpType.CREATE);
}

// Fail if the id already exists. Models are not updateable
indexRequestBuilder.setOpType(DocWriteRequest.OpType.CREATE);
indexRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
@Override
public void update(String modelId, Model model, ActionListener<IndexResponse> listener)
throws IOException {
putInternal(modelId, model, listener, DocWriteRequest.OpType.INDEX);
}

// After the model is indexed, update metadata
ActionListener<IndexResponse> putMetadataListener = getUpdateModelMetadataListener(model.getModelMetadata(),
listener);
private void putInternal(@Nullable String modelId, Model model, ActionListener<IndexResponse> listener,
DocWriteRequest.OpType requestOpType) throws IOException {

if (!isCreated()) {
create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(putMetadataListener),
listener::onFailure));
return;
if (model == null) {
throw new IllegalArgumentException("Model cannot be null");
}

indexRequestBuilder.execute(putMetadataListener);
}
ModelMetadata modelMetadata = model.getModelMetadata();

@Override
public void put(Model model, ActionListener<AcknowledgedResponse> listener) throws IOException {
String base64Model = Base64.getEncoder().encodeToString(model.getModelBlob());

Map<String, Object> parameters = ImmutableMap.of(
KNNConstants.KNN_ENGINE, model.getModelMetadata().getKnnEngine().getName(),
KNNConstants.METHOD_PARAMETER_SPACE_TYPE, model.getModelMetadata().getSpaceType().getValue(),
KNNConstants.DIMENSION, model.getModelMetadata().getDimension(),
KNNConstants.MODEL_BLOB_PARAMETER, base64Model
);
Map<String, Object> parameters = new HashMap<String, Object>() {{
put(KNNConstants.KNN_ENGINE, modelMetadata.getKnnEngine().getName());
put(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, modelMetadata.getSpaceType().getValue());
put(KNNConstants.DIMENSION, modelMetadata.getDimension());
put(KNNConstants.MODEL_STATE, modelMetadata.getState().getName());
put(KNNConstants.MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString());
put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription());
put(KNNConstants.MODEL_ERROR, modelMetadata.getError());
}};

byte[] modelBlob = model.getModelBlob();

if (modelBlob == null && ModelState.CREATED.equals(modelMetadata.getState())) {
throw new IllegalArgumentException("Model binary cannot be null when model state is CREATED");
}

// Only add model if it is not null
if (modelBlob != null) {
String base64Model = Base64.getEncoder().encodeToString(modelBlob);
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, base64Model);
}

IndexRequestBuilder indexRequestBuilder = client.prepareIndex(MODEL_INDEX_NAME, "_doc");

// Set id for request only if modelId is present
if (modelId != null) {
indexRequestBuilder.setId(modelId);
}

indexRequestBuilder.setSource(parameters);

// Fail if the id already exists. Models are not updateable
indexRequestBuilder.setOpType(DocWriteRequest.OpType.CREATE);
indexRequestBuilder.setOpType(requestOpType); // Delegate whether this request can update based on opType
indexRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// After the model is indexed, update metadata
ActionListener<IndexResponse> putMetadataListener = getUpdateModelMetadataListener(model.getModelMetadata(),
listener);
// After metadata update finishes, remove item from cache if necessary. If no model id is
// passed then nothing needs to be removed from the cache
//TODO: Bug. Model needs to be removed from all nodes caches, not just local.
// https://github.com/opensearch-project/k-NN/issues/93
ActionListener<IndexResponse> onMetaListener;
if (modelId != null) {
onMetaListener = ActionListener.wrap(response -> {
ModelCache.getInstance().remove(modelId);
listener.onResponse(response);
}, listener::onFailure);
} else {
onMetaListener = listener;
}

// If the index has not been created yet, create it and then add the document
// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
if (ModelState.CREATED.equals(model.getModelMetadata().getState())) {
onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
} else {
onIndexListener = onMetaListener;
}

// Create the model index if it does not already exist
if (!isCreated()) {
create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(putMetadataListener),
listener::onFailure));
create(ActionListener.wrap(createIndexResponse -> indexRequestBuilder.execute(onIndexListener),
onIndexListener::onFailure));
return;
}

indexRequestBuilder.execute(putMetadataListener);
indexRequestBuilder.execute(onIndexListener);
}

private ActionListener<IndexResponse> getUpdateModelMetadataListener(ModelMetadata modelMetadata,
ActionListener<AcknowledgedResponse> listener) {
ActionListener<IndexResponse> listener) {
return ActionListener.wrap(indexResponse -> client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(indexResponse.getId(), false, modelMetadata),
listener
// Here we wrap the IndexResponse listener around an AcknowledgedListener. This allows us
// to pass the indexResponse back up.
ActionListener.wrap(acknowledgedResponse -> listener.onResponse(indexResponse),
listener::onFailure)
), listener::onFailure);
}

Expand All @@ -269,15 +310,23 @@ public Model get(String modelId) throws ExecutionException, InterruptedException
Object engine = responseMap.get(KNNConstants.KNN_ENGINE);
Object space = responseMap.get(KNNConstants.METHOD_PARAMETER_SPACE_TYPE);
Object dimension = responseMap.get(KNNConstants.DIMENSION);
Object state = responseMap.get(KNNConstants.MODEL_STATE);
Object timestamp = responseMap.get(KNNConstants.MODEL_TIMESTAMP);
Object description = responseMap.get(KNNConstants.MODEL_DESCRIPTION);
Object error = responseMap.get(KNNConstants.MODEL_ERROR);
Object blob = responseMap.get(KNNConstants.MODEL_BLOB_PARAMETER);

if (blob == null) {
throw new IllegalArgumentException("No model available in \"" + MODEL_INDEX_NAME + "\" index with id \""
+ modelId + "\".");
// If byte blob is not there, it means that the state has not yet been updated to CREATED.
byte[] byteBlob = null;
if (blob != null) {
byteBlob = Base64.getDecoder().decode((String) blob);
}

ModelMetadata modelMetadata = new ModelMetadata(KNNEngine.getEngine((String) engine),
SpaceType.getSpace((String) space), (Integer) dimension);
return new Model(modelMetadata, Base64.getDecoder().decode((String) blob));
SpaceType.getSpace((String) space), (Integer) dimension, ModelState.getModelState((String) state),
TimeValue.parseTimeValue((String) timestamp, KNNConstants.MODEL_TIMESTAMP), (String) description,
(String) error);
return new Model(modelMetadata, byteBlob);
}

@Override
Expand Down Expand Up @@ -326,6 +375,8 @@ public void delete(String modelId, ActionListener<DeleteResponse> listener) {
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model deletion from the index, remove the model from the model cache
//TODO: Bug. Model needs to be removed from all nodes caches, not just local.
// https://github.com/opensearch-project/k-NN/issues/93
ActionListener<DeleteResponse> onModelDeleteListener = ActionListener.wrap(deleteResponse -> {
ModelCache.getInstance().remove(modelId);
listener.onResponse(deleteResponse);
Expand Down
Loading

0 comments on commit 24016bb

Please sign in to comment.