diff --git a/bom/application/pom.xml b/bom/application/pom.xml index 96378cfc56a55..b9877fd00f56e 100644 --- a/bom/application/pom.xml +++ b/bom/application/pom.xml @@ -3094,6 +3094,11 @@ quarkus-project-core-extension-codestarts ${project.version} + + io.quarkus.junit5 + junit5-virtual-threads + ${project.version} + diff --git a/docs/src/main/asciidoc/virtual-threads.adoc b/docs/src/main/asciidoc/virtual-threads.adoc index 946c009dbc637..88718b399face 100644 --- a/docs/src/main/asciidoc/virtual-threads.adoc +++ b/docs/src/main/asciidoc/virtual-threads.adoc @@ -462,6 +462,73 @@ quarkus.virtual-threads.name-prefix= ---- +== Testing virtual thread applications + +As mentioned above, virtual threads have a few limitations that can drastically affect your application performance and memory usage. +The _junit5-virtual-threads_ extension provides a way to detect pinned carrier threads while running your tests. +Thus, you can eliminate one of the most prominent limitations or be aware of the problem. + +To enable this detection: + +* 1) Add the `junit5-virtual-threads` dependency to your project: +[source, xml] +---- + + io.quarkus.junit5 + junit5-virtual-threads + test + +---- + +* 2) In your test case, add the `io.quarkus.test.junit5.virtual.VirtualThreadUnit` and `io.quarkus.test.junit.virtual.ShouldNotPin` annotations: +[source, java] +---- +@QuarkusTest +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +@VirtualThreadUnit // Use the extension +@ShouldNotPin // Detect pinned carrier thread +class TodoResourceTest { + // ... +} +---- + +When you run your test (remember to use Java 21+), Quarkus detects pinned carrier threads. +When it happens, the test fails. + +The `@ShouldNotPin` can also be used on methods directly. + +The _junit5-virtual-threads_ also provides a `@ShouldPin` annotation for cases where pinning is unavoidable. +The following snippet demonstrates the `@ShouldPin` annotation usage and the possibility to inject a `ThreadPinnedEvents` instance in your test to verify when the carrier thread was pinned manually. + +[source, java] +---- +@VirtualThreadUnit // Use the extension +public class LoomUnitExampleTest { + + CodeUnderTest codeUnderTest = new CodeUnderTest(); + + @Test + @ShouldNotPin + public void testThatShouldNotPin() { + // ... + } + + @Test + @ShouldPin(atMost = 1) + public void testThatShouldPinAtMostOnce() { + codeUnderTest.pin(); + } + + @Test + public void testThatShouldNotPin(ThreadPinnedEvents events) { // Inject an object to check the pin events + Assertions.assertTrue(events.getEvents().isEmpty()); + codeUnderTest.pin(); + await().until(() -> events.getEvents().size() > 0); + Assertions.assertEquals(events.getEvents().size(), 1); + } + +} +---- == Additional references diff --git a/independent-projects/junit5-virtual-threads/pom.xml b/independent-projects/junit5-virtual-threads/pom.xml new file mode 100644 index 0000000000000..f3cac0d65d5bc --- /dev/null +++ b/independent-projects/junit5-virtual-threads/pom.xml @@ -0,0 +1,298 @@ + + + 4.0.0 + + + io.quarkus + quarkus-parent + 999-SNAPSHOT + ../parent/pom.xml + + + io.quarkus.junit5 + junit5-virtual-threads + + Quarkus - JUnit 5 Extension - Virtual Threads + Module that allows detecting virtual threads pinning + https://github.com/quarkusio/quarkus + + + + Apache License, Version 2.0 + repo + https://www.apache.org/licenses/LICENSE-2.0.html + + + + + https://github.com/quarkusio/quarkus + scm:git:git@github.com:quarkusio/quarkus.git + scm:git:git@github.com:quarkusio/quarkus.git + HEAD + + + + UTF-8 + 11 + 11 + 11 + + 3.11.0 + 3.2.1 + 3.1.2 + 3.1.3 + 2.23.0 + 1.9.0 + + 5.9.3 + + + + + + org.junit.jupiter + junit-jupiter + compile + ${junit.jupiter.version} + + + + + + + + maven-compiler-plugin + ${compiler.plugin.version} + + + io.smallrye + jandex-maven-plugin + ${jandex.version} + + + maven-javadoc-plugin + + true + none + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + io.quarkus + quarkus-enforcer-rules + ${project.version} + + + + + enforce + + + + + classpath:enforcer-rules/quarkus-require-java-version.xml + + + classpath:enforcer-rules/quarkus-require-maven-version.xml + + + classpath:enforcer-rules/quarkus-banned-dependencies.xml + + + + + com.google.code.findbugs:jsr305 + + com.google.guava:listenablefuture + + + + + + enforce + + + + + + maven-surefire-plugin + ${surefire.plugin.version} + + + + + -Djava.io.tmpdir="${project.build.directory}" + MAVEN_OPTS + + + + + net.revelc.code.formatter + formatter-maven-plugin + ${formatter-maven-plugin.version} + + + quarkus-ide-config + io.quarkus + ${project.version} + + + + + .cache/formatter-maven-plugin-${formatter-maven-plugin.version} + eclipse-format.xml + LF + ${format.skip} + + + + net.revelc.code + impsort-maven-plugin + ${impsort-maven-plugin.version} + + + .cache/impsort-maven-plugin-${impsort-maven-plugin.version} + java.,javax.,jakarta.,org.,com. + * + ${format.skip} + true + + + + + + + + + sonatype-nexus-snapshots + https://s01.oss.sonatype.org/content/repositories/snapshots + + + sonatype-nexus-release + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + + quick-build + + + quickly + + + + true + true + true + + + clean install + + + + + + quick-build-ci + + + quickly-ci + + + + true + true + true + true + + + + format + + true + + !no-format + + + + + + net.revelc.code.formatter + formatter-maven-plugin + + + process-sources + + format + + + + + + net.revelc.code + impsort-maven-plugin + + + sort-imports + + sort + + + + + true + + + + + + + validate + + true + + no-format + + + + + + net.revelc.code.formatter + formatter-maven-plugin + + + process-sources + + validate + + + + + + net.revelc.code + impsort-maven-plugin + + true + + + + check-imports + + check + + + + + + + + + + \ No newline at end of file diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ShouldNotPin.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ShouldNotPin.java new file mode 100644 index 0000000000000..bf4228f87c3b8 --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ShouldNotPin.java @@ -0,0 +1,16 @@ +package io.quarkus.test.junit5.virtual; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marker indicating that the test method or class should not pin the carrier thread. + * If, during the execution of the test, a virtual thread pins the carrier thread, the test fails. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.METHOD, ElementType.TYPE }) +public @interface ShouldNotPin { + +} diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ShouldPin.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ShouldPin.java new file mode 100644 index 0000000000000..d98f5166a8c3a --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ShouldPin.java @@ -0,0 +1,18 @@ +package io.quarkus.test.junit5.virtual; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Indicates that the method or class can pin. At most can be set to indicate the maximum number of events. + * If, during the execution of the test, a virtual thread does not pin the carrier thread, or pins it more than + * the given {@code atMost} value, the test fails. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.METHOD, ElementType.TYPE }) +public @interface ShouldPin { + int atMost() default Integer.MAX_VALUE; + +} diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ThreadPinnedEvents.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ThreadPinnedEvents.java new file mode 100644 index 0000000000000..296566af8d31f --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/ThreadPinnedEvents.java @@ -0,0 +1,17 @@ +package io.quarkus.test.junit5.virtual; + +import java.util.List; + +import jdk.jfr.consumer.RecordedEvent; + +/** + * Object that can be injected in a test method. + * It gives controlled on the captured events, and so let you do manual checks. + *

+ * The returned list is a copy of the list of captured events. + */ +public interface ThreadPinnedEvents { + + List getEvents(); + +} diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/VirtualThreadUnit.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/VirtualThreadUnit.java new file mode 100644 index 0000000000000..b37efa15ef7fc --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/VirtualThreadUnit.java @@ -0,0 +1,19 @@ +package io.quarkus.test.junit5.virtual; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.junit.jupiter.api.extension.ExtendWith; + +import io.quarkus.test.junit5.virtual.internal.VirtualThreadExtension; + +/** + * Extends the test case to detect pinned carrier thread. + */ +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.RUNTIME) +@ExtendWith(VirtualThreadExtension.class) +public @interface VirtualThreadUnit { +} diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/Collector.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/Collector.java new file mode 100644 index 0000000000000..74bf7bcf27fab --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/Collector.java @@ -0,0 +1,192 @@ +package io.quarkus.test.junit5.virtual.internal; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; + +import jdk.jfr.consumer.RecordedEvent; + +public class Collector implements Consumer { + public static final String CARRIER_PINNED_EVENT_NAME = "jdk.VirtualThreadPinned"; + private static final Logger LOGGER = Logger.getLogger(Collector.class.getName()); + + private final List> observers = new CopyOnWriteArrayList<>(); + + private final List events = new CopyOnWriteArrayList<>(); + + private final EventStreamFacade facade; + + volatile State state = State.INIT; + + public Collector() { + if (EventStreamFacade.available) { + facade = new EventStreamFacade(); + facade.enable(CARRIER_PINNED_EVENT_NAME).withStackTrace(); + facade.enable(InternalEvents.SHUTDOWN_EVENT_NAME).withoutStackTrace(); + facade.enable(InternalEvents.CAPTURING_STARTED_EVENT_NAME).withoutStackTrace(); + facade.enable(InternalEvents.CAPTURING_STOPPED_EVENT_NAME).withoutStackTrace(); + facade.enable(InternalEvents.INITIALIZATION_EVENT_NAME).withoutStackTrace(); + facade.setOrdered(true); + facade.setMaxSize(100); + facade.onEvent(this); + } else { + facade = null; + } + } + + public void init() { + if (facade != null) { + long begin = System.nanoTime(); + CountDownLatch latch = new CountDownLatch(1); + observers.add(re -> { + if (re.getEventType().getName().equals(InternalEvents.INITIALIZATION_EVENT_NAME)) { + latch.countDown(); + return true; + } + return false; + }); + facade.startAsync(); + new InternalEvents.InitializationEvent().commit(); + try { + if (latch.await(10, TimeUnit.SECONDS)) { + long end = System.nanoTime(); + state = State.STARTED; + LOGGER.log(Level.FINE, "Event collection started in {0}s", (end - begin) / 1000000); + } else { + throw new IllegalStateException( + "Unable to start JFR collection, RecordingStartedEvent event not received after 10s"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + } + + public void start() { + if (facade != null) { + CountDownLatch latch = new CountDownLatch(1); + String id = UUID.randomUUID().toString(); + long begin = System.nanoTime(); + observers.add(re -> { + if (re.getEventType().getName().equals(InternalEvents.CAPTURING_STARTED_EVENT_NAME)) { + if (id.equals(re.getString("id"))) { + events.clear(); + state = State.COLLECTING; + latch.countDown(); + return true; + } + } + return false; + }); + + new InternalEvents.CapturingStartedEvent(id).commit(); + + try { + if (!latch.await(10, TimeUnit.SECONDS)) { + throw new IllegalStateException("Unable to start JFR collection, START_EVENT event not received after 10s"); + } + long end = System.nanoTime(); + LOGGER.log(Level.FINE, "Event capturing started in {0}s", (end - begin) / 1000000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + } + + public List stop() { + if (facade != null) { + CountDownLatch latch = new CountDownLatch(1); + String id = UUID.randomUUID().toString(); + var begin = System.nanoTime(); + observers.add(re -> { + if (re.getEventType().getName().equals(InternalEvents.CAPTURING_STOPPED_EVENT_NAME)) { + state = State.STARTED; + latch.countDown(); + return true; + } + return false; + }); + + new InternalEvents.CapturingStoppedEvent(id).commit(); + + try { + if (!latch.await(10, TimeUnit.SECONDS)) { + throw new IllegalStateException("Unable to start JFR collection, STOP_EVENT event not received after 10s"); + } + var end = System.nanoTime(); + LOGGER.log(Level.FINE, "Event collection stopped in {0}s", (end - begin) / 1000000); + return new ArrayList<>(events); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + return Collections.emptyList(); + } + + public void shutdown() { + if (facade != null) { + CountDownLatch latch = new CountDownLatch(1); + var begin = System.nanoTime(); + observers.add(re -> { + if (re.getEventType().getName().equals(InternalEvents.SHUTDOWN_EVENT_NAME)) { + latch.countDown(); + return true; + } + return false; + }); + InternalEvents.ShutdownEvent event = new InternalEvents.ShutdownEvent(); + event.commit(); + try { + if (latch.await(10, TimeUnit.SECONDS)) { + state = State.INIT; + var end = System.nanoTime(); + LOGGER.log(Level.FINE, "Event collector shutdown in {0}s", (end - begin) / 1000000); + facade.stop(); + } else { + throw new IllegalStateException( + "Unable to stop JFR collection, RecordingStoppedEvent event not received at 10s"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + } + + @Override + public void accept(RecordedEvent re) { + if (state == State.COLLECTING) { + events.add(re); + } + List> toBeRemoved = new ArrayList<>(); + observers.forEach(c -> { + if (c.apply(re)) { + toBeRemoved.add(c); + } + }); + observers.removeAll(toBeRemoved); + } + + public List getEvents() { + return new ArrayList<>(events); + } + + enum State { + INIT, + STARTED, + COLLECTING + } + +} diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/EventStreamFacade.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/EventStreamFacade.java new file mode 100644 index 0000000000000..10061444ecf18 --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/EventStreamFacade.java @@ -0,0 +1,118 @@ +package io.quarkus.test.junit5.virtual.internal; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.function.Consumer; + +import org.junit.jupiter.api.function.ThrowingSupplier; + +import jdk.jfr.EventSettings; +import jdk.jfr.consumer.RecordedEvent; + +/** + * The RecordingStream is only Java 14+, and the code must be Java 11. + * This class provides the used API, but under the hood use MethodHandle. + */ +public class EventStreamFacade { + public static final String CARRIER_PINNED_EVENT_NAME = "jdk.VirtualThreadPinned"; + + /** + * Whether the RecordingStream API is available. + */ + public static final boolean available; + + private static final MethodHandle constructor; + + private static final MethodHandle enableMethod; + + private static final MethodHandle stopMethod; + + private static final MethodHandle startAsyncMethod; + + private static final MethodHandle setMaxSizeMethod; + + private static final MethodHandle setOrderedMethod; + + private static final MethodHandle onEventMethod; + + static { + boolean en; + MethodHandle tempConstructor = null; + MethodHandle tempEnable = null; + MethodHandle tempStartAsync = null; + MethodHandle tempStop = null; + MethodHandle tempSetMaxSize = null; + MethodHandle tempSetOrdered = null; + MethodHandle tempOnEvent = null; + try { + MethodHandles.Lookup lookup = MethodHandles.publicLookup(); + var clazz = EventStreamFacade.class.getClassLoader().loadClass("jdk.jfr.consumer.RecordingStream"); + tempConstructor = lookup.findConstructor(clazz, MethodType.methodType(void.class)); + tempEnable = lookup.findVirtual(clazz, "enable", MethodType.methodType(EventSettings.class, String.class)); + tempSetMaxSize = lookup.findVirtual(clazz, "setMaxSize", MethodType.methodType(void.class, long.class)); + tempSetOrdered = lookup.findVirtual(clazz, "setOrdered", MethodType.methodType(void.class, boolean.class)); + tempOnEvent = lookup.findVirtual(clazz, "onEvent", MethodType.methodType(void.class, Consumer.class)); + tempStartAsync = lookup.findVirtual(clazz, "startAsync", MethodType.methodType(void.class)); + tempStop = lookup.findVirtual(clazz, "stop", MethodType.methodType(boolean.class)); + en = true; + } catch (Throwable e) { + e.printStackTrace(); + en = false; + } + available = en; + constructor = tempConstructor; + enableMethod = tempEnable; + startAsyncMethod = tempStartAsync; + stopMethod = tempStop; + setMaxSizeMethod = tempSetMaxSize; + setOrderedMethod = tempSetOrdered; + onEventMethod = tempOnEvent; + } + + private final Object stream; + + public EventStreamFacade() { + try { + this.stream = constructor.invoke(); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + private T invoke(ThrowingSupplier invocation) { + if (!available) { + throw new UnsupportedOperationException("Stream recording not configured correctly, make sure you use Java 14+"); + } + try { + return invocation.get(); + } catch (Throwable e) { + throw new RuntimeException("Unable to invoke event stream method", e); + } + } + + public EventSettings enable(String event) { + return invoke(() -> (EventSettings) enableMethod.invoke(stream, event)); + } + + public void startAsync() { + invoke(() -> startAsyncMethod.invoke(stream)); + } + + public void setMaxSize(int max) { + invoke(() -> setMaxSizeMethod.invoke(stream, max)); + } + + public void setOrdered(boolean ordered) { + invoke(() -> setOrderedMethod.invoke(stream, ordered)); + } + + public void onEvent(Consumer consumer) { + invoke(() -> onEventMethod.invoke(stream, consumer)); + } + + public boolean stop() { + return invoke(() -> (boolean) stopMethod.invoke(stream)); + } + +} diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/InternalEvents.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/InternalEvents.java new file mode 100644 index 0000000000000..1128e67ae5ed1 --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/InternalEvents.java @@ -0,0 +1,61 @@ +package io.quarkus.test.junit5.virtual.internal; + +import jdk.jfr.Category; +import jdk.jfr.Event; +import jdk.jfr.Label; +import jdk.jfr.Name; +import jdk.jfr.StackTrace; + +/** + * Internal events used during the capture. + */ +public interface InternalEvents { + + String INITIALIZATION_EVENT_NAME = "io.quarkus.test.junit5.virtual.internal.InternalEvents.InitializationEvent"; + String SHUTDOWN_EVENT_NAME = "io.quarkus.test.junit5.virtual.internal.InternalEvents.ShutdownEvent"; + + String CAPTURING_STARTED_EVENT_NAME = "io.quarkus.test.junit5.virtual.internal.InternalEvents.CapturingStartedEvent"; + String CAPTURING_STOPPED_EVENT_NAME = "io.quarkus.test.junit5.virtual.internal.InternalEvents.CapturingStoppedEvent"; + + @Name(INITIALIZATION_EVENT_NAME) + @Category("virtual-thread-unit") + @StackTrace(value = false) + class InitializationEvent extends Event { + // Marker event + } + + @Name(SHUTDOWN_EVENT_NAME) + @Category("virtual-thread-unit") + @StackTrace(value = false) + class ShutdownEvent extends Event { + // Marker event + } + + @Name(CAPTURING_STARTED_EVENT_NAME) + @Category("virtual-thread-unit") + @StackTrace(value = false) + class CapturingStartedEvent extends Event { + + @Name("id") + @Label("id") + public final String id; + + public CapturingStartedEvent(String id) { + this.id = id; + } + } + + @Name(CAPTURING_STOPPED_EVENT_NAME) + @Category("virtual-thread-unit") + @StackTrace(value = false) + class CapturingStoppedEvent extends Event { + + @Name("id") + @Label("id") + public final String id; + + public CapturingStoppedEvent(String id) { + this.id = id; + } + } +} diff --git a/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/VirtualThreadExtension.java b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/VirtualThreadExtension.java new file mode 100644 index 0000000000000..593c0951b6632 --- /dev/null +++ b/independent-projects/junit5-virtual-threads/src/main/java/io/quarkus/test/junit5/virtual/internal/VirtualThreadExtension.java @@ -0,0 +1,174 @@ +package io.quarkus.test.junit5.virtual.internal; + +import static io.quarkus.test.junit5.virtual.internal.EventStreamFacade.CARRIER_PINNED_EVENT_NAME; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.AfterEachCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.BeforeEachCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolutionException; +import org.junit.jupiter.api.extension.ParameterResolver; +import org.junit.jupiter.api.extension.TestInstantiationException; + +import io.quarkus.test.junit5.virtual.ShouldNotPin; +import io.quarkus.test.junit5.virtual.ShouldPin; +import io.quarkus.test.junit5.virtual.ThreadPinnedEvents; +import jdk.jfr.consumer.RecordedEvent; +import jdk.jfr.consumer.RecordedFrame; + +public class VirtualThreadExtension + implements BeforeAllCallback, AfterAllCallback, BeforeEachCallback, AfterEachCallback, ParameterResolver { + + public static final String _COLLECTOR_KEY = "collector"; + private ExtensionContext.Namespace namespace; + + @Override + public void beforeAll(ExtensionContext extensionContext) { + Collector collector = new Collector(); + namespace = ExtensionContext.Namespace.create("loom-unit"); + var store = extensionContext.getStore(namespace); + store.put(_COLLECTOR_KEY, collector); + collector.init(); + } + + @Override + public void afterAll(ExtensionContext extensionContext) { + var store = extensionContext.getStore(namespace); + store.get(_COLLECTOR_KEY, Collector.class).shutdown(); + } + + @Override + public void beforeEach(ExtensionContext extensionContext) { + var clazz = extensionContext.getRequiredTestClass(); + var method = extensionContext.getRequiredTestMethod(); + if (requiresRecording(clazz, method)) { + var store = extensionContext.getStore(namespace); + store.get(_COLLECTOR_KEY, Collector.class).start(); + + if (getShouldPin(extensionContext.getRequiredTestClass(), extensionContext.getRequiredTestMethod()) != null + && getShouldNotPin(extensionContext.getRequiredTestClass(), + extensionContext.getRequiredTestMethod()) != null) { + throw new TestInstantiationException("Cannot execute test " + extensionContext.getDisplayName() + + ": @ShouldPin and @ShouldNotPin are used on the method or class"); + } + } + } + + private boolean requiresRecording(Class clazz, Method method) { + if (clazz.isAnnotationPresent(ShouldNotPin.class) || clazz.isAnnotationPresent(ShouldPin.class) + || method.isAnnotationPresent(ShouldNotPin.class) || method.isAnnotationPresent(ShouldPin.class)) { + return true; + } + return Arrays.asList(method.getParameterTypes()).contains(ThreadPinnedEvents.class); + } + + private ShouldPin getShouldPin(Class clazz, Method method) { + if (method.isAnnotationPresent(ShouldPin.class)) { + return method.getAnnotation(ShouldPin.class); + } + + if (method.isAnnotationPresent(ShouldNotPin.class)) { + // If the method overrides the class annotation. + return null; + } + + if (clazz.isAnnotationPresent(ShouldPin.class)) { + return clazz.getAnnotation(ShouldPin.class); + } + + return null; + } + + private ShouldNotPin getShouldNotPin(Class clazz, Method method) { + if (method.isAnnotationPresent(ShouldNotPin.class)) { + return method.getAnnotation(ShouldNotPin.class); + } + + if (method.isAnnotationPresent(ShouldPin.class)) { + // If the method overrides the class annotation. + return null; + } + + if (clazz.isAnnotationPresent(ShouldNotPin.class)) { + return clazz.getAnnotation(ShouldNotPin.class); + } + + return null; + } + + @Override + public void afterEach(ExtensionContext extensionContext) { + Method method = extensionContext.getRequiredTestMethod(); + Class clazz = extensionContext.getRequiredTestClass(); + if (!requiresRecording(clazz, method)) { + return; + } + var store = extensionContext.getStore(namespace); + List captured = store.get(_COLLECTOR_KEY, Collector.class).stop(); + List pinEvents = captured.stream() + .filter(re -> re.getEventType().getName().equals(CARRIER_PINNED_EVENT_NAME)).collect(Collectors.toList()); + + ShouldPin pin = getShouldPin(clazz, method); + ShouldNotPin notpin = getShouldNotPin(clazz, method); + + if (pin != null) { + if (pinEvents.isEmpty()) { + throw new AssertionError( + "The test " + extensionContext.getDisplayName() + " was expected to pin the carrier thread, it didn't"); + } + if (pin.atMost() != Integer.MAX_VALUE && pinEvents.size() > pin.atMost()) { + throw new AssertionError("The test " + extensionContext.getDisplayName() + + " was expected to pin the carrier thread at most " + pin.atMost() + + ", but we collected " + pinEvents.size() + " events\n" + dump(pinEvents)); + } + } + + if (notpin != null) { + if (!pinEvents.isEmpty()) { + throw new AssertionError( + "The test " + extensionContext.getDisplayName() + " was expected to NOT pin the carrier thread" + + ", but we collected " + pinEvents.size() + " event(s)\n" + dump(pinEvents)); + } + } + + } + + private static final String STACK_TRACE_TEMPLATE = "\t%s.%s(%s.java:%d)\n"; + + private String dump(List pinEvents) { + StringBuilder builder = new StringBuilder(); + for (RecordedEvent pinEvent : pinEvents) { + builder.append("* Pinning event captured: \n"); + for (RecordedFrame recordedFrame : pinEvent.getStackTrace().getFrames()) { + String output = String.format(STACK_TRACE_TEMPLATE, recordedFrame.getMethod().getType().getName(), + recordedFrame.getMethod().getName(), recordedFrame.getMethod().getType().getName(), + recordedFrame.getLineNumber()); + builder.append(output); + } + } + return builder.toString(); + } + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return parameterContext.getParameter().getType().equals(ThreadPinnedEvents.class); + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) + throws ParameterResolutionException { + return (ThreadPinnedEvents) () -> { + var store = extensionContext.getStore(namespace); + return store.get(_COLLECTOR_KEY, Collector.class).getEvents(); + }; + } + +} diff --git a/pom.xml b/pom.xml index 67663d7b763ab..70d8be7009311 100644 --- a/pom.xml +++ b/pom.xml @@ -95,6 +95,7 @@ independent-projects/enforcer-rules independent-projects/resteasy-reactive independent-projects/extension-maven-plugin + independent-projects/junit5-virtual-threads bom/application