Skip to content

Commit

Permalink
Support hierarchical task cancellation (elastic#54757)
Browse files Browse the repository at this point in the history
With this change, when a task is canceled, the task manager will cancel
not only its direct child tasks but all also its descendant tasks.

Closes elastic#50990
  • Loading branch information
dnhatn committed Apr 13, 2020
1 parent 51c6f69 commit 96bb116
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 260 deletions.
2 changes: 1 addition & 1 deletion docs/reference/cluster/tasks.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ nodes `nodeId1` and `nodeId2`.

`wait_for_completion`::
(Optional, boolean) If `true`, the request blocks until the cancellation of the
task and its child tasks is completed. Otherwise, the request can return soon
task and its descendant tasks is completed. Otherwise, the request can return soon
after the cancellation is started. Defaults to `false`.

[source,console]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
},
"wait_for_completion": {
"type":"boolean",
"description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false"
"description":"Should the request block until the cancellation of the task and its descendant tasks is completed. Defaults to false"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public String getReason() {
}

/**
* If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed.
* If {@code true}, the request blocks until the cancellation of the task and its descendant tasks is completed.
* Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}.
*/
public void setWaitForCompletion(boolean waitForCompletion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
package org.elasticsearch.action.admin.cluster.node.tasks.cancel;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -104,34 +106,43 @@ protected void processTasks(CancelTasksRequest request, Consumer<CancellableTask
@Override
protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener<TaskInfo> listener) {
String nodeId = clusterService.localNode().getId();
if (cancellableTask.shouldCancelChildrenOnCancellation()) {
cancelTaskAndDescendants(cancellableTask, request.getReason(), request.waitForCompletion(),
ActionListener.map(listener, r -> cancellableTask.taskInfo(nodeId, false)));
}

void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
if (task.shouldCancelChildrenOnCancellation()) {
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
Collection<DiscoveryNode> childrenNodes =
taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null));
taskManager.startBanOnChildrenNodes(task.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(task, reason, () -> groupedListener.onResponse(null));

StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener);
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
completedListener.whenComplete(
r -> removeBanOnNodes(cancellableTask, childrenNodes),
e -> removeBanOnNodes(cancellableTask, childrenNodes));
// if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
completedListener.whenComplete(r -> removeBanOnNodes(task, childrenNodes), e -> removeBanOnNodes(task, childrenNodes));
// if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
if (request.waitForCompletion()) {
completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
if (waitForCompletion) {
completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
} else {
banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
}
} else {
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId());
taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)));
logger.trace("task {} doesn't have any children that should be cancelled", task.getId());
if (waitForCompletion) {
taskManager.cancel(task, reason, () -> listener.onResponse(null));
} else {
taskManager.cancel(task, reason, () -> {});
listener.onResponse(null);
}
}
}

private void setBanOnNodes(String reason, CancellableTask task, Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) {
listener.onResponse(null);
return;
Expand All @@ -140,7 +151,7 @@ private void setBanOnNodes(String reason, CancellableTask task, Collection<Disco
GroupedActionListener<Void> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(
new TaskId(clusterService.localNode().getId(), task.getId()), reason);
new TaskId(clusterService.localNode().getId(), task.getId()), reason, waitForCompletion);
for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
Expand Down Expand Up @@ -171,33 +182,41 @@ private static class BanParentTaskRequest extends TransportRequest {

private final TaskId parentTaskId;
private final boolean ban;
private final boolean waitForCompletion;
private final String reason;

static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason) {
return new BanParentTaskRequest(parentTaskId, reason);
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
}

static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
return new BanParentTaskRequest(parentTaskId);
}

private BanParentTaskRequest(TaskId parentTaskId, String reason) {
private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
this.parentTaskId = parentTaskId;
this.ban = true;
this.reason = reason;
this.waitForCompletion = waitForCompletion;
}

private BanParentTaskRequest(TaskId parentTaskId) {
this.parentTaskId = parentTaskId;
this.ban = false;
this.reason = null;
this.waitForCompletion = false;
}

private BanParentTaskRequest(StreamInput in) throws IOException {
super(in);
parentTaskId = TaskId.readFromStream(in);
ban = in.readBoolean();
reason = ban ? in.readString() : null;
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
waitForCompletion = in.readBoolean();
} else {
waitForCompletion = false;
}
}

@Override
Expand All @@ -208,6 +227,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (ban) {
out.writeString(reason);
}
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
out.writeBoolean(waitForCompletion);
}
}
}

Expand All @@ -217,13 +239,20 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
if (request.ban) {
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
clusterService.localNode().getId(), request.reason);
taskManager.setBan(request.parentTaskId, request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(ActionListener.map(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1);
for (CancellableTask childTask : childTasks) {
cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
}
listener.onResponse(null);
} else {
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId,
clusterService.localNode().getId());
taskManager.removeBan(request.parentTaskId);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
}

Expand Down
22 changes: 8 additions & 14 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,9 @@ public int getBanCount() {
* Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing.
* <p>
* This method is called when a parent task that has children is cancelled.
* @return a list of pending cancellable child tasks
*/
public void setBan(TaskId parentTaskId, String reason) {
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);

// Set the ban first, so the newly created tasks cannot be registered
Expand All @@ -344,14 +345,10 @@ public void setBan(TaskId parentTaskId, String reason) {
banedParents.put(parentTaskId, reason);
}
}

// Now go through already running tasks and cancel them
for (Map.Entry<Long, CancellableTaskHolder> taskEntry : cancellableTasks.entrySet()) {
CancellableTaskHolder holder = taskEntry.getValue();
if (holder.hasParent(parentTaskId)) {
holder.cancel(reason);
}
}
return cancellableTasks.values().stream()
.filter(t -> t.hasParent(parentTaskId))
.map(t -> t.task)
.collect(Collectors.toList());
}

/**
Expand All @@ -365,11 +362,8 @@ public void removeBan(TaskId parentTaskId) {
}

// for testing
public boolean childTasksCancelledOrBanned(TaskId parentTaskId) {
if (banedParents.containsKey(parentTaskId)) {
return true;
}
return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId));
public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(banedParents.keySet());
}

/**
Expand Down
Loading

0 comments on commit 96bb116

Please sign in to comment.