From 96bb1164f03a4b9ef19e966680b11f8ad123231f Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Mon, 6 Apr 2020 12:00:02 -0400 Subject: [PATCH] Support hierarchical task cancellation (#54757) With this change, when a task is canceled, the task manager will cancel not only its direct child tasks but all also its descendant tasks. Closes #50990 --- docs/reference/cluster/tasks.asciidoc | 2 +- .../rest-api-spec/api/tasks.cancel.json | 2 +- .../node/tasks/cancel/CancelTasksRequest.java | 2 +- .../cancel/TransportCancelTasksAction.java | 69 ++- .../org/elasticsearch/tasks/TaskManager.java | 22 +- .../node/tasks/CancellableTasksIT.java | 451 +++++++++--------- 6 files changed, 288 insertions(+), 260 deletions(-) diff --git a/docs/reference/cluster/tasks.asciidoc b/docs/reference/cluster/tasks.asciidoc index 197da4a09ef56..af8f600c10df7 100644 --- a/docs/reference/cluster/tasks.asciidoc +++ b/docs/reference/cluster/tasks.asciidoc @@ -234,7 +234,7 @@ nodes `nodeId1` and `nodeId2`. `wait_for_completion`:: (Optional, boolean) If `true`, the request blocks until the cancellation of the -task and its child tasks is completed. Otherwise, the request can return soon +task and its descendant tasks is completed. Otherwise, the request can return soon after the cancellation is started. Defaults to `false`. [source,console] diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json b/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json index 6c3d424050b50..966197f4e76d4 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json @@ -42,7 +42,7 @@ }, "wait_for_completion": { "type":"boolean", - "description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false" + "description":"Should the request block until the cancellation of the task and its descendant tasks is completed. Defaults to false" } } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java index 712f6bd65889b..94bff2abefe4b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java @@ -79,7 +79,7 @@ public String getReason() { } /** - * If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed. + * If {@code true}, the request blocks until the cancellation of the task and its descendant tasks is completed. * Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}. */ public void setWaitForCompletion(boolean waitForCompletion) { diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java index b2130dae9e568..1d2300c6996cd 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java @@ -20,11 +20,13 @@ package org.elasticsearch.action.admin.cluster.node.tasks.cancel; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.StepListener; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -104,34 +106,43 @@ protected void processTasks(CancelTasksRequest request, Consumer listener) { String nodeId = clusterService.localNode().getId(); - if (cancellableTask.shouldCancelChildrenOnCancellation()) { + cancelTaskAndDescendants(cancellableTask, request.getReason(), request.waitForCompletion(), + ActionListener.map(listener, r -> cancellableTask.taskInfo(nodeId, false))); + } + + void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener listener) { + if (task.shouldCancelChildrenOnCancellation()) { StepListener completedListener = new StepListener<>(); GroupedActionListener groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3); Collection childrenNodes = - taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null)); - taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null)); + taskManager.startBanOnChildrenNodes(task.getId(), () -> groupedListener.onResponse(null)); + taskManager.cancel(task, reason, () -> groupedListener.onResponse(null)); StepListener banOnNodesListener = new StepListener<>(); - setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener); + setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener); banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure); // We remove bans after all child tasks are completed although in theory we can do it on a per-node basis. - completedListener.whenComplete( - r -> removeBanOnNodes(cancellableTask, childrenNodes), - e -> removeBanOnNodes(cancellableTask, childrenNodes)); - // if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are + completedListener.whenComplete(r -> removeBanOnNodes(task, childrenNodes), e -> removeBanOnNodes(task, childrenNodes)); + // if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are // completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes. - if (request.waitForCompletion()) { - completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure); + if (waitForCompletion) { + completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure); } else { - banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure); + banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure); } } else { - logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId()); - taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false))); + logger.trace("task {} doesn't have any children that should be cancelled", task.getId()); + if (waitForCompletion) { + taskManager.cancel(task, reason, () -> listener.onResponse(null)); + } else { + taskManager.cancel(task, reason, () -> {}); + listener.onResponse(null); + } } } - private void setBanOnNodes(String reason, CancellableTask task, Collection childNodes, ActionListener listener) { + private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task, + Collection childNodes, ActionListener listener) { if (childNodes.isEmpty()) { listener.onResponse(null); return; @@ -140,7 +151,7 @@ private void setBanOnNodes(String reason, CancellableTask task, Collection groupedListener = new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size()); final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest( - new TaskId(clusterService.localNode().getId(), task.getId()), reason); + new TaskId(clusterService.localNode().getId(), task.getId()), reason, waitForCompletion); for (DiscoveryNode node : childNodes) { transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { @@ -171,26 +182,29 @@ private static class BanParentTaskRequest extends TransportRequest { private final TaskId parentTaskId; private final boolean ban; + private final boolean waitForCompletion; private final String reason; - static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason) { - return new BanParentTaskRequest(parentTaskId, reason); + static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) { + return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion); } static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) { return new BanParentTaskRequest(parentTaskId); } - private BanParentTaskRequest(TaskId parentTaskId, String reason) { + private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) { this.parentTaskId = parentTaskId; this.ban = true; this.reason = reason; + this.waitForCompletion = waitForCompletion; } private BanParentTaskRequest(TaskId parentTaskId) { this.parentTaskId = parentTaskId; this.ban = false; this.reason = null; + this.waitForCompletion = false; } private BanParentTaskRequest(StreamInput in) throws IOException { @@ -198,6 +212,11 @@ private BanParentTaskRequest(StreamInput in) throws IOException { parentTaskId = TaskId.readFromStream(in); ban = in.readBoolean(); reason = ban ? in.readString() : null; + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + waitForCompletion = in.readBoolean(); + } else { + waitForCompletion = false; + } } @Override @@ -208,6 +227,9 @@ public void writeTo(StreamOutput out) throws IOException { if (ban) { out.writeString(reason); } + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + out.writeBoolean(waitForCompletion); + } } } @@ -217,13 +239,20 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC if (request.ban) { logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId, clusterService.localNode().getId(), request.reason); - taskManager.setBan(request.parentTaskId, request.reason); + final List childTasks = taskManager.setBan(request.parentTaskId, request.reason); + final GroupedActionListener listener = new GroupedActionListener<>(ActionListener.map( + new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE), + childTasks.size() + 1); + for (CancellableTask childTask : childTasks) { + cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener); + } + listener.onResponse(null); } else { logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId, clusterService.localNode().getId()); taskManager.removeBan(request.parentTaskId); + channel.sendResponse(TransportResponse.Empty.INSTANCE); } - channel.sendResponse(TransportResponse.Empty.INSTANCE); } } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index f60d4aa79008a..dc6f126c18adb 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -333,8 +333,9 @@ public int getBanCount() { * Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing. *

* This method is called when a parent task that has children is cancelled. + * @return a list of pending cancellable child tasks */ - public void setBan(TaskId parentTaskId, String reason) { + public List setBan(TaskId parentTaskId, String reason) { logger.trace("setting ban for the parent task {} {}", parentTaskId, reason); // Set the ban first, so the newly created tasks cannot be registered @@ -344,14 +345,10 @@ public void setBan(TaskId parentTaskId, String reason) { banedParents.put(parentTaskId, reason); } } - - // Now go through already running tasks and cancel them - for (Map.Entry taskEntry : cancellableTasks.entrySet()) { - CancellableTaskHolder holder = taskEntry.getValue(); - if (holder.hasParent(parentTaskId)) { - holder.cancel(reason); - } - } + return cancellableTasks.values().stream() + .filter(t -> t.hasParent(parentTaskId)) + .map(t -> t.task) + .collect(Collectors.toList()); } /** @@ -365,11 +362,8 @@ public void removeBan(TaskId parentTaskId) { } // for testing - public boolean childTasksCancelledOrBanned(TaskId parentTaskId) { - if (banedParents.containsKey(parentTaskId)) { - return true; - } - return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId)); + public Set getBannedTaskIds() { + return Collections.unmodifiableSet(banedParents.keySet()); } /** diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java index f0ccb79876004..ba04992d55af3 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java @@ -41,12 +41,14 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -57,97 +59,147 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.StreamSupport; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; public class CancellableTasksIT extends ESIntegTestCase { - static final Map arrivedLatches = ConcurrentCollections.newConcurrentMap(); - static final Map beforeExecuteLatches = ConcurrentCollections.newConcurrentMap(); - static final Map completedLatches = ConcurrentCollections.newConcurrentMap(); + + static int idGenerator = 0; + static final Map beforeSendLatches = ConcurrentCollections.newConcurrentMap(); + static final Map arrivedLatches = ConcurrentCollections.newConcurrentMap(); + static final Map beforeExecuteLatches = ConcurrentCollections.newConcurrentMap(); + static final Map completedLatches = ConcurrentCollections.newConcurrentMap(); @Before public void resetTestStates() { + idGenerator = 0; + beforeSendLatches.clear(); arrivedLatches.clear(); beforeExecuteLatches.clear(); completedLatches.clear(); } - List setupChildRequests(Set nodes) { - int numRequests = randomIntBetween(1, 30); - List childRequests = new ArrayList<>(); - for (int i = 0; i < numRequests; i++) { - ChildRequest req = new ChildRequest(i, randomFrom(nodes)); - childRequests.add(req); - arrivedLatches.put(req, new CountDownLatch(1)); - beforeExecuteLatches.put(req, new CountDownLatch(1)); - completedLatches.put(req, new CountDownLatch(1)); - } - return childRequests; + static TestRequest generateTestRequest(Set nodes, int level, int maxLevel) { + List subRequests = new ArrayList<>(); + int lower = level == 0 ? 1 : 0; + int upper = 10 / (level + 1); + int numOfSubRequests = randomIntBetween(lower, upper); + for (int i = 0; i < numOfSubRequests && level <= maxLevel; i++) { + subRequests.add(generateTestRequest(nodes, level + 1, maxLevel)); + } + final TestRequest request = new TestRequest(idGenerator++, randomFrom(nodes), subRequests); + beforeSendLatches.put(request, new CountDownLatch(1)); + arrivedLatches.put(request, new CountDownLatch(1)); + beforeExecuteLatches.put(request, new CountDownLatch(1)); + completedLatches.put(request, new CountDownLatch(1)); + return request; } - public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception { + static void randomDescendants(TestRequest request, Set result) { if (randomBoolean()) { - internalCluster().startNodes(randomIntBetween(1, 3)); + result.add(request); + for (TestRequest subReq : request.subRequests) { + randomDescendants(subReq, result); + } } - Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); - List childRequests = setupChildRequests(nodes); - ActionFuture mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); - List completedRequests = randomSubsetOf(between(0, childRequests.size() - 1), childRequests); - for (ChildRequest req : completedRequests) { + } + + /** + * Allow some parts of the request to be completed + * @return a pending child requests + */ + static Set allowPartialRequest(TestRequest request) throws Exception { + final Set sentRequests = new HashSet<>(); + while (sentRequests.isEmpty()) { + for (TestRequest subRequest : request.subRequests) { + randomDescendants(subRequest, sentRequests); + } + } + for (TestRequest req : sentRequests) { + beforeSendLatches.get(req).countDown(); + } + for (TestRequest req : sentRequests) { + arrivedLatches.get(req).await(); + } + Set completedRequests = new HashSet<>(); + for (TestRequest req : randomSubsetOf(sentRequests)) { + if (sentRequests.containsAll(req.descendants())) { + completedRequests.add(req); + completedRequests.addAll(req.descendants()); + } + } + for (TestRequest req : completedRequests) { beforeExecuteLatches.get(req).countDown(); + } + for (TestRequest req : completedRequests) { completedLatches.get(req).await(); } - List outstandingRequests = childRequests.stream(). - filter(r -> completedRequests.contains(r) == false) - .collect(Collectors.toList()); - for (ChildRequest req : outstandingRequests) { - arrivedLatches.get(req).await(); + return Sets.difference(sentRequests, completedRequests); + } + + static void allowEntireRequest(TestRequest request) { + beforeSendLatches.get(request).countDown(); + beforeExecuteLatches.get(request).countDown(); + for (TestRequest subReq : request.subRequests) { + allowEntireRequest(subReq); + } + } + + public void testBanOnlyNodesWithOutstandingDescendantTasks() throws Exception { + if (randomBoolean()) { + internalCluster().startNodes(randomIntBetween(1, 3)); + } + Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); + final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4)); + ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); + Set pendingRequests = allowPartialRequest(rootRequest); + TaskId rootTaskId = getRootTaskId(rootRequest); + ActionFuture cancelFuture = client().admin().cluster().prepareCancelTasks() + .setTaskId(rootTaskId).waitForCompletion(true).execute(); + if (randomBoolean()) { + List runningTasks = client().admin().cluster().prepareListTasks() + .setActions(TransportTestAction.ACTION.name()).setDetailed(true).get().getTasks(); + for (TaskInfo subTask : randomSubsetOf(runningTasks)) { + client().admin().cluster().prepareCancelTasks().setTaskId(subTask.getTaskId()).waitForCompletion(false).get(); + } } - TaskId taskId = getMainTaskId(); - ActionFuture cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) - .waitForCompletion(true).execute(); - Set nodesWithOutstandingChildTask = outstandingRequests.stream().map(r -> r.targetNode).collect(Collectors.toSet()); assertBusy(() -> { for (DiscoveryNode node : nodes) { TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager(); - if (nodesWithOutstandingChildTask.contains(node)) { - assertThat(taskManager.getBanCount(), equalTo(1)); - } else { - assertThat(taskManager.getBanCount(), equalTo(0)); + Set expectedBans = new HashSet<>(); + for (TestRequest req : pendingRequests) { + if (req.node.equals(node)) { + List childTasks = taskManager.getTasks().values().stream() + .filter(t -> t.getParentTaskId() != null && t.getDescription().equals(req.taskDescription())) + .collect(Collectors.toList()); + assertThat(childTasks, hasSize(1)); + CancellableTask childTask = (CancellableTask) childTasks.get(0); + assertTrue(childTask.isCancelled()); + expectedBans.add(childTask.getParentTaskId()); + } } + assertThat(taskManager.getBannedTaskIds(), equalTo(expectedBans)); } }); - // failed to spawn child tasks after cancelling - if (randomBoolean()) { - DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get(); - TransportMainAction mainAction = internalCluster().getInstance(TransportMainAction.class, nodeWithParentTask.getName()); - PlainActionFuture future = new PlainActionFuture<>(); - ChildRequest req = new ChildRequest(-1, randomFrom(nodes)); - completedLatches.put(req, new CountDownLatch(1)); - mainAction.startChildTask(taskId, req, future); - TransportException te = expectThrows(TransportException.class, future::actionGet); - assertThat(te.getCause(), instanceOf(TaskCancelledException.class)); - assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks")); - } - for (ChildRequest req : outstandingRequests) { - beforeExecuteLatches.get(req).countDown(); - } + allowEntireRequest(rootRequest); cancelFuture.actionGet(); - waitForMainTask(mainTaskFuture); + waitForRootTask(rootTaskFuture); assertBusy(() -> { for (DiscoveryNode node : nodes) { TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager(); @@ -158,27 +210,20 @@ public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception { public void testCancelTaskMultipleTimes() throws Exception { Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); - List childRequests = setupChildRequests(nodes); - ActionFuture mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); - for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) { - arrivedLatches.get(r).await(); - } - TaskId taskId = getMainTaskId(); + TestRequest rootRequest = generateTestRequest(nodes, 0, randomIntBetween(1, 3)); + ActionFuture mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); + TaskId taskId = getRootTaskId(rootRequest); + allowPartialRequest(rootRequest); + CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get(); + assertThat(resp.getTaskFailures(), empty()); + assertThat(resp.getNodeFailures(), empty()); ActionFuture cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) .waitForCompletion(true).execute(); - ensureChildTasksCancelledOrBanned(taskId); - if (randomBoolean()) { - CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get(); - assertThat(resp.getTaskFailures(), empty()); - assertThat(resp.getNodeFailures(), empty()); - } assertFalse(cancelFuture.isDone()); - for (ChildRequest r : childRequests) { - beforeExecuteLatches.get(r).countDown(); - } + allowEntireRequest(rootRequest); assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); - waitForMainTask(mainTaskFuture); + waitForRootTask(mainTaskFuture); CancelTasksResponse cancelError = client().admin().cluster().prepareCancelTasks() .setTaskId(taskId).waitForCompletion(randomBoolean()).get(); assertThat(cancelError.getNodeFailures(), hasSize(1)); @@ -188,12 +233,12 @@ public void testCancelTaskMultipleTimes() throws Exception { public void testDoNotWaitForCompletion() throws Exception { Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); - List childRequests = setupChildRequests(nodes); - ActionFuture mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); - for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) { - arrivedLatches.get(r).await(); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + ActionFuture mainTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); + TaskId taskId = getRootTaskId(rootRequest); + if (randomBoolean()) { + allowPartialRequest(rootRequest); } - TaskId taskId = getMainTaskId(); boolean waitForCompletion = randomBoolean(); ActionFuture cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) .waitForCompletion(waitForCompletion).execute(); @@ -202,145 +247,130 @@ public void testDoNotWaitForCompletion() throws Exception { } else { assertBusy(() -> assertTrue(cancelFuture.isDone())); } - for (ChildRequest r : childRequests) { - beforeExecuteLatches.get(r).countDown(); - } - waitForMainTask(mainTaskFuture); + allowEntireRequest(rootRequest); + waitForRootTask(mainTaskFuture); + } + + public void testFailedToStartChildTaskAfterCancelled() { + Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); + TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 3)); + ActionFuture rootTaskFuture = client().execute(TransportTestAction.ACTION, rootRequest); + TaskId taskId = getRootTaskId(rootRequest); + client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get(); + DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get(); + TransportTestAction mainAction = internalCluster().getInstance(TransportTestAction.class, nodeWithParentTask.getName()); + PlainActionFuture future = new PlainActionFuture<>(); + TestRequest subRequest = generateTestRequest(nodes, 0, between(0, 1)); + beforeSendLatches.get(subRequest).countDown(); + mainAction.startSubTask(taskId, subRequest, future); + TransportException te = expectThrows(TransportException.class, future::actionGet); + assertThat(te.getCause(), instanceOf(TaskCancelledException.class)); + assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks")); + allowEntireRequest(rootRequest); + waitForRootTask(rootTaskFuture); } - TaskId getMainTaskId() { + static TaskId getRootTaskId(TestRequest request) { ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks() - .setActions(TransportMainAction.ACTION.name()).setDetailed(true).get(); - assertThat(listTasksResponse.getTasks(), hasSize(1)); - return listTasksResponse.getTasks().get(0).getTaskId(); + .setActions(TransportTestAction.ACTION.name()).setDetailed(true).get(); + List tasks = listTasksResponse.getTasks().stream() + .filter(t -> t.getDescription().equals(request.taskDescription())) + .collect(Collectors.toList()); + assertThat(tasks, hasSize(1)); + return tasks.get(0).getTaskId(); } - void waitForMainTask(ActionFuture mainTask) { + static void waitForRootTask(ActionFuture rootTask) { try { - mainTask.actionGet(); + rootTask.actionGet(); } catch (Exception e) { final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class); - assertThat(cause.getMessage(), - either(equalTo("The parent task was cancelled, shouldn't start any child tasks")) - .or(containsString("Task cancelled before it started:"))); + assertThat(cause.getMessage(), anyOf( + equalTo("The parent task was cancelled, shouldn't start any child tasks"), + containsString("Task cancelled before it started:"), + equalTo("Task was cancelled while executing"))); } } - public static class MainRequest extends ActionRequest { - final List childRequests; + static class TestRequest extends ActionRequest { + final int id; + final DiscoveryNode node; + final List subRequests; - public MainRequest(List childRequests) { - this.childRequests = childRequests; + TestRequest(int id, DiscoveryNode node, List subRequests) { + this.id = id; + this.node = node; + this.subRequests = subRequests; } - public MainRequest(StreamInput in) throws IOException { + TestRequest(StreamInput in) throws IOException { super(in); - this.childRequests = in.readList(ChildRequest::new); - } - - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeList(childRequests); - } - - @Override - public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { - @Override - public boolean shouldCancelChildrenOnCancellation() { - return true; - } - }; - } - } - - public static class MainResponse extends ActionResponse { - public MainResponse() { + this.id = in.readInt(); + this.node = new DiscoveryNode(in); + this.subRequests = in.readList(TestRequest::new); } - public MainResponse(StreamInput in) throws IOException { - super(in); + List descendants() { + List descendants = new ArrayList<>(); + for (TestRequest subRequest : subRequests) { + descendants.add(subRequest); + descendants.addAll(subRequest.descendants()); + } + return descendants; } @Override - public void writeTo(StreamOutput out) throws IOException { - - } - } - - public static class ChildRequest extends ActionRequest { - final int id; - final DiscoveryNode targetNode; - - public ChildRequest(int id, DiscoveryNode targetNode) { - this.id = id; - this.targetNode = targetNode; - } - - - public ChildRequest(StreamInput in) throws IOException { - super(in); - this.id = in.readInt(); - this.targetNode = new DiscoveryNode(in); + public ActionRequestValidationException validate() { + return null; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeInt(id); - targetNode.writeTo(out); + node.writeTo(out); + out.writeList(subRequests); } @Override - public ActionRequestValidationException validate() { - return null; + public String getDescription() { + return taskDescription(); } - @Override - public String getDescription() { - return "childTask[" + id + "]"; + String taskDescription() { + return "id=" + id; } @Override public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - if (randomBoolean()) { - boolean shouldCancelChildrenOnCancellation = randomBoolean(); - return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { - @Override - public boolean shouldCancelChildrenOnCancellation() { - return shouldCancelChildrenOnCancellation; - } - }; - } else { - return super.createTask(id, type, action, parentTaskId, headers); - } + return new CancellableTask(id, type, action, taskDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + }; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - ChildRequest that = (ChildRequest) o; - return id == that.id && targetNode.equals(that.targetNode); + TestRequest that = (TestRequest) o; + return id == that.id; } @Override public int hashCode() { - return Objects.hash(id, targetNode); + return Objects.hash(id); } } - public static class ChildResponse extends ActionResponse { - public ChildResponse() { + public static class TestResponse extends ActionResponse { + public TestResponse() { + } - public ChildResponse(StreamInput in) throws IOException { + public TestResponse(StreamInput in) throws IOException { super(in); } @@ -350,33 +380,44 @@ public void writeTo(StreamOutput out) throws IOException { } } - public static class TransportMainAction extends HandledTransportAction { + public static class TransportTestAction extends HandledTransportAction { - public static ActionType ACTION = new ActionType<>("internal::main_action", MainResponse::new); + static AtomicInteger counter = new AtomicInteger(); + public static ActionType ACTION = new ActionType<>("internal::test_action", TestResponse::new); private final TransportService transportService; private final NodeClient client; @Inject - public TransportMainAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) { - super(ACTION.name(), transportService, actionFilters, MainRequest::new, ThreadPool.Names.GENERIC); + public TransportTestAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) { + super(ACTION.name(), transportService, actionFilters, TestRequest::new, ThreadPool.Names.GENERIC); this.transportService = transportService; this.client = client; } @Override - protected void doExecute(Task task, MainRequest request, ActionListener listener) { - GroupedActionListener groupedListener = - new GroupedActionListener<>(ActionListener.map(listener, r -> new MainResponse()), request.childRequests.size()); - for (ChildRequest childRequest : request.childRequests) { + protected void doExecute(Task task, TestRequest request, ActionListener listener) { + arrivedLatches.get(request).countDown(); + List subRequests = request.subRequests; + GroupedActionListener groupedListener = + new GroupedActionListener<>(ActionListener.map(listener, r -> new TestResponse()), subRequests.size() + 1); + transportService.getThreadPool().generic().execute(ActionRunnable.supply(groupedListener, () -> { + beforeExecuteLatches.get(request).await(); + if (((CancellableTask) task).isCancelled()) { + throw new TaskCancelledException("Task was cancelled while executing"); + } + counter.incrementAndGet(); + return new TestResponse(); + })); + for (TestRequest subRequest : subRequests) { TaskId parentTaskId = new TaskId(client.getLocalNodeId(), task.getId()); - startChildTask(parentTaskId, childRequest, groupedListener); + startSubTask(parentTaskId, subRequest, groupedListener); } } - protected void startChildTask(TaskId parentTaskId, ChildRequest childRequest, ActionListener listener) { - childRequest.setParentTask(parentTaskId); - final CountDownLatch completeLatch = completedLatches.get(childRequest); - LatchedActionListener latchedListener = new LatchedActionListener<>(listener, completeLatch); + protected void startSubTask(TaskId parentTaskId, TestRequest subRequest, ActionListener listener) { + subRequest.setParentTask(parentTaskId); + CountDownLatch completeLatch = completedLatches.get(subRequest); + LatchedActionListener latchedListener = new LatchedActionListener<>(listener, completeLatch); transportService.getThreadPool().generic().execute(new AbstractRunnable() { @Override public void onFailure(Exception e) { @@ -384,20 +425,20 @@ public void onFailure(Exception e) { } @Override - protected void doRun() { - if (client.getLocalNodeId().equals(childRequest.targetNode.getId()) && randomBoolean()) { + protected void doRun() throws Exception { + beforeSendLatches.get(subRequest).await(); + if (client.getLocalNodeId().equals(subRequest.node.getId()) && randomBoolean()) { try { - client.executeLocally(TransportChildAction.ACTION, childRequest, latchedListener); + client.executeLocally(TransportTestAction.ACTION, subRequest, latchedListener); } catch (TaskCancelledException e) { latchedListener.onFailure(new TransportException(e)); } } else { - transportService.sendRequest(childRequest.targetNode, TransportChildAction.ACTION.name(), childRequest, - new TransportResponseHandler() { - + transportService.sendRequest(subRequest.node, ACTION.name(), subRequest, + new TransportResponseHandler() { @Override - public void handleResponse(ChildResponse response) { - latchedListener.onResponse(new ChildResponse()); + public void handleResponse(TestResponse response) { + latchedListener.onResponse(response); } @Override @@ -411,8 +452,8 @@ public String executor() { } @Override - public ChildResponse read(StreamInput in) throws IOException { - return new ChildResponse(in); + public TestResponse read(StreamInput in) throws IOException { + return new TestResponse(in); } }); } @@ -421,40 +462,16 @@ public ChildResponse read(StreamInput in) throws IOException { } } - public static class TransportChildAction extends HandledTransportAction { - public static ActionType ACTION = new ActionType<>("internal:child_action", ChildResponse::new); - private final TransportService transportService; - - - @Inject - public TransportChildAction(TransportService transportService, ActionFilters actionFilters) { - super(ACTION.name(), transportService, actionFilters, ChildRequest::new, ThreadPool.Names.GENERIC); - this.transportService = transportService; - } - - @Override - protected void doExecute(Task task, ChildRequest request, ActionListener listener) { - assertThat(request.targetNode, equalTo(transportService.getLocalNode())); - arrivedLatches.get(request).countDown(); - transportService.getThreadPool().executor(ThreadPool.Names.GENERIC).execute(ActionRunnable.supply(listener, () -> { - beforeExecuteLatches.get(request).await(); - return new ChildResponse(); - })); - } - } public static class TaskPlugin extends Plugin implements ActionPlugin { @Override public List> getActions() { - return Arrays.asList( - new ActionHandler<>(TransportMainAction.ACTION, TransportMainAction.class), - new ActionHandler<>(TransportChildAction.ACTION, TransportChildAction.class) - ); + return Collections.singletonList(new ActionHandler<>(TransportTestAction.ACTION, TransportTestAction.class)); } @Override public List> getClientActions() { - return Arrays.asList(TransportMainAction.ACTION, TransportChildAction.ACTION); + return Collections.singletonList(TransportTestAction.ACTION); } } @@ -471,16 +488,4 @@ protected Collection> transportClientPlugins() { plugins.add(TaskPlugin.class); return plugins; } - - /** - * Ensures that all outstanding child tasks of the given parent task are banned or being cancelled. - */ - protected static void ensureChildTasksCancelledOrBanned(TaskId taskId) throws Exception { - assertBusy(() -> { - for (String nodeName : internalCluster().getNodeNames()) { - final TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager(); - assertTrue(taskManager.childTasksCancelledOrBanned(taskId)); - } - }); - } }