Skip to content

Commit

Permalink
Optimize Task Manager parent bans
Browse files Browse the repository at this point in the history
Reduces network traffic when cancelling parents. Instead of
broadcasting parent ban request to all nodes, we now keep track of
nodes with child tasks and only send ban requests to these nodes.

Relates to elastic#50990
  • Loading branch information
imotov committed Jan 17, 2020
1 parent 394f09c commit 0556bc4
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@

package org.elasticsearch.action.admin.cluster.node.tasks.cancel;

import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
Expand All @@ -47,7 +45,10 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

Expand Down Expand Up @@ -105,16 +106,17 @@ protected void processTasks(CancelTasksRequest request, Consumer<CancellableTask
@Override
protected synchronized void taskOperation(CancelTasksRequest request, CancellableTask cancellableTask,
ActionListener<TaskInfo> listener) {
String nodeId = clusterService.localNode().getId();
DiscoveryNode localNode = clusterService.localNode();
String nodeId = localNode.getId();
final boolean canceled;
if (cancellableTask.shouldCancelChildrenOnCancellation()) {
DiscoveryNodes childNodes = clusterService.state().nodes();
final BanLock banLock = new BanLock(childNodes.getSize(), () -> removeBanOnNodes(cancellableTask, childNodes));
Set<DiscoveryNode> childNodes = withLocalNode(localNode, taskManager.startBanOnChildrenNodes(cancellableTask.getId()));
final BanLock banLock = new BanLock(childNodes.size(), () -> removeBanOnNodes(cancellableTask, childNodes));
canceled = taskManager.cancel(cancellableTask, request.getReason(), banLock::onTaskFinished);
if (canceled) {
// /In case the task has some child tasks, we need to wait for until ban is set on all nodes
logger.trace("cancelling task {} on child nodes", cancellableTask.getId());
AtomicInteger responses = new AtomicInteger(childNodes.getSize());
logger.info("cancelling task {} on child nodes", cancellableTask.getId());
AtomicInteger responses = new AtomicInteger(childNodes.size());
List<Exception> failures = new ArrayList<>();
setBanOnNodes(request.getReason(), cancellableTask, childNodes, new ActionListener<Void>() {
@Override
Expand Down Expand Up @@ -145,11 +147,12 @@ private void processResponse() {
}
});
}
// }
} else {
canceled = taskManager.cancel(cancellableTask, request.getReason(),
() -> listener.onResponse(cancellableTask.taskInfo(nodeId, false)));
if (canceled) {
logger.trace("task {} doesn't have any children that should be cancelled", cancellableTask.getId());
logger.trace("task {} doesn't require child task cancellation", cancellableTask.getId());
}
}
if (canceled == false) {
Expand All @@ -158,22 +161,36 @@ private void processResponse() {
}
}

private void setBanOnNodes(String reason, CancellableTask task, DiscoveryNodes nodes, ActionListener<Void> listener) {
private static Set<DiscoveryNode> withLocalNode(DiscoveryNode localNode, Set<DiscoveryNode> nodes) {
if (nodes.contains(localNode) == false) {
if ( nodes.isEmpty()) {
return Collections.singleton(localNode);
} else {
Set<DiscoveryNode> nodesWithLocal = new HashSet<>(nodes);
nodes.add(localNode);
return nodesWithLocal;
}
} else {
return nodes;
}
}

private void setBanOnNodes(String reason, CancellableTask task, Set<DiscoveryNode> nodes, ActionListener<Void> listener) {
sendSetBanRequest(nodes,
BanParentTaskRequest.createSetBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId()), reason),
listener);
}

private void removeBanOnNodes(CancellableTask task, DiscoveryNodes nodes) {
private void removeBanOnNodes(CancellableTask task, Set<DiscoveryNode> nodes) {
sendRemoveBanRequest(nodes,
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(clusterService.localNode().getId(), task.getId())));
}

private void sendSetBanRequest(DiscoveryNodes nodes, BanParentTaskRequest request, ActionListener<Void> listener) {
for (ObjectObjectCursor<String, DiscoveryNode> node : nodes.getNodes()) {
logger.trace("Sending ban for tasks with the parent [{}] to the node [{}], ban [{}]", request.parentTaskId, node.key,
private void sendSetBanRequest(Set<DiscoveryNode> nodes, BanParentTaskRequest request, ActionListener<Void> listener) {
for (DiscoveryNode node : nodes) {
logger.trace("Sending ban for tasks with the parent [{}] to the node [{}], ban [{}]", request.parentTaskId, node.getId(),
request.ban);
transportService.sendRequest(node.value, BAN_PARENT_ACTION_NAME, request,
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, request,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
Expand All @@ -182,17 +199,17 @@ public void handleResponse(TransportResponse.Empty response) {

@Override
public void handleException(TransportException exp) {
logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.key);
logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.getId());
listener.onFailure(exp);
}
});
}
}

private void sendRemoveBanRequest(DiscoveryNodes nodes, BanParentTaskRequest request) {
for (ObjectObjectCursor<String, DiscoveryNode> node : nodes.getNodes()) {
logger.debug("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.key);
transportService.sendRequest(node.value, BAN_PARENT_ACTION_NAME, request, EmptyTransportResponseHandler
private void sendRemoveBanRequest(Set<DiscoveryNode> nodes, BanParentTaskRequest request) {
for (DiscoveryNode node : nodes) {
logger.debug("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node.getId());
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, request, EmptyTransportResponseHandler
.INSTANCE_SAME);
}
}
Expand Down
39 changes: 39 additions & 0 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -417,12 +418,31 @@ public void waitForTaskCompletion(Task task, long untilInNanos) {
throw new ElasticsearchTimeoutException("Timed out waiting for completion of [{}]", task);
}

public void registerChildNode(long taskId, DiscoveryNode node) {
CancellableTaskHolder holder = cancellableTasks.get(taskId);
if (holder != null) {
holder.registerChildNode(node);
}
}

public Set<DiscoveryNode> startBanOnChildrenNodes(long taskId) {
CancellableTaskHolder holder = cancellableTasks.get(taskId);
if (holder != null) {
return holder.startBan();
}
return Collections.emptySet();
}

private static class CancellableTaskHolder {

private static final String TASK_FINISHED_MARKER = "task finished";

private final CancellableTask task;

private final Set<DiscoveryNode> nodes = new HashSet<>();

private volatile boolean banChildren = false;

private volatile String cancellationReason = null;

private volatile Runnable cancellationListener = null;
Expand All @@ -431,6 +451,25 @@ private static class CancellableTaskHolder {
this.task = task;
}

public void registerChildNode(DiscoveryNode node) {
synchronized (this) {
if (banChildren) {
throw new TaskCancelledException("The parent task was cancelled, shouldn't start any children tasks");
}
nodes.add(node);
}
}

public Set<DiscoveryNode> startBan() {
synchronized (this) {
if (banChildren) {
throw new TaskCancelledException("The parent task was cancelled, shouldn't start any children tasks");
}
banChildren = true;
}
return nodes;
}

/**
* Marks task as cancelled.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,15 @@ public final <T extends TransportResponse> void sendRequest(final Transport.Conn
final TransportRequestOptions options,
TransportResponseHandler<T> handler) {
try {
if (request.getParentTask().isSet()) {
taskManager.registerChildNode(request.getParentTask().getId(), connection.getNode());
}
asyncSender.sendRequest(connection, action, request, options, handler);
} catch (NodeNotConnectedException ex) {
// the caller might not handle this so we invoke the handler
handler.handleException(ex);
} catch (TaskCancelledException ex) {
handler.handleException(new TransportException(ex));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.test.ClusterServiceUtils.setState;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

Expand Down Expand Up @@ -184,19 +183,24 @@ protected NodeResponse nodeOperation(CancellableNodeRequest request, Task task)
}
}

private Task startCancellableTestNodesAction(boolean waitForActionToStart, int blockedNodesCount, ActionListener<NodesResponse>
listener) throws InterruptedException {
return startCancellableTestNodesAction(waitForActionToStart, randomSubsetOf(blockedNodesCount, testNodes), new
CancellableNodesRequest("Test Request"), listener);
private Task startCancellableTestNodesAction(boolean waitForActionToStart, int runNodesCount, int blockedNodesCount,
ActionListener<NodesResponse> listener) throws InterruptedException {
List<TestNode> runOnNodes = randomSubsetOf(runNodesCount, testNodes);

return startCancellableTestNodesAction(waitForActionToStart, runOnNodes, randomSubsetOf(blockedNodesCount, runOnNodes), new
CancellableNodesRequest("Test Request",runOnNodes.stream().map(TestNode::getNodeId).toArray(String[]::new)), listener);
}

private Task startCancellableTestNodesAction(boolean waitForActionToStart, Collection<TestNode> blockOnNodes, CancellableNodesRequest
private Task startCancellableTestNodesAction(boolean waitForActionToStart, List<TestNode> runOnNodes,
Collection<TestNode> blockOnNodes, CancellableNodesRequest
request, ActionListener<NodesResponse> listener) throws InterruptedException {
CountDownLatch actionLatch = waitForActionToStart ? new CountDownLatch(nodesCount) : null;
CountDownLatch actionLatch = waitForActionToStart ? new CountDownLatch(runOnNodes.size()) : null;
CancellableTestNodesAction[] actions = new CancellableTestNodesAction[nodesCount];
for (int i = 0; i < testNodes.length; i++) {
boolean shouldBlock = blockOnNodes.contains(testNodes[i]);
logger.info("The action in the node [{}] should block: [{}]", testNodes[i].getNodeId(), shouldBlock);
boolean shouldRun = runOnNodes.contains(testNodes[i]);
logger.info("The action on the node [{}] should run: [{}] should block: [{}]", testNodes[i].getNodeId(), shouldRun,
shouldBlock);
actions[i] = new CancellableTestNodesAction("internal:testAction", threadPool, testNodes[i]
.clusterService, testNodes[i].transportService, shouldBlock, actionLatch);
}
Expand All @@ -218,20 +222,22 @@ public void testBasicTaskCancellation() throws Exception {
logger.info("waitForActionToStart is set to {}", waitForActionToStart);
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
int blockedNodesCount = randomIntBetween(0, nodesCount);
Task mainTask = startCancellableTestNodesAction(waitForActionToStart, blockedNodesCount, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
int runNodesCount = randomIntBetween(1, nodesCount);
int blockedNodesCount = randomIntBetween(0, runNodesCount);
Task mainTask = startCancellableTestNodesAction(waitForActionToStart, runNodesCount, blockedNodesCount,
new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}

@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
});
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
});

// Cancel main task
CancelTasksRequest request = new CancelTasksRequest();
Expand All @@ -251,12 +257,12 @@ public void onFailure(Exception e) {
// Make sure that the request was successful
assertNull(throwableReference.get());
assertNotNull(responseReference.get());
assertEquals(nodesCount, responseReference.get().getNodes().size());
assertEquals(runNodesCount, responseReference.get().getNodes().size());
assertEquals(0, responseReference.get().failureCount());
} else {
// We canceled the request, in this case it should have fail, but we should get partial response
assertNull(throwableReference.get());
assertEquals(nodesCount, responseReference.get().failureCount() + responseReference.get().getNodes().size());
assertEquals(runNodesCount, responseReference.get().failureCount() + responseReference.get().getNodes().size());
// and we should have at least as many failures as the number of blocked operations
// (we might have cancelled some non-blocked operations before they even started and that's ok)
assertThat(responseReference.get().failureCount(), greaterThanOrEqualTo(blockedNodesCount));
Expand Down Expand Up @@ -291,19 +297,23 @@ public void testChildTasksCancellation() throws Exception {
CountDownLatch responseLatch = new CountDownLatch(1);
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
Task mainTask = startCancellableTestNodesAction(true, nodesCount, new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}
int runNodesCount = randomIntBetween(1, nodesCount);
int blockedNodesCount = randomIntBetween(0, runNodesCount);
Task mainTask = startCancellableTestNodesAction(true, runNodesCount, blockedNodesCount,
new ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}

@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
}
});
);

// Cancel all child tasks without cancelling the main task, which should quit on its own
CancelTasksRequest request = new CancelTasksRequest();
Expand All @@ -316,8 +326,10 @@ public void onFailure(Exception e) {
// Awaiting for the main task to finish
responseLatch.await();

// Should have cancelled tasks on all nodes
assertThat(response.getTasks().size(), equalTo(testNodes.length));
// Should have cancelled tasks at least on all nodes where it was blocked
assertThat(response.getTasks().size(), lessThanOrEqualTo(runNodesCount));
// but may also encounter some nodes where it was still running
assertThat(response.getTasks().size(), greaterThanOrEqualTo(blockedNodesCount));

assertBusy(() -> {
// Make sure that main task is no longer running
Expand All @@ -339,20 +351,22 @@ public void testTaskCancellationOnCoordinatingNodeLeavingTheCluster() throws Exc

// We shouldn't block on the first node since it's leaving the cluster anyway so it doesn't matter
List<TestNode> blockOnNodes = randomSubsetOf(blockedNodesCount, Arrays.copyOfRange(testNodes, 1, nodesCount));
Task mainTask = startCancellableTestNodesAction(true, blockOnNodes, new CancellableNodesRequest("Test Request"), new
ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}

@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
});
Task mainTask = startCancellableTestNodesAction(true, Arrays.asList(testNodes), blockOnNodes,
new CancellableNodesRequest("Test Request"), new
ActionListener<NodesResponse>() {
@Override
public void onResponse(NodesResponse listTasksResponse) {
responseReference.set(listTasksResponse);
responseLatch.countDown();
}

@Override
public void onFailure(Exception e) {
throwableReference.set(e);
responseLatch.countDown();
}
}
);

String mainNode = testNodes[0].getNodeId();

Expand Down

0 comments on commit 0556bc4

Please sign in to comment.