Skip to content

Commit

Permalink
Track cancellable tasks by parent ID (#76186)
Browse files Browse the repository at this point in the history
Today when cancelling a task with its descendants we perform a linear
scan through all the tasks looking for the few that have the right
parent ID. With potentially hundreds of thousands of tasks this takes
quite some time, particularly if there are many tasks to cancel.

This commit introduces a second map that tracks the tasks by their
parent ID so that it's super-cheap to find the descendants that need to
be cancelled.

Closes #75316
  • Loading branch information
DaveCTurner committed Aug 9, 2021
1 parent aaa5fc0 commit 0b041d6
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.tasks;

import org.elasticsearch.common.util.concurrent.ConcurrentCollections;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

/**
* Tracks items that are associated with cancellable tasks, supporting efficient lookup by task ID and by parent task ID
*/
public class CancellableTasksTracker<T> {

private final T[] empty;

public CancellableTasksTracker(T[] empty) {
assert empty.length == 0;
this.empty = empty;
}

private final Map<Long, T> byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
private final Map<TaskId, T[]> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();

/**
* Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too.
*/
public void put(Task task, T item) {
final long taskId = task.getId();
if (task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
oldValue = empty;
}
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1);
newValue[oldValue.length] = item;
return newValue;
});
}
final T oldItem = byTaskId.put(taskId, item);
assert oldItem == null : "duplicate entry for task [" + taskId + "]";
}

/**
* Get the item that corresponds with the given task, or {@code null} if there is no such item.
*/
public T get(long id) {
return byTaskId.get(id);
}

/**
* Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times
* for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is
* actually being completed by a concurrent call that's still ongoing.
*/
public T remove(Task task) {
final long taskId = task.getId();
final T oldItem = byTaskId.remove(taskId);
if (oldItem != null && task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
return null;
}
if (oldValue.length == 1) {
if (oldValue[0] == oldItem) {
return null;
} else {
return oldValue;
}
}
if (oldValue[0] == oldItem) {
return Arrays.copyOfRange(oldValue, 1, oldValue.length);
}
for (int i = 1; i < oldValue.length; i++) {
if (oldValue[i] == oldItem) {
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1);
System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1);
return newValue;
}
}
return oldValue;
});
}
return oldItem;
}

/**
* Return a collection of all the tracked items. May be large. In the presence of concurrent calls to {@link #put} and {@link #remove}
* it behaves similarly to {@link ConcurrentHashMap#values()}.
*/
public Collection<T> values() {
return byTaskId.values();
}

/**
* Return a collection of all the tracked items with a given parent, which will include at least every item for which {@link #put}
* completed, but {@link #remove} hasn't started. May include some additional items for which all the calls to {@link #remove} that
* started before this method was called have not completed.
*/
public Stream<T> getByParent(TaskId parentTaskId) {
final T[] byParent = byParentTaskId.get(parentTaskId);
if (byParent == null) {
return Stream.empty();
}
return Arrays.stream(byParent);
}

// assertion for tests, not an invariant but should eventually be true
boolean assertConsistent() {
// mustn't leak any items tracked by parent
assert byTaskId.isEmpty() == false || byParentTaskId.isEmpty();

// every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent
final Set<T> byTaskValues = new HashSet<>(byTaskId.values());
for (T[] byParent : byParentTaskId.values()) {
assert byParent.length > 0;
for (T t : byParent) {
assert byTaskValues.contains(t);
}
}

return true;
}
}
26 changes: 15 additions & 11 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ public class TaskManager implements ClusterStateApplier {

private final ConcurrentMapLong<Task> tasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency();

private final ConcurrentMapLong<CancellableTaskHolder> cancellableTasks = ConcurrentCollections
.newConcurrentMapLongWithAggressiveConcurrency();
private final CancellableTasksTracker<CancellableTaskHolder> cancellableTasks
= new CancellableTasksTracker<>(new CancellableTaskHolder[0]);

private final AtomicLong taskIdGenerator = new AtomicLong();

Expand Down Expand Up @@ -146,8 +146,7 @@ public Task register(String type, String action, TaskAwareRequest request) {
private void registerCancellableTask(Task task) {
CancellableTask cancellableTask = (CancellableTask) task;
CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask);
CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder);
assert oldHolder == null;
cancellableTasks.put(task, holder);
// Check if this task was banned before we start it. The empty check is used to avoid
// computing the hash code of the parent taskId as most of the time bannedParents is empty.
if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) {
Expand Down Expand Up @@ -187,15 +186,18 @@ public void cancel(CancellableTask task, String reason, Runnable listener) {
public Task unregister(Task task) {
logger.trace("unregister task for id: {}", task.getId());
if (task instanceof CancellableTask) {
CancellableTaskHolder holder = cancellableTasks.remove(task.getId());
CancellableTaskHolder holder = cancellableTasks.remove(task);
if (holder != null) {
holder.finish();
assert holder.task == task;
return holder.getTask();
} else {
return null;
}
} else {
return tasks.remove(task.getId());
final Task removedTask = tasks.remove(task.getId());
assert removedTask == null || removedTask == task;
return removedTask;
}
}

Expand Down Expand Up @@ -372,10 +374,7 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason, Transpor
}
}
}
return cancellableTasks.values().stream()
.filter(t -> t.hasParent(parentTaskId))
.map(t -> t.task)
.collect(Collectors.toList());
return cancellableTasks.getByParent(parentTaskId).map(t -> t.task).collect(Collectors.toList());
}

/**
Expand All @@ -393,6 +392,11 @@ public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(bannedParents.keySet());
}

// for testing
public boolean assertCancellableTaskConsistency() {
return cancellableTasks.assertConsistent();
}

private class Ban {
final String reason;
final boolean perChannel; // TODO: Remove this in 8.0
Expand Down Expand Up @@ -630,7 +634,7 @@ Set<Transport.Connection> startBan(String reason, Runnable onChildTasksCompleted
* @return a releasable that should be called when this pending task is completed
*/
public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) {
assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet";
assert cancellableTasks.get(task.getId()) != null : "task [" + task.getId() + "] is not registered yet";
final ChannelPendingTaskTracker tracker = startTrackingChannel(channel, trackerChannel -> trackerChannel.addTask(task));
return () -> tracker.removeTask(task);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.tasks;

import org.elasticsearch.test.ESTestCase;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;

public class CancellableTasksTrackerTests extends ESTestCase {

private static class TestTask {
private final Thread actionThread;
private final Thread watchThread;
private final Thread concurrentRemoveThread;

// 0 == before put, 1 == during put, 2 == after put, before remove, 3 == during remove, 4 == after remove
private final AtomicInteger state = new AtomicInteger();
private final boolean concurrentRemove = randomBoolean();

TestTask(Task task, String item, CancellableTasksTracker<String> tracker, Runnable awaitStart) {
if (concurrentRemove) {
concurrentRemoveThread = new Thread(() -> {
awaitStart.run();

for (int i = 0; i < 10; i++) {
if (3 <= state.get()) {
final String removed = tracker.remove(task);
if (removed != null) {
assertSame(item, removed);
}
}
}
});
} else {
concurrentRemoveThread = new Thread(awaitStart);
}

actionThread = new Thread(() -> {
awaitStart.run();

state.incrementAndGet();
tracker.put(task, item);
state.incrementAndGet();

Thread.yield();

state.incrementAndGet();
final String removed = tracker.remove(task);
state.incrementAndGet();
if (concurrentRemove == false || removed != null) {
assertSame(item, removed);
}

assertNull(tracker.remove(task));
}, "action-thread-" + item);

watchThread = new Thread(() -> {
awaitStart.run();

for (int i = 0; i < 10; i++) {
final int stateBefore = state.get();
final String getResult = tracker.get(task.getId());
final Set<String> getByParentResult = tracker.getByParent(task.getParentTaskId()).collect(Collectors.toSet());
final Set<String> values = new HashSet<>(tracker.values());
final int stateAfter = state.get();

assertThat(stateBefore, lessThanOrEqualTo(stateAfter));

if (getResult != null && task.getParentTaskId().isSet() && tracker.get(task.getId()) != null) {
assertThat(getByParentResult, hasItem(item));
}

if (stateAfter == 0) {
assertNull(getResult);
assertThat(getByParentResult, not(hasItem(item)));
assertThat(values, not(hasItem(item)));
}

if (stateBefore == 2 && stateAfter == 2) {
assertSame(item, getResult);
if (task.getParentTaskId().isSet()) {
assertThat(getByParentResult, hasItem(item));
} else {
assertThat(getByParentResult, empty());
}
assertThat(values, hasItem(item));
}

if (stateBefore == 4) {
assertNull(getResult);
if (concurrentRemove == false) {
assertThat(getByParentResult, not(hasItem(item)));
} // else our remove might have completed but the concurrent one hasn't updated the parent ID map yet
assertThat(values, not(hasItem(item)));
}
}
}, "watch-thread-" + item);
}

void start() {
watchThread.start();
concurrentRemoveThread.start();
actionThread.start();
}

void join() throws InterruptedException {
actionThread.join();
concurrentRemoveThread.join();
watchThread.join();
}
}

public void testCancellableTasksTracker() throws InterruptedException {

final TaskId[] parentTaskIds
= randomArray(10, 10, TaskId[]::new, () -> new TaskId(randomAlphaOfLength(5), randomNonNegativeLong()));

final CancellableTasksTracker<String> tracker = new CancellableTasksTracker<>(new String[0]);
final TestTask[] tasks = new TestTask[between(1, 100)];

final Runnable awaitStart = new Runnable() {
private final CyclicBarrier startBarrier = new CyclicBarrier(tasks.length * 3);

@Override
public void run() {
try {
startBarrier.await(10, TimeUnit.SECONDS);
} catch (InterruptedException | BrokenBarrierException | TimeoutException e) {
throw new AssertionError("unexpected", e);
}
}
};

for (int i = 0; i < tasks.length; i++) {
tasks[i] = new TestTask(
new Task(
randomNonNegativeLong(),
randomAlphaOfLength(5),
randomAlphaOfLength(5),
randomAlphaOfLength(5),
rarely() ? TaskId.EMPTY_TASK_ID : randomFrom(parentTaskIds),
Collections.emptyMap()),
"item-" + i,
tracker,
awaitStart
);
}

for (TestTask task : tasks) {
task.start();
}

for (TestTask task : tasks) {
task.join();
}

tracker.assertConsistent();
}
}
Loading

0 comments on commit 0b041d6

Please sign in to comment.