diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java index 3db834bb579ff..69767ce0b24f0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseListener; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; @@ -16,9 +19,12 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -108,4 +114,37 @@ private static void assertDefaultE5Config(Map modelConfig) { Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32)) ); } + + public void testMultipleInferencesTriggeringDownloadAndDeploy() throws InterruptedException { + int numParallelRequests = 4; + var latch = new CountDownLatch(numParallelRequests); + var errors = new ArrayList(); + + var listener = new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + errors.add(exception); + latch.countDown(); + } + }; + + var inputs = List.of("Hello World", "Goodnight moon"); + var queryParams = Map.of("timeout", "120s"); + for (int i = 0; i < numParallelRequests; i++) { + var request = createInferenceRequest( + Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID), + inputs, + queryParams + ); + client().performRequestAsync(request, listener); + } + + latch.await(); + assertThat(errors.toString(), errors, empty()); + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 6790b9bb14c5a..4e32ef99d06dd 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -373,12 +373,17 @@ protected Map infer(String modelId, TaskType taskType, List inferInternal(String endpoint, List input, Map queryParameters) throws IOException { + protected Request createInferenceRequest(String endpoint, List input, Map queryParameters) { var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input)); if (queryParameters.isEmpty() == false) { request.addParameters(queryParameters); } + return request; + } + + private Map inferInternal(String endpoint, List input, Map queryParameters) throws IOException { + var request = createInferenceRequest(endpoint, input, queryParameters); var response = client().performRequest(request); assertOkOrCreated(response); return entityAsMap(response); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java index b710b24cbda31..b76de5eeedbfc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java @@ -7,14 +7,9 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; public class CustomElandModel extends ElasticsearchInternalModel { @@ -39,31 +34,10 @@ public CustomElandModel( } @Override - public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, - ActionListener listener - ) { - - return new ActionListener<>() { - @Override - public void onResponse(CreateTrainedModelAssignmentAction.Response response) { - listener.onResponse(Boolean.TRUE); - } - - @Override - public void onFailure(Exception e) { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { - listener.onFailure( - new ResourceNotFoundException( - "Could not start the inference as the custom eland model [{0}] for this platform cannot be found." - + " Custom models need to be loaded into the cluster with eland before they can be started.", - internalServiceSettings.modelId() - ) - ); - return; - } - listener.onFailure(e); - } - }; + protected String modelNotFoundErrorMessage(String modelId) { + return "Could not deploy model [" + + modelId + + "] as the model cannot be found." + + " Custom models need to be loaded into the cluster with Eland before they can be started."; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java index 724c7a8f0a166..ce6c6258d0393 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticDeployedModel.java @@ -36,6 +36,11 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA throw new IllegalStateException("cannot start model that uses an existing deployment"); } + @Override + protected String modelNotFoundErrorMessage(String modelId) { + throw new IllegalStateException("cannot start model [" + modelId + "] that uses an existing deployment"); + } + @Override public ActionListener getCreateTrainedModelAssignmentActionListener( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index 2405243f302bc..aa12bf0c645c3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.Strings; import org.elasticsearch.core.TimeValue; @@ -15,8 +18,10 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED; @@ -79,10 +84,38 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA return startRequest; } - public abstract ActionListener getCreateTrainedModelAssignmentActionListener( + public ActionListener getCreateTrainedModelAssignmentActionListener( Model model, ActionListener listener - ); + ) { + return new ActionListener<>() { + @Override + public void onResponse(CreateTrainedModelAssignmentAction.Response response) { + listener.onResponse(Boolean.TRUE); + } + + @Override + public void onFailure(Exception e) { + var cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof ResourceNotFoundException) { + listener.onFailure(new ResourceNotFoundException(modelNotFoundErrorMessage(internalServiceSettings.modelId()))); + return; + } else if (cause instanceof ElasticsearchStatusException statusException) { + if (statusException.status() == RestStatus.CONFLICT + && statusException.getRootCause() instanceof ResourceAlreadyExistsException) { + // Deployment is already started + listener.onResponse(Boolean.TRUE); + } + return; + } + listener.onFailure(e); + } + }; + } + + protected String modelNotFoundErrorMessage(String modelId) { + return "Could not deploy model [" + modelId + "] as the model cannot be found."; + } public boolean usesExistingDeployment() { return internalServiceSettings.getDeploymentId() != null; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java index 8d2f59171a601..2594f18db3fb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModel.java @@ -7,13 +7,8 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; public class ElserInternalModel extends ElasticsearchInternalModel { @@ -37,31 +32,4 @@ public ElserInternalServiceSettings getServiceSettings() { public ElserMlNodeTaskSettings getTaskSettings() { return (ElserMlNodeTaskSettings) super.getTaskSettings(); } - - @Override - public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, - ActionListener listener - ) { - return new ActionListener<>() { - @Override - public void onResponse(CreateTrainedModelAssignmentAction.Response response) { - listener.onResponse(Boolean.TRUE); - } - - @Override - public void onFailure(Exception e) { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { - listener.onFailure( - new ResourceNotFoundException( - "Could not start the ELSER service as the ELSER model for this platform cannot be found." - + " ELSER needs to be downloaded before it can be started." - ) - ); - return; - } - listener.onFailure(e); - } - }; - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java index fee00d04d940b..2dcf91140c995 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java @@ -7,13 +7,8 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; public class MultilingualE5SmallModel extends ElasticsearchInternalModel { @@ -31,34 +26,4 @@ public MultilingualE5SmallModel( public MultilingualE5SmallInternalServiceSettings getServiceSettings() { return (MultilingualE5SmallInternalServiceSettings) super.getServiceSettings(); } - - @Override - public ActionListener getCreateTrainedModelAssignmentActionListener( - Model model, - ActionListener listener - ) { - - return new ActionListener<>() { - @Override - public void onResponse(CreateTrainedModelAssignmentAction.Response response) { - listener.onResponse(Boolean.TRUE); - } - - @Override - public void onFailure(Exception e) { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { - listener.onFailure( - new ResourceNotFoundException( - "Could not start the TextEmbeddingService service as the " - + "Multilingual-E5-Small model for this platform cannot be found." - + " Multilingual-E5-Small needs to be downloaded before it can be started" - ) - ); - return; - } - listener.onFailure(e); - } - }; - } - } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java new file mode 100644 index 0000000000000..929dac6ee357a --- /dev/null +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.packageloader.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.tasks.RemovedTaskListener; +import org.elasticsearch.tasks.Task; + +public record DownloadTaskRemovedListener(ModelDownloadTask trackedTask, ActionListener listener) + implements + RemovedTaskListener { + + @Override + public void onRemoved(Task task) { + if (task.getId() == trackedTask.getId()) { + if (trackedTask.getTaskException() == null) { + listener.onResponse(AcknowledgedResponse.TRUE); + } else { + listener.onFailure(trackedTask.getTaskException()); + } + } + } +} diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java index 59977bd418e11..dd09c3cf65fec 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java @@ -13,6 +13,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.MlTasks; import java.io.IOException; import java.util.Map; @@ -51,9 +52,12 @@ public void writeTo(StreamOutput out) throws IOException { } private final AtomicReference downloadProgress = new AtomicReference<>(new DownLoadProgress(0, 0)); + private final String modelId; + private volatile Exception taskException; - public ModelDownloadTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { - super(id, type, action, description, parentTaskId, headers); + public ModelDownloadTask(long id, String type, String action, String modelId, TaskId parentTaskId, Map headers) { + super(id, type, action, taskDescription(modelId), parentTaskId, headers); + this.modelId = modelId; } void setProgress(int totalParts, int downloadedParts) { @@ -65,4 +69,19 @@ public DownloadStatus getStatus() { return new DownloadStatus(downloadProgress.get()); } + public String getModelId() { + return modelId; + } + + public void setTaskException(Exception exception) { + this.taskException = exception; + } + + public Exception getTaskException() { + return taskException; + } + + public static String taskDescription(String modelId) { + return MlTasks.downloadModelTaskDescription(modelId); + } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index 76b7781b1cffe..2a14a8761e357 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -30,7 +30,6 @@ import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.common.notifications.Level; @@ -42,6 +41,9 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -49,7 +51,6 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION; import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE; -import static org.elasticsearch.xpack.core.ml.MlTasks.downloadModelTaskDescription; public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction { @@ -57,6 +58,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction< private final Client client; private final CircuitBreakerService circuitBreakerService; + final Map> taskRemovedListenersByModelId; @Inject public TransportLoadTrainedModelPackage( @@ -81,6 +83,7 @@ public TransportLoadTrainedModelPackage( ); this.client = new OriginSettingClient(client, ML_ORIGIN); this.circuitBreakerService = circuitBreakerService; + taskRemovedListenersByModelId = new HashMap<>(); } @Override @@ -91,6 +94,12 @@ protected ClusterBlockException checkBlock(Request request, ClusterState state) @Override protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { + if (handleDownloadInProgress(request.getModelId(), request.isWaitForCompletion(), listener)) { + logger.debug("Existing download of model [{}] in progress", request.getModelId()); + // download in progress, nothing to do + return; + } + ModelDownloadTask downloadTask = createDownloadTask(request); try { @@ -107,7 +116,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.noop(); - importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask); + importModel(client, () -> unregisterTask(downloadTask), request, modelImporter, downloadTask, downloadCompleteListener); } catch (Exception e) { taskManager.unregister(downloadTask); listener.onFailure(e); @@ -124,22 +133,91 @@ private ParentTaskAssigningClient getParentTaskAssigningClient(Task originTask) return new ParentTaskAssigningClient(client, parentTaskId); } + /** + * Look for a current download task of the model and optionally wait + * for that task to complete if there is one. + * synchronized with {@code unregisterTask} to prevent the task being + * removed before the remove listener is added. + * @param modelId Model being downloaded + * @param isWaitForCompletion Wait until the download completes before + * calling the listener + * @param listener Model download listener + * @return True if a download task is in progress + */ + synchronized boolean handleDownloadInProgress( + String modelId, + boolean isWaitForCompletion, + ActionListener listener + ) { + var description = ModelDownloadTask.taskDescription(modelId); + var tasks = taskManager.getCancellableTasks().values(); + + ModelDownloadTask inProgress = null; + for (var task : tasks) { + if (description.equals(task.getDescription()) && task instanceof ModelDownloadTask downloadTask) { + inProgress = downloadTask; + break; + } + } + + if (inProgress != null) { + if (isWaitForCompletion == false) { + // Not waiting for the download to complete, it is enough that the download is in progress + // Respond now not when the download completes + listener.onResponse(AcknowledgedResponse.TRUE); + return true; + } + // Otherwise register a task removed listener which is called + // once the tasks is complete and unregistered + var tracker = new DownloadTaskRemovedListener(inProgress, listener); + taskRemovedListenersByModelId.computeIfAbsent(modelId, s -> new ArrayList<>()).add(tracker); + taskManager.registerRemovedTaskListener(tracker); + return true; + } + + return false; + } + + /** + * Unregister the completed task triggering any remove task listeners. + * This method is synchronized to prevent the task being removed while + * {@code waitForExistingDownload} is in progress. + * @param task The completed task + */ + synchronized void unregisterTask(ModelDownloadTask task) { + taskManager.unregister(task); // unregister will call the on remove function + + var trackers = taskRemovedListenersByModelId.remove(task.getModelId()); + if (trackers != null) { + for (var tracker : trackers) { + taskManager.unregisterRemovedTaskListener(tracker); + } + } + } + /** * This is package scope so that we can test the logic directly. - * This should only be called from the masterOperation method and the tests + * This should only be called from the masterOperation method and the tests. + * This method is static for testing. * * @param auditClient a client which should only be used to send audit notifications. This client cannot be associated with the passed * in task, that way when the task is cancelled the notification requests can * still be performed. If it is associated with the task (i.e. via ParentTaskAssigningClient), * then the requests will throw a TaskCancelledException. + * @param unregisterTaskFn Runnable to unregister the task. Because this is a static function + * a lambda is used rather than the instance method. + * @param request The download request + * @param modelImporter The importer + * @param task Download task + * @param listener Listener */ static void importModel( Client auditClient, - TaskManager taskManager, + Runnable unregisterTaskFn, Request request, ModelImporter modelImporter, - ActionListener listener, - Task task + ModelDownloadTask task, + ActionListener listener ) { final String modelId = request.getModelId(); final long relativeStartNanos = System.nanoTime(); @@ -155,9 +233,12 @@ static void importModel( Level.INFO ); listener.onResponse(AcknowledgedResponse.TRUE); - }, exception -> listener.onFailure(processException(auditClient, modelId, exception))); + }, exception -> { + task.setTaskException(exception); + listener.onFailure(processException(auditClient, modelId, exception)); + }); - modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task))); + modelImporter.doImport(ActionListener.runAfter(finishListener, unregisterTaskFn)); } static Exception processException(Client auditClient, String modelId, Exception e) { @@ -197,14 +278,7 @@ public TaskId getParentTask() { @Override public ModelDownloadTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new ModelDownloadTask( - id, - type, - action, - downloadModelTaskDescription(request.getModelId()), - parentTaskId, - headers - ); + return new ModelDownloadTask(id, type, action, request.getModelId(), parentTaskId, headers); } }, false); } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java index cbcfd5b760779..3486ce6af0db5 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java @@ -10,13 +10,19 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.common.notifications.Level; import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; @@ -27,9 +33,13 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; +import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION; +import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.core.Is.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -37,6 +47,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class TransportLoadTrainedModelPackageTests extends ESTestCase { private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]"; @@ -44,17 +55,10 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase { public void testSendsFinishedUploadNotification() { var uploader = createUploader(null); var taskManager = mock(TaskManager.class); - var task = mock(Task.class); + var task = mock(ModelDownloadTask.class); var client = mock(Client.class); - TransportLoadTrainedModelPackage.importModel( - client, - taskManager, - createRequestWithWaiting(), - uploader, - ActionListener.noop(), - task - ); + TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, ActionListener.noop()); var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class); // 2 notifications- the start and finish messages @@ -108,32 +112,63 @@ public void testSendsWarningNotificationForTaskCancelledException() throws Excep public void testCallsOnResponseWithAcknowledgedResponse() throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); - var task = mock(Task.class); + var task = mock(ModelDownloadTask.class); ModelImporter uploader = createUploader(null); var responseRef = new AtomicReference(); var listener = ActionListener.wrap(responseRef::set, e -> fail("received an exception: " + e.getMessage())); - TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task); + TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, listener); assertThat(responseRef.get(), is(AcknowledgedResponse.TRUE)); } public void testDoesNotCallListenerWhenNotWaitingForCompletion() { var uploader = mock(ModelImporter.class); var client = mock(Client.class); - var taskManager = mock(TaskManager.class); - var task = mock(Task.class); - + var task = mock(ModelDownloadTask.class); TransportLoadTrainedModelPackage.importModel( client, - taskManager, + () -> {}, createRequestWithoutWaiting(), uploader, - ActionListener.running(ESTestCase::fail), - task + task, + ActionListener.running(ESTestCase::fail) ); } + public void testWaitForExistingDownload() { + var taskManager = mock(TaskManager.class); + var modelId = "foo"; + var task = new ModelDownloadTask(1L, MODEL_IMPORT_TASK_TYPE, MODEL_IMPORT_TASK_ACTION, modelId, new TaskId("node", 1L), Map.of()); + when(taskManager.getCancellableTasks()).thenReturn(Map.of(1L, task)); + + var transportService = mock(TransportService.class); + when(transportService.getTaskManager()).thenReturn(taskManager); + + var action = new TransportLoadTrainedModelPackage( + transportService, + mock(ClusterService.class), + mock(ThreadPool.class), + mock(ActionFilters.class), + mock(IndexNameExpressionResolver.class), + mock(Client.class), + mock(CircuitBreakerService.class) + ); + + assertTrue(action.handleDownloadInProgress(modelId, true, ActionListener.noop())); + verify(taskManager).registerRemovedTaskListener(any()); + assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1)); + assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1)); + + // With wait for completion == false no new removed listener will be added + assertTrue(action.handleDownloadInProgress(modelId, false, ActionListener.noop())); + verify(taskManager, times(1)).registerRemovedTaskListener(any()); + assertThat(action.taskRemovedListenersByModelId.entrySet(), hasSize(1)); + assertThat(action.taskRemovedListenersByModelId.get(modelId), hasSize(1)); + + assertFalse(action.handleDownloadInProgress("no-task-for-this-one", randomBoolean(), ActionListener.noop())); + } + private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception { var esStatusException = new ElasticsearchStatusException(message, status, exception); @@ -152,7 +187,7 @@ private void assertNotificationAndOnFailure( ) throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); - var task = mock(Task.class); + var task = mock(ModelDownloadTask.class); ModelImporter uploader = createUploader(thrownException); var failureRef = new AtomicReference(); @@ -160,7 +195,14 @@ private void assertNotificationAndOnFailure( (AcknowledgedResponse response) -> { fail("received a acknowledged response: " + response.toString()); }, failureRef::set ); - TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task); + TransportLoadTrainedModelPackage.importModel( + client, + () -> taskManager.unregister(task), + createRequestWithWaiting(), + uploader, + task, + listener + ); var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class); // 2 notifications- the starting message and the failure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index 5fd70ce71cd24..f01372ca4f246 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -190,11 +190,11 @@ protected void masterOperation( () -> "[" + request.getDeploymentId() + "] creating new assignment for model [" + request.getModelId() + "] failed", e ); - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException resourceAlreadyExistsException) { e = new ElasticsearchStatusException( "Cannot start deployment [{}] because it has already been started", RestStatus.CONFLICT, - e, + resourceAlreadyExistsException, request.getDeploymentId() ); }