From 400b7ecb14c9b10c80173c91358b9fc08138d8bc Mon Sep 17 00:00:00 2001 From: Iraklis Psaroudakis Date: Mon, 3 Apr 2023 15:54:18 +0300 Subject: [PATCH] Child requests proactively cancel children tasks (#92588) To make this possible we modify the CancellableTasksTracker to track children tasks by the Request ID as well. That way, we can send an Action to cancel a child based on the parent task and the Request ID. This is especially useful when parents' children requests timeout on the parents' side. Fixes #90353 Relates #66992 --- docs/changelog/92588.yaml | 6 + .../node/tasks/CancellableTasksIT.java | 108 ++++++++++++----- .../TransportReplicationAction.java | 10 ++ .../cluster/service/MasterService.java | 3 + .../PersistentTasksNodeService.java | 5 + .../tasks/CancellableTasksTracker.java | 112 ++++++++++++------ .../elasticsearch/tasks/TaskAwareRequest.java | 12 ++ .../tasks/TaskCancellationService.java | 74 ++++++++++++ .../org/elasticsearch/tasks/TaskManager.java | 61 ++++++++-- .../transport/InboundHandler.java | 2 + .../transport/TransportRequest.java | 18 +++ .../transport/TransportService.java | 19 ++- .../tasks/BanFailureLoggingTests.java | 5 + .../tasks/CancellableTasksTrackerTests.java | 10 +- .../elasticsearch/tasks/TaskManagerTests.java | 9 ++ .../transport/TransportActionProxyTests.java | 4 + ...ortServiceDeserializationFailureTests.java | 5 + .../action/InternalExecutePolicyAction.java | 5 + .../xpack/enrich/EnrichPolicyRunnerTests.java | 6 + .../TrainedModelAssignmentNodeService.java | 5 + .../InferencePyTorchActionTests.java | 3 + .../ql/async/AsyncTaskManagementService.java | 10 ++ .../rest-api-spec/test/10_analyze.yml | 4 +- .../blobstore/testkit/BlobAnalyzeAction.java | 31 +++-- .../testkit/GetBlobChecksumAction.java | 3 +- 25 files changed, 435 insertions(+), 95 deletions(-) create mode 100644 docs/changelog/92588.yaml diff --git a/docs/changelog/92588.yaml b/docs/changelog/92588.yaml new file mode 100644 index 0000000000000..0447207b398b7 --- /dev/null +++ b/docs/changelog/92588.yaml @@ -0,0 +1,6 @@ +pr: 92588 +summary: Failed tasks proactively cancel children tasks +area: Snapshot/Restore +type: enhancement +issues: + - 90353 diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java index 3bcd7626a5f02..7aff274009588 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java @@ -34,6 +34,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.CancellableTask; @@ -44,8 +45,11 @@ import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.threadpool.ThreadPoolStats; +import org.elasticsearch.transport.ReceiveTimeoutTransportException; import org.elasticsearch.transport.SendRequestTransportException; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; import org.junit.After; @@ -63,6 +67,7 @@ import java.util.stream.Collectors; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsStringIgnoringCase; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -77,24 +82,25 @@ public class CancellableTasksIT extends ESIntegTestCase { static final Map completedLatches = ConcurrentCollections.newConcurrentMap(); @After - public void ensureAllBansRemoved() throws Exception { + public void ensureBansAndCancellationsConsistency() throws Exception { assertBusy(() -> { for (String node : internalCluster().getNodeNames()) { TaskManager taskManager = internalCluster().getInstance(TransportService.class, node).getTaskManager(); assertThat("node " + node, taskManager.getBannedTaskIds(), empty()); + assertThat("node " + node, taskManager.assertCancellableTaskConsistency(), equalTo(true)); } }, 30, TimeUnit.SECONDS); } - static TestRequest generateTestRequest(Set nodes, int level, int maxLevel) { + static TestRequest generateTestRequest(Set nodes, int level, int maxLevel, boolean timeout) { List subRequests = new ArrayList<>(); int lower = level == 0 ? 1 : 0; int upper = 10 / (level + 1); int numOfSubRequests = randomIntBetween(lower, upper); for (int i = 0; i < numOfSubRequests && level <= maxLevel; i++) { - subRequests.add(generateTestRequest(nodes, level + 1, maxLevel)); + subRequests.add(generateTestRequest(nodes, level + 1, maxLevel, timeout)); } - final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests); + final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests, level == 0 ? false : timeout); beforeSendLatches.put(request, new CountDownLatch(1)); arrivedLatches.put(request, new CountDownLatch(1)); beforeExecuteLatches.put(request, new CountDownLatch(1)); @@ -157,7 +163,7 @@ public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception { internalCluster().startNodes(randomIntBetween(1, 3)); } Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4)); + final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4), false); ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); Set pendingRequests = allowPartialRequest(rootRequest); TaskId rootTaskId = getRootTaskId(rootRequest); @@ -203,14 +209,14 @@ public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception { } finally { allowEntireRequest(rootRequest); cancelFuture.actionGet(); - waitForRootTask(rootTaskFuture); - ensureAllBansRemoved(); + waitForRootTask(rootTaskFuture, false); + ensureBansAndCancellationsConsistency(); } } public void testCancelTaskMultipleTimes() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3), false); ActionFuture mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); TaskId taskId = getRootTaskId(rootRequest); allowPartialRequest(rootRequest); @@ -227,7 +233,7 @@ public void testCancelTaskMultipleTimes() throws Exception { allowEntireRequest(rootRequest); assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); - waitForRootTask(mainTaskFuture); + waitForRootTask(mainTaskFuture, false); CancelTasksResponse cancelError = client().admin() .cluster() .prepareCancelTasks() @@ -237,12 +243,12 @@ public void testCancelTaskMultipleTimes() throws Exception { assertThat(cancelError.getNodeFailures(), hasSize(1)); final Throwable notFound = ExceptionsHelper.unwrap(cancelError.getNodeFailures().get(0), ResourceNotFoundException.class); assertThat(notFound.getMessage(), equalTo("task [" + taskId + "] is not found")); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } public void testDoNotWaitForCompletion() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3), false); ActionFuture mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); TaskId taskId = getRootTaskId(rootRequest); if (randomBoolean()) { @@ -261,34 +267,34 @@ public void testDoNotWaitForCompletion() throws Exception { assertBusy(() -> assertTrue(cancelFuture.isDone())); } allowEntireRequest(rootRequest); - waitForRootTask(mainTaskFuture); + waitForRootTask(mainTaskFuture, false); cancelFuture.actionGet(); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } public void testFailedToStartChildTaskAfterCancelled() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3), false); ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); TaskId taskId = getRootTaskId(rootRequest); client().admin().cluster().prepareCancelTasks().setTargetTaskId(taskId).waitForCompletion(false).get(); DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get(); TransportTestAction mainAction = internalCluster().getInstance(TransportTestAction.class, nodeWithParentTask.getName()); PlainActionFuture future = new PlainActionFuture<>(); - TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1)); + TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1), false); beforeSendLatches.get(subRequest).countDown(); mainAction.startSubTask(taskId, subRequest, future); TaskCancelledException te = expectThrows(TaskCancelledException.class, future::actionGet); assertThat(te.getMessage(), equalTo("parent task was cancelled [by user request]")); allowEntireRequest(rootRequest); - waitForRootTask(rootTaskFuture); - ensureAllBansRemoved(); + waitForRootTask(rootTaskFuture, false); + ensureBansAndCancellationsConsistency(); } public void testCancelOrphanedTasks() throws Exception { final String nodeWithRootTask = internalCluster().startDataOnlyNode(); Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3), false); client(nodeWithRootTask).execute(TransportTestAction.ACTION, rootRequest); allowPartialRequest(rootRequest); try { @@ -307,13 +313,13 @@ public void testCancelOrphanedTasks() throws Exception { }, 30, TimeUnit.SECONDS); } finally { allowEntireRequest(rootRequest); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } } public void testRemoveBanParentsOnDisconnect() throws Exception { Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); - final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4)); + final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4), false); client().execute(TransportTestAction.ACTION, rootRequest); Set pendingRequests = allowPartialRequest(rootRequest); TaskId rootTaskId = getRootTaskId(rootRequest); @@ -367,10 +373,28 @@ public void testRemoveBanParentsOnDisconnect() throws Exception { } finally { allowEntireRequest(rootRequest); cancelFuture.actionGet(); - ensureAllBansRemoved(); + ensureBansAndCancellationsConsistency(); } } + public void testChildrenTasksCancelledOnTimeout() throws Exception { + Set nodes = clusterService().state().nodes().stream().collect(Collectors.toSet()); + final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4), true); + ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); + allowEntireRequest(rootRequest); + waitForRootTask(rootTaskFuture, true); + assertBusy(() -> { + for (DiscoveryNode node : nodes) { + TransportService transportService = internalCluster().getInstance(TransportService.class, node.getName()); + for (ThreadPoolStats.Stats stat : transportService.getThreadPool().stats()) { + assertEquals(0, stat.getActive()); + assertEquals(0, stat.getQueue()); + } + } + }, 60L, TimeUnit.SECONDS); + ensureBansAndCancellationsConsistency(); + } + static TaskId getRootTaskId(TestRequest request) throws Exception { SetOnce taskId = new SetOnce<>(); assertBusy(() -> { @@ -390,19 +414,24 @@ static TaskId getRootTaskId(TestRequest request) throws Exception { return taskId.get(); } - static void waitForRootTask(ActionFuture rootTask) { + static void waitForRootTask(ActionFuture rootTask, boolean expectToTimeout) { try { rootTask.actionGet(); } catch (Exception e) { - final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class); + final Throwable cause = ExceptionsHelper.unwrap( + e, + expectToTimeout ? ReceiveTimeoutTransportException.class : TaskCancelledException.class + ); assertNotNull(cause); assertThat( cause.getMessage(), - anyOf( - equalTo("parent task was cancelled [by user request]"), - equalTo("task cancelled before starting [by user request]"), - equalTo("task cancelled [by user request]") - ) + expectToTimeout + ? containsStringIgnoringCase("timed out after") + : anyOf( + equalTo("parent task was cancelled [by user request]"), + equalTo("task cancelled before starting [by user request]"), + equalTo("task cancelled [by user request]") + ) ); } } @@ -411,11 +440,13 @@ static class TestRequest extends ActionRequest { final int id; final DiscoveryNode node; final List subRequests; + final boolean timeout; - TestRequest(int id, DiscoveryNode node, List subRequests) { + TestRequest(int id, DiscoveryNode node, List subRequests, boolean timeout) { this.id = id; this.node = node; this.subRequests = subRequests; + this.timeout = timeout; } TestRequest(StreamInput in) throws IOException { @@ -423,6 +454,7 @@ static class TestRequest extends ActionRequest { this.id = in.readInt(); this.node = new DiscoveryNode(in); this.subRequests = in.readList(TestRequest::new); + this.timeout = in.readBoolean(); } List descendants() { @@ -445,6 +477,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(id); node.writeTo(out); out.writeList(subRequests); + out.writeBoolean(timeout); } @Override @@ -513,7 +546,16 @@ protected void doExecute(Task task, TestRequest request, ActionListener { assertTrue(beforeExecuteLatches.get(request).await(60, TimeUnit.SECONDS)); - ((CancellableTask) task).ensureNotCancelled(); + if (request.timeout) { + // Simulate working until cancelled + while (((CancellableTask) task).isCancelled() == false) { + try { + Thread.sleep(1); + } catch (InterruptedException e) {} + } + } else { + ((CancellableTask) task).ensureNotCancelled(); + } return new TestResponse(); })); for (TestRequest subRequest : subRequests) { @@ -535,17 +577,21 @@ public void onFailure(Exception e) { @Override protected void doRun() throws Exception { assertTrue(beforeSendLatches.get(subRequest).await(60, TimeUnit.SECONDS)); - if (client.getLocalNodeId().equals(subRequest.node.getId()) && randomBoolean()) { + if (client.getLocalNodeId().equals(subRequest.node.getId()) && subRequest.timeout == false && randomBoolean()) { try { client.executeLocally(TransportTestAction.ACTION, subRequest, latchedListener); } catch (TaskCancelledException e) { latchedListener.onFailure(new SendRequestTransportException(subRequest.node, ACTION.name(), e)); } } else { + final TransportRequestOptions transportRequestOptions = subRequest.timeout + ? TransportRequestOptions.timeout(TimeValue.timeValueMillis(400)) + : TransportRequestOptions.EMPTY; transportService.sendRequest( subRequest.node, ACTION.name(), subRequest, + transportRequestOptions, new ActionListenerResponseHandler(latchedListener, TestResponse::new) ); } diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java index 26c4f1ded8cfc..c639afb7fc5cf 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportReplicationAction.java @@ -1335,6 +1335,16 @@ public TaskId getParentTask() { return request.getParentTask(); } + @Override + public void setRequestId(long requestId) { + request.setRequestId(requestId); + } + + @Override + public long getRequestId() { + return request.getRequestId(); + } + @Override public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { return request.createTask(id, type, action, parentTaskId, headers); diff --git a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java index 7397bfff39ca9..5dfd60274225e 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java @@ -254,6 +254,9 @@ private void executeAndPublishBatch( @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java index f48ad3c856fef..0b4d9a0b80ceb 100644 --- a/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java +++ b/server/src/main/java/org/elasticsearch/persistent/PersistentTasksNodeService.java @@ -175,6 +175,11 @@ public void setParentTask(TaskId taskId) { throw new UnsupportedOperationException("parent task if for persistent tasks shouldn't change"); } + @Override + public void setRequestId(long requestId) { + throw new UnsupportedOperationException("does not have a request ID"); + } + @Override public TaskId getParentTask() { return parentTaskId; diff --git a/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java index c1723e492dde3..a44b653d66fb2 100644 --- a/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java +++ b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java @@ -31,21 +31,45 @@ public CancellableTasksTracker(T[] empty) { } private final Map byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); - private final Map byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); + private final Map> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); + + /** + * Gets the cancellable children of a parent task. + * + * Note: children of non-positive request IDs (e.g., -1) may be grouped together. + */ + public Stream getChildrenByRequestId(TaskId parentTaskId, long childRequestId) { + Map byRequestId = byParentTaskId.get(parentTaskId); + if (byRequestId != null) { + T[] children = byRequestId.get(childRequestId); + if (children != null) { + return Arrays.stream(children); + } + } + return Stream.empty(); + } /** * Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too. */ - public void put(Task task, T item) { + public void put(Task task, long requestId, T item) { final long taskId = task.getId(); if (task.getParentTaskId().isSet()) { - byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> { - if (oldValue == null) { - oldValue = empty; + byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> { + if (oldRequestIdMap == null) { + oldRequestIdMap = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); } - final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1); - newValue[oldValue.length] = item; - return newValue; + + oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> { + if (oldValue == null) { + oldValue = empty; + } + final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1); + newValue[oldValue.length] = item; + return newValue; + }); + + return oldRequestIdMap; }); } final T oldItem = byTaskId.put(taskId, item); @@ -60,36 +84,50 @@ public T get(long id) { } /** - * Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times - * for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is - * actually being completed by a concurrent call that's still ongoing. + * Remove (and return) the item that corresponds with the given task and request ID. Return {@code null} if not present. Safe to call + * multiple times for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if + * the removal is actually being completed by a concurrent call that's still ongoing. */ public T remove(Task task) { final long taskId = task.getId(); final T oldItem = byTaskId.remove(taskId); if (oldItem != null && task.getParentTaskId().isSet()) { - byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> { - if (oldValue == null) { + byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> { + if (oldRequestIdMap == null) { return null; } - if (oldValue.length == 1) { - if (oldValue[0] == oldItem) { - return null; - } else { + + for (Long requestId : oldRequestIdMap.keySet()) { + oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> { + if (oldValue == null) { + return null; + } + if (oldValue.length == 1) { + if (oldValue[0] == oldItem) { + return null; + } else { + return oldValue; + } + } + if (oldValue[0] == oldItem) { + return Arrays.copyOfRange(oldValue, 1, oldValue.length); + } + for (int i = 1; i < oldValue.length; i++) { + if (oldValue[i] == oldItem) { + final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1); + System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1); + return newValue; + } + } return oldValue; - } - } - if (oldValue[0] == oldItem) { - return Arrays.copyOfRange(oldValue, 1, oldValue.length); + }); } - for (int i = 1; i < oldValue.length; i++) { - if (oldValue[i] == oldItem) { - final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1); - System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1); - return newValue; - } + + if (oldRequestIdMap.keySet().isEmpty()) { + return null; } - return oldValue; + + return oldRequestIdMap; }); } return oldItem; @@ -109,11 +147,11 @@ public Collection values() { * started before this method was called have not completed. */ public Stream getByParent(TaskId parentTaskId) { - final T[] byParent = byParentTaskId.get(parentTaskId); + final Map byParent = byParentTaskId.get(parentTaskId); if (byParent == null) { return Stream.empty(); } - return Arrays.stream(byParent); + return byParent.values().stream().flatMap(Stream::of); } // assertion for tests, not an invariant but should eventually be true @@ -123,12 +161,14 @@ boolean assertConsistent() { // every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent final Set byTaskValues = new HashSet<>(byTaskId.values()); - for (T[] byParent : byParentTaskId.values()) { - assert byParent.length > 0; - for (T t : byParent) { - assert byTaskValues.contains(t); - } - } + byParentTaskId.values().forEach(byParentMap -> { + byParentMap.forEach((requestId, byParentArray) -> { + assert byParentArray.length > 0; + for (T t : byParentArray) { + assert byTaskValues.contains(t); + } + }); + }); return true; } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java b/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java index d0f7e3565e233..a791066ea5089 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java @@ -26,6 +26,18 @@ default void setParentTask(String parentTaskNode, long parentTaskId) { */ void setParentTask(TaskId taskId); + /** + * Gets the request ID. Defaults to -1, meaning "no request ID is set". + */ + default long getRequestId() { + return -1; + } + + /** + * Set the request ID related to this task. + */ + void setRequestId(long requestId); + /** * Get a reference to the task that created this request. Implementers should default to * {@link TaskId#EMPTY_TASK_ID}, meaning "there is no parent". diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index 1407d5abb2ce9..9114f6695aaea 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.threadpool.ThreadPool; @@ -44,6 +45,8 @@ public class TaskCancellationService { public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban"; + public static final String CANCEL_CHILD_ACTION_NAME = "internal:admin/tasks/cancel_child"; + public static final TransportVersion VERSION_SUPPORTING_CANCEL_CHILD_ACTION = TransportVersion.V_8_8_0; private static final Logger logger = LogManager.getLogger(TaskCancellationService.class); private final TransportService transportService; private final TaskManager taskManager; @@ -59,6 +62,12 @@ public TaskCancellationService(TransportService transportService) { BanParentTaskRequest::new, new BanParentRequestHandler() ); + transportService.registerRequestHandler( + CANCEL_CHILD_ACTION_NAME, + ThreadPool.Names.SAME, + CancelChildRequest::new, + new CancelChildRequestHandler() + ); } private String localNodeId() { @@ -341,4 +350,69 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC } } } + + private static class CancelChildRequest extends TransportRequest { + + private final TaskId parentTaskId; + private final long childRequestId; + private final String reason; + + static CancelChildRequest createCancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) { + return new CancelChildRequest(parentTaskId, childRequestId, reason); + } + + private CancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) { + this.parentTaskId = parentTaskId; + this.childRequestId = childRequestId; + this.reason = reason; + } + + private CancelChildRequest(StreamInput in) throws IOException { + super(in); + parentTaskId = TaskId.readFromStream(in); + childRequestId = in.readLong(); + reason = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + parentTaskId.writeTo(out); + out.writeLong(childRequestId); + out.writeString(reason); + } + } + + private class CancelChildRequestHandler implements TransportRequestHandler { + @Override + public void messageReceived(final CancelChildRequest request, final TransportChannel channel, Task task) throws Exception { + taskManager.cancelChildLocal(request.parentTaskId, request.childRequestId, request.reason); + channel.sendResponse(TransportResponse.Empty.INSTANCE); + } + } + + /** + * Sends an action to cancel a child task, associated with the given request ID and parent task. + */ + public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.Connection childConnection, String reason) { + if (childConnection.getTransportVersion().onOrAfter(VERSION_SUPPORTING_CANCEL_CHILD_ACTION)) { + DiscoveryNode childNode = childConnection.getNode(); + logger.debug( + "sending cancellation of child of parent task [{}] with request ID [{}] to node [{}] because of [{}]", + parentTask, + childRequestId, + childNode, + reason + ); + final CancelChildRequest request = CancelChildRequest.createCancelChildRequest(parentTask, childRequestId, reason); + transportService.sendRequest( + childNode, + CANCEL_CHILD_ACTION_NAME, + request, + TransportRequestOptions.EMPTY, + EmptyTransportResponseHandler.INSTANCE_SAME + ); + } + } + } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index 2a3c22c353da8..8a3d8f5ec9184 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -151,7 +151,7 @@ public Task register(String type, String action, TaskAwareRequest request, boole } if (task instanceof CancellableTask) { - registerCancellableTask(task, traceRequest); + registerCancellableTask(task, request.getRequestId(), traceRequest); } else { Task previousTask = tasks.put(task.getId(), task); assert previousTask == null; @@ -209,6 +209,9 @@ public void onResponse(Response response) { @Override public void onFailure(Exception e) { try { + if (request.getParentTask().isSet()) { + cancelChildLocal(request.getParentTask(), request.getRequestId(), e.toString()); + } release(); } finally { taskListener.onFailure(e); @@ -228,10 +231,10 @@ private void release() { } } - private void registerCancellableTask(Task task, boolean traceRequest) { + private void registerCancellableTask(Task task, long requestId, boolean traceRequest) { CancellableTask cancellableTask = (CancellableTask) task; CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask); - cancellableTasks.put(task, holder); + cancellableTasks.put(task, requestId, holder); if (traceRequest) { startTrace(threadPool.getThreadContext(), task); } @@ -250,6 +253,16 @@ private void registerCancellableTask(Task task, boolean traceRequest) { } } + private TaskCancellationService getCancellationService() { + final TaskCancellationService service = cancellationService.get(); + if (service != null) { + return service; + } else { + assert false : "TaskCancellationService is not initialized"; + throw new IllegalStateException("TaskCancellationService is not initialized"); + } + } + /** * Cancels a task *

@@ -267,6 +280,40 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { } } + /** + * Cancels children tasks of the specified parent, with the request ID specified, as long as the request ID is positive. + * + * Note: There may be multiple children for the same request ID. In this edge case all these multiple children are cancelled. + */ + public void cancelChildLocal(TaskId parentTaskId, long childRequestId, String reason) { + if (childRequestId > 0) { + List children = cancellableTasks.getChildrenByRequestId(parentTaskId, childRequestId).toList(); + if (children.isEmpty() == false) { + for (CancellableTaskHolder child : children) { + if (logger.isTraceEnabled()) { + logger.trace( + "cancelling child task [{}] of parent task [{}] and request ID [{}] with reason [{}]", + child.getTask(), + parentTaskId, + childRequestId, + reason + ); + } + child.cancel(reason); + } + } + } + } + + /** + * Send an Action to cancel children tasks of the specified parent, with the request ID specified. + * + * Note: There may be multiple children for the same request ID. In this edge case all these multiple children are cancelled. + */ + public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.Connection childConnection, String reason) { + getCancellationService().cancelChildRemote(parentTask, childRequestId, childConnection, reason); + } + /** * Unregister the task */ @@ -769,13 +816,7 @@ protected void doRun() { } public void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) { - final TaskCancellationService service = cancellationService.get(); - if (service != null) { - service.cancelTaskAndDescendants(task, reason, waitForCompletion, listener); - } else { - assert false : "TaskCancellationService is not initialized"; - throw new IllegalStateException("TaskCancellationService is not initialized"); - } + getCancellationService().cancelTaskAndDescendants(task, reason, waitForCompletion, listener); } public List getTaskHeaders() { diff --git a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java index 503d2c3be81e1..94158fb4a7f1e 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundHandler.java @@ -277,6 +277,8 @@ private void handleRequest(TcpChannel channel, Head } try { request.remoteAddress(channel.getRemoteAddress()); + assert requestId > 0; + request.setRequestId(requestId); // in case we throw an exception, i.e. when the limit is hit, we don't want to verify final int nextByte = stream.read(); // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker diff --git a/server/src/main/java/org/elasticsearch/transport/TransportRequest.java b/server/src/main/java/org/elasticsearch/transport/TransportRequest.java index 094d441d8a1c8..7646703faaa70 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportRequest.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportRequest.java @@ -31,6 +31,11 @@ public Empty(StreamInput in) throws IOException { */ private TaskId parentTaskId = TaskId.EMPTY_TASK_ID; + /** + * Request ID. Defaults to -1, meaning "no request ID is set". + */ + private volatile long requestId = -1; + public TransportRequest() {} public TransportRequest(StreamInput in) throws IOException { @@ -53,6 +58,19 @@ public TaskId getParentTask() { return parentTaskId; } + /** + * Set the request ID of this request. + */ + @Override + public void setRequestId(long requestId) { + this.requestId = requestId; + } + + @Override + public long getRequestId() { + return requestId; + } + @Override public void writeTo(StreamOutput out) throws IOException { parentTaskId.writeTo(out); diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index b22f4436e9338..ed82b670b1c2e 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -793,7 +793,14 @@ public final void sendRequest( if (unregisterChildNode == null) { delegate = handler; } else { - delegate = new UnregisterChildTransportResponseHandler<>(unregisterChildNode, handler, action); + delegate = new UnregisterChildTransportResponseHandler<>( + unregisterChildNode, + handler, + action, + request, + unwrappedConn, + taskManager + ); } } else { delegate = handler; @@ -895,6 +902,7 @@ private void sendRequestInternal( ContextRestoreResponseHandler responseHandler = new ContextRestoreResponseHandler<>(storedContextSupplier, handler); // TODO we can probably fold this entire request ID dance into connection.sendRequest but it will be a bigger refactoring final long requestId = responseHandlers.add(new Transport.ResponseContext<>(responseHandler, connection, action)); + request.setRequestId(requestId); final TimeoutHandler timeoutHandler; if (options.timeout() != null) { timeoutHandler = new TimeoutHandler(requestId, connection.getNode(), action); @@ -915,6 +923,7 @@ private void sendRequestInternal( assert options.timeout() != null; timeoutHandler.scheduleTimeout(options.timeout()); } + logger.trace("sending internal request id [{}] action [{}] request [{}] options [{}]", requestId, action, request, options); connection.sendRequest(requestId, action, request, options); // local node optimization happens upstream } catch (final Exception e) { handleInternalSendException(action, node, requestId, timeoutHandler, e); @@ -1654,7 +1663,10 @@ Releasable withRef() { private record UnregisterChildTransportResponseHandler ( Releasable unregisterChildNode, TransportResponseHandler handler, - String action + String action, + TransportRequest childRequest, + Transport.Connection childConnection, + TaskManager taskManager ) implements TransportResponseHandler { @Override @@ -1665,6 +1677,9 @@ public void handleResponse(T response) { @Override public void handleException(TransportException exp) { + assert childRequest.getParentTask().isSet(); + taskManager.cancelChildRemote(childRequest.getParentTask(), childRequest.getRequestId(), childConnection, exp.toString()); + unregisterChildNode.close(); handler.handleException(exp); } diff --git a/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java b/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java index 843b05f7b8a94..644d6bcb901ad 100644 --- a/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/BanFailureLoggingTests.java @@ -205,6 +205,11 @@ public void setParentTask(TaskId taskId) { fail("setParentTask should not be called"); } + @Override + public void setRequestId(long requestId) { + fail("setRequestId should not be called"); + } + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java index 7c29405ec7e0d..da40e307bce28 100644 --- a/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java @@ -35,6 +35,7 @@ private static class TestTask { // 0 == before put, 1 == during put, 2 == after put, before remove, 3 == during remove, 4 == after remove private final AtomicInteger state = new AtomicInteger(); private final boolean concurrentRemove = randomBoolean(); + private final long requestId = randomIntBetween(-1, 10); TestTask(Task task, String item, CancellableTasksTracker tracker, Runnable awaitStart) { if (concurrentRemove) { @@ -58,7 +59,7 @@ private static class TestTask { awaitStart.run(); state.incrementAndGet(); - tracker.put(task, item); + tracker.put(task, requestId, item); state.incrementAndGet(); Thread.yield(); @@ -80,6 +81,8 @@ private static class TestTask { final int stateBefore = state.get(); final String getResult = tracker.get(task.getId()); final Set getByParentResult = tracker.getByParent(task.getParentTaskId()).collect(Collectors.toSet()); + final Set getByChildrenResult = tracker.getChildrenByRequestId(task.getParentTaskId(), requestId) + .collect(Collectors.toSet()); final Set values = new HashSet<>(tracker.values()); final int stateAfter = state.get(); @@ -87,11 +90,13 @@ private static class TestTask { if (getResult != null && task.getParentTaskId().isSet() && tracker.get(task.getId()) != null) { assertThat(getByParentResult, hasItem(item)); + assertThat(getByChildrenResult, hasItem(item)); } if (stateAfter == 0) { assertNull(getResult); assertThat(getByParentResult, not(hasItem(item))); + assertThat(getByChildrenResult, not(hasItem(item))); assertThat(values, not(hasItem(item))); } @@ -99,8 +104,10 @@ private static class TestTask { assertSame(item, getResult); if (task.getParentTaskId().isSet()) { assertThat(getByParentResult, hasItem(item)); + assertThat(getByChildrenResult, hasItem(item)); } else { assertThat(getByParentResult, empty()); + assertThat(getByChildrenResult, empty()); } assertThat(values, hasItem(item)); } @@ -109,6 +116,7 @@ private static class TestTask { assertNull(getResult); if (concurrentRemove == false) { assertThat(getByParentResult, not(hasItem(item))); + assertThat(getByChildrenResult, not(hasItem(item))); } // else our remove might have completed but the concurrent one hasn't updated the parent ID map yet assertThat(values, not(hasItem(item))); } diff --git a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java index a00801c4501d7..ae8bdf08a5ede 100644 --- a/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java +++ b/server/src/test/java/org/elasticsearch/tasks/TaskManagerTests.java @@ -288,6 +288,9 @@ public void testRegisterTaskStartsTracing() { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; @@ -309,6 +312,9 @@ public void testUnregisterTaskStopsTracing() { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; @@ -476,6 +482,9 @@ private TaskAwareRequest makeTaskRequest(boolean cancellable, final int parentTa @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return new TaskId("something", parentTaskNum); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java index 7a7cdd521a8e9..7901e7c31cece 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.core.RefCounted; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancellationService; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; @@ -68,10 +69,13 @@ public void setUp() throws Exception { super.setUp(); threadPool = new TestThreadPool(getClass().getName()); serviceA = buildService(version0, transportVersion0); // this one supports dynamic tracer updates + serviceA.taskManager.setTaskCancellationService(new TaskCancellationService(serviceA)); nodeA = serviceA.getLocalDiscoNode(); serviceB = buildService(version1, transportVersion1); // this one doesn't support dynamic tracer updates + serviceB.taskManager.setTaskCancellationService(new TaskCancellationService(serviceB)); nodeB = serviceB.getLocalDiscoNode(); serviceC = buildService(version1, transportVersion1); // this one doesn't support dynamic tracer updates + serviceC.taskManager.setTaskCancellationService(new TaskCancellationService(serviceC)); nodeC = serviceC.getLocalDiscoNode(); serviceD = buildService(version1, transportVersion1); nodeD = serviceD.getLocalDiscoNode(); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java b/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java index ec1944a65519b..4cfda499f028c 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportServiceDeserializationFailureTests.java @@ -125,6 +125,11 @@ public void setParentTask(TaskId taskId) { fail("should not be called"); } + @Override + public void setRequestId(long requestId) { + fail("should not be called"); + } + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java index feeba1a4c3ccf..e99b787926361 100644 --- a/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java +++ b/x-pack/plugin/enrich/src/main/java/org/elasticsearch/xpack/enrich/action/InternalExecutePolicyAction.java @@ -142,6 +142,11 @@ public void setParentTask(TaskId taskId) { request.setParentTask(taskId); } + @Override + public void setRequestId(long requestId) { + request.setRequestId(requestId); + } + @Override public TaskId getParentTask() { return request.getParentTask(); diff --git a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java index 5792e6e123ef5..67f7b8498ecd6 100644 --- a/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java +++ b/x-pack/plugin/enrich/src/test/java/org/elasticsearch/xpack/enrich/EnrichPolicyRunnerTests.java @@ -1781,6 +1781,9 @@ public void testRunnerWithForceMergeRetry() throws Exception { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; @@ -2026,6 +2029,9 @@ private EnrichPolicyRunner createPolicyRunner( @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 8a1f818ed22f6..2d6045534bfd3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -299,6 +299,11 @@ public void setParentTask(TaskId taskId) { throw new UnsupportedOperationException("parent task id for model assignment tasks shouldn't change"); } + @Override + public void setRequestId(long requestId) { + throw new UnsupportedOperationException("does not have request ID"); + } + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java index aa7831bcc03f1..181b6abd5b549 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/InferencePyTorchActionTests.java @@ -194,6 +194,9 @@ public void testCallingRunAfterParentTaskCancellation() throws Exception { @Override public void setParentTask(TaskId taskId) {} + @Override + public void setRequestId(long requestId) {} + @Override public TaskId getParentTask() { return TaskId.EMPTY_TASK_ID; diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java index 494533d2c061f..80e2a265c71c5 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java @@ -107,6 +107,16 @@ public TaskId getParentTask() { return request.getParentTask(); } + @Override + public void setRequestId(long requestId) { + request.setRequestId(requestId); + } + + @Override + public long getRequestId() { + return request.getRequestId(); + } + @Override public Task createTask(long id, String type, String actionName, TaskId parentTaskId, Map headers) { Map originHeaders = ClientHelper.getPersistableSafeSecurityHeaders( diff --git a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml index 25c1ed26bcdc7..6223ca8443b0e 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml +++ b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml @@ -158,8 +158,8 @@ setup: --- "Timeout with large blobs": - skip: - version: all - reason: "AwaitsFix https://github.com/elastic/elasticsearch/issues/90353" + version: "- 7.13.99" + reason: "abortWrites flag introduced in 7.14, and mixed-cluster support not required" - do: catch: request diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java index 65acc3b7e2f14..c82e2991d9f79 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/BlobAnalyzeAction.java @@ -32,6 +32,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.CancellableThreads; import org.elasticsearch.core.Nullable; import org.elasticsearch.repositories.RepositoriesService; import org.elasticsearch.repositories.Repository; @@ -39,7 +40,6 @@ import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportRequestOptions; @@ -216,6 +216,7 @@ static class BlobAnalysis { private final GroupedActionListener readNodesListener; private final StepListener write1Step = new StepListener<>(); private final StepListener write2Step = new StepListener<>(); + private final CancellableThreads cancellableThreads = new CancellableThreads(); BlobAnalysis( TransportService transportService, @@ -271,6 +272,8 @@ static class BlobAnalysis { ), this::cancelReadsCleanUpAndReturnFailure ); + + task.addListener(() -> { cancellableThreads.cancel(task.getReasonCancelled()); }); } void run() { @@ -332,15 +335,21 @@ public StreamInput streamInput() throws IOException { blobContainer.writeBlob(request.blobName, bytesReference, failIfExists); } } else { - blobContainer.writeBlob( - request.blobName, - repository.maybeRateLimitSnapshots( - new RandomBlobContentStream(content, request.getTargetLength()), - throttledNanos::addAndGet - ), - request.targetLength, - failIfExists - ); + cancellableThreads.execute(() -> { + try { + blobContainer.writeBlob( + request.blobName, + repository.maybeRateLimitSnapshots( + new RandomBlobContentStream(content, request.getTargetLength()), + throttledNanos::addAndGet + ), + request.targetLength, + failIfExists + ); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); } final long elapsedNanos = System.nanoTime() - startNanos; final long checksum = content.getChecksum(checksumStart, checksumEnd); @@ -621,7 +630,7 @@ private WriteDetails(long bytesWritten, long elapsedNanos, long throttledNanos, } } - public static class Request extends ActionRequest implements TaskAwareRequest { + public static class Request extends ActionRequest { private final String repositoryName; private final String blobPath; private final String blobName; diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java index 14e760875c9c7..96828fd5a4c04 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/GetBlobChecksumAction.java @@ -28,7 +28,6 @@ import org.elasticsearch.repositories.blobstore.BlobStoreRepository; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -186,7 +185,7 @@ protected void doExecute(Task task, Request request, ActionListener li } - public static class Request extends ActionRequest implements TaskAwareRequest { + public static class Request extends ActionRequest { private final String repositoryName; private final String blobPath;