Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hierarchical task cancellation #54757

Merged
merged 3 commits into from
Apr 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/reference/cluster/tasks.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ nodes `nodeId1` and `nodeId2`.

`wait_for_completion`::
(Optional, boolean) If `true`, the request blocks until the cancellation of the
task and its child tasks is completed. Otherwise, the request can return soon
task and its descendant tasks is completed. Otherwise, the request can return soon
after the cancellation is started. Defaults to `false`.

[source,console]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
},
"wait_for_completion": {
"type":"boolean",
"description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false"
"description":"Should the request block until the cancellation of the task and its descendant tasks is completed. Defaults to false"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public String getReason() {
}

/**
* If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed.
* If {@code true}, the request blocks until the cancellation of the task and its descendant tasks is completed.
* Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}.
*/
public void setWaitForCompletion(boolean waitForCompletion) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
package org.elasticsearch.action.admin.cluster.node.tasks.cancel;

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -104,34 +106,43 @@ protected void processTasks(CancelTasksRequest request, Consumer<CancellableTask
@Override
protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener<TaskInfo> listener) {
String nodeId = clusterService.localNode().getId();
if (cancellableTask.shouldCancelChildrenOnCancellation()) {
cancelTaskAndDescendants(cancellableTask, request.getReason(), request.waitForCompletion(),
ActionListener.map(listener, r -> cancellableTask.taskInfo(nodeId, false)));
}

void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
if (task.shouldCancelChildrenOnCancellation()) {
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3);
Collection<DiscoveryNode> childrenNodes =
taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null));
taskManager.startBanOnChildrenNodes(task.getId(), () -> groupedListener.onResponse(null));
taskManager.cancel(task, reason, () -> groupedListener.onResponse(null));

StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener);
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
completedListener.whenComplete(
r -> removeBanOnNodes(cancellableTask, childrenNodes),
e -> removeBanOnNodes(cancellableTask, childrenNodes));
// if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
completedListener.whenComplete(r -> removeBanOnNodes(task, childrenNodes), e -> removeBanOnNodes(task, childrenNodes));
// if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
if (request.waitForCompletion()) {
completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
if (waitForCompletion) {
completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
} else {
banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure);
banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
}
} else {
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId());
taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)));
logger.trace("task {} doesn't have any children that should be cancelled", task.getId());
if (waitForCompletion) {
taskManager.cancel(task, reason, () -> listener.onResponse(null));
} else {
taskManager.cancel(task, reason, () -> {});
listener.onResponse(null);
}
}
}

private void setBanOnNodes(String reason, CancellableTask task, Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) {
listener.onResponse(null);
return;
Expand All @@ -140,7 +151,7 @@ private void setBanOnNodes(String reason, CancellableTask task, Collection<Disco
GroupedActionListener<Void> groupedListener =
new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(
new TaskId(clusterService.localNode().getId(), task.getId()), reason);
new TaskId(clusterService.localNode().getId(), task.getId()), reason, waitForCompletion);
for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
Expand Down Expand Up @@ -171,33 +182,41 @@ private static class BanParentTaskRequest extends TransportRequest {

private final TaskId parentTaskId;
private final boolean ban;
private final boolean waitForCompletion;
private final String reason;

static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason) {
return new BanParentTaskRequest(parentTaskId, reason);
static BanParentTaskRequest createSetBanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
return new BanParentTaskRequest(parentTaskId, reason, waitForCompletion);
}

static BanParentTaskRequest createRemoveBanParentTaskRequest(TaskId parentTaskId) {
return new BanParentTaskRequest(parentTaskId);
}

private BanParentTaskRequest(TaskId parentTaskId, String reason) {
private BanParentTaskRequest(TaskId parentTaskId, String reason, boolean waitForCompletion) {
this.parentTaskId = parentTaskId;
this.ban = true;
this.reason = reason;
this.waitForCompletion = waitForCompletion;
}

private BanParentTaskRequest(TaskId parentTaskId) {
this.parentTaskId = parentTaskId;
this.ban = false;
this.reason = null;
this.waitForCompletion = false;
}

private BanParentTaskRequest(StreamInput in) throws IOException {
super(in);
parentTaskId = TaskId.readFromStream(in);
ban = in.readBoolean();
reason = ban ? in.readString() : null;
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
waitForCompletion = in.readBoolean();
} else {
waitForCompletion = false;
}
}

@Override
Expand All @@ -208,6 +227,9 @@ public void writeTo(StreamOutput out) throws IOException {
if (ban) {
out.writeString(reason);
}
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
out.writeBoolean(waitForCompletion);
}
}
}

Expand All @@ -217,13 +239,20 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
if (request.ban) {
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
clusterService.localNode().getId(), request.reason);
taskManager.setBan(request.parentTaskId, request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(ActionListener.map(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request), r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1);
for (CancellableTask childTask : childTasks) {
cancelTaskAndDescendants(childTask, request.reason, request.waitForCompletion, listener);
}
listener.onResponse(null);
} else {
logger.debug("Removing ban for the parent [{}] on the node [{}]", request.parentTaskId,
clusterService.localNode().getId());
taskManager.removeBan(request.parentTaskId);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
}

Expand Down
22 changes: 8 additions & 14 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,9 @@ public int getBanCount() {
* Bans all tasks with the specified parent task from execution, cancels all tasks that are currently executing.
* <p>
* This method is called when a parent task that has children is cancelled.
* @return a list of pending cancellable child tasks
*/
public void setBan(TaskId parentTaskId, String reason) {
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);

// Set the ban first, so the newly created tasks cannot be registered
Expand All @@ -377,14 +378,10 @@ public void setBan(TaskId parentTaskId, String reason) {
banedParents.put(parentTaskId, reason);
}
}

// Now go through already running tasks and cancel them
for (Map.Entry<Long, CancellableTaskHolder> taskEntry : cancellableTasks.entrySet()) {
CancellableTaskHolder holder = taskEntry.getValue();
if (holder.hasParent(parentTaskId)) {
holder.cancel(reason);
}
}
return cancellableTasks.values().stream()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we need to eventually optimize this, in case there might be a very large list of cancellable tasks. Also, this does not have proper happens-before, as iterating a concurrent map after an element has been added is not guaranteed to yield the element (it's eventually consistent).

.filter(t -> t.hasParent(parentTaskId))
.map(t -> t.task)
.collect(Collectors.toUnmodifiableList());
}

/**
Expand All @@ -398,11 +395,8 @@ public void removeBan(TaskId parentTaskId) {
}

// for testing
public boolean childTasksCancelledOrBanned(TaskId parentTaskId) {
if (banedParents.containsKey(parentTaskId)) {
return true;
}
return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId));
public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(banedParents.keySet());
}

/**
Expand Down
Loading