diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index b61b87e4a8139..ead7c836463fd 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -105,7 +105,8 @@ protected void masterOperation(Task task, Request request, ClusterState state, A .execute(() -> importModel(client, taskManager, request, modelImporter, listener, downloadTask)); } catch (Exception e) { taskManager.unregister(downloadTask); - throw e; + listener.onFailure(e); + return; } if (request.isWaitForCompletion() == false) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java index 093e4213a5db1..7eba51176aacf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java @@ -139,7 +139,7 @@ static void cancelDownloadTask(Client client, String modelId, ActionListener null, taskListener); } static Set getReferencedModelKeys(IngestMetadata ingestMetadata, IngestService ingestService) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 242d5e00f0ec7..9df04aa9c09d5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -403,14 +403,21 @@ static void checkForExistingTask( ActionListener storeModelListener, TimeValue timeout ) { - TaskRetriever.getDownloadTaskInfo(client, modelId, isWaitForCompletion, ActionListener.wrap(taskInfo -> { - if (taskInfo != null) { - getModelInformation(client, modelId, sendResponseListener); - } else { - // no task exists so proceed with creating the model - storeModelListener.onResponse(null); - } - }, sendResponseListener::onFailure), timeout); + TaskRetriever.getDownloadTaskInfo( + client, + modelId, + isWaitForCompletion, + timeout, + () -> "Timed out waiting for model download to complete", + ActionListener.wrap(taskInfo -> { + if (taskInfo != null) { + getModelInformation(client, modelId, sendResponseListener); + } else { + // no task exists so proceed with creating the model + storeModelListener.onResponse(null); + } + }, sendResponseListener::onFailure) + ); } private static void getExistingTaskInfo(Client client, String modelId, boolean waitForCompletion, ActionListener listener) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index 4a569b374582a..9f2a7e349f163 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -559,21 +559,29 @@ private static void step1CheckForDownloadTask( String modelId, ActionListener nextStepListener ) { - TaskRetriever.getDownloadTaskInfo(mlOriginClient, modelId, timeout != null, ActionListener.wrap(taskInfo -> { - if (taskInfo == null) { - nextStepListener.onResponse(null); - } else { - failOrRespondWith0( - () -> new ElasticsearchStatusException( - Messages.getMessage(Messages.MODEL_DOWNLOAD_IN_PROGRESS, modelId), - RestStatus.REQUEST_TIMEOUT - ), - errorIfDefinitionIsMissing, - modelId, - failureListener - ); - } - }, failureListener::onFailure), timeout); + // check task is present, do not wait for completion + TaskRetriever.getDownloadTaskInfo( + mlOriginClient, + modelId, + timeout != null, + timeout, + () -> Messages.getMessage(Messages.MODEL_DOWNLOAD_IN_PROGRESS, modelId), + ActionListener.wrap(taskInfo -> { + if (taskInfo == null) { + nextStepListener.onResponse(null); + } else { + failOrRespondWith0( + () -> new ElasticsearchStatusException( + Messages.getMessage(Messages.MODEL_DOWNLOAD_IN_PROGRESS, modelId), + RestStatus.REQUEST_TIMEOUT + ), + errorIfDefinitionIsMissing, + modelId, + failureListener + ); + } + }, failureListener::onFailure) + ); } private static void failOrRespondWith0( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/TaskRetriever.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/TaskRetriever.java index 652592bb08591..b60f57e5aaaf6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/TaskRetriever.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/TaskRetriever.java @@ -7,20 +7,28 @@ package org.elasticsearch.xpack.ml.utils; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ElasticsearchTimeoutException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.core.TimeValue; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.transport.ReceiveTimeoutTransportException; import org.elasticsearch.xpack.core.ml.MlTasks; +import java.util.function.Supplier; + import static org.elasticsearch.xpack.core.ml.MlTasks.downloadModelTaskDescription; /** * Utility class for retrieving download tasks created by a PUT trained model API request. */ public class TaskRetriever { + /** * Returns a {@link TaskInfo} if one exists representing an in-progress trained model download. * @@ -28,16 +36,18 @@ public class TaskRetriever { * @param modelId the id of the model to check for an existing task * @param waitForCompletion a boolean flag determine if the request should wait for an existing task to complete before returning (aka * wait for the download to complete) + * @param timeout the timeout value in seconds that the request should fail if it does not complete + * @param errorMessageOnWaitTimeout Message to use if the request times out with {@code waitForCompletion == true} * @param listener a listener, if a task is found it is returned via {@code ActionListener.onResponse(taskInfo)}. * If a task is not found null is returned - * @param timeout the timeout value in seconds that the request should fail if it does not complete */ public static void getDownloadTaskInfo( Client client, String modelId, boolean waitForCompletion, - ActionListener listener, - TimeValue timeout + TimeValue timeout, + Supplier errorMessageOnWaitTimeout, + ActionListener listener ) { client.admin() .cluster() @@ -53,19 +63,46 @@ public static void getDownloadTaskInfo( if (tasks.size() > 0) { // there really shouldn't be more than a single task but if there is we'll just use the first one listener.onResponse(tasks.get(0)); + } else if (waitForCompletion && didItTimeout(response)) { + listener.onFailure(taskDidNotCompleteException(errorMessageOnWaitTimeout.get())); } else { + response.rethrowFailures("Checking model [" + modelId + "] download status"); listener.onResponse(null); } - }, - e -> listener.onFailure( + }, e -> { + listener.onFailure( new ElasticsearchStatusException( "Unable to retrieve task information for model id [{}]", RestStatus.INTERNAL_SERVER_ERROR, e, modelId ) - ) - )); + ); + })); + } + + private static boolean didItTimeout(ListTasksResponse response) { + if (response.getNodeFailures().isEmpty() == false) { + // if one node timed out then the others will also have timed out + var firstNodeFailure = response.getNodeFailures().get(0); + if (firstNodeFailure.status() == RestStatus.REQUEST_TIMEOUT) { + return true; + } + + var timeoutException = ExceptionsHelper.unwrap( + firstNodeFailure, + ElasticsearchTimeoutException.class, + ReceiveTimeoutTransportException.class + ); + if (timeoutException != null) { + return true; + } + } + return false; + } + + private static ElasticsearchException taskDidNotCompleteException(String message) { + return new ElasticsearchStatusException(message, RestStatus.REQUEST_TIMEOUT); } private TaskRetriever() {} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TaskRetrieverTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TaskRetrieverTests.java index 719a9be43080f..6ee39266ba5fc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TaskRetrieverTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/TaskRetrieverTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.utils; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequestBuilder; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; @@ -67,7 +68,7 @@ public void testGetExistingTaskInfoCallsOnFailureForAnError() { var listener = new PlainActionFuture(); - getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT); + getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener); var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(exception.status(), is(RestStatus.INTERNAL_SERVER_ERROR)); @@ -78,7 +79,7 @@ public void testGetExistingTaskInfoCallsListenerWithNullWhenNoTasksExist() { var client = mockClientWithTasksResponse(Collections.emptyList(), threadPool); var listener = new PlainActionFuture(); - getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT); + getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener); assertThat(listener.actionGet(TIMEOUT), nullValue()); } @@ -88,7 +89,7 @@ public void testGetExistingTaskInfoCallsListenerWithTaskInfoWhenTaskExists() { var client = mockClientWithTasksResponse(listTaskInfo, threadPool); var listener = new PlainActionFuture(); - getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT); + getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener); assertThat(listener.actionGet(TIMEOUT), is(listTaskInfo.get(0))); } @@ -98,11 +99,37 @@ public void testGetExistingTaskInfoCallsListenerWithFirstTaskInfoWhenMultipleTas var client = mockClientWithTasksResponse(listTaskInfo, threadPool); var listener = new PlainActionFuture(); - getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT); + getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener); assertThat(listener.actionGet(TIMEOUT), is(listTaskInfo.get(0))); } + public void testGetTimeoutOnWaitForCompletion() { + var client = mockListTasksClient(threadPool); + + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener actionListener = (ActionListener) invocationOnMock.getArguments()[2]; + actionListener.onResponse( + new ListTasksResponse( + List.of(), + List.of(), + List.of(new ElasticsearchStatusException("node timeout", RestStatus.REQUEST_TIMEOUT)) + ) + ); + + return Void.TYPE; + }).when(client).execute(same(TransportListTasksAction.TYPE), any(), any()); + + var listener = new PlainActionFuture(); + + getDownloadTaskInfo(client, "modelId", true, TIMEOUT, () -> "Testing timeout", listener); + + var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exception.status(), is(RestStatus.REQUEST_TIMEOUT)); + assertThat(exception.getMessage(), is("Testing timeout")); + } + /** * A helper method for setting up a mock cluster client to return the passed in list of tasks. * diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index af3ecd2637843..fdccf473b358a 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -218,9 +218,6 @@ setup: --- "Test start deployment fails while model download in progress": - - skip: - version: "all" - reason: "Awaits fix: https://github.com/elastic/elasticsearch/issues/102948" - do: ml.put_trained_model: model_id: .elser_model_2 @@ -230,10 +227,13 @@ setup: "field_names": ["text_field"] } } + # Set a low timeout so the test doesn't actually wait + # for the model download to complete - do: catch: /Model download task is currently running\. Wait for trained model \[.elser_model_2\] download task to complete then try again/ ml.start_trained_model_deployment: model_id: .elser_model_2 + timeout: 1s - do: ml.delete_trained_model: model_id: .elser_model_2