From 78d8e71d29f80cf0fe2b513b0b3f5316cec806df Mon Sep 17 00:00:00 2001
From: Georgios Andrianakis <geoand@gmail.com>
Date: Thu, 22 Jun 2023 13:57:08 +0300
Subject: [PATCH] Allow executing tests on Vert.x blocking thread pool

Relates to: #34222
---
 .../rest-client-reactive/deployment/pom.xml   |  25 +++++
 .../RestClientInTestMethodWithContextVT.java  |  73 ++++++++++++
 .../quarkus/test/vertx/RunOnVertxContext.java |   5 +
 .../RunOnVertxContextTestMethodInvoker.java   | 104 ++++++++++++++++--
 4 files changed, 199 insertions(+), 8 deletions(-)
 create mode 100644 extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/RestClientInTestMethodWithContextVT.java

diff --git a/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml b/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml
index 3d48dda2d0c4c..bda784d488c4d 100644
--- a/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml
+++ b/extensions/resteasy-reactive/rest-client-reactive/deployment/pom.xml
@@ -88,6 +88,11 @@
             <artifactId>stork-service-discovery-static-list</artifactId>
             <scope>test</scope>
         </dependency>
+        <dependency>
+            <groupId>io.quarkus</groupId>
+            <artifactId>quarkus-test-vertx</artifactId>
+            <scope>test</scope>
+        </dependency>
    </dependencies>
 
     <build>
@@ -104,6 +109,26 @@
                     </annotationProcessorPaths>
                 </configuration>
             </plugin>
+
+            <plugin>
+                <artifactId>maven-surefire-plugin</artifactId>
+                <executions>
+                    <!--
+                    The tests that use @RunOnVertxContext should not be mixed with regular tests because they can cause class loader leaks
+                    So to avoid having a new maven module for these, we just introduce a new surefire run
+                    -->
+                    <execution>
+                        <id>run-on-vertx</id>
+                        <phase>test</phase>
+                        <goals>
+                            <goal>test</goal>
+                        </goals>
+                        <configuration>
+                            <includes>**/*VT.java</includes>
+                        </configuration>
+                    </execution>
+                </executions>
+            </plugin>
         </plugins>
     </build>
 
diff --git a/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/RestClientInTestMethodWithContextVT.java b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/RestClientInTestMethodWithContextVT.java
new file mode 100644
index 0000000000000..625ee1df07c2e
--- /dev/null
+++ b/extensions/resteasy-reactive/rest-client-reactive/deployment/src/test/java/io/quarkus/rest/client/reactive/RestClientInTestMethodWithContextVT.java
@@ -0,0 +1,73 @@
+package io.quarkus.rest.client.reactive;
+
+import static io.quarkus.rest.client.reactive.RestClientTestUtil.setUrlForClass;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.net.MalformedURLException;
+
+import jakarta.ws.rs.GET;
+import jakarta.ws.rs.Path;
+import jakarta.ws.rs.client.ClientRequestContext;
+import jakarta.ws.rs.client.ClientRequestFilter;
+
+import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
+import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;
+import org.eclipse.microprofile.rest.client.inject.RestClient;
+import org.jboss.resteasy.reactive.RestHeader;
+import org.jboss.shrinkwrap.api.asset.StringAsset;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import io.quarkus.test.QuarkusUnitTest;
+import io.quarkus.test.vertx.RunOnVertxContext;
+import io.smallrye.common.vertx.ContextLocals;
+
+public class RestClientInTestMethodWithContextVT {
+
+    public static final String CORRELATION_ID_HEADER_NAME = "correlationId";
+
+    @RegisterExtension
+    static final QuarkusUnitTest config = new QuarkusUnitTest()
+            .withApplicationRoot((jar) -> jar
+                    .addClasses(Client.class, Resource.class, ClientFilter.class)
+                    .addAsResource(
+                            new StringAsset(setUrlForClass(Client.class)), "application.properties"));
+
+    @RestClient
+    Client client;
+
+    @RunOnVertxContext(runOnEventLoop = false)
+    @Test
+    public void test() {
+        ContextLocals.put(CORRELATION_ID_HEADER_NAME, "dummy");
+        assertThat(client.getMock()).isEqualTo("dummy");
+    }
+
+    @Path("/test")
+    @RegisterRestClient(configKey = "mock-client")
+    @RegisterProvider(ClientFilter.class)
+    public interface Client {
+
+        @GET
+        String getMock();
+    }
+
+    @Path("/test")
+    public static class Resource {
+
+        @GET
+        public String get(@RestHeader(CORRELATION_ID_HEADER_NAME) String correlationId) {
+            return correlationId;
+        }
+    }
+
+    public static class ClientFilter implements ClientRequestFilter {
+        @Override
+        public void filter(ClientRequestContext clientRequestContext) throws MalformedURLException {
+
+            String correlationId = ContextLocals.<String> get(CORRELATION_ID_HEADER_NAME).orElse(null);
+            clientRequestContext.getHeaders().putSingle(CORRELATION_ID_HEADER_NAME, correlationId);
+        }
+    }
+
+}
diff --git a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java
index 13e35c64e8f65..271c9202357bd 100644
--- a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java
+++ b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContext.java
@@ -32,4 +32,9 @@
      * @return {@code true} by default.
      */
     boolean duplicateContext() default true;
+
+    /**
+     * If {@code true}, the test method is run on the Event Loop, otherwise it will be run on Vert.x blocking thread pool
+     */
+    boolean runOnEventLoop() default true;
 }
diff --git a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java
index 39214de3f94b3..30c027f212135 100644
--- a/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java
+++ b/test-framework/vertx/src/main/java/io/quarkus/test/vertx/RunOnVertxContextTestMethodInvoker.java
@@ -16,6 +16,7 @@
 import io.smallrye.common.vertx.VertxContext;
 import io.vertx.core.Context;
 import io.vertx.core.Handler;
+import io.vertx.core.Promise;
 import io.vertx.core.Vertx;
 
 public class RunOnVertxContextTestMethodInvoker implements TestMethodInvoker {
@@ -70,17 +71,27 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List<Ob
             throw new IllegalStateException("Vert.x instance has not been created before attempting to run test method '"
                     + actualTestMethod.getName() + "' of test class '" + testClassName + "'");
         }
-        CompletableFuture<Object> cf = new CompletableFuture<>();
-        RunTestMethodOnContextHandler handler = new RunTestMethodOnContextHandler(actualTestInstance, actualTestMethod,
-                actualTestMethodArgs, uniAsserter, cf);
+
         Context context = vertx.getOrCreateContext();
-        boolean shouldDuplicateContext = shouldContextBeDuplicated(
-                actualTestInstance != null ? actualTestInstance.getClass() : Object.class, actualTestMethod);
+        Class<?> testClass = actualTestInstance != null ? actualTestInstance.getClass() : Object.class;
+        boolean shouldDuplicateContext = shouldContextBeDuplicated(testClass, actualTestMethod);
         if (shouldDuplicateContext) {
             context = VertxContext.getOrCreateDuplicatedContext(context);
             VertxContextSafetyToggle.setContextSafe(context, true);
         }
-        context.runOnContext(handler);
+
+        CompletableFuture<Object> cf;
+        if (shouldRunOnEventLoop(testClass, actualTestMethod)) {
+            cf = new CompletableFuture<>();
+            var handler = new RunTestMethodOnVertxEventLoopContextHandler(actualTestInstance, actualTestMethod,
+                    actualTestMethodArgs, uniAsserter, cf);
+            context.runOnContext(handler);
+        } else {
+            var handler = new RunTestMethodOnVertxBlockingContextHandler(actualTestInstance, actualTestMethod,
+                    actualTestMethodArgs, uniAsserter);
+            cf = ((CompletableFuture) context.executeBlocking(handler).toCompletionStage());
+        }
+
         try {
             return cf.get();
         } catch (InterruptedException e) {
@@ -90,6 +101,7 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List<Ob
             // the test itself threw an exception
             throw e.getCause();
         }
+
     }
 
     private boolean shouldContextBeDuplicated(Class<?> c, Method m) {
@@ -106,7 +118,19 @@ private boolean shouldContextBeDuplicated(Class<?> c, Method m) {
         }
     }
 
-    public static class RunTestMethodOnContextHandler implements Handler<Void> {
+    private boolean shouldRunOnEventLoop(Class<?> c, Method m) {
+        RunOnVertxContext runOnVertxContext = m.getAnnotation(RunOnVertxContext.class);
+        if (runOnVertxContext == null) {
+            runOnVertxContext = c.getAnnotation(RunOnVertxContext.class);
+        }
+        if (runOnVertxContext == null) {
+            return true;
+        } else {
+            return runOnVertxContext.runOnEventLoop();
+        }
+    }
+
+    public static class RunTestMethodOnVertxEventLoopContextHandler implements Handler<Void> {
         private static final Runnable DO_NOTHING = new Runnable() {
             @Override
             public void run() {
@@ -119,7 +143,7 @@ public void run() {
         private final DefaultUniAsserter uniAsserter;
         private final CompletableFuture<Object> future;
 
-        public RunTestMethodOnContextHandler(Object testInstance, Method targetMethod, List<Object> methodArgs,
+        public RunTestMethodOnVertxEventLoopContextHandler(Object testInstance, Method targetMethod, List<Object> methodArgs,
                 DefaultUniAsserter uniAsserter, CompletableFuture<Object> future) {
             this.testInstance = testInstance;
             this.future = future;
@@ -172,4 +196,68 @@ public void accept(Throwable t) {
         }
     }
 
+    public static class RunTestMethodOnVertxBlockingContextHandler implements Handler<Promise<Object>> {
+        private static final Runnable DO_NOTHING = new Runnable() {
+            @Override
+            public void run() {
+            }
+        };
+
+        private final Object testInstance;
+        private final Method targetMethod;
+        private final List<Object> methodArgs;
+        private final DefaultUniAsserter uniAsserter;
+
+        public RunTestMethodOnVertxBlockingContextHandler(Object testInstance, Method targetMethod, List<Object> methodArgs,
+                DefaultUniAsserter uniAsserter) {
+            this.testInstance = testInstance;
+            this.targetMethod = targetMethod;
+            this.methodArgs = methodArgs;
+            this.uniAsserter = uniAsserter;
+        }
+
+        @Override
+        public void handle(Promise<Object> promise) {
+            ManagedContext requestContext = Arc.container().requestContext();
+            if (requestContext.isActive()) {
+                doRun(promise, DO_NOTHING);
+            } else {
+                requestContext.activate();
+                doRun(promise, new Runnable() {
+                    @Override
+                    public void run() {
+                        requestContext.terminate();
+                    }
+                });
+            }
+        }
+
+        private void doRun(Promise<Object> promise, Runnable onTerminate) {
+            try {
+                Object testMethodResult = targetMethod.invoke(testInstance, methodArgs.toArray(new Object[0]));
+                if (uniAsserter != null) {
+                    uniAsserter.execution.subscribe().with(new Consumer<Object>() {
+                        @Override
+                        public void accept(Object o) {
+                            onTerminate.run();
+                            promise.complete();
+                        }
+                    }, new Consumer<>() {
+                        @Override
+                        public void accept(Throwable t) {
+                            onTerminate.run();
+                            promise.fail(t);
+                        }
+                    });
+                } else {
+                    onTerminate.run();
+                    promise.complete(testMethodResult);
+                }
+            } catch (Throwable t) {
+                onTerminate.run();
+                promise.fail(t.getCause());
+            }
+        }
+    }
+
 }