Skip to content

Commit

Permalink
Child requests proactively cancel children tasks (#92588)
Browse files Browse the repository at this point in the history
To make this possible we modify the CancellableTasksTracker to track children tasks by the Request ID as well. That way, we can send an Action to cancel a child based on the parent task and the Request ID.

This is especially useful when parents' children requests timeout on the parents' side.

Fixes #90353
Relates #66992
  • Loading branch information
kingherc authored Apr 3, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent f353be2 commit 400b7ec
Showing 25 changed files with 435 additions and 95 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/92588.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 92588
summary: Failed tasks proactively cancel children tasks
area: Snapshot/Restore
type: enhancement
issues:
- 90353

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1335,6 +1335,16 @@ public TaskId getParentTask() {
return request.getParentTask();
}

@Override
public void setRequestId(long requestId) {
request.setRequestId(requestId);
}

@Override
public long getRequestId() {
return request.getRequestId();
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return request.createTask(id, type, action, parentTaskId, headers);
Original file line number Diff line number Diff line change
@@ -254,6 +254,9 @@ private <T extends ClusterStateTaskListener> void executeAndPublishBatch(
@Override
public void setParentTask(TaskId taskId) {}

@Override
public void setRequestId(long requestId) {}

@Override
public TaskId getParentTask() {
return TaskId.EMPTY_TASK_ID;
Original file line number Diff line number Diff line change
@@ -175,6 +175,11 @@ public void setParentTask(TaskId taskId) {
throw new UnsupportedOperationException("parent task if for persistent tasks shouldn't change");
}

@Override
public void setRequestId(long requestId) {
throw new UnsupportedOperationException("does not have a request ID");
}

@Override
public TaskId getParentTask() {
return parentTaskId;
Original file line number Diff line number Diff line change
@@ -31,21 +31,45 @@ public CancellableTasksTracker(T[] empty) {
}

private final Map<Long, T> byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
private final Map<TaskId, T[]> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
private final Map<TaskId, Map<Long, T[]>> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();

/**
* Gets the cancellable children of a parent task.
*
* Note: children of non-positive request IDs (e.g., -1) may be grouped together.
*/
public Stream<T> getChildrenByRequestId(TaskId parentTaskId, long childRequestId) {
Map<Long, T[]> byRequestId = byParentTaskId.get(parentTaskId);
if (byRequestId != null) {
T[] children = byRequestId.get(childRequestId);
if (children != null) {
return Arrays.stream(children);
}
}
return Stream.empty();
}

/**
* Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too.
*/
public void put(Task task, T item) {
public void put(Task task, long requestId, T item) {
final long taskId = task.getId();
if (task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
oldValue = empty;
byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> {
if (oldRequestIdMap == null) {
oldRequestIdMap = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
}
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1);
newValue[oldValue.length] = item;
return newValue;

oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> {
if (oldValue == null) {
oldValue = empty;
}
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1);
newValue[oldValue.length] = item;
return newValue;
});

return oldRequestIdMap;
});
}
final T oldItem = byTaskId.put(taskId, item);
@@ -60,36 +84,50 @@ public T get(long id) {
}

/**
* Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times
* for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is
* actually being completed by a concurrent call that's still ongoing.
* Remove (and return) the item that corresponds with the given task and request ID. Return {@code null} if not present. Safe to call
* multiple times for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if
* the removal is actually being completed by a concurrent call that's still ongoing.
*/
public T remove(Task task) {
final long taskId = task.getId();
final T oldItem = byTaskId.remove(taskId);
if (oldItem != null && task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
byParentTaskId.compute(task.getParentTaskId(), (taskKey, oldRequestIdMap) -> {
if (oldRequestIdMap == null) {
return null;
}
if (oldValue.length == 1) {
if (oldValue[0] == oldItem) {
return null;
} else {

for (Long requestId : oldRequestIdMap.keySet()) {
oldRequestIdMap.compute(requestId, (requestIdKey, oldValue) -> {
if (oldValue == null) {
return null;
}
if (oldValue.length == 1) {
if (oldValue[0] == oldItem) {
return null;
} else {
return oldValue;
}
}
if (oldValue[0] == oldItem) {
return Arrays.copyOfRange(oldValue, 1, oldValue.length);
}
for (int i = 1; i < oldValue.length; i++) {
if (oldValue[i] == oldItem) {
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1);
System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1);
return newValue;
}
}
return oldValue;
}
}
if (oldValue[0] == oldItem) {
return Arrays.copyOfRange(oldValue, 1, oldValue.length);
});
}
for (int i = 1; i < oldValue.length; i++) {
if (oldValue[i] == oldItem) {
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1);
System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1);
return newValue;
}

if (oldRequestIdMap.keySet().isEmpty()) {
return null;
}
return oldValue;

return oldRequestIdMap;
});
}
return oldItem;
@@ -109,11 +147,11 @@ public Collection<T> values() {
* started before this method was called have not completed.
*/
public Stream<T> getByParent(TaskId parentTaskId) {
final T[] byParent = byParentTaskId.get(parentTaskId);
final Map<Long, T[]> byParent = byParentTaskId.get(parentTaskId);
if (byParent == null) {
return Stream.empty();
}
return Arrays.stream(byParent);
return byParent.values().stream().flatMap(Stream::of);
}

// assertion for tests, not an invariant but should eventually be true
@@ -123,12 +161,14 @@ boolean assertConsistent() {

// every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent
final Set<T> byTaskValues = new HashSet<>(byTaskId.values());
for (T[] byParent : byParentTaskId.values()) {
assert byParent.length > 0;
for (T t : byParent) {
assert byTaskValues.contains(t);
}
}
byParentTaskId.values().forEach(byParentMap -> {
byParentMap.forEach((requestId, byParentArray) -> {
assert byParentArray.length > 0;
for (T t : byParentArray) {
assert byTaskValues.contains(t);
}
});
});

return true;
}
12 changes: 12 additions & 0 deletions server/src/main/java/org/elasticsearch/tasks/TaskAwareRequest.java
Original file line number Diff line number Diff line change
@@ -26,6 +26,18 @@ default void setParentTask(String parentTaskNode, long parentTaskId) {
*/
void setParentTask(TaskId taskId);

/**
* Gets the request ID. Defaults to -1, meaning "no request ID is set".
*/
default long getRequestId() {
return -1;
}

/**
* Set the request ID related to this task.
*/
void setRequestId(long requestId);

/**
* Get a reference to the task that created this request. Implementers should default to
* {@link TaskId#EMPTY_TASK_ID}, meaning "there is no parent".
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.threadpool.ThreadPool;
@@ -44,6 +45,8 @@

public class TaskCancellationService {
public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
public static final String CANCEL_CHILD_ACTION_NAME = "internal:admin/tasks/cancel_child";
public static final TransportVersion VERSION_SUPPORTING_CANCEL_CHILD_ACTION = TransportVersion.V_8_8_0;
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
@@ -59,6 +62,12 @@ public TaskCancellationService(TransportService transportService) {
BanParentTaskRequest::new,
new BanParentRequestHandler()
);
transportService.registerRequestHandler(
CANCEL_CHILD_ACTION_NAME,
ThreadPool.Names.SAME,
CancelChildRequest::new,
new CancelChildRequestHandler()
);
}

private String localNodeId() {
@@ -341,4 +350,69 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
}
}
}

private static class CancelChildRequest extends TransportRequest {

private final TaskId parentTaskId;
private final long childRequestId;
private final String reason;

static CancelChildRequest createCancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) {
return new CancelChildRequest(parentTaskId, childRequestId, reason);
}

private CancelChildRequest(TaskId parentTaskId, long childRequestId, String reason) {
this.parentTaskId = parentTaskId;
this.childRequestId = childRequestId;
this.reason = reason;
}

private CancelChildRequest(StreamInput in) throws IOException {
super(in);
parentTaskId = TaskId.readFromStream(in);
childRequestId = in.readLong();
reason = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
parentTaskId.writeTo(out);
out.writeLong(childRequestId);
out.writeString(reason);
}
}

private class CancelChildRequestHandler implements TransportRequestHandler<CancelChildRequest> {
@Override
public void messageReceived(final CancelChildRequest request, final TransportChannel channel, Task task) throws Exception {
taskManager.cancelChildLocal(request.parentTaskId, request.childRequestId, request.reason);
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
}

/**
* Sends an action to cancel a child task, associated with the given request ID and parent task.
*/
public void cancelChildRemote(TaskId parentTask, long childRequestId, Transport.Connection childConnection, String reason) {
if (childConnection.getTransportVersion().onOrAfter(VERSION_SUPPORTING_CANCEL_CHILD_ACTION)) {
DiscoveryNode childNode = childConnection.getNode();
logger.debug(
"sending cancellation of child of parent task [{}] with request ID [{}] to node [{}] because of [{}]",
parentTask,
childRequestId,
childNode,
reason
);
final CancelChildRequest request = CancelChildRequest.createCancelChildRequest(parentTask, childRequestId, reason);
transportService.sendRequest(
childNode,
CANCEL_CHILD_ACTION_NAME,
request,
TransportRequestOptions.EMPTY,
EmptyTransportResponseHandler.INSTANCE_SAME
);
}
}

}
Loading

0 comments on commit 400b7ec

Please sign in to comment.