From 98e65b19db393243dc59fee8bfd232d9481500f6 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 1 Apr 2020 11:22:13 -0400 Subject: [PATCH] Broadcast cancellation to only nodes have outstanding child tasks (#54312) Today when canceling a task we broadcast ban/unban requests to all nodes in the cluster. This strategy does not scale well for hierarchical cancellation. With this change, we will track outstanding child requests and broadcast the cancellation to only nodes that have outstanding child tasks. This change also prevents a parent task from sending child requests once it got canceled. Relates #50990 Supersedes #51157 Co-authored-by: Igor Motov Co-authored-by: Yannick Welsch --- .../client/TasksRequestConverters.java | 3 + .../client/tasks/CancelTasksRequest.java | 24 +- .../client/TasksRequestConvertersTests.java | 14 +- .../TasksClientDocumentationIT.java | 3 +- .../high-level/tasks/cancel_tasks.asciidoc | 4 +- docs/reference/cluster/tasks.asciidoc | 5 + .../rest-api-spec/api/tasks.cancel.json | 4 + .../node/tasks/cancel/CancelTasksRequest.java | 21 + .../cancel/CancelTasksRequestBuilder.java | 4 + .../cancel/TransportCancelTasksAction.java | 155 ++---- .../action/support/TransportAction.java | 19 +- .../admin/cluster/RestCancelTasksAction.java | 1 + .../elasticsearch/tasks/CancellableTask.java | 17 +- .../org/elasticsearch/tasks/TaskManager.java | 203 ++++++-- .../transport/TransportService.java | 29 ++ .../node/tasks/CancellableTasksIT.java | 475 ++++++++++++++++++ .../node/tasks/CancellableTasksTests.java | 178 +++++-- 17 files changed, 934 insertions(+), 225 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/TasksRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/TasksRequestConverters.java index 9099e8a854121..6cf2b615a43cb 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/TasksRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/TasksRequestConverters.java @@ -39,6 +39,9 @@ static Request cancelTasks(CancelTasksRequest req) { params .withNodes(req.getNodes()) .withActions(req.getActions()); + if (req.getWaitForCompletion() != null) { + params.withWaitForCompletion(req.getWaitForCompletion()); + } request.addParameters(params.asMap()); return request; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/tasks/CancelTasksRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/tasks/CancelTasksRequest.java index 9677c72928195..15eb7cfd20f71 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/tasks/CancelTasksRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/tasks/CancelTasksRequest.java @@ -33,6 +33,7 @@ public class CancelTasksRequest implements Validatable { private Optional timeout = Optional.empty(); private Optional parentTaskId = Optional.empty(); private Optional taskId = Optional.empty(); + private Boolean waitForCompletion; CancelTasksRequest(){} @@ -76,6 +77,14 @@ public Optional getTaskId() { return taskId; } + public Boolean getWaitForCompletion() { + return waitForCompletion; + } + + public void setWaitForCompletion(boolean waitForCompletion) { + this.waitForCompletion = waitForCompletion; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -85,12 +94,13 @@ public boolean equals(Object o) { Objects.equals(getActions(), that.getActions()) && Objects.equals(getTimeout(), that.getTimeout()) && Objects.equals(getParentTaskId(), that.getParentTaskId()) && - Objects.equals(getTaskId(), that.getTaskId()) ; + Objects.equals(getTaskId(), that.getTaskId()) && + Objects.equals(waitForCompletion, that.waitForCompletion); } @Override public int hashCode() { - return Objects.hash(getNodes(), getActions(), getTimeout(), getParentTaskId(), getTaskId()); + return Objects.hash(getNodes(), getActions(), getTimeout(), getParentTaskId(), getTaskId(), waitForCompletion); } @Override @@ -101,6 +111,7 @@ public String toString() { ", timeout=" + timeout + ", parentTaskId=" + parentTaskId + ", taskId=" + taskId + + ", waitForCompletion=" + waitForCompletion + '}'; } @@ -110,6 +121,7 @@ public static class Builder { private Optional parentTaskId = Optional.empty(); private List actionsFilter = new ArrayList<>(); private List nodesFilter = new ArrayList<>(); + private Boolean waitForCompletion; public Builder withTimeout(TimeValue timeout){ this.timeout = Optional.of(timeout); @@ -138,6 +150,11 @@ public Builder withNodesFiltered(List nodes){ return this; } + public Builder withWaitForCompletion(boolean waitForCompletion) { + this.waitForCompletion = waitForCompletion; + return this; + } + public CancelTasksRequest build() { CancelTasksRequest request = new CancelTasksRequest(); timeout.ifPresent(request::setTimeout); @@ -145,6 +162,9 @@ public CancelTasksRequest build() { parentTaskId.ifPresent(request::setParentTaskId); request.setNodes(nodesFilter); request.setActions(actionsFilter); + if (waitForCompletion != null) { + request.setWaitForCompletion(waitForCompletion); + } return request; } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/TasksRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/TasksRequestConvertersTests.java index 054cd6a4ea009..3802383164960 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/TasksRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/TasksRequestConvertersTests.java @@ -22,6 +22,7 @@ import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequest; +import org.elasticsearch.client.tasks.CancelTasksRequest; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; @@ -40,14 +41,15 @@ public void testCancelTasks() { new org.elasticsearch.client.tasks.TaskId(randomAlphaOfLength(5), randomNonNegativeLong()); org.elasticsearch.client.tasks.TaskId parentTaskId = new org.elasticsearch.client.tasks.TaskId(randomAlphaOfLength(5), randomNonNegativeLong()); - org.elasticsearch.client.tasks.CancelTasksRequest request = - new org.elasticsearch.client.tasks.CancelTasksRequest.Builder() - .withTaskId(taskId) - .withParentTaskId(parentTaskId) - .build(); + CancelTasksRequest.Builder builder = new CancelTasksRequest.Builder().withTaskId(taskId).withParentTaskId(parentTaskId); expectedParams.put("task_id", taskId.toString()); expectedParams.put("parent_task_id", parentTaskId.toString()); - Request httpRequest = TasksRequestConverters.cancelTasks(request); + if (randomBoolean()) { + boolean waitForCompletion = randomBoolean(); + builder.withWaitForCompletion(waitForCompletion); + expectedParams.put("wait_for_completion", Boolean.toString(waitForCompletion)); + } + Request httpRequest = TasksRequestConverters.cancelTasks(builder.build()); assertThat(httpRequest, notNullValue()); assertThat(httpRequest.getMethod(), equalTo(HttpPost.METHOD_NAME)); assertThat(httpRequest.getEntity(), nullValue()); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/TasksClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/TasksClientDocumentationIT.java index dd99067968e5d..11068a72c0bed 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/TasksClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/TasksClientDocumentationIT.java @@ -166,7 +166,8 @@ public void testCancelTasks() throws IOException { // tag::cancel-tasks-request-filter CancelTasksRequest byTaskIdRequest = new org.elasticsearch.client.tasks.CancelTasksRequest.Builder() // <1> .withTaskId(new org.elasticsearch.client.tasks.TaskId("myNode",44L)) // <2> - .build(); // <3> + .withWaitForCompletion(true) // <3> + .build(); // <4> // end::cancel-tasks-request-filter } diff --git a/docs/java-rest/high-level/tasks/cancel_tasks.asciidoc b/docs/java-rest/high-level/tasks/cancel_tasks.asciidoc index 42f31322896e8..69c317efa82c4 100644 --- a/docs/java-rest/high-level/tasks/cancel_tasks.asciidoc +++ b/docs/java-rest/high-level/tasks/cancel_tasks.asciidoc @@ -22,7 +22,9 @@ include-tagged::{doc-tests}/TasksClientDocumentationIT.java[cancel-tasks-request -------------------------------------------------- <1> Cancel a task <2> Cancel only cluster-related tasks -<3> Cancel all tasks running on nodes nodeId1 and nodeId2 +<3> Should the request block until the cancellation of the task and its child tasks is completed. +Otherwise, the request can return soon after the cancellation is started. Defaults to `false`. +<4> Cancel all tasks running on nodes nodeId1 and nodeId2 ==== Synchronous Execution diff --git a/docs/reference/cluster/tasks.asciidoc b/docs/reference/cluster/tasks.asciidoc index 0dada2a98a12b..197da4a09ef56 100644 --- a/docs/reference/cluster/tasks.asciidoc +++ b/docs/reference/cluster/tasks.asciidoc @@ -232,6 +232,11 @@ list tasks command, so multiple tasks can be cancelled at the same time. For example, the following command will cancel all reindex tasks running on the nodes `nodeId1` and `nodeId2`. +`wait_for_completion`:: +(Optional, boolean) If `true`, the request blocks until the cancellation of the +task and its child tasks is completed. Otherwise, the request can return soon +after the cancellation is started. Defaults to `false`. + [source,console] -------------------------------------------------- POST _tasks/_cancel?nodes=nodeId1,nodeId2&actions=*reindex diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json b/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json index d9d681899868d..6c3d424050b50 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/tasks.cancel.json @@ -39,6 +39,10 @@ "parent_task_id":{ "type":"string", "description":"Cancel tasks with specified parent task id (node_id:task_number). Set to -1 to cancel all." + }, + "wait_for_completion": { + "type":"boolean", + "description":"Should the request block until the cancellation of the task and its child tasks is completed. Defaults to false" } } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java index 5c87b1da45d12..712f6bd65889b 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequest.java @@ -19,6 +19,7 @@ package org.elasticsearch.action.admin.cluster.node.tasks.cancel; +import org.elasticsearch.Version; import org.elasticsearch.action.support.tasks.BaseTasksRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -33,20 +34,28 @@ public class CancelTasksRequest extends BaseTasksRequest { public static final String DEFAULT_REASON = "by user request"; + public static final boolean DEFAULT_WAIT_FOR_COMPLETION = false; private String reason = DEFAULT_REASON; + private boolean waitForCompletion = DEFAULT_WAIT_FOR_COMPLETION; public CancelTasksRequest() {} public CancelTasksRequest(StreamInput in) throws IOException { super(in); this.reason = in.readString(); + if (in.getVersion().onOrAfter(Version.V_7_8_0)) { + waitForCompletion = in.readBoolean(); + } } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(reason); + if (out.getVersion().onOrAfter(Version.V_7_8_0)) { + out.writeBoolean(waitForCompletion); + } } @Override @@ -68,4 +77,16 @@ public CancelTasksRequest setReason(String reason) { public String getReason() { return reason; } + + /** + * If {@code true}, the request blocks until the cancellation of the task and its child tasks is completed. + * Otherwise, the request can return soon after the cancellation is started. Defaults to {@code false}. + */ + public void setWaitForCompletion(boolean waitForCompletion) { + this.waitForCompletion = waitForCompletion; + } + + public boolean waitForCompletion() { + return waitForCompletion; + } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequestBuilder.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequestBuilder.java index 982ebc38e490d..4fc722125b8e7 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequestBuilder.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/CancelTasksRequestBuilder.java @@ -31,4 +31,8 @@ public CancelTasksRequestBuilder(ElasticsearchClient client, CancelTasksAction a super(client, action, new CancelTasksRequest()); } + public CancelTasksRequestBuilder waitForCompletion(boolean waitForCompletion) { + request.setWaitForCompletion(waitForCompletion); + return this; + } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java index ba009644dcd74..b2130dae9e568 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/tasks/cancel/TransportCancelTasksAction.java @@ -19,15 +19,15 @@ package org.elasticsearch.action.admin.cluster.node.tasks.cancel; -import com.carrotsearch.hppc.cursors.ObjectObjectCursor; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.StepListener; import org.elasticsearch.action.TaskOperationFailure; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; @@ -46,9 +46,8 @@ import org.elasticsearch.transport.TransportService; import java.io.IOException; -import java.util.ArrayList; +import java.util.Collection; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; /** @@ -90,7 +89,7 @@ protected void processTasks(CancelTasksRequest request, Consumer listener) { + protected void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask, ActionListener listener) { String nodeId = clusterService.localNode().getId(); - final boolean canceled; if (cancellableTask.shouldCancelChildrenOnCancellation()) { - DiscoveryNodes childNodes = clusterService.state().nodes(); - final BanLock banLock = new BanLock(childNodes.getSize(), () -> removeBanOnNodes(cancellableTask, childNodes)); - canceled = taskManager.cancel(cancellableTask, request.getReason(), banLock::onTaskFinished); - if (canceled) { - // /In case the task has some child tasks, we need to wait for until ban is set on all nodes - logger.trace("cancelling task {} on child nodes", cancellableTask.getId()); - AtomicInteger responses = new AtomicInteger(childNodes.getSize()); - List failures = new ArrayList<>(); - setBanOnNodes(request.getReason(), cancellableTask, childNodes, new ActionListener() { - @Override - public void onResponse(Void aVoid) { - processResponse(); - } - - @Override - public void onFailure(Exception e) { - synchronized (failures) { - failures.add(e); - } - processResponse(); - } - - private void processResponse() { - banLock.onBanSet(); - if (responses.decrementAndGet() == 0) { - if (failures.isEmpty() == false) { - IllegalStateException exception = new IllegalStateException("failed to cancel children of the task [" + - cancellableTask.getId() + "]"); - failures.forEach(exception::addSuppressed); - listener.onFailure(exception); - } else { - listener.onResponse(cancellableTask.taskInfo(nodeId, false)); - } - } - } - }); - } - } else { - canceled = taskManager.cancel(cancellableTask, request.getReason(), - () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false))); - if (canceled) { - logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId()); + StepListener completedListener = new StepListener<>(); + GroupedActionListener groupedListener = new GroupedActionListener<>(ActionListener.map(completedListener, r -> null), 3); + Collection childrenNodes = + taskManager.startBanOnChildrenNodes(cancellableTask.getId(), () -> groupedListener.onResponse(null)); + taskManager.cancel(cancellableTask, request.getReason(), () -> groupedListener.onResponse(null)); + + StepListener banOnNodesListener = new StepListener<>(); + setBanOnNodes(request.getReason(), cancellableTask, childrenNodes, banOnNodesListener); + banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure); + // We remove bans after all child tasks are completed although in theory we can do it on a per-node basis. + completedListener.whenComplete( + r -> removeBanOnNodes(cancellableTask, childrenNodes), + e -> removeBanOnNodes(cancellableTask, childrenNodes)); + // if wait_for_child_tasks is true, then only return when (1) bans are placed on child nodes, (2) child tasks are + // completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes. + if (request.waitForCompletion()) { + completedListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure); + } else { + banOnNodesListener.whenComplete(r -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)), listener::onFailure); } - } - if (canceled == false) { - logger.trace("task {} is already cancelled", cancellableTask.getId()); - throw new IllegalStateException("task with id " + cancellableTask.getId() + " is already cancelled"); + } else { + logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId()); + taskManager.cancel(cancellableTask, request.getReason(), () -> listener.onResponse(cancellableTask.taskInfo(nodeId, false))); } } - private void setBanOnNodes(String reason, CancellableTask task, DiscoveryNodes nodes, ActionListener listener) { - sendSetBanRequest(nodes, - BanParentTaskRequest.createSetBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()), reason), - listener); - } - - private void removeBanOnNodes(CancellableTask task, DiscoveryNodes nodes) { - sendRemoveBanRequest(nodes, - BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()))); - } - - private void sendSetBanRequest(DiscoveryNodes nodes, BanParentTaskRequest request, ActionListener listener) { - for (ObjectObjectCursor node : nodes.getNodes()) { - logger.trace("Sending ban for tasks with the parent [{}] to the node [{}], ban [{}]", request.parentTaskId, node.key, - request.ban); - transportService.sendRequest(node.value, BAN_PARENT_ACTION_NAME, request, + private void setBanOnNodes(String reason, CancellableTask task, Collection childNodes, ActionListener listener) { + if (childNodes.isEmpty()) { + listener.onResponse(null); + return; + } + logger.trace("cancelling task {} on child nodes {}", task.getId(), childNodes); + GroupedActionListener groupedListener = + new GroupedActionListener<>(ActionListener.map(listener, r -> null), childNodes.size()); + final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest( + new TaskId(clusterService.localNode().getId(), task.getId()), reason); + for (DiscoveryNode node : childNodes) { + transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) { @Override public void handleResponse(TransportResponse.Empty response) { - listener.onResponse(null); + groupedListener.onResponse(null); } @Override public void handleException(TransportException exp) { - logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.key); - listener.onFailure(exp); + logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", banRequest.parentTaskId, node); + groupedListener.onFailure(exp); } }); } } - private void sendRemoveBanRequest(DiscoveryNodes nodes, BanParentTaskRequest request) { - for (ObjectObjectCursor node : nodes.getNodes()) { - logger.debug("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.key); - transportService.sendRequest(node.value, BAN_PARENT_ACTION_NAME, request, EmptyTransportResponseHandler - .INSTANCE_SAME); - } - } - - private static class BanLock { - private final Runnable finish; - private final AtomicInteger counter; - private final int nodesSize; - - BanLock(int nodesSize, Runnable finish) { - counter = new AtomicInteger(0); - this.finish = finish; - this.nodesSize = nodesSize; - } - - public void onBanSet() { - if (counter.decrementAndGet() == 0) { - finish(); - } - } - - public void onTaskFinished() { - if (counter.addAndGet(nodesSize) == 0) { - finish(); - } + private void removeBanOnNodes(CancellableTask task, Collection childNodes) { + final BanParentTaskRequest request = + BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId())); + for (DiscoveryNode node : childNodes) { + logger.trace("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node); + transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, request, EmptyTransportResponseHandler.INSTANCE_SAME); } - - public void finish() { - finish.run(); - } - } private static class BanParentTaskRequest extends TransportRequest { diff --git a/server/src/main/java/org/elasticsearch/action/support/TransportAction.java b/server/src/main/java/org/elasticsearch/action/support/TransportAction.java index abba923f72a41..c99e19a83afad 100644 --- a/server/src/main/java/org/elasticsearch/action/support/TransportAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/TransportAction.java @@ -24,7 +24,10 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskListener; import org.elasticsearch.tasks.TaskManager; @@ -47,6 +50,14 @@ protected TransportAction(String actionName, ActionFilters actionFilters, TaskMa this.taskManager = taskManager; } + private Releasable registerChildNode(TaskId parentTask) { + if (parentTask.isSet()) { + return taskManager.registerChildNode(parentTask.getId(), taskManager.localNode()); + } else { + return () -> {}; + } + } + /** * Use this method when the transport action call should result in creation of a new task associated with the call. * @@ -60,12 +71,14 @@ public final Task execute(Request request, ActionListener listener) { * task. That just seems like too many objects. Thus the two versions of * this method. */ + final Releasable unregisterChildNode = registerChildNode(request.getParentTask()); Task task = taskManager.register("transport", actionName, request); execute(task, request, new ActionListener() { @Override public void onResponse(Response response) { try { taskManager.unregister(task); + unregisterChildNode.close(); } finally { listener.onResponse(response); } @@ -74,7 +87,7 @@ public void onResponse(Response response) { @Override public void onFailure(Exception e) { try { - taskManager.unregister(task); + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); } finally { listener.onFailure(e); } @@ -88,12 +101,14 @@ public void onFailure(Exception e) { * {@link TaskListener} which listens for the completion of the action. */ public final Task execute(Request request, TaskListener listener) { + final Releasable unregisterChildNode = registerChildNode(request.getParentTask()); Task task = taskManager.register("transport", actionName, request); execute(task, request, new ActionListener() { @Override public void onResponse(Response response) { try { taskManager.unregister(task); + unregisterChildNode.close(); } finally { listener.onResponse(task, response); } @@ -102,7 +117,7 @@ public void onResponse(Response response) { @Override public void onFailure(Exception e) { try { - taskManager.unregister(task); + Releasables.close(unregisterChildNode, () -> taskManager.unregister(task)); } finally { listener.onFailure(task, e); } diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestCancelTasksAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestCancelTasksAction.java index 99321627cf7fb..10dd78361c8f2 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestCancelTasksAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/cluster/RestCancelTasksAction.java @@ -68,6 +68,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC cancelTasksRequest.setNodes(nodesIds); cancelTasksRequest.setActions(actions); cancelTasksRequest.setParentTaskId(parentTaskId); + cancelTasksRequest.setWaitForCompletion(request.paramAsBoolean("wait_for_completion", cancelTasksRequest.waitForCompletion())); return channel -> client.admin().cluster().cancelTasks(cancelTasksRequest, listTasksResponseListener(nodesInCluster, groupBy, channel)); } diff --git a/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java b/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java index 1d43076305ccd..57c09d7f22723 100644 --- a/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java +++ b/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java @@ -22,14 +22,15 @@ import org.elasticsearch.common.Nullable; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicBoolean; /** * A task that can be canceled */ public abstract class CancellableTask extends Task { - private final AtomicReference reason = new AtomicReference<>(); + private volatile String reason; + private final AtomicBoolean cancelled = new AtomicBoolean(false); public CancellableTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { super(id, type, action, description, parentTaskId, headers); @@ -40,8 +41,10 @@ public CancellableTask(long id, String type, String action, String description, */ final void cancel(String reason) { assert reason != null; - this.reason.compareAndSet(null, reason); - onCancelled(); + if (cancelled.compareAndSet(false, true)) { + this.reason = reason; + onCancelled(); + } } /** @@ -58,15 +61,15 @@ public boolean cancelOnParentLeaving() { public abstract boolean shouldCancelChildrenOnCancellation(); public boolean isCancelled() { - return reason.get() != null; + return cancelled.get(); } /** * The reason the task was cancelled or null if it hasn't been cancelled. */ @Nullable - public String getReasonCancelled() { - return reason.get(); + public final String getReasonCancelled() { + return reason; } /** diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index 92c86a04cdb5d..adc3ce6661009 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -19,6 +19,8 @@ package org.elasticsearch.tasks; +import com.carrotsearch.hppc.ObjectIntHashMap; +import com.carrotsearch.hppc.ObjectIntMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -31,6 +33,8 @@ import org.elasticsearch.cluster.ClusterStateApplier; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; @@ -41,6 +45,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -50,6 +55,9 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import static org.elasticsearch.common.unit.TimeValue.timeValueMillis; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEADER_SIZE; @@ -78,7 +86,7 @@ public class TaskManager implements ClusterStateApplier { private TaskResultsService taskResultsService; - private DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES; + private volatile DiscoveryNodes lastDiscoveryNodes = DiscoveryNodes.EMPTY_NODES; private final ByteSizeValue maxHeaderSize; @@ -132,13 +140,14 @@ private void registerCancellableTask(Task task) { CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask); CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder); assert oldHolder == null; - // Check if this task was banned before we start it + // Check if this task was banned before we start it. The empty check is used to avoid + // computing the hash code of the parent taskId as most of the time banedParents is empty. if (task.getParentTaskId().isSet() && banedParents.isEmpty() == false) { String reason = banedParents.get(task.getParentTaskId()); if (reason != null) { try { holder.cancel(reason); - throw new IllegalStateException("Task cancelled before it started: " + reason); + throw new TaskCancelledException("Task cancelled before it started: " + reason); } finally { // let's clean up the registration unregister(task); @@ -150,18 +159,18 @@ private void registerCancellableTask(Task task) { /** * Cancels a task *

- * Returns true if cancellation was started successful, null otherwise. - * * After starting cancellation on the parent task, the task manager tries to cancel all children tasks * of the current task. Once cancellation of the children tasks is done, the listener is triggered. + * If the task is completed or unregistered from TaskManager, then the listener is called immediately. */ - public boolean cancel(CancellableTask task, String reason, Runnable listener) { + public void cancel(CancellableTask task, String reason, Runnable listener) { CancellableTaskHolder holder = cancellableTasks.get(task.getId()); if (holder != null) { logger.trace("cancelling task with id {}", task.getId()); - return holder.cancel(reason, listener); + holder.cancel(reason, listener); + } else { + listener.run(); } - return false; } /** @@ -182,6 +191,23 @@ public Task unregister(Task task) { } } + /** + * Register a node on which a child task will execute. The returned {@link Releasable} must be called + * to unregister the child node once the child task is completed or failed. + */ + public Releasable registerChildNode(long taskId, DiscoveryNode node) { + final CancellableTaskHolder holder = cancellableTasks.get(taskId); + if (holder != null) { + holder.registerChildNode(node); + return Releasables.releaseOnce(() -> holder.unregisterChildNode(node)); + } + return () -> {}; + } + + public DiscoveryNode localNode() { + return lastDiscoveryNodes.getLocalNode(); + } + /** * Stores the task failure */ @@ -339,6 +365,31 @@ public void removeBan(TaskId parentTaskId) { banedParents.remove(parentTaskId); } + // for testing + public boolean childTasksCancelledOrBanned(TaskId parentTaskId) { + if (banedParents.containsKey(parentTaskId)) { + return true; + } + return cancellableTasks.values().stream().noneMatch(task -> task.hasParent(parentTaskId)); + } + + /** + * Start rejecting new child requests as the parent task was cancelled. + * + * @param taskId the parent task id + * @param onChildTasksCompleted called when all child tasks are completed or failed + * @return the set of current nodes that have outstanding child tasks + */ + public Collection startBanOnChildrenNodes(long taskId, Runnable onChildTasksCompleted) { + final CancellableTaskHolder holder = cancellableTasks.get(taskId); + if (holder != null) { + return holder.startBan(onChildTasksCompleted); + } else { + onChildTasksCompleted.run(); + return Collections.emptySet(); + } + } + @Override public void applyClusterState(ClusterChangedEvent event) { lastDiscoveryNodes = event.state().getNodes(); @@ -388,74 +439,76 @@ public void waitForTaskCompletion(Task task, long untilInNanos) { } private static class CancellableTaskHolder { - - private static final String TASK_FINISHED_MARKER = "task finished"; - private final CancellableTask task; - - private volatile String cancellationReason = null; - - private volatile Runnable cancellationListener = null; + private boolean finished = false; + private List cancellationListeners = null; + private ObjectIntMap childTasksPerNode = null; + private boolean banChildren = false; + private List childTaskCompletedListeners = null; CancellableTaskHolder(CancellableTask task) { this.task = task; } - /** - * Marks task as cancelled. - *

- * Returns true if cancellation was successful, false otherwise. - */ - public boolean cancel(String reason, Runnable listener) { - final boolean cancelled; + void cancel(String reason, Runnable listener) { + final Runnable toRun; synchronized (this) { - assert reason != null; - if (cancellationReason == null) { - cancellationReason = reason; - cancellationListener = listener; - cancelled = true; + if (finished) { + assert cancellationListeners == null; + toRun = listener; } else { - // Already cancelled by somebody else - cancelled = false; + toRun = () -> {}; + if (listener != null) { + if (cancellationListeners == null) { + cancellationListeners = new ArrayList<>(); + } + cancellationListeners.add(listener); + } } } - if (cancelled) { + try { task.cancel(reason); + } finally { + if (toRun != null) { + toRun.run(); + } } - return cancelled; } - /** - * Marks task as cancelled. - *

- * Returns true if cancellation was successful, false otherwise. - */ - public boolean cancel(String reason) { - return cancel(reason, null); + void cancel(String reason) { + task.cancel(reason); } /** * Marks task as finished. */ public void finish() { - Runnable listener = null; + final List listeners; synchronized (this) { - if (cancellationReason != null) { - // The task was cancelled, we need to notify the listener - if (cancellationListener != null) { - listener = cancellationListener; - cancellationListener = null; - } + this.finished = true; + if (cancellationListeners != null) { + listeners = cancellationListeners; + cancellationListeners = null; } else { - cancellationReason = TASK_FINISHED_MARKER; + listeners = Collections.emptyList(); } } // We need to call the listener outside of the synchronised section to avoid potential bottle necks // in the listener synchronization - if (listener != null) { - listener.run(); - } + notifyListeners(listeners); + } + private void notifyListeners(List listeners) { + assert Thread.holdsLock(this) == false; + Exception rootException = null; + for (Runnable listener : listeners) { + try { + listener.run(); + } catch (RuntimeException inner) { + rootException = ExceptionsHelper.useOrSuppress(rootException, inner); + } + } + ExceptionsHelper.reThrowIfNotNull(rootException); } public boolean hasParent(TaskId parentTaskId) { @@ -465,6 +518,58 @@ public boolean hasParent(TaskId parentTaskId) { public CancellableTask getTask() { return task; } + + synchronized void registerChildNode(DiscoveryNode node) { + if (banChildren) { + throw new TaskCancelledException("The parent task was cancelled, shouldn't start any child tasks"); + } + if (childTasksPerNode == null) { + childTasksPerNode = new ObjectIntHashMap<>(); + } + childTasksPerNode.addTo(node, 1); + } + + void unregisterChildNode(DiscoveryNode node) { + final List listeners; + synchronized (this) { + if (childTasksPerNode.addTo(node, -1) == 0) { + childTasksPerNode.remove(node); + } + if (childTasksPerNode.isEmpty() && this.childTaskCompletedListeners != null) { + listeners = childTaskCompletedListeners; + childTaskCompletedListeners = null; + } else { + listeners = Collections.emptyList(); + } + } + notifyListeners(listeners); + } + + Set startBan(Runnable onChildTasksCompleted) { + final Set pendingChildNodes; + final Runnable toRun; + synchronized (this) { + banChildren = true; + if (childTasksPerNode == null) { + pendingChildNodes = Collections.emptySet(); + } else { + pendingChildNodes = StreamSupport.stream(childTasksPerNode.spliterator(), false) + .map(e -> e.key).collect(Collectors.toSet()); + } + if (pendingChildNodes.isEmpty()) { + assert childTaskCompletedListeners == null; + toRun = onChildTasksCompleted; + } else { + toRun = () -> {}; + if (childTaskCompletedListeners == null) { + childTaskCompletedListeners = new ArrayList<>(); + } + childTaskCompletedListeners.add(onChildTasksCompleted); + } + } + toRun.run(); + return pendingChildNodes; + } } } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index 66ef8dd663df2..8da2c1380b64d 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -36,6 +36,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.ClusterSettings; @@ -618,6 +619,34 @@ public final void sendRequest(final Transport.Conn final TransportRequestOptions options, TransportResponseHandler handler) { try { + if (request.getParentTask().isSet()) { + // TODO: capture the connection instead so that we can cancel child tasks on the remote connections. + final Releasable unregisterChildNode = taskManager.registerChildNode(request.getParentTask().getId(), connection.getNode()); + final TransportResponseHandler delegate = handler; + handler = new TransportResponseHandler() { + @Override + public void handleResponse(T response) { + unregisterChildNode.close(); + delegate.handleResponse(response); + } + + @Override + public void handleException(TransportException exp) { + unregisterChildNode.close(); + delegate.handleException(exp); + } + + @Override + public String executor() { + return delegate.executor(); + } + + @Override + public T read(StreamInput in) throws IOException { + return delegate.read(in); + } + }; + } asyncSender.sendRequest(connection, action, request, options, handler); } catch (final Exception ex) { // the caller might not handle this so we invoke the handler diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java new file mode 100644 index 0000000000000..ee81e9f3f9439 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksIT.java @@ -0,0 +1,475 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.action.admin.cluster.node.tasks; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.LatchedActionListener; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportException; +import org.elasticsearch.transport.TransportResponse; +import org.elasticsearch.transport.TransportResponseHandler; +import org.elasticsearch.transport.TransportService; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.either; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; + +public class CancellableTasksIT extends ESIntegTestCase { + static final Map arrivedLatches = ConcurrentCollections.newConcurrentMap(); + static final Map beforeExecuteLatches = ConcurrentCollections.newConcurrentMap(); + static final Map completedLatches = ConcurrentCollections.newConcurrentMap(); + + @Before + public void resetTestStates() { + arrivedLatches.clear(); + beforeExecuteLatches.clear(); + completedLatches.clear(); + } + + List setupChildRequests(Set nodes) { + int numRequests = randomIntBetween(1, 30); + List childRequests = new ArrayList<>(); + for (int i = 0; i < numRequests; i++) { + ChildRequest req = new ChildRequest(i, randomFrom(nodes)); + childRequests.add(req); + arrivedLatches.put(req, new CountDownLatch(1)); + beforeExecuteLatches.put(req, new CountDownLatch(1)); + completedLatches.put(req, new CountDownLatch(1)); + } + return childRequests; + } + + public void testBanOnlyNodesWithOutstandingChildTasks() throws Exception { + if (randomBoolean()) { + internalCluster().startNodes(randomIntBetween(1, 3)); + } + Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); + List childRequests = setupChildRequests(nodes); + ActionFuture mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); + List completedRequests = randomSubsetOf(between(0, childRequests.size() - 1), childRequests); + for (ChildRequest req : completedRequests) { + beforeExecuteLatches.get(req).countDown(); + completedLatches.get(req).await(); + } + List outstandingRequests = childRequests.stream(). + filter(r -> completedRequests.contains(r) == false) + .collect(Collectors.toList()); + for (ChildRequest req : outstandingRequests) { + arrivedLatches.get(req).await(); + } + TaskId taskId = getMainTaskId(); + ActionFuture cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) + .waitForCompletion(true).execute(); + Set nodesWithOutstandingChildTask = outstandingRequests.stream().map(r -> r.targetNode).collect(Collectors.toSet()); + assertBusy(() -> { + for (DiscoveryNode node : nodes) { + TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager(); + if (nodesWithOutstandingChildTask.contains(node)) { + assertThat(taskManager.getBanCount(), equalTo(1)); + } else { + assertThat(taskManager.getBanCount(), equalTo(0)); + } + } + }); + // failed to spawn child tasks after cancelling + if (randomBoolean()) { + DiscoveryNode nodeWithParentTask = nodes.stream().filter(n -> n.getId().equals(taskId.getNodeId())).findFirst().get(); + TransportMainAction mainAction = internalCluster().getInstance(TransportMainAction.class, nodeWithParentTask.getName()); + PlainActionFuture future = new PlainActionFuture<>(); + ChildRequest req = new ChildRequest(-1, randomFrom(nodes)); + completedLatches.put(req, new CountDownLatch(1)); + mainAction.startChildTask(taskId, req, future); + TransportException te = expectThrows(TransportException.class, future::actionGet); + assertThat(te.getCause(), instanceOf(TaskCancelledException.class)); + assertThat(te.getCause().getMessage(), equalTo("The parent task was cancelled, shouldn't start any child tasks")); + } + for (ChildRequest req : outstandingRequests) { + beforeExecuteLatches.get(req).countDown(); + } + cancelFuture.actionGet(); + waitForMainTask(mainTaskFuture); + assertBusy(() -> { + for (DiscoveryNode node : nodes) { + TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager(); + assertThat(taskManager.getBanCount(), equalTo(0)); + } + }); + } + + public void testCancelTaskMultipleTimes() throws Exception { + Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); + List childRequests = setupChildRequests(nodes); + ActionFuture mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); + for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) { + arrivedLatches.get(r).await(); + } + TaskId taskId = getMainTaskId(); + ActionFuture cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) + .waitForCompletion(true).execute(); + ensureChildTasksCancelledOrBanned(taskId); + if (randomBoolean()) { + CancelTasksResponse resp = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).waitForCompletion(false).get(); + assertThat(resp.getTaskFailures(), empty()); + assertThat(resp.getNodeFailures(), empty()); + } + assertFalse(cancelFuture.isDone()); + for (ChildRequest r : childRequests) { + beforeExecuteLatches.get(r).countDown(); + } + assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); + assertThat(cancelFuture.actionGet().getTaskFailures(), empty()); + waitForMainTask(mainTaskFuture); + CancelTasksResponse cancelError = client().admin().cluster().prepareCancelTasks() + .setTaskId(taskId).waitForCompletion(randomBoolean()).get(); + assertThat(cancelError.getNodeFailures(), hasSize(1)); + final Throwable notFound = ExceptionsHelper.unwrap(cancelError.getNodeFailures().get(0), ResourceNotFoundException.class); + assertThat(notFound.getMessage(), equalTo("task [" + taskId + "] is not found")); + } + + public void testDoNotWaitForCompletion() throws Exception { + Set nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet()); + List childRequests = setupChildRequests(nodes); + ActionFuture mainTaskFuture = client().execute(TransportMainAction.ACTION, new MainRequest(childRequests)); + for (ChildRequest r : randomSubsetOf(between(1, childRequests.size()), childRequests)) { + arrivedLatches.get(r).await(); + } + TaskId taskId = getMainTaskId(); + boolean waitForCompletion = randomBoolean(); + ActionFuture cancelFuture = client().admin().cluster().prepareCancelTasks().setTaskId(taskId) + .waitForCompletion(waitForCompletion).execute(); + if (waitForCompletion) { + assertFalse(cancelFuture.isDone()); + } else { + assertBusy(() -> assertTrue(cancelFuture.isDone())); + } + for (ChildRequest r : childRequests) { + beforeExecuteLatches.get(r).countDown(); + } + waitForMainTask(mainTaskFuture); + } + + TaskId getMainTaskId() { + ListTasksResponse listTasksResponse = client().admin().cluster().prepareListTasks() + .setActions(TransportMainAction.ACTION.name()).setDetailed(true).get(); + assertThat(listTasksResponse.getTasks(), hasSize(1)); + return listTasksResponse.getTasks().get(0).getTaskId(); + } + + void waitForMainTask(ActionFuture mainTask) { + try { + mainTask.actionGet(); + } catch (Exception e) { + final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class); + assertThat(cause.getMessage(), + either(equalTo("The parent task was cancelled, shouldn't start any child tasks")) + .or(containsString("Task cancelled before it started:"))); + } + } + + public static class MainRequest extends ActionRequest { + final List childRequests; + + public MainRequest(List childRequests) { + this.childRequests = childRequests; + } + + public MainRequest(StreamInput in) throws IOException { + super(in); + this.childRequests = in.readList(ChildRequest::new); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeList(childRequests); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + }; + } + } + + public static class MainResponse extends ActionResponse { + public MainResponse() { + } + + public MainResponse(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + } + + public static class ChildRequest extends ActionRequest { + final int id; + final DiscoveryNode targetNode; + + public ChildRequest(int id, DiscoveryNode targetNode) { + this.id = id; + this.targetNode = targetNode; + } + + + public ChildRequest(StreamInput in) throws IOException { + super(in); + this.id = in.readInt(); + this.targetNode = new DiscoveryNode(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeInt(id); + targetNode.writeTo(out); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public String getDescription() { + return "childTask[" + id + "]"; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + if (randomBoolean()) { + boolean shouldCancelChildrenOnCancellation = randomBoolean(); + return new CancellableTask(id, type, action, getDescription(), parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return shouldCancelChildrenOnCancellation; + } + }; + } else { + return super.createTask(id, type, action, parentTaskId, headers); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ChildRequest that = (ChildRequest) o; + return id == that.id && targetNode.equals(that.targetNode); + } + + @Override + public int hashCode() { + return Objects.hash(id, targetNode); + } + } + + public static class ChildResponse extends ActionResponse { + public ChildResponse() { + } + + public ChildResponse(StreamInput in) throws IOException { + super(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + } + + public static class TransportMainAction extends HandledTransportAction { + + public static ActionType ACTION = new ActionType<>("internal::main_action", MainResponse::new); + private final TransportService transportService; + private final NodeClient client; + + @Inject + public TransportMainAction(TransportService transportService, NodeClient client, ActionFilters actionFilters) { + super(ACTION.name(), transportService, actionFilters, MainRequest::new, ThreadPool.Names.GENERIC); + this.transportService = transportService; + this.client = client; + } + + @Override + protected void doExecute(Task task, MainRequest request, ActionListener listener) { + GroupedActionListener groupedListener = + new GroupedActionListener<>(ActionListener.map(listener, r -> new MainResponse()), request.childRequests.size()); + for (ChildRequest childRequest : request.childRequests) { + TaskId parentTaskId = new TaskId(client.getLocalNodeId(), task.getId()); + startChildTask(parentTaskId, childRequest, groupedListener); + } + } + + protected void startChildTask(TaskId parentTaskId, ChildRequest childRequest, ActionListener listener) { + childRequest.setParentTask(parentTaskId); + final CountDownLatch completeLatch = completedLatches.get(childRequest); + LatchedActionListener latchedListener = new LatchedActionListener<>(listener, completeLatch); + transportService.getThreadPool().generic().execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + protected void doRun() { + if (client.getLocalNodeId().equals(childRequest.targetNode.getId()) && randomBoolean()) { + try { + client.executeLocally(TransportChildAction.ACTION, childRequest, latchedListener); + } catch (TaskCancelledException e) { + latchedListener.onFailure(new TransportException(e)); + } + } else { + transportService.sendRequest(childRequest.targetNode, TransportChildAction.ACTION.name(), childRequest, + new TransportResponseHandler() { + + @Override + public void handleResponse(ChildResponse response) { + latchedListener.onResponse(new ChildResponse()); + } + + @Override + public void handleException(TransportException exp) { + latchedListener.onFailure(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public ChildResponse read(StreamInput in) throws IOException { + return new ChildResponse(in); + } + }); + } + } + }); + } + } + + public static class TransportChildAction extends HandledTransportAction { + public static ActionType ACTION = new ActionType<>("internal:child_action", ChildResponse::new); + private final TransportService transportService; + + + @Inject + public TransportChildAction(TransportService transportService, ActionFilters actionFilters) { + super(ACTION.name(), transportService, actionFilters, ChildRequest::new, ThreadPool.Names.GENERIC); + this.transportService = transportService; + } + + @Override + protected void doExecute(Task task, ChildRequest request, ActionListener listener) { + assertThat(request.targetNode, equalTo(transportService.getLocalNode())); + arrivedLatches.get(request).countDown(); + transportService.getThreadPool().executor(ThreadPool.Names.GENERIC).execute(ActionRunnable.supply(listener, () -> { + beforeExecuteLatches.get(request).await(); + return new ChildResponse(); + })); + } + } + + public static class TaskPlugin extends Plugin implements ActionPlugin { + @Override + public List> getActions() { + return Arrays.asList( + new ActionHandler<>(TransportMainAction.ACTION, TransportMainAction.class), + new ActionHandler<>(TransportChildAction.ACTION, TransportChildAction.class) + ); + } + } + + @Override + protected Collection> nodePlugins() { + final List> plugins = new ArrayList<>(super.nodePlugins()); + plugins.add(TaskPlugin.class); + return plugins; + } + + /** + * Ensures that all outstanding child tasks of the given parent task are banned or being cancelled. + */ + protected static void ensureChildTasksCancelledOrBanned(TaskId taskId) throws Exception { + assertBusy(() -> { + for (String nodeName : internalCluster().getNodeNames()) { + final TaskManager taskManager = internalCluster().getInstance(TransportService.class, nodeName).getTaskManager(); + assertTrue(taskManager.childTasksCancelledOrBanned(taskId)); + } + }); + } +} diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java index 03fa65fde4f6a..f7be1080c2df0 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java @@ -21,6 +21,7 @@ import com.carrotsearch.randomizedtesting.RandomizedContext; import com.carrotsearch.randomizedtesting.generators.RandomNumbers; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequest; @@ -39,16 +40,21 @@ import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Phaser; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.test.ClusterServiceUtils.setState; @@ -189,19 +195,24 @@ protected NodeResponse nodeOperation(CancellableNodeRequest request) { } } - private Task startCancellableTestNodesAction(boolean waitForActionToStart, int blockedNodesCount, ActionListener - listener) throws InterruptedException { - return startCancellableTestNodesAction(waitForActionToStart, randomSubsetOf(blockedNodesCount, testNodes), new - CancellableNodesRequest("Test Request"), listener); + private Task startCancellableTestNodesAction(boolean waitForActionToStart, int runNodesCount, int blockedNodesCount, + ActionListener listener) throws InterruptedException { + List runOnNodes = randomSubsetOf(runNodesCount, testNodes); + + return startCancellableTestNodesAction(waitForActionToStart, runOnNodes, randomSubsetOf(blockedNodesCount, runOnNodes), new + CancellableNodesRequest("Test Request",runOnNodes.stream().map(TestNode::getNodeId).toArray(String[]::new)), listener); } - private Task startCancellableTestNodesAction(boolean waitForActionToStart, Collection blockOnNodes, CancellableNodesRequest - request, ActionListener listener) throws InterruptedException { - CountDownLatch actionLatch = waitForActionToStart ? new CountDownLatch(nodesCount) : null; + private Task startCancellableTestNodesAction(boolean waitForActionToStart, List runOnNodes, + Collection blockOnNodes, CancellableNodesRequest + request, ActionListener listener) throws InterruptedException { + CountDownLatch actionLatch = waitForActionToStart ? new CountDownLatch(runOnNodes.size()) : null; CancellableTestNodesAction[] actions = new CancellableTestNodesAction[nodesCount]; for (int i = 0; i < testNodes.length; i++) { boolean shouldBlock = blockOnNodes.contains(testNodes[i]); - logger.info("The action in the node [{}] should block: [{}]", testNodes[i].getNodeId(), shouldBlock); + boolean shouldRun = runOnNodes.contains(testNodes[i]); + logger.info("The action on the node [{}] should run: [{}] should block: [{}]", testNodes[i].getNodeId(), shouldRun, + shouldBlock); actions[i] = new CancellableTestNodesAction("internal:testAction", threadPool, testNodes[i] .clusterService, testNodes[i].transportService, shouldBlock, actionLatch); } @@ -222,20 +233,22 @@ public void testBasicTaskCancellation() throws Exception { logger.info("waitForActionToStart is set to {}", waitForActionToStart); final AtomicReference responseReference = new AtomicReference<>(); final AtomicReference throwableReference = new AtomicReference<>(); - int blockedNodesCount = randomIntBetween(0, nodesCount); - Task mainTask = startCancellableTestNodesAction(waitForActionToStart, blockedNodesCount, new ActionListener() { - @Override - public void onResponse(NodesResponse listTasksResponse) { - responseReference.set(listTasksResponse); - responseLatch.countDown(); - } + int runNodesCount = randomIntBetween(1, nodesCount); + int blockedNodesCount = randomIntBetween(0, runNodesCount); + Task mainTask = startCancellableTestNodesAction(waitForActionToStart, runNodesCount, blockedNodesCount, + new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + responseLatch.countDown(); + } - @Override - public void onFailure(Exception e) { - throwableReference.set(e); - responseLatch.countDown(); - } - }); + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + responseLatch.countDown(); + } + }); // Cancel main task CancelTasksRequest request = new CancelTasksRequest(); @@ -255,12 +268,12 @@ public void onFailure(Exception e) { // Make sure that the request was successful assertNull(throwableReference.get()); assertNotNull(responseReference.get()); - assertEquals(nodesCount, responseReference.get().getNodes().size()); + assertEquals(runNodesCount, responseReference.get().getNodes().size()); assertEquals(0, responseReference.get().failureCount()); } else { // We canceled the request, in this case it should have fail, but we should get partial response assertNull(throwableReference.get()); - assertEquals(nodesCount, responseReference.get().failureCount() + responseReference.get().getNodes().size()); + assertEquals(runNodesCount, responseReference.get().failureCount() + responseReference.get().getNodes().size()); // and we should have at least as many failures as the number of blocked operations // (we might have cancelled some non-blocked operations before they even started and that's ok) assertThat(responseReference.get().failureCount(), greaterThanOrEqualTo(blockedNodesCount)); @@ -295,19 +308,23 @@ public void testChildTasksCancellation() throws Exception { CountDownLatch responseLatch = new CountDownLatch(1); final AtomicReference responseReference = new AtomicReference<>(); final AtomicReference throwableReference = new AtomicReference<>(); - Task mainTask = startCancellableTestNodesAction(true, nodesCount, new ActionListener() { - @Override - public void onResponse(NodesResponse listTasksResponse) { - responseReference.set(listTasksResponse); - responseLatch.countDown(); - } + int runNodesCount = randomIntBetween(1, nodesCount); + int blockedNodesCount = randomIntBetween(0, runNodesCount); + Task mainTask = startCancellableTestNodesAction(true, runNodesCount, blockedNodesCount, + new ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + responseLatch.countDown(); + } - @Override - public void onFailure(Exception e) { - throwableReference.set(e); - responseLatch.countDown(); + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + responseLatch.countDown(); + } } - }); + ); // Cancel all child tasks without cancelling the main task, which should quit on its own CancelTasksRequest request = new CancelTasksRequest(); @@ -320,8 +337,10 @@ public void onFailure(Exception e) { // Awaiting for the main task to finish responseLatch.await(); - // Should have cancelled tasks on all nodes - assertThat(response.getTasks().size(), equalTo(testNodes.length)); + // Should have cancelled tasks at least on all nodes where it was blocked + assertThat(response.getTasks().size(), lessThanOrEqualTo(runNodesCount)); + // but may also encounter some nodes where it was still running + assertThat(response.getTasks().size(), greaterThanOrEqualTo(blockedNodesCount)); assertBusy(() -> { // Make sure that main task is no longer running @@ -343,20 +362,22 @@ public void testTaskCancellationOnCoordinatingNodeLeavingTheCluster() throws Exc // We shouldn't block on the first node since it's leaving the cluster anyway so it doesn't matter List blockOnNodes = randomSubsetOf(blockedNodesCount, Arrays.copyOfRange(testNodes, 1, nodesCount)); - Task mainTask = startCancellableTestNodesAction(true, blockOnNodes, new CancellableNodesRequest("Test Request"), new - ActionListener() { - @Override - public void onResponse(NodesResponse listTasksResponse) { - responseReference.set(listTasksResponse); - responseLatch.countDown(); - } - - @Override - public void onFailure(Exception e) { - throwableReference.set(e); - responseLatch.countDown(); - } - }); + Task mainTask = startCancellableTestNodesAction(true, Arrays.asList(testNodes), blockOnNodes, + new CancellableNodesRequest("Test Request"), new + ActionListener() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + responseLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + responseLatch.countDown(); + } + } + ); String mainNode = testNodes[0].getNodeId(); @@ -415,6 +436,63 @@ public void onFailure(Exception e) { } + public void testNonExistingTaskCancellation() throws Exception { + setupTestNodes(Settings.EMPTY); + connectNodes(testNodes); + + // Cancel a task that doesn't exist + CancelTasksRequest request = new CancelTasksRequest(); + request.setReason("Testing Cancellation"); + request.setActions("do-not-match-anything"); + request.setNodes( + randomSubsetOf(randomIntBetween(1,testNodes.length - 1), testNodes).stream().map(TestNode::getNodeId).toArray(String[]::new)); + // And send the cancellation request to a random node + CancelTasksResponse response = ActionTestUtils.executeBlocking( + testNodes[randomIntBetween(1, testNodes.length - 1)].transportCancelTasksAction, request); + + // Shouldn't have cancelled anything + assertThat(response.getTasks().size(), equalTo(0)); + + assertBusy(() -> { + // Make sure that main task is no longer running + ListTasksResponse listTasksResponse = ActionTestUtils.executeBlocking( + testNodes[randomIntBetween(0, testNodes.length - 1)].transportListTasksAction, + new ListTasksRequest().setActions(CancelTasksAction.NAME + "*")); + assertEquals(0, listTasksResponse.getTasks().size()); + }); + } + + public void testCancelConcurrently() throws Exception { + setupTestNodes(Settings.EMPTY); + final TaskManager taskManager = testNodes[0].transportService.getTaskManager(); + int numTasks = randomIntBetween(1, 10); + List tasks = new ArrayList<>(numTasks); + for (int i = 0; i < numTasks; i++) { + tasks.add((CancellableTask) taskManager.register("type-" + i, "action-" + i, new CancellableNodeRequest())); + } + Thread[] threads = new Thread[randomIntBetween(1, 8)]; + AtomicIntegerArray notified = new AtomicIntegerArray(threads.length); + Phaser phaser = new Phaser(threads.length + 1); + final CancellableTask cancellingTask = randomFrom(tasks); + for (int i = 0; i < threads.length; i++) { + int idx = i; + threads[i] = new Thread(() -> { + phaser.arriveAndAwaitAdvance(); + taskManager.cancel(cancellingTask, "test", () -> assertTrue(notified.compareAndSet(idx, 0, 1))); + }); + threads[i].start(); + } + phaser.arriveAndAwaitAdvance(); + taskManager.unregister(cancellingTask); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + assertThat(notified.get(i), equalTo(1)); + } + AtomicBoolean called = new AtomicBoolean(); + taskManager.cancel(cancellingTask, "test", () -> assertTrue(called.compareAndSet(false, true))); + assertTrue(called.get()); + } + private static void debugDelay(String name) { // Introduce an additional pseudo random repeatable race conditions String delayName = RandomizedContext.current().getRunnerSeedAsString() + ":" + name;