diff --git a/integration-tests/rest-client-reactive/pom.xml b/integration-tests/rest-client-reactive/pom.xml index 318f28d271932..55c61d25554a1 100644 --- a/integration-tests/rest-client-reactive/pom.xml +++ b/integration-tests/rest-client-reactive/pom.xml @@ -74,6 +74,11 @@ quarkus-junit5-mockito test + + io.quarkus + quarkus-test-vertx + test + com.github.tomakehurst wiremock-jre8-standalone diff --git a/integration-tests/rest-client-reactive/src/main/java/io/quarkus/it/rest/client/main/ClientCallingResource.java b/integration-tests/rest-client-reactive/src/main/java/io/quarkus/it/rest/client/main/ClientCallingResource.java index 104871aa1ef07..950fe0ea3b9cf 100644 --- a/integration-tests/rest-client-reactive/src/main/java/io/quarkus/it/rest/client/main/ClientCallingResource.java +++ b/integration-tests/rest-client-reactive/src/main/java/io/quarkus/it/rest/client/main/ClientCallingResource.java @@ -74,6 +74,10 @@ void init(@Observes Router router) { rc.response().end("Hello World!"); }); + router.get("/correlation").handler(rc -> { + rc.response().end(rc.request().getHeader(CorrelationIdClient.CORRELATION_ID_HEADER_NAME)); + }); + router.post("/call-client-with-global-client-logger").blockingHandler(rc -> { String url = rc.body().asString(); ClientWithClientLogger client = QuarkusRestClientBuilder.newBuilder().baseUri(URI.create(url)) diff --git a/integration-tests/rest-client-reactive/src/main/java/io/quarkus/it/rest/client/main/CorrelationIdClient.java b/integration-tests/rest-client-reactive/src/main/java/io/quarkus/it/rest/client/main/CorrelationIdClient.java new file mode 100644 index 0000000000000..213fdf4234d04 --- /dev/null +++ b/integration-tests/rest-client-reactive/src/main/java/io/quarkus/it/rest/client/main/CorrelationIdClient.java @@ -0,0 +1,31 @@ +package io.quarkus.it.rest.client.main; + +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.client.ClientRequestContext; +import jakarta.ws.rs.client.ClientRequestFilter; + +import org.eclipse.microprofile.rest.client.annotation.RegisterProvider; +import org.eclipse.microprofile.rest.client.inject.RegisterRestClient; + +import io.smallrye.common.vertx.ContextLocals; + +@Path("/correlation") +@RegisterRestClient(configKey = "correlation") +@RegisterProvider(CorrelationIdClient.CorrelationIdClientFilter.class) +public interface CorrelationIdClient { + + String CORRELATION_ID_HEADER_NAME = "CorrelationId"; + + @GET + String get(); + + class CorrelationIdClientFilter implements ClientRequestFilter { + + @Override + public void filter(ClientRequestContext requestContext) { + String correlationId = ContextLocals. get(CORRELATION_ID_HEADER_NAME).orElse(null); + requestContext.getHeaders().putSingle(CORRELATION_ID_HEADER_NAME, correlationId); + } + } +} diff --git a/integration-tests/rest-client-reactive/src/main/resources/application.properties b/integration-tests/rest-client-reactive/src/main/resources/application.properties index 5a438ed680a25..e76e85e297bd1 100644 --- a/integration-tests/rest-client-reactive/src/main/resources/application.properties +++ b/integration-tests/rest-client-reactive/src/main/resources/application.properties @@ -1,8 +1,8 @@ w-client-logger/mp-rest/url=${test.url} w-exception-mapper/mp-rest/url=${test.url} w-fault-tolerance/mp-rest/url=${test.url} +correlation/mp-rest/url=${test.url} io.quarkus.it.rest.client.main.ParamClient/mp-rest/url=${test.url} -io.quarkus.it.rest.client.multipart.MultipartClient/mp-rest/url=${test.url} # global client logging scope quarkus.rest-client.logging.scope=request-response # Self-Signed client diff --git a/integration-tests/rest-client-reactive/src/test/java/io/quarkus/it/rest/client/RestClientInTestMethodWithContextTest.java b/integration-tests/rest-client-reactive/src/test/java/io/quarkus/it/rest/client/RestClientInTestMethodWithContextTest.java new file mode 100644 index 0000000000000..579fa1e432a29 --- /dev/null +++ b/integration-tests/rest-client-reactive/src/test/java/io/quarkus/it/rest/client/RestClientInTestMethodWithContextTest.java @@ -0,0 +1,25 @@ +package io.quarkus.it.rest.client; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.eclipse.microprofile.rest.client.inject.RestClient; +import org.junit.jupiter.api.Test; + +import io.quarkus.it.rest.client.main.CorrelationIdClient; +import io.quarkus.test.junit.QuarkusTest; +import io.quarkus.test.vertx.RunOnVertxContext; +import io.smallrye.common.vertx.ContextLocals; + +@QuarkusTest +public class RestClientInTestMethodWithContextTest { + + @RestClient + CorrelationIdClient client; + + @RunOnVertxContext(runOnEventLoop = false) + @Test + public void test() { + ContextLocals.put(CorrelationIdClient.CORRELATION_ID_HEADER_NAME, "dummy"); + assertThat(client.get()).isEqualTo("dummy"); + } +} diff --git a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java index 13e35c64e8f65..271c9202357bd 100644 --- a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java +++ b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java @@ -32,4 +32,9 @@ * @return {@code true} by default. */ boolean duplicateContext() default true; + + /** + * If {@code true}, the test method is run on the Event Loop, otherwise it will be run on Vert.x blocking thread pool + */ + boolean runOnEventLoop() default true; } diff --git a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java index 39214de3f94b3..30c027f212135 100644 --- a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java +++ b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java @@ -16,6 +16,7 @@ import io.smallrye.common.vertx.VertxContext; import io.vertx.core.Context; import io.vertx.core.Handler; +import io.vertx.core.Promise; import io.vertx.core.Vertx; public class RunOnVertxContextTestMethodInvoker implements TestMethodInvoker { @@ -70,17 +71,27 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List cf = new CompletableFuture<>(); - RunTestMethodOnContextHandler handler = new RunTestMethodOnContextHandler(actualTestInstance, actualTestMethod, - actualTestMethodArgs, uniAsserter, cf); + Context context = vertx.getOrCreateContext(); - boolean shouldDuplicateContext = shouldContextBeDuplicated( - actualTestInstance != null ? actualTestInstance.getClass() : Object.class, actualTestMethod); + Class testClass = actualTestInstance != null ? actualTestInstance.getClass() : Object.class; + boolean shouldDuplicateContext = shouldContextBeDuplicated(testClass, actualTestMethod); if (shouldDuplicateContext) { context = VertxContext.getOrCreateDuplicatedContext(context); VertxContextSafetyToggle.setContextSafe(context, true); } - context.runOnContext(handler); + + CompletableFuture cf; + if (shouldRunOnEventLoop(testClass, actualTestMethod)) { + cf = new CompletableFuture<>(); + var handler = new RunTestMethodOnVertxEventLoopContextHandler(actualTestInstance, actualTestMethod, + actualTestMethodArgs, uniAsserter, cf); + context.runOnContext(handler); + } else { + var handler = new RunTestMethodOnVertxBlockingContextHandler(actualTestInstance, actualTestMethod, + actualTestMethodArgs, uniAsserter); + cf = ((CompletableFuture) context.executeBlocking(handler).toCompletionStage()); + } + try { return cf.get(); } catch (InterruptedException e) { @@ -90,6 +101,7 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List c, Method m) { @@ -106,7 +118,19 @@ private boolean shouldContextBeDuplicated(Class c, Method m) { } } - public static class RunTestMethodOnContextHandler implements Handler { + private boolean shouldRunOnEventLoop(Class c, Method m) { + RunOnVertxContext runOnVertxContext = m.getAnnotation(RunOnVertxContext.class); + if (runOnVertxContext == null) { + runOnVertxContext = c.getAnnotation(RunOnVertxContext.class); + } + if (runOnVertxContext == null) { + return true; + } else { + return runOnVertxContext.runOnEventLoop(); + } + } + + public static class RunTestMethodOnVertxEventLoopContextHandler implements Handler { private static final Runnable DO_NOTHING = new Runnable() { @Override public void run() { @@ -119,7 +143,7 @@ public void run() { private final DefaultUniAsserter uniAsserter; private final CompletableFuture future; - public RunTestMethodOnContextHandler(Object testInstance, Method targetMethod, List methodArgs, + public RunTestMethodOnVertxEventLoopContextHandler(Object testInstance, Method targetMethod, List methodArgs, DefaultUniAsserter uniAsserter, CompletableFuture future) { this.testInstance = testInstance; this.future = future; @@ -172,4 +196,68 @@ public void accept(Throwable t) { } } + public static class RunTestMethodOnVertxBlockingContextHandler implements Handler> { + private static final Runnable DO_NOTHING = new Runnable() { + @Override + public void run() { + } + }; + + private final Object testInstance; + private final Method targetMethod; + private final List methodArgs; + private final DefaultUniAsserter uniAsserter; + + public RunTestMethodOnVertxBlockingContextHandler(Object testInstance, Method targetMethod, List methodArgs, + DefaultUniAsserter uniAsserter) { + this.testInstance = testInstance; + this.targetMethod = targetMethod; + this.methodArgs = methodArgs; + this.uniAsserter = uniAsserter; + } + + @Override + public void handle(Promise promise) { + ManagedContext requestContext = Arc.container().requestContext(); + if (requestContext.isActive()) { + doRun(promise, DO_NOTHING); + } else { + requestContext.activate(); + doRun(promise, new Runnable() { + @Override + public void run() { + requestContext.terminate(); + } + }); + } + } + + private void doRun(Promise promise, Runnable onTerminate) { + try { + Object testMethodResult = targetMethod.invoke(testInstance, methodArgs.toArray(new Object[0])); + if (uniAsserter != null) { + uniAsserter.execution.subscribe().with(new Consumer() { + @Override + public void accept(Object o) { + onTerminate.run(); + promise.complete(); + } + }, new Consumer<>() { + @Override + public void accept(Throwable t) { + onTerminate.run(); + promise.fail(t); + } + }); + } else { + onTerminate.run(); + promise.complete(testMethodResult); + } + } catch (Throwable t) { + onTerminate.run(); + promise.fail(t.getCause()); + } + } + } + }