diff --git a/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml b/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml index 3d48dda2d0c4c1..e0808a4ba7bc41 100644 --- a/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml +++ b/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml @@ -88,6 +88,11 @@ stork-service-discovery-static-list test + + io.quarkus + quarkus-test-vertx + test + @@ -104,6 +109,26 @@ + + + maven-surefire-plugin + + + + prod-mode + test + + test + + + **/*VT.java + + + + diff --git a/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/RestClientInTestMethodWithContextVT.java b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/RestClientInTestMethodWithContextVT.java new file mode 100644 index 00000000000000..625ee1df07c2ec --- /dev/null +++ b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/RestClientInTestMethodWithContextVT.java @@ -0,0 +1,73 @@ +package io.quarkus.rest.client.reactive; + +import static io.quarkus.rest.client.reactive.RestClientTestUtil.setUrlForClass; +import static org.assertj.core.api.Assertions.assertThat; + +import java.net.MalformedURLException; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientRequestFilter; + +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; +import org.eclipse.microprofile.rest.client.inject.RestClient; +import org.jboss.resteasy.reactive.RestHeader; +import org.jboss.shrinkwrap.api.asset.StringAsset; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.vertx.RunOnVertxContext; +import io.smallrye.common.vertx.ContextLocals; + +public class RestClientInTestMethodWithContextVT { + + public static final String CORRELATION_ID_HEADER_NAME = "correlationId"; + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(Client.class, Resource.class, ClientFilter.class) + .addAsResource( + new StringAsset(setUrlForClass(Client.class)), "application.properties")); + + @RestClient + Client client; + + @RunOnVertxContext(runOnEventLoop = false) + @Test + public void test() { + ContextLocals.put(CORRELATION_ID_HEADER_NAME, "dummy"); + assertThat(client.getMock()).isEqualTo("dummy"); + } + + @Path("/test") + @RegisterRestClient(configKey = "mock-client") + @RegisterProvider(ClientFilter.class) + public interface Client { + + @GET + String getMock(); + } + + @Path("/test") + public static class Resource { + + @GET + public String get(@RestHeader(CORRELATION_ID_HEADER_NAME) String correlationId) { + return correlationId; + } + } + + public static class ClientFilter implements ClientRequestFilter { + @Override + public void filter(ClientRequestContext clientRequestContext) throws MalformedURLException { + + String correlationId = ContextLocals. get(CORRELATION_ID_HEADER_NAME).orElse(null); + clientRequestContext.getHeaders().putSingle(CORRELATION_ID_HEADER_NAME, correlationId); + } + } + +} diff --git a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java index 13e35c64e8f65e..271c9202357bd4 100644 --- a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java +++ b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java @@ -32,4 +32,9 @@ * @return {@code true} by default. */ boolean duplicateContext() default true; + + /** + * If {@code true}, the test method is run on the Event Loop, otherwise it will be run on Vert.x blocking thread pool + */ + boolean runOnEventLoop() default true; } diff --git a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java index 39214de3f94b30..30c027f2121352 100644 --- a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java +++ b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java @@ -16,6 +16,7 @@ import io.smallrye.common.vertx.VertxContext; import io.vertx.core.Context; import io.vertx.core.Handler; +import io.vertx.core.Promise; import io.vertx.core.Vertx; public class RunOnVertxContextTestMethodInvoker implements TestMethodInvoker { @@ -70,17 +71,27 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List cf = new CompletableFuture<>(); - RunTestMethodOnContextHandler handler = new RunTestMethodOnContextHandler(actualTestInstance, actualTestMethod, - actualTestMethodArgs, uniAsserter, cf); + Context context = vertx.getOrCreateContext(); - boolean shouldDuplicateContext = shouldContextBeDuplicated( - actualTestInstance != null ? actualTestInstance.getClass() : Object.class, actualTestMethod); + Class testClass = actualTestInstance != null ? actualTestInstance.getClass() : Object.class; + boolean shouldDuplicateContext = shouldContextBeDuplicated(testClass, actualTestMethod); if (shouldDuplicateContext) { context = VertxContext.getOrCreateDuplicatedContext(context); VertxContextSafetyToggle.setContextSafe(context, true); } - context.runOnContext(handler); + + CompletableFuture cf; + if (shouldRunOnEventLoop(testClass, actualTestMethod)) { + cf = new CompletableFuture<>(); + var handler = new RunTestMethodOnVertxEventLoopContextHandler(actualTestInstance, actualTestMethod, + actualTestMethodArgs, uniAsserter, cf); + context.runOnContext(handler); + } else { + var handler = new RunTestMethodOnVertxBlockingContextHandler(actualTestInstance, actualTestMethod, + actualTestMethodArgs, uniAsserter); + cf = ((CompletableFuture) context.executeBlocking(handler).toCompletionStage()); + } + try { return cf.get(); } catch (InterruptedException e) { @@ -90,6 +101,7 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List c, Method m) { @@ -106,7 +118,19 @@ private boolean shouldContextBeDuplicated(Class c, Method m) { } } - public static class RunTestMethodOnContextHandler implements Handler { + private boolean shouldRunOnEventLoop(Class c, Method m) { + RunOnVertxContext runOnVertxContext = m.getAnnotation(RunOnVertxContext.class); + if (runOnVertxContext == null) { + runOnVertxContext = c.getAnnotation(RunOnVertxContext.class); + } + if (runOnVertxContext == null) { + return true; + } else { + return runOnVertxContext.runOnEventLoop(); + } + } + + public static class RunTestMethodOnVertxEventLoopContextHandler implements Handler { private static final Runnable DO_NOTHING = new Runnable() { @Override public void run() { @@ -119,7 +143,7 @@ public void run() { private final DefaultUniAsserter uniAsserter; private final CompletableFuture future; - public RunTestMethodOnContextHandler(Object testInstance, Method targetMethod, List methodArgs, + public RunTestMethodOnVertxEventLoopContextHandler(Object testInstance, Method targetMethod, List methodArgs, DefaultUniAsserter uniAsserter, CompletableFuture future) { this.testInstance = testInstance; this.future = future; @@ -172,4 +196,68 @@ public void accept(Throwable t) { } } + public static class RunTestMethodOnVertxBlockingContextHandler implements Handler> { + private static final Runnable DO_NOTHING = new Runnable() { + @Override + public void run() { + } + }; + + private final Object testInstance; + private final Method targetMethod; + private final List methodArgs; + private final DefaultUniAsserter uniAsserter; + + public RunTestMethodOnVertxBlockingContextHandler(Object testInstance, Method targetMethod, List methodArgs, + DefaultUniAsserter uniAsserter) { + this.testInstance = testInstance; + this.targetMethod = targetMethod; + this.methodArgs = methodArgs; + this.uniAsserter = uniAsserter; + } + + @Override + public void handle(Promise promise) { + ManagedContext requestContext = Arc.container().requestContext(); + if (requestContext.isActive()) { + doRun(promise, DO_NOTHING); + } else { + requestContext.activate(); + doRun(promise, new Runnable() { + @Override + public void run() { + requestContext.terminate(); + } + }); + } + } + + private void doRun(Promise promise, Runnable onTerminate) { + try { + Object testMethodResult = targetMethod.invoke(testInstance, methodArgs.toArray(new Object[0])); + if (uniAsserter != null) { + uniAsserter.execution.subscribe().with(new Consumer() { + @Override + public void accept(Object o) { + onTerminate.run(); + promise.complete(); + } + }, new Consumer<>() { + @Override + public void accept(Throwable t) { + onTerminate.run(); + promise.fail(t); + } + }); + } else { + onTerminate.run(); + promise.complete(testMethodResult); + } + } catch (Throwable t) { + onTerminate.run(); + promise.fail(t.getCause()); + } + } + } + }