Skip to content

Commit

Permalink
Handle rejection in PrioritizedThrottledTaskRunner (#92621)
Browse files Browse the repository at this point in the history
Today `PrioritizedThrottledTaskRunner` submits a bare `Runnable` to the
executor, wrapping around the `Runnable` received from the caller, which
effectively assumes that tasks are never rejected from the threadpool.
However this utility accepts any `Executor`, so this is not a safe
assumption to make. This commit moves to submitting an
`AbstractRunnable` to the executor, and to requiring callers to pass in
an `AbstractRunnable`, so that failures and rejections can be handled
properly.
  • Loading branch information
DaveCTurner authored Jan 3, 2023
1 parent 8ae6359 commit b990254
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.core.Strings;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executor;
Expand All @@ -21,7 +22,7 @@
* natural ordering of the tasks, limiting the max number of concurrently running tasks. Each new task
* that is dequeued to be run, is forked off to the given executor.
*/
public class PrioritizedThrottledTaskRunner<T extends Comparable<T> & Runnable> {
public class PrioritizedThrottledTaskRunner<T extends AbstractRunnable & Comparable<T>> {
private static final Logger logger = LogManager.getLogger(PrioritizedThrottledTaskRunner.class);

private final String taskRunnerName;
Expand Down Expand Up @@ -70,7 +71,54 @@ protected void pollAndSpawn() {
// non-empty queue and no workers!
if (tasks.peek() == null) break;
} else {
executor.execute(() -> runTask(task));
executor.execute(new AbstractRunnable() {
private boolean rejected; // need not be volatile - if we're rejected then that happens-before calling onAfter

@Override
public boolean isForceExecution() {
return task.isForceExecution();
}

@Override
public void onRejection(Exception e) {
logger.trace("[{}] task {} rejected", taskRunnerName, task);
rejected = true;
task.onRejection(e);
}

@Override
public void onFailure(Exception e) {
logger.trace(() -> Strings.format("[%s] task %s failed", taskRunnerName, task), e);
task.onFailure(e);
}

@Override
protected void doRun() throws Exception {
logger.trace("[{}] running task {}", taskRunnerName, task);
task.doRun();
}

@Override
public void onAfter() {
try {
task.onAfter();
} finally {
// To avoid missing to run tasks that are enqueued and waiting, we check the queue again once running
// a task is finished.
int decremented = runningTasks.decrementAndGet();
assert decremented >= 0;

if (rejected == false) {
pollAndSpawn();
}
}
}

@Override
public String toString() {
return task.toString();
}
});
}
}
}
Expand All @@ -91,17 +139,4 @@ public int runningTasks() {
public int queueSize() {
return tasks.size();
}

private void runTask(final T task) {
try {
logger.trace("[{}] running task {}", taskRunnerName, task);
task.run();
} finally {
// To avoid missing to run tasks that are enqueued and waiting, we check the queue again once running
// a task is finished.
int decremented = runningTasks.decrementAndGet();
assert decremented >= 0;
pollAndSpawn();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@

package org.elasticsearch.repositories.blobstore;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.common.CheckedBiConsumer;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.PrioritizedThrottledTaskRunner;
import org.elasticsearch.core.Strings;
import org.elasticsearch.repositories.SnapshotShardContext;

import java.io.IOException;
Expand All @@ -28,11 +32,12 @@
* and zero or more {@link FileSnapshotTask}s.
*/
public class ShardSnapshotTaskRunner {
private static final Logger logger = LogManager.getLogger(ShardSnapshotTaskRunner.class);
private final PrioritizedThrottledTaskRunner<SnapshotTask> taskRunner;
private final Consumer<SnapshotShardContext> shardSnapshotter;
private final CheckedBiConsumer<SnapshotShardContext, FileInfo, IOException> fileSnapshotter;

abstract static class SnapshotTask implements Comparable<SnapshotTask>, Runnable {
abstract static class SnapshotTask extends AbstractRunnable implements Comparable<SnapshotTask> {

private static final Comparator<SnapshotTask> COMPARATOR = Comparator.comparingLong(
(SnapshotTask t) -> t.context().snapshotStartTime()
Expand All @@ -54,6 +59,12 @@ public SnapshotShardContext context() {
public final int compareTo(SnapshotTask other) {
return COMPARATOR.compare(this, other);
}

@Override
public void onFailure(Exception e) {
assert false : e;
logger.error(Strings.format("snapshot task [%s] unexpectedly failed", this), e);
}
}

class ShardSnapshotTask extends SnapshotTask {
Expand All @@ -62,7 +73,7 @@ class ShardSnapshotTask extends SnapshotTask {
}

@Override
public void run() {
public void doRun() {
shardSnapshotter.accept(context);
}

Expand All @@ -88,7 +99,7 @@ class FileSnapshotTask extends SnapshotTask {
}

@Override
public void run() {
public void doRun() {
ActionRunnable.run(fileSnapshotListener, () -> {
FileInfo fileInfo = fileInfos.get();
if (fileInfo != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.elasticsearch.common.util.concurrent;

import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
Expand All @@ -17,6 +18,7 @@
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

Expand All @@ -40,7 +42,7 @@ public void tearDown() throws Exception {
TestThreadPool.terminate(threadPool, 30, TimeUnit.SECONDS);
}

static class TestTask implements Comparable<TestTask>, Runnable {
static class TestTask extends AbstractRunnable implements Comparable<TestTask> {

private final Runnable runnable;
private final int priority;
Expand All @@ -56,9 +58,14 @@ public int compareTo(TestTask o) {
}

@Override
public void run() {
public void doRun() {
runnable.run();
}

@Override
public void onFailure(Exception e) {
throw new AssertionError("unexpected", e);
}
}

public void testMultiThreadedEnqueue() throws Exception {
Expand Down Expand Up @@ -171,6 +178,51 @@ public void testEnqueueSpawnsNewTasksUpToMax() throws Exception {
assertThat(taskRunner.queueSize(), equalTo(0));
}

public void testFailsTasksOnRejectionOrShutdown() throws Exception {
final var maxThreads = between(1, 5);
final var threadFactory = EsExecutors.daemonThreadFactory("test");
final var threadContext = new ThreadContext(Settings.EMPTY);
final var executor = randomBoolean()
? EsExecutors.newScaling("test", 1, maxThreads, 0, TimeUnit.MILLISECONDS, true, threadFactory, threadContext)
: EsExecutors.newFixed("test", maxThreads, between(1, 5), threadFactory, threadContext, false);
final var taskRunner = new PrioritizedThrottledTaskRunner<TestTask>("test", between(1, maxThreads * 2), executor);
final var totalPermits = between(1, maxThreads * 2);
final var permits = new Semaphore(totalPermits);
final var taskCompleted = new CountDownLatch(between(1, maxThreads * 2));
final var rejectionCountDown = new CountDownLatch(between(1, maxThreads * 2));

final var spawnThread = new Thread(() -> {
try {
while (true) {
assertTrue(permits.tryAcquire(10, TimeUnit.SECONDS));
taskRunner.enqueueTask(new TestTask(taskCompleted::countDown, getRandomPriority()) {
@Override
public void onRejection(Exception e) {
rejectionCountDown.countDown();
}

@Override
public void onAfter() {
permits.release();
}
});
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
spawnThread.start();
assertTrue(taskCompleted.await(10, TimeUnit.SECONDS));
executor.shutdown();
assertTrue(executor.awaitTermination(30, TimeUnit.SECONDS));
assertTrue(rejectionCountDown.await(10, TimeUnit.SECONDS));
spawnThread.interrupt();
spawnThread.join();
assertThat(taskRunner.runningTasks(), equalTo(0));
assertThat(taskRunner.queueSize(), equalTo(0));
assertTrue(permits.tryAcquire(totalPermits));
}

private int getRandomPriority() {
return randomIntBetween(-1000, 1000);
}
Expand Down

0 comments on commit b990254

Please sign in to comment.