Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FAIS delete model API #118

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.delete.DeleteAction;
Expand Down Expand Up @@ -430,6 +431,11 @@ public void delete(String modelId, ActionListener<DeleteResponse> listener) {
//TODO: Bug. Model needs to be removed from all nodes caches, not just local.
// https://github.com/opensearch-project/k-NN/issues/93
ActionListener<DeleteResponse> onModelDeleteListener = ActionListener.wrap(deleteResponse -> {
if(deleteResponse.getResult() != DocWriteResponse.Result.DELETED){
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
listener.onFailure(new ResourceNotFoundException(modelId, errorMessage));
return;
}
ModelCache.getInstance().remove(modelId);
listener.onResponse(deleteResponse);
}, listener::onFailure);
Expand Down
11 changes: 9 additions & 2 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.plugin.rest.RestDeleteModelHandler;
import org.opensearch.knn.plugin.rest.RestGetModelHandler;
import org.opensearch.knn.plugin.rest.RestKNNStatsHandler;
import org.opensearch.knn.plugin.rest.RestKNNWarmupHandler;
import org.opensearch.knn.plugin.script.KNNScoringScriptEngine;
import org.opensearch.knn.plugin.stats.KNNStats;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
import org.opensearch.knn.plugin.transport.DeleteModelTransportAction;
import org.opensearch.knn.plugin.transport.GetModelAction;
import org.opensearch.knn.plugin.transport.GetModelTransportAction;
import org.opensearch.knn.plugin.transport.KNNStatsAction;
Expand Down Expand Up @@ -191,8 +194,11 @@ public List<RestHandler> getRestHandlers(Settings settings,
RestKNNWarmupHandler restKNNWarmupHandler = new RestKNNWarmupHandler(settings, restController, clusterService,
indexNameExpressionResolver);
RestGetModelHandler restGetModelHandler = new RestGetModelHandler();
RestDeleteModelHandler restDeleteModelHandler = new RestDeleteModelHandler();

return Arrays.asList(restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler);
return ImmutableList.of(
restKNNStatsHandler, restKNNWarmupHandler, restGetModelHandler, restDeleteModelHandler
);
}

/**
Expand All @@ -206,7 +212,8 @@ public List<RestHandler> getRestHandlers(Settings settings,
new ActionHandler<>(UpdateModelMetadataAction.INSTANCE, UpdateModelMetadataTransportAction.class),
new ActionHandler<>(TrainingJobRouteDecisionInfoAction.INSTANCE,
TrainingJobRouteDecisionInfoTransportAction.class),
new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class)
new ActionHandler<>(GetModelAction.INSTANCE, GetModelTransportAction.class),
new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package org.opensearch.knn.plugin.rest;

import com.google.common.collect.ImmutableList;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.Strings;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.DeleteModelAction;
import org.opensearch.knn.plugin.transport.DeleteModelRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
import org.opensearch.rest.action.admin.cluster.RestNodesUsageAction;

import java.util.List;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;

public class RestDeleteModelHandler extends BaseRestHandler {

public static final String NAME = "knn_delete_model_action";

/**
* @return the name of RestDeleteModelHandler.This is used in the response to the
* {@link RestNodesUsageAction}.
*/
@Override
public String getName() {
return NAME;
}

@Override
public List<Route> routes() {
return ImmutableList
.of(
new Route(
RestRequest.Method.DELETE,
String.format(Locale.ROOT, "%s/%s/{%s}", KNNPlugin.KNN_BASE_URI, MODELS, MODEL_ID)
)
);
}

/**
* Prepare the request for deleting model.
*
* @param request the request to execute
* @param client client for executing actions on the local node
* @return the action to execute
*/
@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
String modelID = request.param(MODEL_ID);
if (!Strings.hasText(modelID)) {
throw new IllegalArgumentException("model ID cannot be empty");
}
DeleteModelRequest deleteModelRequest = new DeleteModelRequest(modelID);
return channel -> client.execute(DeleteModelAction.INSTANCE, deleteModelRequest, new RestToXContentListener<>(channel));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionType;
import org.opensearch.action.delete.DeleteResponse;

public class DeleteModelAction extends ActionType<DeleteResponse> {


public static final DeleteModelAction INSTANCE = new DeleteModelAction();
public static final String NAME = "cluster:admin/knn_delete_model_action";

private DeleteModelAction() {
super(NAME, DeleteResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.Strings;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;

import java.io.IOException;

import static org.opensearch.action.ValidateActions.addValidationError;

public class DeleteModelRequest extends ActionRequest {

private String modelID;

public DeleteModelRequest(StreamInput in) throws IOException {
super(in);
this.modelID = in.readString();
}

public DeleteModelRequest(String modelID) {
super();
this.modelID = modelID;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelID);
}

@Override
public ActionRequestValidationException validate() {
if(Strings.hasText(modelID)) {
return null;
}
return addValidationError("Model id cannot be empty ", null);
}

public String getModelID() {
return modelID;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.action.ActionListener;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class DeleteModelTransportAction extends HandledTransportAction<DeleteModelRequest, DeleteResponse> {


private final ModelDao modelDao;

@Inject
public DeleteModelTransportAction(TransportService transportService, ActionFilters filters) {
super(DeleteModelAction.NAME, transportService, filters, DeleteModelRequest::new);
this.modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
}

@Override
protected void doExecute(Task task, DeleteModelRequest request, ActionListener<DeleteResponse> listener) {
String modelID = request.getModelID();
modelDao.delete(modelID, listener);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.action;

import org.apache.http.util.EntityUtils;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.rest.RestStatus;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.MODELS;
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR;
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_STATE;
import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP;

/**
* Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestDeleteModelHandler}
*/

public class RestDeleteModelHandlerIT extends KNNRestTestCase {

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED,
"2021-03-27", "test model", "");
}

public void testDeleteModelExists() throws IOException {
createModelSystemIndex();
String testModelID = "test-model-id";
byte[] testModelBlob = "hello".getBytes();
ModelMetadata testModelMetadata = getModelMetadata();

addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob);
assertEquals(getDocCount(MODEL_INDEX_NAME),1);

String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID);
Request request = new Request("DELETE", restURI);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

assertEquals(getDocCount(MODEL_INDEX_NAME),0);
}

public void testDeleteModelFailsInvalid() throws IOException {
createModelSystemIndex();
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id");
Request request = new Request("DELETE", restURI);

ResponseException ex = expectThrows(ResponseException.class, () ->
client().performRequest(request));
assertTrue(ex.getMessage().contains("\"invalid-model-id\""));
}

public void testDeleteModelFailsBlank() throws IOException {
createModelSystemIndex();
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, " ");
Request request = new Request("DELETE", restURI);

expectThrows(IllegalArgumentException.class, () -> client().performRequest(request));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.plugin.transport;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.knn.KNNTestCase;

import java.io.IOException;

public class DeleteModelRequestTests extends KNNTestCase {
public void testStreams() throws IOException {
String modelId = "test-model";
DeleteModelRequest deleteModelRequest = new DeleteModelRequest(modelId);
BytesStreamOutput streamOutput = new BytesStreamOutput();
deleteModelRequest.writeTo(streamOutput);
DeleteModelRequest deleteModelRequestCopy = new DeleteModelRequest(streamOutput.bytes().streamInput());
assertEquals(deleteModelRequest.getModelID(), deleteModelRequestCopy.getModelID());
}
}