From b99025433a0451549a6b2ae9b5e892e794d9d922 Mon Sep 17 00:00:00 2001 From: David Turner Date: Tue, 3 Jan 2023 10:58:20 +0000 Subject: [PATCH] Handle rejection in PrioritizedThrottledTaskRunner (#92621) Today `PrioritizedThrottledTaskRunner` submits a bare `Runnable` to the executor, wrapping around the `Runnable` received from the caller, which effectively assumes that tasks are never rejected from the threadpool. However this utility accepts any `Executor`, so this is not a safe assumption to make. This commit moves to submitting an `AbstractRunnable` to the executor, and to requiring callers to pass in an `AbstractRunnable`, so that failures and rejections can be handled properly. --- .../PrioritizedThrottledTaskRunner.java | 65 ++++++++++++++----- .../blobstore/ShardSnapshotTaskRunner.java | 17 ++++- .../PrioritizedThrottledTaskRunnerTests.java | 56 +++++++++++++++- 3 files changed, 118 insertions(+), 20 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java index 9bc7617a4c796..55639481306b4 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunner.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.core.Strings; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executor; @@ -21,7 +22,7 @@ * natural ordering of the tasks, limiting the max number of concurrently running tasks. Each new task * that is dequeued to be run, is forked off to the given executor. */ -public class PrioritizedThrottledTaskRunner & Runnable> { +public class PrioritizedThrottledTaskRunner> { private static final Logger logger = LogManager.getLogger(PrioritizedThrottledTaskRunner.class); private final String taskRunnerName; @@ -70,7 +71,54 @@ protected void pollAndSpawn() { // non-empty queue and no workers! if (tasks.peek() == null) break; } else { - executor.execute(() -> runTask(task)); + executor.execute(new AbstractRunnable() { + private boolean rejected; // need not be volatile - if we're rejected then that happens-before calling onAfter + + @Override + public boolean isForceExecution() { + return task.isForceExecution(); + } + + @Override + public void onRejection(Exception e) { + logger.trace("[{}] task {} rejected", taskRunnerName, task); + rejected = true; + task.onRejection(e); + } + + @Override + public void onFailure(Exception e) { + logger.trace(() -> Strings.format("[%s] task %s failed", taskRunnerName, task), e); + task.onFailure(e); + } + + @Override + protected void doRun() throws Exception { + logger.trace("[{}] running task {}", taskRunnerName, task); + task.doRun(); + } + + @Override + public void onAfter() { + try { + task.onAfter(); + } finally { + // To avoid missing to run tasks that are enqueued and waiting, we check the queue again once running + // a task is finished. + int decremented = runningTasks.decrementAndGet(); + assert decremented >= 0; + + if (rejected == false) { + pollAndSpawn(); + } + } + } + + @Override + public String toString() { + return task.toString(); + } + }); } } } @@ -91,17 +139,4 @@ public int runningTasks() { public int queueSize() { return tasks.size(); } - - private void runTask(final T task) { - try { - logger.trace("[{}] running task {}", taskRunnerName, task); - task.run(); - } finally { - // To avoid missing to run tasks that are enqueued and waiting, we check the queue again once running - // a task is finished. - int decremented = runningTasks.decrementAndGet(); - assert decremented >= 0; - pollAndSpawn(); - } - } } diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunner.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunner.java index 7c0a2ed9c65ab..bd4084f1a4015 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunner.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunner.java @@ -8,10 +8,14 @@ package org.elasticsearch.repositories.blobstore; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.common.CheckedBiConsumer; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.PrioritizedThrottledTaskRunner; +import org.elasticsearch.core.Strings; import org.elasticsearch.repositories.SnapshotShardContext; import java.io.IOException; @@ -28,11 +32,12 @@ * and zero or more {@link FileSnapshotTask}s. */ public class ShardSnapshotTaskRunner { + private static final Logger logger = LogManager.getLogger(ShardSnapshotTaskRunner.class); private final PrioritizedThrottledTaskRunner taskRunner; private final Consumer shardSnapshotter; private final CheckedBiConsumer fileSnapshotter; - abstract static class SnapshotTask implements Comparable, Runnable { + abstract static class SnapshotTask extends AbstractRunnable implements Comparable { private static final Comparator COMPARATOR = Comparator.comparingLong( (SnapshotTask t) -> t.context().snapshotStartTime() @@ -54,6 +59,12 @@ public SnapshotShardContext context() { public final int compareTo(SnapshotTask other) { return COMPARATOR.compare(this, other); } + + @Override + public void onFailure(Exception e) { + assert false : e; + logger.error(Strings.format("snapshot task [%s] unexpectedly failed", this), e); + } } class ShardSnapshotTask extends SnapshotTask { @@ -62,7 +73,7 @@ class ShardSnapshotTask extends SnapshotTask { } @Override - public void run() { + public void doRun() { shardSnapshotter.accept(context); } @@ -88,7 +99,7 @@ class FileSnapshotTask extends SnapshotTask { } @Override - public void run() { + public void doRun() { ActionRunnable.run(fileSnapshotListener, () -> { FileInfo fileInfo = fileInfos.get(); if (fileInfo != null) { diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java index 3768f5556b4e9..c52380b4dc126 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/PrioritizedThrottledTaskRunnerTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.common.util.concurrent; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -17,6 +18,7 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -40,7 +42,7 @@ public void tearDown() throws Exception { TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS); } - static class TestTask implements Comparable, Runnable { + static class TestTask extends AbstractRunnable implements Comparable { private final Runnable runnable; private final int priority; @@ -56,9 +58,14 @@ public int compareTo(TestTask o) { } @Override - public void run() { + public void doRun() { runnable.run(); } + + @Override + public void onFailure(Exception e) { + throw new AssertionError("unexpected", e); + } } public void testMultiThreadedEnqueue() throws Exception { @@ -171,6 +178,51 @@ public void testEnqueueSpawnsNewTasksUpToMax() throws Exception { assertThat(taskRunner.queueSize(), equalTo(0)); } + public void testFailsTasksOnRejectionOrShutdown() throws Exception { + final var maxThreads = between(1, 5); + final var threadFactory = EsExecutors.daemonThreadFactory("test"); + final var threadContext = new ThreadContext(Settings.EMPTY); + final var executor = randomBoolean() + ? EsExecutors.newScaling("test", 1, maxThreads, 0, TimeUnit.MILLISECONDS, true, threadFactory, threadContext) + : EsExecutors.newFixed("test", maxThreads, between(1, 5), threadFactory, threadContext, false); + final var taskRunner = new PrioritizedThrottledTaskRunner("test", between(1, maxThreads * 2), executor); + final var totalPermits = between(1, maxThreads * 2); + final var permits = new Semaphore(totalPermits); + final var taskCompleted = new CountDownLatch(between(1, maxThreads * 2)); + final var rejectionCountDown = new CountDownLatch(between(1, maxThreads * 2)); + + final var spawnThread = new Thread(() -> { + try { + while (true) { + assertTrue(permits.tryAcquire(10, TimeUnit.SECONDS)); + taskRunner.enqueueTask(new TestTask(taskCompleted::countDown, getRandomPriority()) { + @Override + public void onRejection(Exception e) { + rejectionCountDown.countDown(); + } + + @Override + public void onAfter() { + permits.release(); + } + }); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + spawnThread.start(); + assertTrue(taskCompleted.await(10, TimeUnit.SECONDS)); + executor.shutdown(); + assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS)); + assertTrue(rejectionCountDown.await(10, TimeUnit.SECONDS)); + spawnThread.interrupt(); + spawnThread.join(); + assertThat(taskRunner.runningTasks(), equalTo(0)); + assertThat(taskRunner.queueSize(), equalTo(0)); + assertTrue(permits.tryAcquire(totalPermits)); + } + private int getRandomPriority() { return randomIntBetween(-1000, 1000); }