diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 3fb54d1b3..5a42f983a 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -18,6 +18,7 @@ import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListener; import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.DocWriteResponse; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.delete.DeleteAction; @@ -430,6 +431,11 @@ public void delete(String modelId, ActionListener listener) { //TODO: Bug. Model needs to be removed from all nodes caches, not just local. // https://github.com/opensearch-project/k-NN/issues/93 ActionListener onModelDeleteListener = ActionListener.wrap(deleteResponse -> { + if(deleteResponse.getResult() != DocWriteResponse.Result.DELETED){ + String errorMessage = String.format("Model \" %s \" does not exist", modelId); + listener.onFailure(new ResourceNotFoundException(modelId, errorMessage)); + return; + } ModelCache.getInstance().remove(modelId); listener.onResponse(deleteResponse); }, listener::onFailure); diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 6e8ac85a8..c6bd08ec8 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -33,11 +33,14 @@ import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; +import org.opensearch.knn.plugin.rest.RestDeleteModelHandler; import org.opensearch.knn.plugin.rest.RestGetModelHandler; import org.opensearch.knn.plugin.rest.RestKNNStatsHandler; import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler; import org.opensearch.knn.plugin.script.KNNScoringScriptEngine; import org.opensearch.knn.plugin.stats.KNNStats; +import org.opensearch.knn.plugin.transport.DeleteModelAction; +import org.opensearch.knn.plugin.transport.DeleteModelTransportAction; import org.opensearch.knn.plugin.transport.GetModelAction; import org.opensearch.knn.plugin.transport.GetModelTransportAction; import org.opensearch.knn.plugin.transport.KNNStatsAction; @@ -191,8 +194,11 @@ public List getRestHandlers(Settings settings, RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(settings, restController, clusterService, indexNameExpressionResolver); RestGetModelHandler restGetModelHandler = new RestGetModelHandler(); + RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler(); - return Arrays.asList(restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler); + return ImmutableList.of( + restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler, restDeleteModelHandler + ); } /** @@ -206,7 +212,8 @@ public List getRestHandlers(Settings settings, new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class), new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE, TrainingJobRouteDecisionInfoTransportAction.class), - new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class) + new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class), + new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class) ); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestDeleteModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestDeleteModelHandler.java new file mode 100644 index 000000000..017c0f87d --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestDeleteModelHandler.java @@ -0,0 +1,60 @@ +package org.opensearch.knn.plugin.rest; + +import com.google.common.collect.ImmutableList; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.Strings; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.knn.plugin.transport.DeleteModelAction; +import org.opensearch.knn.plugin.transport.DeleteModelRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.rest.action.admin.cluster.RestNodesUsageAction; + +import java.util.List; +import java.util.Locale; + +import static org.opensearch.knn.common.KNNConstants.MODELS; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; + +public class RestDeleteModelHandler extends BaseRestHandler { + + public static final String NAME = "knn_delete_model_action"; + + /** + * @return the name of RestDeleteModelHandler.This is used in the response to the + * {@link RestNodesUsageAction}. + */ + @Override + public String getName() { + return NAME; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.DELETE, + String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID) + ) + ); + } + + /** + * Prepare the request for deleting model. + * + * @param request the request to execute + * @param client client for executing actions on the local node + * @return the action to execute + */ + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + String modelID = request.param(MODEL_ID); + if (!Strings.hasText(modelID)) { + throw new IllegalArgumentException("model ID cannot be empty"); + } + DeleteModelRequest deleteModelRequest = new DeleteModelRequest(modelID); + return channel -> client.execute(DeleteModelAction.INSTANCE, deleteModelRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelAction.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelAction.java new file mode 100644 index 000000000..0c12aea62 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.delete.DeleteResponse; + +public class DeleteModelAction extends ActionType { + + + public static final DeleteModelAction INSTANCE = new DeleteModelAction(); + public static final String NAME = "cluster:admin/knn_delete_model_action"; + + private DeleteModelAction() { + super(NAME, DeleteResponse::new); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java new file mode 100644 index 000000000..1f1c5315c --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelRequest.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.Strings; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.IOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +public class DeleteModelRequest extends ActionRequest { + + private String modelID; + + public DeleteModelRequest(StreamInput in) throws IOException { + super(in); + this.modelID = in.readString(); + } + + public DeleteModelRequest(String modelID) { + super(); + this.modelID = modelID; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelID); + } + + @Override + public ActionRequestValidationException validate() { + if(Strings.hasText(modelID)) { + return null; + } + return addValidationError("Model id cannot be empty ", null); + } + + public String getModelID() { + return modelID; + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java new file mode 100644 index 000000000..a2b4170a7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/DeleteModelTransportAction.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.inject.Inject; +import org.opensearch.knn.indices.ModelDao; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class DeleteModelTransportAction extends HandledTransportAction { + + + private final ModelDao modelDao; + + @Inject + public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) { + super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new); + this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + } + + @Override + protected void doExecute(Task task, DeleteModelRequest request, ActionListener listener) { + String modelID = request.getModelID(); + modelDao.delete(modelID, listener); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java new file mode 100644 index 000000000..795d5fe02 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.action; + +import org.apache.http.util.EntityUtils; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; +import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.rest.RestStatus; + +import java.io.IOException; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODELS; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; +import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; +import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; + +/** + * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestDeleteModelHandler} + */ + +public class RestDeleteModelHandlerIT extends KNNRestTestCase { + + private ModelMetadata getModelMetadata() { + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, + "2021-03-27", "test model", ""); + } + + public void testDeleteModelExists() throws IOException { + createModelSystemIndex(); + String testModelID = "test-model-id"; + byte[] testModelBlob = "hello".getBytes(); + ModelMetadata testModelMetadata = getModelMetadata(); + + addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); + assertEquals(getDocCount(MODEL_INDEX_NAME),1); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + Request request = new Request("DELETE", restURI); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + assertEquals(getDocCount(MODEL_INDEX_NAME),0); + } + + public void testDeleteModelFailsInvalid() throws IOException { + createModelSystemIndex(); + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id"); + Request request = new Request("DELETE", restURI); + + ResponseException ex = expectThrows(ResponseException.class, () -> + client().performRequest(request)); + assertTrue(ex.getMessage().contains("\"invalid-model-id\"")); + } + + public void testDeleteModelFailsBlank() throws IOException { + createModelSystemIndex(); + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, " "); + Request request = new Request("DELETE", restURI); + + expectThrows(IllegalArgumentException.class, () -> client().performRequest(request)); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/DeleteModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/DeleteModelRequestTests.java new file mode 100644 index 000000000..671e56731 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/DeleteModelRequestTests.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.knn.KNNTestCase; + +import java.io.IOException; + +public class DeleteModelRequestTests extends KNNTestCase { + public void testStreams() throws IOException { + String modelId = "test-model"; + DeleteModelRequest deleteModelRequest = new DeleteModelRequest(modelId); + BytesStreamOutput streamOutput = new BytesStreamOutput(); + deleteModelRequest.writeTo(streamOutput); + DeleteModelRequest deleteModelRequestCopy = new DeleteModelRequest(streamOutput.bytes().streamInput()); + assertEquals(deleteModelRequest.getModelID(), deleteModelRequestCopy.getModelID()); + } +}