Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track cancellable tasks by parent ID #76186

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.tasks;

import org.elasticsearch.common.util.concurrent.ConcurrentCollections;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

/**
* Tracks items that are associated with cancellable tasks, supporting efficient lookup by task ID and by parent task ID
*/
public class CancellableTasksTracker<T> {

private final T[] empty;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can empty be a constant? It doesn't seems to modified anywhere and it's used in Arrays.copyOf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I didn't find a way to do so, but you're welcome to show me how :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this have to be a generic, T == CancellableTaskHolder is the only use case we have isn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in production, I didn't want to make CancellableTaskHolder public just for the tests and I don't think it makes a performance difference. I could be persuaded.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it makes a performance difference. I could be persuaded.

Nah there isn't any difference I think lets leave it as is then IMO :)


public CancellableTasksTracker(T[] empty) {
assert empty.length == 0;
this.empty = empty;
}

private final Map<Long, T> byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
private final Map<TaskId, T[]> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend TaskAssertions.assertAllCancellableTasksAreCancelled and perhaps even ESIntegTestCase to verify that the two maps are in sync after tests have completed (at least that the byParentId map only contains items that are also in the byTaskId map).

We could consider doing it in an assertion too but doing it concurrently while running might be difficult?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I don't think we can express any true invariants very easily; I added an eventually-true assertion in 9554f12.

arteam marked this conversation as resolved.
Show resolved Hide resolved

/**
* Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too.
*/
public void put(Task task, T item) {
final long taskId = task.getId();
if (task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
oldValue = empty;
}
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1);
newValue[oldValue.length] = item;
return newValue;
});
}
final T oldItem = byTaskId.put(taskId, item);
assert oldItem == null : "duplicate entry for task [" + taskId + "]";
}

/**
* Get the item that corresponds with the given task, or {@code null} if there is no such item.
*/
public T get(long id) {
return byTaskId.get(id);
}

/**
* Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times
* for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is
* actually being completed by a concurrent call that's still ongoing.
*/
public T remove(Task task) {
final long taskId = task.getId();
final T oldItem = byTaskId.remove(taskId);
if (oldItem != null && task.getParentTaskId().isSet()) {
byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
if (oldValue == null) {
return null;
}
if (oldValue.length == 1) {
if (oldValue[0] == oldItem) {
return null;
} else {
return oldValue;
}
}
if (oldValue[0] == oldItem) {
return Arrays.copyOfRange(oldValue, 1, oldValue.length);
}
for (int i = 1; i < oldValue.length; i++) {
if (oldValue[i] == oldItem) {
final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1);
System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1);
return newValue;
}
}
return oldValue;
});
}
return oldItem;
}

/**
* Return a collection of all the tracked items. May be large. In the presence of concurrent calls to {@link #put} and {@link #remove}
* it behaves similarly to {@link ConcurrentHashMap#values()}.
*/
public Collection<T> values() {
return byTaskId.values();
}

/**
* Return a collection of all the tracked items with a given parent, which will include at least every item for which {@link #put}
* completed, but {@link #remove} hasn't started. May include some additional items for which all the calls to {@link #remove} that
* started before this method was called have not completed.
*/
public Stream<T> getByParent(TaskId parentTaskId) {
final T[] byParent = byParentTaskId.get(parentTaskId);
if (byParent == null) {
return Stream.empty();
}
return Arrays.stream(byParent);
}

// assertion for tests, not an invariant but should eventually be true
boolean assertConsistent() {
// mustn't leak any items tracked by parent
assert byTaskId.isEmpty() == false || byParentTaskId.isEmpty();

// every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent
final Set<T> byTaskValues = new HashSet<>(byTaskId.values());
for (T[] byParent : byParentTaskId.values()) {
assert byParent.length > 0;
for (T t : byParent) {
assert byTaskValues.contains(t);
}
}

return true;
}
}
25 changes: 15 additions & 10 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public class TaskManager implements ClusterStateApplier {

private final Map<Long, Task> tasks = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();

private final Map<Long, CancellableTaskHolder> cancellableTasks = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
private final CancellableTasksTracker<CancellableTaskHolder> cancellableTasks
= new CancellableTasksTracker<>(new CancellableTaskHolder[0]);

private final AtomicLong taskIdGenerator = new AtomicLong();

Expand Down Expand Up @@ -184,8 +185,7 @@ public void onFailure(Exception e) {
private void registerCancellableTask(Task task) {
CancellableTask cancellableTask = (CancellableTask) task;
CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask);
CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder);
assert oldHolder == null;
cancellableTasks.put(task, holder);
// Check if this task was banned before we start it. The empty check is used to avoid
// computing the hash code of the parent taskId as most of the time bannedParents is empty.
if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) {
Expand Down Expand Up @@ -225,15 +225,18 @@ public void cancel(CancellableTask task, String reason, Runnable listener) {
public Task unregister(Task task) {
logger.trace("unregister task for id: {}", task.getId());
if (task instanceof CancellableTask) {
CancellableTaskHolder holder = cancellableTasks.remove(task.getId());
CancellableTaskHolder holder = cancellableTasks.remove(task);
if (holder != null) {
holder.finish();
assert holder.task == task;
return holder.getTask();
} else {
return null;
}
} else {
return tasks.remove(task.getId());
final Task removedTask = tasks.remove(task.getId());
assert removedTask == null || removedTask == task;
return removedTask;
}
}

Expand Down Expand Up @@ -388,10 +391,7 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason, Transpor
ban.registerChannel(DIRECT_CHANNEL_TRACKER);
}
}
return cancellableTasks.values().stream()
.filter(t -> t.hasParent(parentTaskId))
.map(t -> t.task)
.collect(Collectors.toUnmodifiableList());
return cancellableTasks.getByParent(parentTaskId).map(t -> t.task).collect(Collectors.toUnmodifiableList());
}

/**
Expand All @@ -409,6 +409,11 @@ public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(bannedParents.keySet());
}

// for testing
public boolean assertCancellableTaskConsistency() {
return cancellableTasks.assertConsistent();
}

private class Ban {
final String reason;
final Set<ChannelPendingTaskTracker> channels;
Expand Down Expand Up @@ -622,7 +627,7 @@ Set<Transport.Connection> startBan(String reason, Runnable onChildTasksCompleted
* @return a releasable that should be called when this pending task is completed
*/
public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) {
assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet";
assert cancellableTasks.get(task.getId()) != null : "task [" + task.getId() + "] is not registered yet";
final ChannelPendingTaskTracker tracker = startTrackingChannel(channel, trackerChannel -> trackerChannel.addTask(task));
return () -> tracker.removeTask(task);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.tasks;

import org.elasticsearch.test.ESTestCase;

import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;

public class CancellableTasksTrackerTests extends ESTestCase {

private static class TestTask {
private final Thread actionThread;
private final Thread watchThread;
private final Thread concurrentRemoveThread;

// 0 == before put, 1 == during put, 2 == after put, before remove, 3 == during remove, 4 == after remove
private final AtomicInteger state = new AtomicInteger();
private final boolean concurrentRemove = randomBoolean();

TestTask(Task task, String item, CancellableTasksTracker<String> tracker, Runnable awaitStart) {
if (concurrentRemove) {
concurrentRemoveThread = new Thread(() -> {
awaitStart.run();

for (int i = 0; i < 10; i++) {
if (3 <= state.get()) {
final String removed = tracker.remove(task);
if (removed != null) {
assertSame(item, removed);
}
}
}
});
} else {
concurrentRemoveThread = new Thread(awaitStart);
}

actionThread = new Thread(() -> {
awaitStart.run();

state.incrementAndGet();
tracker.put(task, item);
state.incrementAndGet();

Thread.yield();

state.incrementAndGet();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure we would ever see state==2 in other threads, I would think it would be ok for the JVM to optimize it away.

This is possibly even true for the entire actionThread since there is no synchronization with the watch threads here. The watch threads may only see the the final output. I would advocate a yield should help, but that is outside JMM I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, an earlier iteration had some assertions in between these calls. In theory you're right that it would be legitimate for the compiler to collapse these tests to something fairly trivial, but in practice they do cover many interleavings and the other threads do indeed observe state==2 cases. Just try adding bugs :) I'll add a yield anyway.

final String removed = tracker.remove(task);
state.incrementAndGet();
if (concurrentRemove == false || removed != null) {
assertSame(item, removed);
}

assertNull(tracker.remove(task));
}, "action-thread-" + item);

watchThread = new Thread(() -> {
awaitStart.run();

for (int i = 0; i < 10; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not always run at least until state.get() ==4 (and then another round)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that this sometimes noticeably added to the test runtime, we spent ages doing this loop before actually doing the useful work. Again, in practice this covers what we care about.

final int stateBefore = state.get();
final String getResult = tracker.get(task.getId());
final Set<String> getByParentResult = tracker.getByParent(task.getParentTaskId()).collect(Collectors.toSet());
final Set<String> values = new HashSet<>(tracker.values());
final int stateAfter = state.get();

assertThat(stateBefore, lessThanOrEqualTo(stateAfter));

if (getResult != null && task.getParentTaskId().isSet() && tracker.get(task.getId()) != null) {
assertThat(getByParentResult, hasItem(item));
}

if (stateAfter == 0) {
assertNull(getResult);
assertThat(getByParentResult, not(hasItem(item)));
assertThat(values, not(hasItem(item)));
}

if (stateBefore == 2 && stateAfter == 2) {
assertSame(item, getResult);
if (task.getParentTaskId().isSet()) {
assertThat(getByParentResult, hasItem(item));
} else {
assertThat(getByParentResult, empty());
}
assertThat(values, hasItem(item));
}

if (stateBefore == 4) {
assertNull(getResult);
if (concurrentRemove == false) {
assertThat(getByParentResult, not(hasItem(item)));
} // else our remove might have completed but the concurrent one hasn't updated the parent ID map yet
assertThat(values, not(hasItem(item)));
}
}
}, "watch-thread-" + item);
}

void start() {
watchThread.start();
concurrentRemoveThread.start();
actionThread.start();
}

void join() throws InterruptedException {
actionThread.join();
concurrentRemoveThread.join();
watchThread.join();
}
}

public void testCancellableTasksTracker() throws InterruptedException {

final TaskId[] parentTaskIds
= randomArray(10, 10, TaskId[]::new, () -> new TaskId(randomAlphaOfLength(5), randomNonNegativeLong()));

final CancellableTasksTracker<String> tracker = new CancellableTasksTracker<>(new String[0]);
final TestTask[] tasks = new TestTask[between(1, 100)];

final Runnable awaitStart = new Runnable() {
private final CyclicBarrier startBarrier = new CyclicBarrier(tasks.length * 3);

@Override
public void run() {
try {
startBarrier.await(10, TimeUnit.SECONDS);
} catch (InterruptedException | BrokenBarrierException | TimeoutException e) {
throw new AssertionError("unexpected", e);
}
}
};

for (int i = 0; i < tasks.length; i++) {
tasks[i] = new TestTask(
new Task(
randomNonNegativeLong(),
randomAlphaOfLength(5),
randomAlphaOfLength(5),
randomAlphaOfLength(5),
rarely() ? TaskId.EMPTY_TASK_ID : randomFrom(parentTaskIds),
Collections.emptyMap()),
"item-" + i,
tracker,
awaitStart
);
}

for (TestTask task : tasks) {
task.start();
}

for (TestTask task : tasks) {
task.join();
}

tracker.assertConsistent();
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also check the more trivial case where no parent id is set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstand, I think we already do?

                    rarely() ? TaskId.EMPTY_TASK_ID : randomFrom(parentTaskIds),

Loading