diff --git a/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java new file mode 100644 index 0000000000000..c1723e492dde3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java @@ -0,0 +1,135 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.tasks; + +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Stream; + +/** + * Tracks items that are associated with cancellable tasks, supporting efficient lookup by task ID and by parent task ID + */ +public class CancellableTasksTracker<T> { + + private final T[] empty; + + public CancellableTasksTracker(T[] empty) { + assert empty.length == 0; + this.empty = empty; + } + + private final Map<Long, T> byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); + private final Map<TaskId, T[]> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency(); + + /** + * Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too. + */ + public void put(Task task, T item) { + final long taskId = task.getId(); + if (task.getParentTaskId().isSet()) { + byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> { + if (oldValue == null) { + oldValue = empty; + } + final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1); + newValue[oldValue.length] = item; + return newValue; + }); + } + final T oldItem = byTaskId.put(taskId, item); + assert oldItem == null : "duplicate entry for task [" + taskId + "]"; + } + + /** + * Get the item that corresponds with the given task, or {@code null} if there is no such item. + */ + public T get(long id) { + return byTaskId.get(id); + } + + /** + * Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times + * for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is + * actually being completed by a concurrent call that's still ongoing. + */ + public T remove(Task task) { + final long taskId = task.getId(); + final T oldItem = byTaskId.remove(taskId); + if (oldItem != null && task.getParentTaskId().isSet()) { + byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> { + if (oldValue == null) { + return null; + } + if (oldValue.length == 1) { + if (oldValue[0] == oldItem) { + return null; + } else { + return oldValue; + } + } + if (oldValue[0] == oldItem) { + return Arrays.copyOfRange(oldValue, 1, oldValue.length); + } + for (int i = 1; i < oldValue.length; i++) { + if (oldValue[i] == oldItem) { + final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1); + System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1); + return newValue; + } + } + return oldValue; + }); + } + return oldItem; + } + + /** + * Return a collection of all the tracked items. May be large. In the presence of concurrent calls to {@link #put} and {@link #remove} + * it behaves similarly to {@link ConcurrentHashMap#values()}. + */ + public Collection<T> values() { + return byTaskId.values(); + } + + /** + * Return a collection of all the tracked items with a given parent, which will include at least every item for which {@link #put} + * completed, but {@link #remove} hasn't started. May include some additional items for which all the calls to {@link #remove} that + * started before this method was called have not completed. + */ + public Stream<T> getByParent(TaskId parentTaskId) { + final T[] byParent = byParentTaskId.get(parentTaskId); + if (byParent == null) { + return Stream.empty(); + } + return Arrays.stream(byParent); + } + + // assertion for tests, not an invariant but should eventually be true + boolean assertConsistent() { + // mustn't leak any items tracked by parent + assert byTaskId.isEmpty() == false || byParentTaskId.isEmpty(); + + // every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent + final Set<T> byTaskValues = new HashSet<>(byTaskId.values()); + for (T[] byParent : byParentTaskId.values()) { + assert byParent.length > 0; + for (T t : byParent) { + assert byTaskValues.contains(t); + } + } + + return true; + } +} diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index 52ca1a3889911..6aaf68b433f47 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -79,8 +79,8 @@ public class TaskManager implements ClusterStateApplier { private final ConcurrentMapLong<Task> tasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency(); - private final ConcurrentMapLong<CancellableTaskHolder> cancellableTasks = ConcurrentCollections - .newConcurrentMapLongWithAggressiveConcurrency(); + private final CancellableTasksTracker<CancellableTaskHolder> cancellableTasks + = new CancellableTasksTracker<>(new CancellableTaskHolder[0]); private final AtomicLong taskIdGenerator = new AtomicLong(); @@ -146,8 +146,7 @@ public Task register(String type, String action, TaskAwareRequest request) { private void registerCancellableTask(Task task) { CancellableTask cancellableTask = (CancellableTask) task; CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask); - CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder); - assert oldHolder == null; + cancellableTasks.put(task, holder); // Check if this task was banned before we start it. The empty check is used to avoid // computing the hash code of the parent taskId as most of the time bannedParents is empty. if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) { @@ -187,15 +186,18 @@ public void cancel(CancellableTask task, String reason, Runnable listener) { public Task unregister(Task task) { logger.trace("unregister task for id: {}", task.getId()); if (task instanceof CancellableTask) { - CancellableTaskHolder holder = cancellableTasks.remove(task.getId()); + CancellableTaskHolder holder = cancellableTasks.remove(task); if (holder != null) { holder.finish(); + assert holder.task == task; return holder.getTask(); } else { return null; } } else { - return tasks.remove(task.getId()); + final Task removedTask = tasks.remove(task.getId()); + assert removedTask == null || removedTask == task; + return removedTask; } } @@ -372,10 +374,7 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason, Transpor } } } - return cancellableTasks.values().stream() - .filter(t -> t.hasParent(parentTaskId)) - .map(t -> t.task) - .collect(Collectors.toList()); + return cancellableTasks.getByParent(parentTaskId).map(t -> t.task).collect(Collectors.toList()); } /** @@ -393,6 +392,11 @@ public Set<TaskId> getBannedTaskIds() { return Collections.unmodifiableSet(bannedParents.keySet()); } + // for testing + public boolean assertCancellableTaskConsistency() { + return cancellableTasks.assertConsistent(); + } + private class Ban { final String reason; final boolean perChannel; // TODO: Remove this in 8.0 @@ -630,7 +634,7 @@ Set<Transport.Connection> startBan(String reason, Runnable onChildTasksCompleted * @return a releasable that should be called when this pending task is completed */ public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) { - assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet"; + assert cancellableTasks.get(task.getId()) != null : "task [" + task.getId() + "] is not registered yet"; final ChannelPendingTaskTracker tracker = startTrackingChannel(channel, trackerChannel -> trackerChannel.addTask(task)); return () -> tracker.removeTask(task); } diff --git a/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java new file mode 100644 index 0000000000000..31a4be55b05a1 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.tasks; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.not; + +public class CancellableTasksTrackerTests extends ESTestCase { + + private static class TestTask { + private final Thread actionThread; + private final Thread watchThread; + private final Thread concurrentRemoveThread; + + // 0 == before put, 1 == during put, 2 == after put, before remove, 3 == during remove, 4 == after remove + private final AtomicInteger state = new AtomicInteger(); + private final boolean concurrentRemove = randomBoolean(); + + TestTask(Task task, String item, CancellableTasksTracker<String> tracker, Runnable awaitStart) { + if (concurrentRemove) { + concurrentRemoveThread = new Thread(() -> { + awaitStart.run(); + + for (int i = 0; i < 10; i++) { + if (3 <= state.get()) { + final String removed = tracker.remove(task); + if (removed != null) { + assertSame(item, removed); + } + } + } + }); + } else { + concurrentRemoveThread = new Thread(awaitStart); + } + + actionThread = new Thread(() -> { + awaitStart.run(); + + state.incrementAndGet(); + tracker.put(task, item); + state.incrementAndGet(); + + Thread.yield(); + + state.incrementAndGet(); + final String removed = tracker.remove(task); + state.incrementAndGet(); + if (concurrentRemove == false || removed != null) { + assertSame(item, removed); + } + + assertNull(tracker.remove(task)); + }, "action-thread-" + item); + + watchThread = new Thread(() -> { + awaitStart.run(); + + for (int i = 0; i < 10; i++) { + final int stateBefore = state.get(); + final String getResult = tracker.get(task.getId()); + final Set<String> getByParentResult = tracker.getByParent(task.getParentTaskId()).collect(Collectors.toSet()); + final Set<String> values = new HashSet<>(tracker.values()); + final int stateAfter = state.get(); + + assertThat(stateBefore, lessThanOrEqualTo(stateAfter)); + + if (getResult != null && task.getParentTaskId().isSet() && tracker.get(task.getId()) != null) { + assertThat(getByParentResult, hasItem(item)); + } + + if (stateAfter == 0) { + assertNull(getResult); + assertThat(getByParentResult, not(hasItem(item))); + assertThat(values, not(hasItem(item))); + } + + if (stateBefore == 2 && stateAfter == 2) { + assertSame(item, getResult); + if (task.getParentTaskId().isSet()) { + assertThat(getByParentResult, hasItem(item)); + } else { + assertThat(getByParentResult, empty()); + } + assertThat(values, hasItem(item)); + } + + if (stateBefore == 4) { + assertNull(getResult); + if (concurrentRemove == false) { + assertThat(getByParentResult, not(hasItem(item))); + } // else our remove might have completed but the concurrent one hasn't updated the parent ID map yet + assertThat(values, not(hasItem(item))); + } + } + }, "watch-thread-" + item); + } + + void start() { + watchThread.start(); + concurrentRemoveThread.start(); + actionThread.start(); + } + + void join() throws InterruptedException { + actionThread.join(); + concurrentRemoveThread.join(); + watchThread.join(); + } + } + + public void testCancellableTasksTracker() throws InterruptedException { + + final TaskId[] parentTaskIds + = randomArray(10, 10, TaskId[]::new, () -> new TaskId(randomAlphaOfLength(5), randomNonNegativeLong())); + + final CancellableTasksTracker<String> tracker = new CancellableTasksTracker<>(new String[0]); + final TestTask[] tasks = new TestTask[between(1, 100)]; + + final Runnable awaitStart = new Runnable() { + private final CyclicBarrier startBarrier = new CyclicBarrier(tasks.length * 3); + + @Override + public void run() { + try { + startBarrier.await(10, TimeUnit.SECONDS); + } catch (InterruptedException | BrokenBarrierException | TimeoutException e) { + throw new AssertionError("unexpected", e); + } + } + }; + + for (int i = 0; i < tasks.length; i++) { + tasks[i] = new TestTask( + new Task( + randomNonNegativeLong(), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + randomAlphaOfLength(5), + rarely() ? TaskId.EMPTY_TASK_ID : randomFrom(parentTaskIds), + Collections.emptyMap()), + "item-" + i, + tracker, + awaitStart + ); + } + + for (TestTask task : tasks) { + task.start(); + } + + for (TestTask task : tasks) { + task.join(); + } + + tracker.assertConsistent(); + } +} diff --git a/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java b/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java index 3b7cc3c07080b..e99f0ea33af5a 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.transport.TransportService; import java.util.List; @@ -46,7 +47,9 @@ public static void assertAllCancellableTasksAreCancelled(String actionPrefix) th assertBusy(() -> { boolean foundTask = false; for (TransportService transportService : internalCluster().getInstances(TransportService.class)) { - for (CancellableTask cancellableTask : transportService.getTaskManager().getCancellableTasks().values()) { + final TaskManager taskManager = transportService.getTaskManager(); + assertTrue(taskManager.assertCancellableTaskConsistency()); + for (CancellableTask cancellableTask : taskManager.getCancellableTasks().values()) { if (cancellableTask.getAction().startsWith(actionPrefix)) { foundTask = true; assertTrue(