Skip to content

Commit

Permalink
Merge pull request #34245 from geoand/#34222
Browse files Browse the repository at this point in the history
Allow executing tests on Vert.x blocking thread pool
  • Loading branch information
geoand authored Jun 22, 2023
2 parents 6ab6016 + db2292e commit cc27b73
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 9 deletions.
5 changes: 5 additions & 0 deletions integration-tests/rest-client-reactive/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
<artifactId>quarkus-junit5-mockito</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-test-vertx</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.github.tomakehurst</groupId>
<artifactId>wiremock-jre8-standalone</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ void init(@Observes Router router) {
rc.response().end("Hello World!");
});

router.get("/correlation").handler(rc -> {
rc.response().end(rc.request().getHeader(CorrelationIdClient.CORRELATION_ID_HEADER_NAME));
});

router.post("/call-client-with-global-client-logger").blockingHandler(rc -> {
String url = rc.body().asString();
ClientWithClientLogger client = QuarkusRestClientBuilder.newBuilder().baseUri(URI.create(url))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package io.quarkus.it.rest.client.main;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.client.ClientRequestContext;
import jakarta.ws.rs.client.ClientRequestFilter;

import org.eclipse.microprofile.rest.client.annotation.RegisterProvider;
import org.eclipse.microprofile.rest.client.inject.RegisterRestClient;

import io.smallrye.common.vertx.ContextLocals;

@Path("/correlation")
@RegisterRestClient(configKey = "correlation")
@RegisterProvider(CorrelationIdClient.CorrelationIdClientFilter.class)
public interface CorrelationIdClient {

String CORRELATION_ID_HEADER_NAME = "CorrelationId";

@GET
String get();

class CorrelationIdClientFilter implements ClientRequestFilter {

@Override
public void filter(ClientRequestContext requestContext) {
String correlationId = ContextLocals.<String> get(CORRELATION_ID_HEADER_NAME).orElse(null);
requestContext.getHeaders().putSingle(CORRELATION_ID_HEADER_NAME, correlationId);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
w-client-logger/mp-rest/url=${test.url}
w-exception-mapper/mp-rest/url=${test.url}
w-fault-tolerance/mp-rest/url=${test.url}
correlation/mp-rest/url=${test.url}
io.quarkus.it.rest.client.main.ParamClient/mp-rest/url=${test.url}
io.quarkus.it.rest.client.multipart.MultipartClient/mp-rest/url=${test.url}
# global client logging scope
quarkus.rest-client.logging.scope=request-response
# Self-Signed client
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.quarkus.it.rest.client;

import static org.assertj.core.api.Assertions.assertThat;

import org.eclipse.microprofile.rest.client.inject.RestClient;
import org.junit.jupiter.api.Test;

import io.quarkus.it.rest.client.main.CorrelationIdClient;
import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.vertx.RunOnVertxContext;
import io.smallrye.common.vertx.ContextLocals;

@QuarkusTest
public class RestClientInTestMethodWithContextTest {

@RestClient
CorrelationIdClient client;

@RunOnVertxContext(runOnEventLoop = false)
@Test
public void test() {
ContextLocals.put(CorrelationIdClient.CORRELATION_ID_HEADER_NAME, "dummy");
assertThat(client.get()).isEqualTo("dummy");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,9 @@
* @return {@code true} by default.
*/
boolean duplicateContext() default true;

/**
* If {@code true}, the test method is run on the Event Loop, otherwise it will be run on Vert.x blocking thread pool
*/
boolean runOnEventLoop() default true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.smallrye.common.vertx.VertxContext;
import io.vertx.core.Context;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;

public class RunOnVertxContextTestMethodInvoker implements TestMethodInvoker {
Expand Down Expand Up @@ -70,17 +71,27 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List<Ob
throw new IllegalStateException("Vert.x instance has not been created before attempting to run test method '"
+ actualTestMethod.getName() + "' of test class '" + testClassName + "'");
}
CompletableFuture<Object> cf = new CompletableFuture<>();
RunTestMethodOnContextHandler handler = new RunTestMethodOnContextHandler(actualTestInstance, actualTestMethod,
actualTestMethodArgs, uniAsserter, cf);

Context context = vertx.getOrCreateContext();
boolean shouldDuplicateContext = shouldContextBeDuplicated(
actualTestInstance != null ? actualTestInstance.getClass() : Object.class, actualTestMethod);
Class<?> testClass = actualTestInstance != null ? actualTestInstance.getClass() : Object.class;
boolean shouldDuplicateContext = shouldContextBeDuplicated(testClass, actualTestMethod);
if (shouldDuplicateContext) {
context = VertxContext.getOrCreateDuplicatedContext(context);
VertxContextSafetyToggle.setContextSafe(context, true);
}
context.runOnContext(handler);

CompletableFuture<Object> cf;
if (shouldRunOnEventLoop(testClass, actualTestMethod)) {
cf = new CompletableFuture<>();
var handler = new RunTestMethodOnVertxEventLoopContextHandler(actualTestInstance, actualTestMethod,
actualTestMethodArgs, uniAsserter, cf);
context.runOnContext(handler);
} else {
var handler = new RunTestMethodOnVertxBlockingContextHandler(actualTestInstance, actualTestMethod,
actualTestMethodArgs, uniAsserter);
cf = ((CompletableFuture) context.executeBlocking(handler).toCompletionStage());
}

try {
return cf.get();
} catch (InterruptedException e) {
Expand All @@ -90,6 +101,7 @@ public Object invoke(Object actualTestInstance, Method actualTestMethod, List<Ob
// the test itself threw an exception
throw e.getCause();
}

}

private boolean shouldContextBeDuplicated(Class<?> c, Method m) {
Expand All @@ -106,7 +118,19 @@ private boolean shouldContextBeDuplicated(Class<?> c, Method m) {
}
}

public static class RunTestMethodOnContextHandler implements Handler<Void> {
private boolean shouldRunOnEventLoop(Class<?> c, Method m) {
RunOnVertxContext runOnVertxContext = m.getAnnotation(RunOnVertxContext.class);
if (runOnVertxContext == null) {
runOnVertxContext = c.getAnnotation(RunOnVertxContext.class);
}
if (runOnVertxContext == null) {
return true;
} else {
return runOnVertxContext.runOnEventLoop();
}
}

public static class RunTestMethodOnVertxEventLoopContextHandler implements Handler<Void> {
private static final Runnable DO_NOTHING = new Runnable() {
@Override
public void run() {
Expand All @@ -119,7 +143,7 @@ public void run() {
private final DefaultUniAsserter uniAsserter;
private final CompletableFuture<Object> future;

public RunTestMethodOnContextHandler(Object testInstance, Method targetMethod, List<Object> methodArgs,
public RunTestMethodOnVertxEventLoopContextHandler(Object testInstance, Method targetMethod, List<Object> methodArgs,
DefaultUniAsserter uniAsserter, CompletableFuture<Object> future) {
this.testInstance = testInstance;
this.future = future;
Expand Down Expand Up @@ -172,4 +196,68 @@ public void accept(Throwable t) {
}
}

public static class RunTestMethodOnVertxBlockingContextHandler implements Handler<Promise<Object>> {
private static final Runnable DO_NOTHING = new Runnable() {
@Override
public void run() {
}
};

private final Object testInstance;
private final Method targetMethod;
private final List<Object> methodArgs;
private final DefaultUniAsserter uniAsserter;

public RunTestMethodOnVertxBlockingContextHandler(Object testInstance, Method targetMethod, List<Object> methodArgs,
DefaultUniAsserter uniAsserter) {
this.testInstance = testInstance;
this.targetMethod = targetMethod;
this.methodArgs = methodArgs;
this.uniAsserter = uniAsserter;
}

@Override
public void handle(Promise<Object> promise) {
ManagedContext requestContext = Arc.container().requestContext();
if (requestContext.isActive()) {
doRun(promise, DO_NOTHING);
} else {
requestContext.activate();
doRun(promise, new Runnable() {
@Override
public void run() {
requestContext.terminate();
}
});
}
}

private void doRun(Promise<Object> promise, Runnable onTerminate) {
try {
Object testMethodResult = targetMethod.invoke(testInstance, methodArgs.toArray(new Object[0]));
if (uniAsserter != null) {
uniAsserter.execution.subscribe().with(new Consumer<Object>() {
@Override
public void accept(Object o) {
onTerminate.run();
promise.complete();
}
}, new Consumer<>() {
@Override
public void accept(Throwable t) {
onTerminate.run();
promise.fail(t);
}
});
} else {
onTerminate.run();
promise.complete(testMethodResult);
}
} catch (Throwable t) {
onTerminate.run();
promise.fail(t.getCause());
}
}
}

}

0 comments on commit cc27b73

Please sign in to comment.