diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index 96e02bfa4f50f..5d6fda57b19f0 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -19,6 +19,7 @@ import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.threadpool.ThreadPool; @@ -100,30 +101,42 @@ void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean wai if (task.shouldCancelChildrenOnCancellation()) { logger.trace("cancelling task [{}] and its descendants", taskId); StepListener completedListener = new StepListener<>(); - CountDownActionListener countDownListener = new CountDownActionListener(3, completedListener); - Collection childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> { - logger.trace("child tasks of parent [{}] are completed", taskId); - countDownListener.onResponse(null); - }); - taskManager.cancel(task, reason, () -> { - logger.trace("task [{}] is cancelled", taskId); - countDownListener.onResponse(null); - }); StepListener setBanListener = new StepListener<>(); + + Collection childConnections; + try (var refs = new RefCountingRunnable(() -> setBanListener.addListener(completedListener))) { + var banChildrenRef = refs.acquire(); + var cancelTaskRef = refs.acquire(); + + childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> { + logger.trace("child tasks of parent [{}] are completed", taskId); + banChildrenRef.close(); + }); + + taskManager.cancel(task, reason, () -> { + logger.trace("task [{}] is cancelled", taskId); + cancelTaskRef.close(); + }); + } setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener); - setBanListener.addListener(countDownListener); - // If we start unbanning when the last child task completed and that child task executed with a specific user, then unban - // requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context. - final Runnable removeBansRunnable = transportService.getThreadPool() - .getThreadContext() - .preserveContext(() -> removeBanOnChildConnections(task, childConnections)); + // We remove bans after all child tasks are completed although in theory we can do it on a per-connection basis. - completedListener.whenComplete(r -> removeBansRunnable.run(), e -> removeBansRunnable.run()); - // if wait_for_completion is true, then only return when (1) bans are placed on child connections, (2) child tasks are - // completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child connections. + completedListener.addListener( + ActionListener.wrap( + transportService.getThreadPool() + .getThreadContext() + // If we start unbanning when the last child task completed and that child task executed with a specific user, then + // unban requests are denied because internal requests can't run with a user. We need to remove bans with the + // current thread context. + .preserveContext(() -> removeBanOnChildConnections(task, childConnections)) + ) + ); + if (waitForCompletion) { + // Wait until (1) bans are placed on child connections, (2) child tasks are completed or failed, (3) main task is cancelled. completedListener.addListener(listener); } else { + // Only wait until bans are placed on child connections setBanListener.addListener(listener); } } else {