Skip to content

Commit

Permalink
[ML] Force stop deployment in use (elastic#80431)
Browse files Browse the repository at this point in the history
Implements a `force` parameter to the stop deployment API.
This allows a user to forcefully stop a deployment. Currently,
this specifically allows stopping a deployment that is in use
by ingest processors.

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
dimitris-athanasiou and elasticmachine committed Nov 8, 2021
1 parent e8bb396 commit f6c40fd
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@ experimental::[]
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

////

[[stop-trained-model-deployment-query-params]]
== {api-query-parms-title}
////

`allow_no_match`::
(Optional, Boolean)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match]


`force`::
(Optional, Boolean) If true, the deployment is stopped even if it is referenced by
ingest pipelines.

////
[role="child_attributes"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
}
}
]
},
"params":{
"allow_no_match":{
"type":"boolean",
"required":false,
"description":"Whether to ignore if a wildcard expression matches no deployments. (This includes `_all` string or when no deployments have been specified)"
},
"force":{
"type":"boolean",
"required":false,
"description":"True if the deployment should be forcefully stopped"
}
},
"body":{
"description":"The stop deployment parameters"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.BaseTasksResponse;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand All @@ -41,6 +45,26 @@ public static class Request extends BaseTasksRequest<Request> implements ToXCont
private boolean allowNoMatch = true;
private boolean force;

private static final ObjectParser<Request, Void> PARSER = new ObjectParser<>(NAME, Request::new);

static {
PARSER.declareString(Request::setId, TrainedModelConfig.MODEL_ID);
PARSER.declareBoolean(Request::setAllowNoMatch, ALLOW_NO_MATCH);
PARSER.declareBoolean(Request::setForce, FORCE);
}

public static Request parseRequest(String id, XContentParser parser) {
Request request = PARSER.apply(parser, null);
if (request.getId() == null) {
request.setId(id);
} else if (Strings.isNullOrEmpty(id) == false && id.equals(request.getId()) == false) {
throw new IllegalArgumentException(
Messages.getMessage(Messages.INCONSISTENT_ID, TrainedModelConfig.MODEL_ID, request.getId(), id)
);
}
return request;
}

public Request(String id) {
setId(id);
}
Expand All @@ -52,6 +76,8 @@ public Request(StreamInput in) throws IOException {
force = in.readBoolean();
}

private Request() {}

public final void setId(String id) {
this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction.Request;

import java.io.IOException;

public class StopTrainedModelDeploymentRequestTests extends AbstractSerializingTestCase<Request> {

@Override
protected Request doParseInstance(XContentParser parser) throws IOException {
return Request.parseRequest(null, parser);
}

@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}

@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(10));
if (randomBoolean()) {
request.setAllowNoMatch(randomBoolean());
}
if (randomBoolean()) {
request.setForce(randomBoolean());
}
return request;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,42 @@ public void testInferencePipelineAgainstUnallocatedModel() throws IOException {
);
}

public void testStopUsedDeploymentByIngestProcessor() throws IOException {
String modelId = "test_stop_used_deployment_by_ingest_processor";
createTrainedModel(modelId);
putModelDefinition(modelId);
putVocabulary(List.of("these", "are", "my", "words"), modelId);
startDeployment(modelId);

client().performRequest(
putPipeline(
"my_pipeline",
"{"
+ "\"processors\": [\n"
+ " {\n"
+ " \"inference\": {\n"
+ " \"model_id\": \""
+ modelId
+ "\"\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}"
)
);
ResponseException ex = expectThrows(ResponseException.class, () -> stopDeployment(modelId));
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(409));
assertThat(
EntityUtils.toString(ex.getResponse().getEntity()),
containsString(
"Cannot stop deployment for model [test_stop_used_deployment_by_ingest_processor] as it is referenced by"
+ " ingest processors; use force to stop the deployment"
)
);

stopDeployment(modelId, true);
}

private int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {
int inferenceCount = 0;
for (var node : nodes) {
Expand Down Expand Up @@ -554,7 +590,15 @@ private Response startDeployment(String modelId, String waitForState) throws IOE
}

private void stopDeployment(String modelId) throws IOException {
Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_stop");
stopDeployment(modelId, false);
}

private void stopDeployment(String modelId, boolean force) throws IOException {
String endpoint = "/_ml/trained_models/" + modelId + "/deployment/_stop";
if (force) {
endpoint += "?force=true";
}
Request request = new Request("POST", endpoint);
client().performRequest(request);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ protected void doExecute(
return;
}

logger.debug("[{}] Received request to undeploy", request.getId());
logger.debug(
() -> new ParameterizedMessage("[{}] Received request to undeploy{}", request.getId(), request.isForce() ? " (force)" : "")
);

ActionListener<GetTrainedModelsAction.Response> getModelListener = ActionListener.wrap(getModelsResponse -> {
List<TrainedModelConfig> models = getModelsResponse.getResources().results();
Expand All @@ -136,10 +138,10 @@ protected void doExecute(
IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE);
Set<String> referencedModels = getReferencedModelKeys(currentIngestMetadata, ingestService);

if (referencedModels.contains(modelId)) {
if (request.isForce() == false && referencedModels.contains(modelId)) {
listener.onFailure(
new ElasticsearchStatusException(
"Cannot stop allocation for model [{}] as it is still referenced by ingest processors",
"Cannot stop deployment for model [{}] as it is referenced by ingest processors; use force to stop the deployment",
RestStatus.CONFLICT,
modelId
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,21 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(modelId);
StopTrainedModelDeploymentAction.Request request;
if (restRequest.hasContentOrSourceParam()) {
request = StopTrainedModelDeploymentAction.Request.parseRequest(modelId, restRequest.contentOrSourceParamParser());
} else {
request = new StopTrainedModelDeploymentAction.Request(modelId);
request.setAllowNoMatch(
restRequest.paramAsBoolean(
StopTrainedModelDeploymentAction.Request.ALLOW_NO_MATCH.getPreferredName(),
request.isAllowNoMatch()
)
);
request.setForce(
restRequest.paramAsBoolean(StopTrainedModelDeploymentAction.Request.FORCE.getPreferredName(), request.isForce())
);
}
return channel -> client.execute(StopTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel));
}
}

0 comments on commit f6c40fd

Please sign in to comment.