Skip to content

Commit

Permalink
Add delete model API (#118)
Browse files Browse the repository at this point in the history
Added delete model  API to delete previously added model for given ModelID.
This requires Handler to register route and prepare request,
DeleteModelRequest, DeleteModelResponse to represent request and response,
DeleteModelTransportAction to sends request across nodes.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB authored and jmazanec15 committed Oct 22, 2021
1 parent f31e260 commit ae50235
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.delete.DeleteAction;
Expand Down Expand Up @@ -430,6 +431,11 @@ public void delete(String modelId, ActionListener<DeleteResponse> listener) {
//TODO: Bug. Model needs to be removed from all nodes caches, not just local.
// https://github.com/opensearch-project/k-NN/issues/93
ActionListener<DeleteResponse> onModelDeleteListener = ActionListener.wrap(deleteResponse -> {
if(deleteResponse.getResult() != DocWriteResponse.Result.DELETED){
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
listener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
ModelCache.getInstance().remove(modelId);
listener.onResponse(deleteResponse);
}, listener::onFailure);
Expand Down
11 changes: 9 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.plugin.rest.RestDeleteModelHandler;
import org.opensearch.knn.plugin.rest.RestGetModelHandler;
import org.opensearch.knn.plugin.rest.RestKNNStatsHandler;
import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.opensearch.knn.plugin.stats.KNNStats;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
import org.opensearch.knn.plugin.transport.DeleteModelTransportAction;
import org.opensearch.knn.plugin.transport.GetModelAction;
import org.opensearch.knn.plugin.transport.GetModelTransportAction;
import org.opensearch.knn.plugin.transport.KNNStatsAction;
Expand Down Expand Up @@ -191,8 +194,11 @@ public List<RestHandler> getRestHandlers(Settings settings,
RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(settings, restController, clusterService,
indexNameExpressionResolver);
RestGetModelHandler restGetModelHandler = new RestGetModelHandler();
RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler();

return Arrays.asList(restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler);
return ImmutableList.of(
restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler, restDeleteModelHandler
);
}

/**
Expand All @@ -206,7 +212,8 @@ public List<RestHandler> getRestHandlers(Settings settings,
new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class),
new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE,
TrainingJobRouteDecisionInfoTransportAction.class),
new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class)
new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class),
new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.opensearch.knn.plugin.rest;

import com.google.common.collect.ImmutableList;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.Strings;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
import org.opensearch.knn.plugin.transport.DeleteModelRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
import org.opensearch.rest.action.admin.cluster.RestNodesUsageAction;

import java.util.List;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;

public class RestDeleteModelHandler extends BaseRestHandler {

public static final String NAME = "knn_delete_model_action";

/**
* @return the name of RestDeleteModelHandler.This is used in the response to the
* {@link RestNodesUsageAction}.
*/
@Override
public String getName() {
return NAME;
}

@Override
public List<Route> routes() {
return ImmutableList
.of(
new Route(
RestRequest.Method.DELETE,
String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID)
)
);
}

/**
* Prepare the request for deleting model.
*
* @param request the request to execute
* @param client client for executing actions on the local node
* @return the action to execute
*/
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
String modelID = request.param(MODEL_ID);
if (!Strings.hasText(modelID)) {
throw new IllegalArgumentException("model ID cannot be empty");
}
DeleteModelRequest deleteModelRequest = new DeleteModelRequest(modelID);
return channel -> client.execute(DeleteModelAction.INSTANCE, deleteModelRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionType;
import org.opensearch.action.delete.DeleteResponse;

public class DeleteModelAction extends ActionType<DeleteResponse> {


public static final DeleteModelAction INSTANCE = new DeleteModelAction();
public static final String NAME = "cluster:admin/knn_delete_model_action";

private DeleteModelAction() {
super(NAME, DeleteResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.Strings;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;

import java.io.IOException;

import static org.opensearch.action.ValidateActions.addValidationError;

public class DeleteModelRequest extends ActionRequest {

private String modelID;

public DeleteModelRequest(StreamInput in) throws IOException {
super(in);
this.modelID = in.readString();
}

public DeleteModelRequest(String modelID) {
super();
this.modelID = modelID;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelID);
}

@Override
public ActionRequestValidationException validate() {
if(Strings.hasText(modelID)) {
return null;
}
return addValidationError("Model id cannot be empty ", null);
}

public String getModelID() {
return modelID;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionListener;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class DeleteModelTransportAction extends HandledTransportAction<DeleteModelRequest, DeleteResponse> {


private final ModelDao modelDao;

@Inject
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) {
super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
}

@Override
protected void doExecute(Task task, DeleteModelRequest request, ActionListener<DeleteResponse> listener) {
String modelID = request.getModelID();
modelDao.delete(modelID, listener);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.action;

import org.apache.http.util.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.rest.RestStatus;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;

/**
* Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestDeleteModelHandler}
*/

public class RestDeleteModelHandlerIT extends KNNRestTestCase {

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED,
"2021-03-27", "test model", "");
}

public void testDeleteModelExists() throws IOException {
createModelSystemIndex();
String testModelID = "test-model-id";
byte[] testModelBlob = "hello".getBytes();
ModelMetadata testModelMetadata = getModelMetadata();

addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob);
assertEquals(getDocCount(MODEL_INDEX_NAME),1);

String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID);
Request request = new Request("DELETE", restURI);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

assertEquals(getDocCount(MODEL_INDEX_NAME),0);
}

public void testDeleteModelFailsInvalid() throws IOException {
createModelSystemIndex();
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id");
Request request = new Request("DELETE", restURI);

ResponseException ex = expectThrows(ResponseException.class, () ->
client().performRequest(request));
assertTrue(ex.getMessage().contains("\"invalid-model-id\""));
}

public void testDeleteModelFailsBlank() throws IOException {
createModelSystemIndex();
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, " ");
Request request = new Request("DELETE", restURI);

expectThrows(IllegalArgumentException.class, () -> client().performRequest(request));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;

public class DeleteModelRequestTests extends KNNTestCase {
public void testStreams() throws IOException {
String modelId = "test-model";
DeleteModelRequest deleteModelRequest = new DeleteModelRequest(modelId);
BytesStreamOutput streamOutput = new BytesStreamOutput();
deleteModelRequest.writeTo(streamOutput);
DeleteModelRequest deleteModelRequestCopy = new DeleteModelRequest(streamOutput.bytes().streamInput());
assertEquals(deleteModelRequest.getModelID(), deleteModelRequestCopy.getModelID());
}
}

0 comments on commit ae50235

Please sign in to comment.