diff --git a/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java
new file mode 100644
index 0000000000000..c1723e492dde3
--- /dev/null
+++ b/server/src/main/java/org/elasticsearch/tasks/CancellableTasksTracker.java
@@ -0,0 +1,135 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.tasks;
+
+import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Stream;
+
+/**
+ * Tracks items that are associated with cancellable tasks, supporting efficient lookup by task ID and by parent task ID
+ */
+public class CancellableTasksTracker<T> {
+
+    private final T[] empty;
+
+    public CancellableTasksTracker(T[] empty) {
+        assert empty.length == 0;
+        this.empty = empty;
+    }
+
+    private final Map<Long, T> byTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
+    private final Map<TaskId, T[]> byParentTaskId = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
+
+    /**
+     * Add an item for the given task. Should only be called once for each task, and {@code item} must be unique per task too.
+     */
+    public void put(Task task, T item) {
+        final long taskId = task.getId();
+        if (task.getParentTaskId().isSet()) {
+            byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
+                if (oldValue == null) {
+                    oldValue = empty;
+                }
+                final T[] newValue = Arrays.copyOf(oldValue, oldValue.length + 1);
+                newValue[oldValue.length] = item;
+                return newValue;
+            });
+        }
+        final T oldItem = byTaskId.put(taskId, item);
+        assert oldItem == null : "duplicate entry for task [" + taskId + "]";
+    }
+
+    /**
+     * Get the item that corresponds with the given task, or {@code null} if there is no such item.
+     */
+    public T get(long id) {
+        return byTaskId.get(id);
+    }
+
+    /**
+     * Remove (and return) the item that corresponds with the given task. Return {@code null} if not present. Safe to call multiple times
+     * for each task. However, {@link #getByParent} may return this task even after a call to this method completes, if the removal is
+     * actually being completed by a concurrent call that's still ongoing.
+     */
+    public T remove(Task task) {
+        final long taskId = task.getId();
+        final T oldItem = byTaskId.remove(taskId);
+        if (oldItem != null && task.getParentTaskId().isSet()) {
+            byParentTaskId.compute(task.getParentTaskId(), (ignored, oldValue) -> {
+                if (oldValue == null) {
+                    return null;
+                }
+                if (oldValue.length == 1) {
+                    if (oldValue[0] == oldItem) {
+                        return null;
+                    } else {
+                        return oldValue;
+                    }
+                }
+                if (oldValue[0] == oldItem) {
+                    return Arrays.copyOfRange(oldValue, 1, oldValue.length);
+                }
+                for (int i = 1; i < oldValue.length; i++) {
+                    if (oldValue[i] == oldItem) {
+                        final T[] newValue = Arrays.copyOf(oldValue, oldValue.length - 1);
+                        System.arraycopy(oldValue, i + 1, newValue, i, oldValue.length - i - 1);
+                        return newValue;
+                    }
+                }
+                return oldValue;
+            });
+        }
+        return oldItem;
+    }
+
+    /**
+     * Return a collection of all the tracked items. May be large. In the presence of concurrent calls to {@link #put} and {@link #remove}
+     * it behaves similarly to {@link ConcurrentHashMap#values()}.
+     */
+    public Collection<T> values() {
+        return byTaskId.values();
+    }
+
+    /**
+     * Return a collection of all the tracked items with a given parent, which will include at least every item for which {@link #put}
+     * completed, but {@link #remove} hasn't started. May include some additional items for which all the calls to {@link #remove} that
+     * started before this method was called have not completed.
+     */
+    public Stream<T> getByParent(TaskId parentTaskId) {
+        final T[] byParent = byParentTaskId.get(parentTaskId);
+        if (byParent == null) {
+            return Stream.empty();
+        }
+        return Arrays.stream(byParent);
+    }
+
+    // assertion for tests, not an invariant but should eventually be true
+    boolean assertConsistent() {
+        // mustn't leak any items tracked by parent
+        assert byTaskId.isEmpty() == false || byParentTaskId.isEmpty();
+
+        // every by-parent value must be tracked by task too; the converse isn't true since we don't track values without a parent
+        final Set<T> byTaskValues = new HashSet<>(byTaskId.values());
+        for (T[] byParent : byParentTaskId.values()) {
+            assert byParent.length > 0;
+            for (T t : byParent) {
+                assert byTaskValues.contains(t);
+            }
+        }
+
+        return true;
+    }
+}
diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java
index 52ca1a3889911..6aaf68b433f47 100644
--- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java
+++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java
@@ -79,8 +79,8 @@ public class TaskManager implements ClusterStateApplier {
 
     private final ConcurrentMapLong<Task> tasks = ConcurrentCollections.newConcurrentMapLongWithAggressiveConcurrency();
 
-    private final ConcurrentMapLong<CancellableTaskHolder> cancellableTasks = ConcurrentCollections
-        .newConcurrentMapLongWithAggressiveConcurrency();
+    private final CancellableTasksTracker<CancellableTaskHolder> cancellableTasks
+        = new CancellableTasksTracker<>(new CancellableTaskHolder[0]);
 
     private final AtomicLong taskIdGenerator = new AtomicLong();
 
@@ -146,8 +146,7 @@ public Task register(String type, String action, TaskAwareRequest request) {
     private void registerCancellableTask(Task task) {
         CancellableTask cancellableTask = (CancellableTask) task;
         CancellableTaskHolder holder = new CancellableTaskHolder(cancellableTask);
-        CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder);
-        assert oldHolder == null;
+        cancellableTasks.put(task, holder);
         // Check if this task was banned before we start it. The empty check is used to avoid
         // computing the hash code of the parent taskId as most of the time bannedParents is empty.
         if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) {
@@ -187,15 +186,18 @@ public void cancel(CancellableTask task, String reason, Runnable listener) {
     public Task unregister(Task task) {
         logger.trace("unregister task for id: {}", task.getId());
         if (task instanceof CancellableTask) {
-            CancellableTaskHolder holder = cancellableTasks.remove(task.getId());
+            CancellableTaskHolder holder = cancellableTasks.remove(task);
             if (holder != null) {
                 holder.finish();
+                assert holder.task == task;
                 return holder.getTask();
             } else {
                 return null;
             }
         } else {
-            return tasks.remove(task.getId());
+            final Task removedTask = tasks.remove(task.getId());
+            assert removedTask == null || removedTask == task;
+            return removedTask;
         }
     }
 
@@ -372,10 +374,7 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason, Transpor
                 }
             }
         }
-        return cancellableTasks.values().stream()
-            .filter(t -> t.hasParent(parentTaskId))
-            .map(t -> t.task)
-            .collect(Collectors.toList());
+        return cancellableTasks.getByParent(parentTaskId).map(t -> t.task).collect(Collectors.toList());
     }
 
     /**
@@ -393,6 +392,11 @@ public Set<TaskId> getBannedTaskIds() {
         return Collections.unmodifiableSet(bannedParents.keySet());
     }
 
+    // for testing
+    public boolean assertCancellableTaskConsistency() {
+        return cancellableTasks.assertConsistent();
+    }
+
     private class Ban {
         final String reason;
         final boolean perChannel; // TODO: Remove this in 8.0
@@ -630,7 +634,7 @@ Set<Transport.Connection> startBan(String reason, Runnable onChildTasksCompleted
      * @return a releasable that should be called when this pending task is completed
      */
     public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) {
-        assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet";
+        assert cancellableTasks.get(task.getId()) != null : "task [" + task.getId() + "] is not registered yet";
         final ChannelPendingTaskTracker tracker = startTrackingChannel(channel, trackerChannel -> trackerChannel.addTask(task));
         return () -> tracker.removeTask(task);
     }
diff --git a/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java
new file mode 100644
index 0000000000000..31a4be55b05a1
--- /dev/null
+++ b/server/src/test/java/org/elasticsearch/tasks/CancellableTasksTrackerTests.java
@@ -0,0 +1,178 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.tasks;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.BrokenBarrierException;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasItem;
+import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.not;
+
+public class CancellableTasksTrackerTests extends ESTestCase {
+
+    private static class TestTask {
+        private final Thread actionThread;
+        private final Thread watchThread;
+        private final Thread concurrentRemoveThread;
+
+        // 0 == before put, 1 == during put, 2 == after put, before remove, 3 == during remove, 4 == after remove
+        private final AtomicInteger state = new AtomicInteger();
+        private final boolean concurrentRemove = randomBoolean();
+
+        TestTask(Task task, String item, CancellableTasksTracker<String> tracker, Runnable awaitStart) {
+            if (concurrentRemove) {
+                concurrentRemoveThread = new Thread(() -> {
+                    awaitStart.run();
+
+                    for (int i = 0; i < 10; i++) {
+                        if (3 <= state.get()) {
+                            final String removed = tracker.remove(task);
+                            if (removed != null) {
+                                assertSame(item, removed);
+                            }
+                        }
+                    }
+                });
+            } else {
+                concurrentRemoveThread = new Thread(awaitStart);
+            }
+
+            actionThread = new Thread(() -> {
+                awaitStart.run();
+
+                state.incrementAndGet();
+                tracker.put(task, item);
+                state.incrementAndGet();
+
+                Thread.yield();
+
+                state.incrementAndGet();
+                final String removed = tracker.remove(task);
+                state.incrementAndGet();
+                if (concurrentRemove == false || removed != null) {
+                    assertSame(item, removed);
+                }
+
+                assertNull(tracker.remove(task));
+            }, "action-thread-" + item);
+
+            watchThread = new Thread(() -> {
+                awaitStart.run();
+
+                for (int i = 0; i < 10; i++) {
+                    final int stateBefore = state.get();
+                    final String getResult = tracker.get(task.getId());
+                    final Set<String> getByParentResult = tracker.getByParent(task.getParentTaskId()).collect(Collectors.toSet());
+                    final Set<String> values = new HashSet<>(tracker.values());
+                    final int stateAfter = state.get();
+
+                    assertThat(stateBefore, lessThanOrEqualTo(stateAfter));
+
+                    if (getResult != null && task.getParentTaskId().isSet() && tracker.get(task.getId()) != null) {
+                        assertThat(getByParentResult, hasItem(item));
+                    }
+
+                    if (stateAfter == 0) {
+                        assertNull(getResult);
+                        assertThat(getByParentResult, not(hasItem(item)));
+                        assertThat(values, not(hasItem(item)));
+                    }
+
+                    if (stateBefore == 2 && stateAfter == 2) {
+                        assertSame(item, getResult);
+                        if (task.getParentTaskId().isSet()) {
+                            assertThat(getByParentResult, hasItem(item));
+                        } else {
+                            assertThat(getByParentResult, empty());
+                        }
+                        assertThat(values, hasItem(item));
+                    }
+
+                    if (stateBefore == 4) {
+                        assertNull(getResult);
+                        if (concurrentRemove == false) {
+                            assertThat(getByParentResult, not(hasItem(item)));
+                        } // else our remove might have completed but the concurrent one hasn't updated the parent ID map yet
+                        assertThat(values, not(hasItem(item)));
+                    }
+                }
+            }, "watch-thread-" + item);
+        }
+
+        void start() {
+            watchThread.start();
+            concurrentRemoveThread.start();
+            actionThread.start();
+        }
+
+        void join() throws InterruptedException {
+            actionThread.join();
+            concurrentRemoveThread.join();
+            watchThread.join();
+        }
+    }
+
+    public void testCancellableTasksTracker() throws InterruptedException {
+
+        final TaskId[] parentTaskIds
+            = randomArray(10, 10, TaskId[]::new, () -> new TaskId(randomAlphaOfLength(5), randomNonNegativeLong()));
+
+        final CancellableTasksTracker<String> tracker = new CancellableTasksTracker<>(new String[0]);
+        final TestTask[] tasks = new TestTask[between(1, 100)];
+
+        final Runnable awaitStart = new Runnable() {
+            private final CyclicBarrier startBarrier = new CyclicBarrier(tasks.length * 3);
+
+            @Override
+            public void run() {
+                try {
+                    startBarrier.await(10, TimeUnit.SECONDS);
+                } catch (InterruptedException | BrokenBarrierException | TimeoutException e) {
+                    throw new AssertionError("unexpected", e);
+                }
+            }
+        };
+
+        for (int i = 0; i < tasks.length; i++) {
+            tasks[i] = new TestTask(
+                new Task(
+                    randomNonNegativeLong(),
+                    randomAlphaOfLength(5),
+                    randomAlphaOfLength(5),
+                    randomAlphaOfLength(5),
+                    rarely() ? TaskId.EMPTY_TASK_ID : randomFrom(parentTaskIds),
+                    Collections.emptyMap()),
+                "item-" + i,
+                tracker,
+                awaitStart
+            );
+        }
+
+        for (TestTask task : tasks) {
+            task.start();
+        }
+
+        for (TestTask task : tasks) {
+            task.join();
+        }
+
+        tracker.assertConsistent();
+    }
+}
diff --git a/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java b/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java
index 3b7cc3c07080b..e99f0ea33af5a 100644
--- a/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java
+++ b/test/framework/src/main/java/org/elasticsearch/test/TaskAssertions.java
@@ -12,6 +12,7 @@
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.tasks.CancellableTask;
 import org.elasticsearch.tasks.TaskInfo;
+import org.elasticsearch.tasks.TaskManager;
 import org.elasticsearch.transport.TransportService;
 
 import java.util.List;
@@ -46,7 +47,9 @@ public static void assertAllCancellableTasksAreCancelled(String actionPrefix) th
         assertBusy(() -> {
             boolean foundTask = false;
             for (TransportService transportService : internalCluster().getInstances(TransportService.class)) {
-                for (CancellableTask cancellableTask : transportService.getTaskManager().getCancellableTasks().values()) {
+                final TaskManager taskManager = transportService.getTaskManager();
+                assertTrue(taskManager.assertCancellableTaskConsistency());
+                for (CancellableTask cancellableTask : taskManager.getCancellableTasks().values()) {
                     if (cancellableTask.getAction().startsWith(actionPrefix)) {
                         foundTask = true;
                         assertTrue(