From 6e9ee38fd8fd0b03738387d6b958edd071e0118e Mon Sep 17 00:00:00 2001 From: Thomas Canava Date: Wed, 4 Dec 2024 22:57:07 +0100 Subject: [PATCH] feat: Add support of @RunOnVirtualThread on class for websockets next server --- .../next/deployment/WebSocketProcessor.java | 5 ++- .../RunOnVirtualThreadTest.java | 42 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index 6b33ff21fd719..1c5ecbb6df7ce 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -1613,12 +1613,15 @@ private static Callback findCallback(Target target, IndexView index, BeanInfo be private static ExecutionModel executionModel(MethodInfo method, TransformedAnnotationsBuildItem transformedAnnotations) { if (KotlinUtils.isKotlinSuspendMethod(method) && (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.RUN_ON_VIRTUAL_THREAD) + || transformedAnnotations.hasAnnotation(method.declaringClass(), + WebSocketDotNames.RUN_ON_VIRTUAL_THREAD) || transformedAnnotations.hasAnnotation(method, WebSocketDotNames.BLOCKING) || transformedAnnotations.hasAnnotation(method, WebSocketDotNames.NON_BLOCKING))) { throw new WebSocketException("Kotlin `suspend` functions in WebSockets Next endpoints may not be " + "annotated @Blocking, @NonBlocking or @RunOnVirtualThread: " + method); } - if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.RUN_ON_VIRTUAL_THREAD)) { + if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.RUN_ON_VIRTUAL_THREAD) + || transformedAnnotations.hasAnnotation(method.declaringClass(), WebSocketDotNames.RUN_ON_VIRTUAL_THREAD)) { return ExecutionModel.VIRTUAL_THREAD; } else if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.BLOCKING)) { return ExecutionModel.WORKER_THREAD; diff --git a/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java b/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java index 0c767e18834cd..80676788b7dbf 100644 --- a/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java +++ b/extensions/websockets-next/deployment/src/test/java21/io/quarkus/websockets/next/test/virtualthreads/RunOnVirtualThreadTest.java @@ -13,6 +13,7 @@ import io.quarkus.test.common.http.TestHTTPResource; import io.quarkus.test.vertx.VirtualThreadsAssertions; import io.quarkus.websockets.next.OnError; +import io.quarkus.websockets.next.OnOpen; import io.quarkus.websockets.next.OnTextMessage; import io.quarkus.websockets.next.WebSocket; import io.quarkus.websockets.next.test.utils.WSClient; @@ -38,6 +39,9 @@ public class RunOnVirtualThreadTest { @TestHTTPResource("end") URI endUri; + @TestHTTPResource("virt-on-class") + URI onClassUri; + @Test void testVirtualThreads() { try (WSClient client = new WSClient(vertx).connect(endUri)) { @@ -52,6 +56,22 @@ void testVirtualThreads() { } } + @Test + void testVirtualThreadsOnClass() { + try (WSClient client = new WSClient(vertx).connect(onClassUri)) { + client.sendAndAwait("foo"); + client.sendAndAwait("bar"); + client.waitForMessages(3); + String open = client.getMessages().get(0).toString(); + String message1 = client.getMessages().get(1).toString(); + String message2 = client.getMessages().get(2).toString(); + assertNotEquals(open, message1, message2); + assertTrue(open.startsWith("wsnext-virtual-thread-")); + assertTrue(message1.startsWith("wsnext-virtual-thread-")); + assertTrue(message2.startsWith("wsnext-virtual-thread-")); + } + } + @WebSocket(path = "/end") public static class Endpoint { @@ -71,7 +91,27 @@ String error(Throwable t) { } } - + + @RunOnVirtualThread + @WebSocket(path = "/virt-on-class") + public static class EndpointVirtOnClass { + + @Inject + RequestScopedBean bean; + + @OnOpen + String open() { + VirtualThreadsAssertions.assertEverything(); + return Thread.currentThread().getName(); + } + + @OnTextMessage + String text(String ignored) { + VirtualThreadsAssertions.assertEverything(); + return Thread.currentThread().getName(); + } + } + @RequestScoped public static class RequestScopedBean {