Skip to content

Commit

Permalink
Propagate Vert.x context on all ExecutorService methods for VirtualTh…
Browse files Browse the repository at this point in the history
…readExecutor
  • Loading branch information
ozangunalp committed Feb 19, 2024
1 parent a786faf commit 1ea93cc
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

import io.vertx.core.Context;
import io.vertx.core.Vertx;
import io.vertx.core.impl.ContextInternal;

Expand All @@ -22,26 +26,77 @@ class ContextPreservingExecutorService implements ExecutorService {
this.delegate = delegate;
}

public void execute(final Runnable command) {
var context = Vertx.currentContext();
if (!(context instanceof ContextInternal)) {
delegate.execute(command);
} else {
ContextInternal contextInternal = (ContextInternal) context;
delegate.execute(new Runnable() {
@Override
public void run() {
final var previousContext = contextInternal.beginDispatch();
try {
command.run();
} finally {
contextInternal.endDispatch(previousContext);
}
private static final class ContextPreservingRunnable implements Runnable {

private final Runnable task;
private final Context context;

public ContextPreservingRunnable(Runnable task) {
this.task = task;
this.context = Vertx.currentContext();
}

@Override
public void run() {
if (context instanceof ContextInternal) {
ContextInternal contextInternal = (ContextInternal) context;
final var previousContext = contextInternal.beginDispatch();
try {
task.run();
} finally {
contextInternal.endDispatch(previousContext);
}
} else {
task.run();
}
}
}

private static final class ContextPreservingCallable<T> implements Callable<T> {

private final Callable<T> task;
private final Context context;

public ContextPreservingCallable(Callable<T> task) {
this.task = task;
this.context = Vertx.currentContext();
}

@Override
public T call() throws Exception {
if (context instanceof ContextInternal) {
ContextInternal contextInternal = (ContextInternal) context;
final var previousContext = contextInternal.beginDispatch();
try {
return task.call();
} finally {
contextInternal.endDispatch(previousContext);
}
});
} else {
return task.call();
}
}
}

private static Runnable decorate(Runnable command) {
Objects.requireNonNull(command);
return new ContextPreservingRunnable(command);
}

private static <T> Callable<T> decorate(Callable<T> task) {
Objects.requireNonNull(task);
return new ContextPreservingCallable<>(task);
}

private static <T> Collection<? extends Callable<T>> decorateAll(Collection<? extends Callable<T>> tasks) {
Objects.requireNonNull(tasks);
return tasks.stream().map(ContextPreservingExecutorService::decorate).collect(Collectors.toList());
}

public void execute(final Runnable command) {
delegate.execute(decorate(command));
}

public boolean isShutdown() {
return delegate.isShutdown();
}
Expand All @@ -56,39 +111,39 @@ public boolean awaitTermination(final long timeout, final TimeUnit unit) throws

@Override
public <T> Future<T> submit(Callable<T> task) {
return delegate.submit(task);
return delegate.submit(decorate(task));
}

@Override
public <T> Future<T> submit(Runnable task, T result) {
return delegate.submit(task, result);
return submit(Executors.callable(task, result));
}

@Override
public Future<?> submit(Runnable task) {
return delegate.submit(task);
return delegate.submit(decorate(task));
}

@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
return delegate.invokeAll(tasks);
return delegate.invokeAll(decorateAll(tasks));
}

@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException {
return delegate.invokeAll(tasks, timeout, unit);
return delegate.invokeAll(decorateAll(tasks), timeout, unit);
}

@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
return delegate.invokeAny(tasks);
return delegate.invokeAny(decorateAll(tasks));
}

@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
return delegate.invokeAny(tasks, timeout, unit);
return delegate.invokeAny(decorateAll(tasks), timeout, unit);
}

public void shutdown() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,34 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledForJreRange;
import org.junit.jupiter.api.condition.JRE;

import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.helpers.test.UniAssertSubscriber;
import io.vertx.core.Context;
import io.vertx.core.Vertx;

class VirtualThreadExecutorSupplierTest {

@BeforeEach
void configRecorder() {
VirtualThreadsRecorder.config = new VirtualThreadsConfig();
VirtualThreadsRecorder.config.enabled = true;
VirtualThreadsRecorder.config.namePrefix = Optional.empty();
}

@Test
@EnabledForJreRange(min = JRE.JAVA_20, disabledReason = "Virtual Threads are a preview feature starting from Java 20")
void virtualThreadCustomScheduler()
Expand Down Expand Up @@ -44,6 +61,74 @@ void execute() throws ClassNotFoundException, InvocationTargetException, NoSuchM
assertSubscriber.awaitItem(Duration.ofSeconds(1)).assertCompleted();
}

@Test
@EnabledForJreRange(min = JRE.JAVA_20, disabledReason = "Virtual Threads are a preview feature starting from Java 20")
void executePropagatesVertxContext() throws ExecutionException, InterruptedException {
ExecutorService executorService = VirtualThreadsRecorder.getCurrent();
Vertx vertx = Vertx.vertx();
CompletableFuture<Context> future = new CompletableFuture<>();
vertx.executeBlocking(() -> {
executorService.execute(() -> {
assertThatItRunsOnVirtualThread();
future.complete(Vertx.currentContext());
});
return null;
}).toCompletionStage().toCompletableFuture().get();
assertThat(future.get()).isNotNull();
}

@Test
@EnabledForJreRange(min = JRE.JAVA_20, disabledReason = "Virtual Threads are a preview feature starting from Java 20")
void executePropagatesVertxContextMutiny() {
ExecutorService executorService = VirtualThreadsRecorder.getCurrent();
Vertx vertx = Vertx.vertx();
var assertSubscriber = Uni.createFrom().voidItem()
.runSubscriptionOn(command -> vertx.executeBlocking(() -> {
command.run();
return null;
}))
.emitOn(executorService)
.map(x -> {
assertThatItRunsOnVirtualThread();
return Vertx.currentContext();
})
.subscribe().withSubscriber(UniAssertSubscriber.create());
assertThat(assertSubscriber.awaitItem().assertCompleted().getItem()).isNotNull();
}

@Test
@EnabledForJreRange(min = JRE.JAVA_20, disabledReason = "Virtual Threads are a preview feature starting from Java 20")
void submitPropagatesVertxContext() throws ExecutionException, InterruptedException {
ExecutorService executorService = VirtualThreadsRecorder.getCurrent();
Vertx vertx = Vertx.vertx();
CompletableFuture<Context> future = new CompletableFuture<>();
vertx.executeBlocking(() -> {
executorService.submit(() -> {
assertThatItRunsOnVirtualThread();
future.complete(Vertx.currentContext());
});
return null;
}).toCompletionStage().toCompletableFuture().get();
assertThat(future.get()).isNotNull();
}

@Test
@EnabledForJreRange(min = JRE.JAVA_20, disabledReason = "Virtual Threads are a preview feature starting from Java 20")
void invokeAllPropagatesVertxContext() throws ExecutionException, InterruptedException {
ExecutorService executorService = VirtualThreadsRecorder.getCurrent();
Vertx vertx = Vertx.vertx();
List<Future<Context>> futures = vertx.executeBlocking(() -> {
return executorService.invokeAll(List.of((Callable<Context>) () -> {
assertThatItRunsOnVirtualThread();
return Vertx.currentContext();
}, (Callable<Context>) () -> {
assertThatItRunsOnVirtualThread();
return Vertx.currentContext();
}));
}).toCompletionStage().toCompletableFuture().get();
assertThat(futures).allSatisfy(contextFuture -> assertThat(contextFuture.get()).isNotNull());
}

public static void assertThatItRunsOnVirtualThread() {
// We cannot depend on a Java 20.
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package io.quarkus.virtual.rr;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;

import jakarta.inject.Inject;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response;

import org.jboss.logmanager.MDC;

import io.quarkus.arc.Arc;
import io.quarkus.test.vertx.VirtualThreadsAssertions;
import io.quarkus.virtual.threads.VirtualThreads;
import io.smallrye.common.annotation.RunOnVirtualThread;
import io.vertx.core.Vertx;

Expand All @@ -17,14 +23,29 @@ public class FilteredResource {
@Inject
Counter counter;

@Inject
@VirtualThreads
ExecutorService vt;

@GET
@RunOnVirtualThread
public Response filtered() {
public Response filtered() throws ExecutionException, InterruptedException {
VirtualThreadsAssertions.assertEverything();

// Request scope
assert counter.increment() == 2;

// Request scope propagated
assert vt.submit(() -> counter.increment()).get() == 3;

// Request scope active
assert Arc.container().requestContext().isActive();
assert vt.submit(() -> Arc.container().requestContext().isActive()).get();

CompletableFuture<Boolean> requestContextActive = new CompletableFuture<>();
vt.execute(() -> requestContextActive.complete(Arc.container().requestContext().isActive()));
assert requestContextActive.get();

// DC
assert Vertx.currentContext().getLocal("filter").equals("test");
Vertx.currentContext().putLocal("test", "test test");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void getFilter(ContainerResponseContext responseContext) {
if (responseContext.getHeaders().get("X-filter") != null) {
VirtualThreadsAssertions.assertEverything();
// the request filter, the method, and here.
assert CDI.current().select(Counter.class).get().increment() == 3;
assert CDI.current().select(Counter.class).get().increment() == 4;
assert Vertx.currentContext().getLocal("test").equals("test test");
assert MDC.get("mdc").equals("test test");
}
Expand Down

0 comments on commit 1ea93cc

Please sign in to comment.