diff --git a/docs/reference/ml/df-analytics/apis/stop-trained-model-deployment.asciidoc b/docs/reference/ml/df-analytics/apis/stop-trained-model-deployment.asciidoc index f5dfcd82d4d75..eb58a9baf8d8c 100644 --- a/docs/reference/ml/df-analytics/apis/stop-trained-model-deployment.asciidoc +++ b/docs/reference/ml/df-analytics/apis/stop-trained-model-deployment.asciidoc @@ -30,10 +30,18 @@ experimental::[] (Required, string) include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] -//// + [[stop-trained-model-deployment-query-params]] == {api-query-parms-title} -//// + +`allow_no_match`:: +(Optional, Boolean) +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=allow-no-match] + + +`force`:: + (Optional, Boolean) If true, the deployment is stopped even if it is referenced by + ingest pipelines. //// [role="child_attributes"] diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.stop_trained_model_deployment.json b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.stop_trained_model_deployment.json index 3e608a890b0a1..f682be52be0e9 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/ml.stop_trained_model_deployment.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/ml.stop_trained_model_deployment.json @@ -26,6 +26,21 @@ } } ] + }, + "params":{ + "allow_no_match":{ + "type":"boolean", + "required":false, + "description":"Whether to ignore if a wildcard expression matches no deployments. (This includes `_all` string or when no deployments have been specified)" + }, + "force":{ + "type":"boolean", + "required":false, + "description":"True if the deployment should be forcefully stopped" + } + }, + "body":{ + "description":"The stop deployment parameters" } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java index 24a6adc899920..e0c8ad6fd90ea 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentAction.java @@ -10,14 +10,18 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.tasks.BaseTasksRequest; import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.tasks.Task; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -41,6 +45,26 @@ public static class Request extends BaseTasksRequest implements ToXCont private boolean allowNoMatch = true; private boolean force; + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + + static { + PARSER.declareString(Request::setId, TrainedModelConfig.MODEL_ID); + PARSER.declareBoolean(Request::setAllowNoMatch, ALLOW_NO_MATCH); + PARSER.declareBoolean(Request::setForce, FORCE); + } + + public static Request parseRequest(String id, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (request.getId() == null) { + request.setId(id); + } else if (Strings.isNullOrEmpty(id) == false && id.equals(request.getId()) == false) { + throw new IllegalArgumentException( + Messages.getMessage(Messages.INCONSISTENT_ID, TrainedModelConfig.MODEL_ID, request.getId(), id) + ); + } + return request; + } + public Request(String id) { setId(id); } @@ -52,6 +76,8 @@ public Request(StreamInput in) throws IOException { force = in.readBoolean(); } + private Request() {} + public final void setId(String id) { this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentRequestTests.java new file mode 100644 index 0000000000000..286926ead757b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopTrainedModelDeploymentRequestTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction.Request; + +import java.io.IOException; + +public class StopTrainedModelDeploymentRequestTests extends AbstractSerializingTestCase { + + @Override + protected Request doParseInstance(XContentParser parser) throws IOException { + return Request.parseRequest(null, parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(10)); + if (randomBoolean()) { + request.setAllowNoMatch(randomBoolean()); + } + if (randomBoolean()) { + request.setForce(randomBoolean()); + } + return request; + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 2d4e3e46f7cea..e6602705dce41 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -484,6 +484,42 @@ public void testInferencePipelineAgainstUnallocatedModel() throws IOException { ); } + public void testStopUsedDeploymentByIngestProcessor() throws IOException { + String modelId = "test_stop_used_deployment_by_ingest_processor"; + createTrainedModel(modelId); + putModelDefinition(modelId); + putVocabulary(List.of("these", "are", "my", "words"), modelId); + startDeployment(modelId); + + client().performRequest( + putPipeline( + "my_pipeline", + "{" + + "\"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"model_id\": \"" + + modelId + + "\"\n" + + " }\n" + + " }\n" + + " ]\n" + + "}" + ) + ); + ResponseException ex = expectThrows(ResponseException.class, () -> stopDeployment(modelId)); + assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(409)); + assertThat( + EntityUtils.toString(ex.getResponse().getEntity()), + containsString( + "Cannot stop deployment for model [test_stop_used_deployment_by_ingest_processor] as it is referenced by" + + " ingest processors; use force to stop the deployment" + ) + ); + + stopDeployment(modelId, true); + } + private int sumInferenceCountOnNodes(List> nodes) { int inferenceCount = 0; for (var node : nodes) { @@ -554,7 +590,15 @@ private Response startDeployment(String modelId, String waitForState) throws IOE } private void stopDeployment(String modelId) throws IOException { - Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/deployment/_stop"); + stopDeployment(modelId, false); + } + + private void stopDeployment(String modelId, boolean force) throws IOException { + String endpoint = "/_ml/trained_models/" + modelId + "/deployment/_stop"; + if (force) { + endpoint += "?force=true"; + } + Request request = new Request("POST", endpoint); client().performRequest(request); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java index 0c2f20ca36b4f..2610f12556e89 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopTrainedModelDeploymentAction.java @@ -109,7 +109,9 @@ protected void doExecute( return; } - logger.debug("[{}] Received request to undeploy", request.getId()); + logger.debug( + () -> new ParameterizedMessage("[{}] Received request to undeploy{}", request.getId(), request.isForce() ? " (force)" : "") + ); ActionListener getModelListener = ActionListener.wrap(getModelsResponse -> { List models = getModelsResponse.getResources().results(); @@ -136,10 +138,10 @@ protected void doExecute( IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE); Set referencedModels = getReferencedModelKeys(currentIngestMetadata, ingestService); - if (referencedModels.contains(modelId)) { + if (request.isForce() == false && referencedModels.contains(modelId)) { listener.onFailure( new ElasticsearchStatusException( - "Cannot stop allocation for model [{}] as it is still referenced by ingest processors", + "Cannot stop deployment for model [{}] as it is referenced by ingest processors; use force to stop the deployment", RestStatus.CONFLICT, modelId ) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java index dd83734554c6f..b555d421d9659 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStopTrainedModelDeploymentAction.java @@ -38,7 +38,21 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); - StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(modelId); + StopTrainedModelDeploymentAction.Request request; + if (restRequest.hasContentOrSourceParam()) { + request = StopTrainedModelDeploymentAction.Request.parseRequest(modelId, restRequest.contentOrSourceParamParser()); + } else { + request = new StopTrainedModelDeploymentAction.Request(modelId); + request.setAllowNoMatch( + restRequest.paramAsBoolean( + StopTrainedModelDeploymentAction.Request.ALLOW_NO_MATCH.getPreferredName(), + request.isAllowNoMatch() + ) + ); + request.setForce( + restRequest.paramAsBoolean(StopTrainedModelDeploymentAction.Request.FORCE.getPreferredName(), request.isForce()) + ); + } return channel -> client.execute(StopTrainedModelDeploymentAction.INSTANCE, request, new RestToXContentListener<>(channel)); } }