Skip to content

Commit

Permalink
[ML] Detect timeout when waiting for download task (elastic#103197)
Browse files Browse the repository at this point in the history
A list tasks timeout indicates the task exists and is in progress.
Interpreting the timeout as the task not existing meant the download
check would incorrectly assume the download had completed.
# Conflicts:
#	x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java
#	x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml
  • Loading branch information
davidkyle committed Jan 15, 2024
1 parent 653086e commit 8578dce
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ protected void masterOperation(Task task, Request request, ClusterState state, A
.execute(() -> importModel(client, taskManager, request, modelImporter, listener, downloadTask));
} catch (Exception e) {
taskManager.unregister(downloadTask);
throw e;
listener.onFailure(e);
return;
}

if (request.isWaitForCompletion() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ static void cancelDownloadTask(Client client, String modelId, ActionListener<Can
);

// setting waitForCompletion to false here so that we don't block waiting for an existing task to complete before returning it
getDownloadTaskInfo(mlClient, modelId, false, taskListener, timeout);
getDownloadTaskInfo(mlClient, modelId, false, timeout, () -> null, taskListener);
}

static Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata, IngestService ingestService) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,14 +403,21 @@ static void checkForExistingTask(
ActionListener<Void> storeModelListener,
TimeValue timeout
) {
TaskRetriever.getDownloadTaskInfo(client, modelId, isWaitForCompletion, ActionListener.wrap(taskInfo -> {
if (taskInfo != null) {
getModelInformation(client, modelId, sendResponseListener);
} else {
// no task exists so proceed with creating the model
storeModelListener.onResponse(null);
}
}, sendResponseListener::onFailure), timeout);
TaskRetriever.getDownloadTaskInfo(
client,
modelId,
isWaitForCompletion,
timeout,
() -> "Timed out waiting for model download to complete",
ActionListener.wrap(taskInfo -> {
if (taskInfo != null) {
getModelInformation(client, modelId, sendResponseListener);
} else {
// no task exists so proceed with creating the model
storeModelListener.onResponse(null);
}
}, sendResponseListener::onFailure)
);
}

private static void getExistingTaskInfo(Client client, String modelId, boolean waitForCompletion, ActionListener<TaskInfo> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,21 +559,29 @@ private static void step1CheckForDownloadTask(
String modelId,
ActionListener<Runnable> nextStepListener
) {
TaskRetriever.getDownloadTaskInfo(mlOriginClient, modelId, timeout != null, ActionListener.wrap(taskInfo -> {
if (taskInfo == null) {
nextStepListener.onResponse(null);
} else {
failOrRespondWith0(
() -> new ElasticsearchStatusException(
Messages.getMessage(Messages.MODEL_DOWNLOAD_IN_PROGRESS, modelId),
RestStatus.REQUEST_TIMEOUT
),
errorIfDefinitionIsMissing,
modelId,
failureListener
);
}
}, failureListener::onFailure), timeout);
// check task is present, do not wait for completion
TaskRetriever.getDownloadTaskInfo(
mlOriginClient,
modelId,
timeout != null,
timeout,
() -> Messages.getMessage(Messages.MODEL_DOWNLOAD_IN_PROGRESS, modelId),
ActionListener.wrap(taskInfo -> {
if (taskInfo == null) {
nextStepListener.onResponse(null);
} else {
failOrRespondWith0(
() -> new ElasticsearchStatusException(
Messages.getMessage(Messages.MODEL_DOWNLOAD_IN_PROGRESS, modelId),
RestStatus.REQUEST_TIMEOUT
),
errorIfDefinitionIsMissing,
modelId,
failureListener
);
}
}, failureListener::onFailure)
);
}

private static void failOrRespondWith0(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,47 @@

package org.elasticsearch.xpack.ml.utils;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.transport.ReceiveTimeoutTransportException;
import org.elasticsearch.xpack.core.ml.MlTasks;

import java.util.function.Supplier;

import static org.elasticsearch.xpack.core.ml.MlTasks.downloadModelTaskDescription;

/**
* Utility class for retrieving download tasks created by a PUT trained model API request.
*/
public class TaskRetriever {

/**
* Returns a {@link TaskInfo} if one exists representing an in-progress trained model download.
*
* @param client a {@link Client} used to retrieve the task
* @param modelId the id of the model to check for an existing task
* @param waitForCompletion a boolean flag determine if the request should wait for an existing task to complete before returning (aka
* wait for the download to complete)
* @param timeout the timeout value in seconds that the request should fail if it does not complete
* @param errorMessageOnWaitTimeout Message to use if the request times out with {@code waitForCompletion == true}
* @param listener a listener, if a task is found it is returned via {@code ActionListener.onResponse(taskInfo)}.
* If a task is not found null is returned
* @param timeout the timeout value in seconds that the request should fail if it does not complete
*/
public static void getDownloadTaskInfo(
Client client,
String modelId,
boolean waitForCompletion,
ActionListener<TaskInfo> listener,
TimeValue timeout
TimeValue timeout,
Supplier<String> errorMessageOnWaitTimeout,
ActionListener<TaskInfo> listener
) {
client.admin()
.cluster()
Expand All @@ -53,19 +63,46 @@ public static void getDownloadTaskInfo(
if (tasks.size() > 0) {
// there really shouldn't be more than a single task but if there is we'll just use the first one
listener.onResponse(tasks.get(0));
} else if (waitForCompletion && didItTimeout(response)) {
listener.onFailure(taskDidNotCompleteException(errorMessageOnWaitTimeout.get()));
} else {
response.rethrowFailures("Checking model [" + modelId + "] download status");
listener.onResponse(null);
}
},
e -> listener.onFailure(
}, e -> {
listener.onFailure(
new ElasticsearchStatusException(
"Unable to retrieve task information for model id [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
e,
modelId
)
)
));
);
}));
}

private static boolean didItTimeout(ListTasksResponse response) {
if (response.getNodeFailures().isEmpty() == false) {
// if one node timed out then the others will also have timed out
var firstNodeFailure = response.getNodeFailures().get(0);
if (firstNodeFailure.status() == RestStatus.REQUEST_TIMEOUT) {
return true;
}

var timeoutException = ExceptionsHelper.unwrap(
firstNodeFailure,
ElasticsearchTimeoutException.class,
ReceiveTimeoutTransportException.class
);
if (timeoutException != null) {
return true;
}
}
return false;
}

private static ElasticsearchException taskDidNotCompleteException(String message) {
return new ElasticsearchStatusException(message, RestStatus.REQUEST_TIMEOUT);
}

private TaskRetriever() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.utils;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequestBuilder;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
Expand Down Expand Up @@ -67,7 +68,7 @@ public void testGetExistingTaskInfoCallsOnFailureForAnError() {

var listener = new PlainActionFuture<TaskInfo>();

getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT);
getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener);

var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.status(), is(RestStatus.INTERNAL_SERVER_ERROR));
Expand All @@ -78,7 +79,7 @@ public void testGetExistingTaskInfoCallsListenerWithNullWhenNoTasksExist() {
var client = mockClientWithTasksResponse(Collections.emptyList(), threadPool);
var listener = new PlainActionFuture<TaskInfo>();

getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT);
getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener);

assertThat(listener.actionGet(TIMEOUT), nullValue());
}
Expand All @@ -88,7 +89,7 @@ public void testGetExistingTaskInfoCallsListenerWithTaskInfoWhenTaskExists() {
var client = mockClientWithTasksResponse(listTaskInfo, threadPool);
var listener = new PlainActionFuture<TaskInfo>();

getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT);
getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener);

assertThat(listener.actionGet(TIMEOUT), is(listTaskInfo.get(0)));
}
Expand All @@ -98,11 +99,37 @@ public void testGetExistingTaskInfoCallsListenerWithFirstTaskInfoWhenMultipleTas
var client = mockClientWithTasksResponse(listTaskInfo, threadPool);
var listener = new PlainActionFuture<TaskInfo>();

getDownloadTaskInfo(client, "modelId", false, listener, TIMEOUT);
getDownloadTaskInfo(client, "modelId", false, TIMEOUT, () -> "", listener);

assertThat(listener.actionGet(TIMEOUT), is(listTaskInfo.get(0)));
}

public void testGetTimeoutOnWaitForCompletion() {
var client = mockListTasksClient(threadPool);

doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
ActionListener<ListTasksResponse> actionListener = (ActionListener<ListTasksResponse>) invocationOnMock.getArguments()[2];
actionListener.onResponse(
new ListTasksResponse(
List.of(),
List.of(),
List.of(new ElasticsearchStatusException("node timeout", RestStatus.REQUEST_TIMEOUT))
)
);

return Void.TYPE;
}).when(client).execute(same(TransportListTasksAction.TYPE), any(), any());

var listener = new PlainActionFuture<TaskInfo>();

getDownloadTaskInfo(client, "modelId", true, TIMEOUT, () -> "Testing timeout", listener);

var exception = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(exception.status(), is(RestStatus.REQUEST_TIMEOUT));
assertThat(exception.getMessage(), is("Testing timeout"));
}

/**
* A helper method for setting up a mock cluster client to return the passed in list of tasks.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,6 @@ setup:

---
"Test start deployment fails while model download in progress":
- skip:
version: "all"
reason: "Awaits fix: https://github.com/elastic/elasticsearch/issues/102948"
- do:
ml.put_trained_model:
model_id: .elser_model_2
Expand All @@ -230,10 +227,13 @@ setup:
"field_names": ["text_field"]
}
}
# Set a low timeout so the test doesn't actually wait
# for the model download to complete
- do:
catch: /Model download task is currently running\. Wait for trained model \[.elser_model_2\] download task to complete then try again/
ml.start_trained_model_deployment:
model_id: .elser_model_2
timeout: 1s
- do:
ml.delete_trained_model:
model_id: .elser_model_2
Expand Down

0 comments on commit 8578dce

Please sign in to comment.