From 2517290b0c15c1979c2c831ccef5a1c308dd2c97 Mon Sep 17 00:00:00 2001 From: Ladislav Thon Date: Mon, 8 Jul 2024 17:13:30 +0200 Subject: [PATCH] WebSockets Next: add support for Kotlin suspend functions Kotlin `suspend` functions are treated like Java methods that return `Uni`. That is, they are considered non-blocking. The implementation uses CDI method invokers (to avoid custom bytecode generation), which actually convert the `suspend` function result into a `Uni` under the hood. With this commit, only single-shot `suspend` functions are supported; `suspend` functions returning `Flow` are not supported yet. --- bom/application/pom.xml | 5 + docs/src/main/asciidoc/kotlin.adoc | 3 + .../asciidoc/websockets-next-reference.adoc | 28 ++- extensions/websockets-next/deployment/pom.xml | 72 ++++++ .../websockets/next/deployment/Callback.java | 20 +- .../KotlinContinuationCallbackArgument.java | 17 ++ .../next/deployment/WebSocketProcessor.java | 225 +++++++++++++----- .../websockets/next/test/utils/WSClient.java | 7 + .../websockets/next/test/kotlin/BinaryEcho.kt | 14 ++ .../next/test/kotlin/BinaryEchoSuspend.kt | 15 ++ .../websockets/next/test/kotlin/Echo.kt | 12 + .../next/test/kotlin/EchoSuspend.kt | 14 ++ .../test/kotlin/KotlinWebSocketClientTest.kt | 93 ++++++++ .../KotlinWebSocketSessionContextTest.kt | 85 +++++++ .../KotlinWebSocketSuspendingClientTest.kt | 99 ++++++++ .../next/test/kotlin/KotlinWebSocketTest.kt | 75 ++++++ .../websockets/next/test/kotlin/Message.kt | 5 + extensions/websockets-next/kotlin/pom.xml | 115 +++++++++ .../kotlin/ApplicationCoroutineScope.kt | 18 ++ .../next/runtime/kotlin/CoroutineInvoker.kt | 37 +++ .../next/runtime/kotlin/VertxDispatcher.kt | 24 ++ extensions/websockets-next/pom.xml | 1 + extensions/websockets-next/runtime/pom.xml | 4 + .../quarkus/arc/processor/KotlinDotNames.java | 7 +- .../io/quarkus/arc/processor/KotlinUtils.java | 25 ++ 25 files changed, 941 insertions(+), 79 deletions(-) create mode 100644 extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSessionContextTest.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt create mode 100644 extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt create mode 100644 extensions/websockets-next/kotlin/pom.xml create mode 100644 extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt create mode 100644 extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt create mode 100644 extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt diff --git a/bom/application/pom.xml b/bom/application/pom.xml index d5befe87004f0..cd8cc958ec44b 100644 --- a/bom/application/pom.xml +++ b/bom/application/pom.xml @@ -2166,6 +2166,11 @@ quarkus-websockets-next-deployment ${project.version} + + io.quarkus + quarkus-websockets-next-kotlin + ${project.version} + io.quarkus quarkus-undertow-spi diff --git a/docs/src/main/asciidoc/kotlin.adoc b/docs/src/main/asciidoc/kotlin.adoc index 84e2e61be86fd..2ff3227ba340a 100644 --- a/docs/src/main/asciidoc/kotlin.adoc +++ b/docs/src/main/asciidoc/kotlin.adoc @@ -504,6 +504,9 @@ The following extensions provide support for Kotlin Coroutines by allowing the u |`quarkus-vertx` |Support is provided for `@ConsumeEvent` methods +|`quarkus-websockets-next` +|Support is provided for server-side and client-side endpoint methods + |=== === Kotlin coroutines and Mutiny diff --git a/docs/src/main/asciidoc/websockets-next-reference.adoc b/docs/src/main/asciidoc/websockets-next-reference.adoc index ff4bf6b790da1..3ccf664f55127 100644 --- a/docs/src/main/asciidoc/websockets-next-reference.adoc +++ b/docs/src/main/asciidoc/websockets-next-reference.adoc @@ -180,7 +180,7 @@ The session context remains active until the `@OnClose` method completes executi In cases where a WebSocket endpoint does not declare an `@OnOpen` method, the session context is still created. It remains active until the connection terminates, regardless of the presence of an `@OnClose` method. -Methods annotated with `@OnTextMessage,` `@OnBinaryMessage,` `@OnOpen`, and `@OnClose` also have the request scoped activated for the duration of the method execution (until it produced its result). +Methods annotated with `@OnTextMessage,` `@OnBinaryMessage,` `@OnOpen`, and `@OnClose` also have the request scope activated for the duration of the method execution (until it produced its result). [[callback-methods]] === Callback methods @@ -224,6 +224,7 @@ Here are the rules governing execution: * When `@RunOnVirtualThread` is employed, each invocation spawns a new virtual thread. * Methods returning `CompletionStage`, `Uni` and `Multi` are considered non-blocking. * Methods returning `void` or plain objects are considered blocking. +* Kotlin `suspend` functions are considered non-blocking. ==== Method parameters @@ -248,10 +249,12 @@ The method must subscribe to the `Multi` to receive these items (or return a Mul Methods annotated with `@OnTextMessage` or `@OnBinaryMessage` can return various types to handle WebSocket communication efficiently: * `void`: Indicates a blocking method where no explicit response is sent back to the client. -* `Uni`: Denotes a non-blocking method where the completion of the returned Uni signifies the end of processing. No explicit response is sent back to the client. +* `Uni`: Denotes a non-blocking method where the completion of the returned `Uni` signifies the end of processing. No explicit response is sent back to the client. * An object of type `X` represents a blocking method in which the returned object is serialized and sent back to the client as a response. * `Uni`: Specifies a non-blocking method where the item emitted by the non-null `Uni` is sent to the client as a response. * `Multi`: Indicates a non-blocking method where the items emitted by the non-null `Multi` are sequentially sent to the client until completion or cancellation. +* Kotlin `suspend` function returning `Unit`: Denotes a non-blocking method where no explicit response is sent back to the client. +* Kotlin `suspend` function returning `X`: Specifies a non-blocking method where the returned item is sent to the client as a response. Here are some examples of these methods: @@ -381,6 +384,8 @@ The supported return types for `@OnOpen` methods are: * An object of type `X`: Represents a blocking method where the returned object is serialized and sent back to the client. * `Uni`: Specifies a non-blocking method where the item emitted by the non-null `Uni` is sent to the client. * `Multi`: Indicates a non-blocking method where the items emitted by the non-null `Multi` are sequentially sent to the client until completion or cancellation. +* Kotlin `suspend` function returning `Unit`: Denotes a non-blocking method where no explicit message is sent back to the client. +* Kotlin `suspend` function returning `X`: Specifies a non-blocking method where the returned item is sent to the client. Items sent to the client are <> except for the `String`, `io.vertx.core.json.JsonObject`, `io.vertx.core.json.JsonArray`, `io.vertx.core.buffer.Buffer`, and `byte[]` types. In the case of `Multi`, Quarkus subscribes to the returned `Multi` and writes the items to the `WebSocket` as they are emitted. @@ -391,6 +396,7 @@ For `@OnClose` methods, the supported return types include: * `void`: The method is considered blocking. * `Uni`: The method is considered non-blocking. +* Kotlin `suspend` function returning `Unit`: The method is considered non-blocking. NOTE: `@OnClose` methods declared on a server endpoint cannot send items to the connected client by returning objects. They can only send messages to the other clients by using the `WebSocketConnection` object. @@ -424,7 +430,7 @@ Alternatively, an error message can be logged or no operation performed. The WebSocket Next extension supports automatic serialization and deserialization of messages. -Objects of type `String`, `JsonObject`, `JsonArray`, `Buffer`, and `byte[]` are sent as-is and by-pass the serialization and deserialization. +Objects of type `String`, `JsonObject`, `JsonArray`, `Buffer`, and `byte[]` are sent as-is and bypass the serialization and deserialization. When no codec is provided, the serialization and deserialization convert the message from/to JSON automatically. When you need to customize the serialization and deserialization, you can provide a custom codec. @@ -485,7 +491,7 @@ Item find(Item item) { //.... } ---- -1. Specify the codec to use for both the deserialization of the incoming message +1. Specify the codec to use for the deserialization of the incoming message 2. Specify the codec to use for the serialization of the outgoing message === Ping/pong messages @@ -509,7 +515,7 @@ quarkus.websockets-next.server.auto-ping-interval=2 <1> The `@OnPongMessage` annotation is used to define a callback that consumes pong messages sent from the client/server. An endpoint must declare at most one method annotated with `@OnPongMessage`. -The callback method must return either `void` or `Uni`, and it must accept a single parameter of type `Buffer`. +The callback method must return either `void` or `Uni` (or be a Kotlin `suspend` function returning `Unit`), and it must accept a single parameter of type `Buffer`. [source,java] ---- @@ -539,18 +545,18 @@ This extension reuses the _main_ HTTP server. Thus, the configuration of the WebSocket server is done in the `quarkus.http.` configuration section. -WebSocket paths configured within the application are concatenated with the root path defined by `quarkus.http.root` (which defaults to /). +WebSocket paths configured within the application are concatenated with the root path defined by `quarkus.http.root` (which defaults to `/`). This concatenation ensures that WebSocket endpoints are appropriately positioned within the application's URL structure. Refer to the xref:http-reference.adoc[HTTP guide] for more details. === Sub-websockets endpoints -A `@WebSocket` endpoint can encapsulate static nested classes, which are also annotated with /`@WebSocket` and represent _sub-websockets_. -The resulting path of these sub-web sockets concatenates the path from the enclosing class and the nested class. +A `@WebSocket` endpoint can encapsulate static nested classes, which are also annotated with `@WebSocket` and represent _sub-websockets_. +The resulting path of these sub-websockets concatenates the path from the enclosing class and the nested class. The resulting path is normalized, following the HTTP URL rules. -Sub-web sockets inherit access to the path parameters declared in the `@WebSocket` annotation of both the enclosing and nested classes. +Sub-websockets inherit access to the path parameters declared in the `@WebSocket` annotation of both the enclosing and nested classes. The `consumePrimary` method within the enclosing class can access the `version` parameter in the following example. Meanwhile, the `consumeNested` method within the nested class can access both `version` and `id` parameters: @@ -884,9 +890,9 @@ public class MyBean { <4> Set the execution model for callback handlers. By default, the callback may block the current thread. However in this case, the callback is executed on the event loop and may not block the current thread. <5> The lambda will be called for every text message sent from the server. -The basic connector is closed to a low-level API and is reserved for advanced users. +The basic connector is closer to a low-level API and is reserved for advanced users. However, unlike others low-level WebSocket clients, it is still a CDI bean and can be injected in other beans. -It also provides a way to configure the execution model of the callbacks, ensuring the optimal integration with the rest of Quarkus. +It also provides a way to configure the execution model of the callbacks, ensuring optimal integration with the rest of Quarkus. [[ws-client-connection]] === WebSocket client connection diff --git a/extensions/websockets-next/deployment/pom.xml b/extensions/websockets-next/deployment/pom.xml index 03928d9cbb66e..3b62c8108947a 100644 --- a/extensions/websockets-next/deployment/pom.xml +++ b/extensions/websockets-next/deployment/pom.xml @@ -65,10 +65,56 @@ smallrye-certificate-generator-junit5 test + + org.jetbrains.kotlin + kotlin-stdlib + test + + + org.jetbrains.kotlinx + kotlinx-coroutines-core + test + + + io.smallrye.reactive + mutiny-kotlin + test + + + org.jetbrains.kotlin + kotlin-maven-plugin + ${kotlin.version} + + + compile + + compile + + + + ${project.basedir}/src/main/kotlin + ${project.basedir}/src/main/java + + + + + test-compile + + test-compile + + + + ${project.basedir}/src/test/kotlin + ${project.basedir}/src/test/java + + + + + maven-compiler-plugin @@ -80,6 +126,32 @@ + + + + default-compile + none + + + + default-testCompile + none + + + java-compile + compile + + compile + + + + java-test-compile + test-compile + + testCompile + + + maven-surefire-plugin diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java index 575230301ae2f..5c7df46e344b2 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java @@ -16,7 +16,10 @@ import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; import io.quarkus.arc.processor.Annotations; +import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.DotNames; +import io.quarkus.arc.processor.KotlinDotNames; +import io.quarkus.arc.processor.KotlinUtils; import io.quarkus.gizmo.BytecodeCreator; import io.quarkus.gizmo.FieldDescriptor; import io.quarkus.gizmo.ResultHandle; @@ -35,15 +38,17 @@ public class Callback { public final Target target; public final String endpointPath; public final AnnotationInstance annotation; + public final BeanInfo bean; public final MethodInfo method; public final ExecutionModel executionModel; public final MessageType messageType; public final List arguments; - public Callback(Target target, AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath, IndexView index) { + public Callback(Target target, AnnotationInstance annotation, BeanInfo bean, MethodInfo method, + ExecutionModel executionModel, CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, IndexView index) { this.target = target; + this.bean = bean; this.method = method; this.annotation = annotation; this.executionModel = executionModel; @@ -104,6 +109,15 @@ public boolean isReturnTypeMulti() { return WebSocketDotNames.MULTI.equals(returnType().name()); } + public boolean isKotlinSuspendFunction() { + return KotlinUtils.isKotlinSuspendMethod(method); + } + + public boolean isKotlinSuspendFunctionReturningUnit() { + return KotlinUtils.isKotlinSuspendMethod(method) + && KotlinUtils.getKotlinSuspendMethodResult(method).name().equals(KotlinDotNames.UNIT); + } + public boolean acceptsMessage() { return messageType != MessageType.NONE; } diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java new file mode 100644 index 0000000000000..fa6eda14b4f14 --- /dev/null +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java @@ -0,0 +1,17 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.arc.processor.KotlinUtils; +import io.quarkus.gizmo.ResultHandle; + +public class KotlinContinuationCallbackArgument implements CallbackArgument { + @Override + public boolean matches(ParameterContext context) { + return KotlinUtils.isKotlinContinuationParameter(context.parameter()); + } + + @Override + public ResultHandle get(InvocationBytecodeContext context) { + // the actual value is provided by the invoker + return context.bytecode().loadNull(); + } +} diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index 4ed22a9985603..a6e5c44f375d7 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -17,18 +17,21 @@ import java.util.stream.Collectors; import jakarta.enterprise.context.SessionScoped; +import jakarta.enterprise.invoke.Invoker; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationTransformation; import org.jboss.jandex.AnnotationValue; import org.jboss.jandex.ClassInfo; import org.jboss.jandex.ClassInfo.NestingType; +import org.jboss.jandex.ClassType; import org.jboss.jandex.DotName; import org.jboss.jandex.IndexView; import org.jboss.jandex.MethodInfo; import org.jboss.jandex.PrimitiveType; import org.jboss.jandex.Type; import org.jboss.jandex.Type.Kind; +import org.objectweb.asm.Opcodes; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem; @@ -39,6 +42,7 @@ import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem; import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; import io.quarkus.arc.deployment.CustomScopeBuildItem; +import io.quarkus.arc.deployment.InvokerFactoryBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.SyntheticBeansRuntimeInitBuildItem; import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; @@ -50,7 +54,11 @@ import io.quarkus.arc.processor.BuiltinScope; import io.quarkus.arc.processor.DotNames; import io.quarkus.arc.processor.InjectionPointInfo; +import io.quarkus.arc.processor.InvokerInfo; +import io.quarkus.arc.processor.KotlinDotNames; +import io.quarkus.arc.processor.KotlinUtils; import io.quarkus.arc.processor.Types; +import io.quarkus.bootstrap.classloading.QuarkusClassLoader; import io.quarkus.deployment.Capabilities; import io.quarkus.deployment.Capability; import io.quarkus.deployment.GeneratedClassGizmoAdaptor; @@ -68,6 +76,7 @@ import io.quarkus.gizmo.CatchBlockCreator; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; +import io.quarkus.gizmo.FieldCreator; import io.quarkus.gizmo.FunctionCreator; import io.quarkus.gizmo.MethodCreator; import io.quarkus.gizmo.MethodDescriptor; @@ -108,6 +117,8 @@ import io.quarkus.websockets.next.runtime.WebSocketHttpServerOptionsCustomizer; import io.quarkus.websockets.next.runtime.WebSocketServerRecorder; import io.quarkus.websockets.next.runtime.WebSocketSessionContext; +import io.quarkus.websockets.next.runtime.kotlin.ApplicationCoroutineScope; +import io.quarkus.websockets.next.runtime.kotlin.CoroutineInvoker; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.groups.UniCreate; @@ -190,6 +201,18 @@ void additionalBeans(CombinedIndexBuildItem combinedIndex, BuildProducer additionalBean) { + if (!QuarkusClassLoader.isClassPresentAtRuntime("kotlinx.coroutines.CoroutineScope")) { + return; + } + + additionalBean.produce(AdditionalBeanBuildItem.builder() + .addBeanClass(ApplicationCoroutineScope.class) + .setUnremovable() + .build()); + } + @BuildStep ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuildItem phase) { return new ContextConfiguratorBuildItem(phase.getContext() @@ -211,6 +234,7 @@ void builtinCallbackArguments(BuildProducer providers providers.produce(new CallbackArgumentBuildItem(new HandshakeRequestCallbackArgument())); providers.produce(new CallbackArgumentBuildItem(new ErrorCallbackArgument())); providers.produce(new CallbackArgumentBuildItem(new CloseReasonCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new KotlinContinuationCallbackArgument())); } @BuildStep @@ -228,7 +252,7 @@ void collectGlobalErrorHandlers(BeanArchiveIndexBuildItem beanArchiveIndex, ClassInfo beanClass = bean.getTarget().get().asClass(); if (beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET) == null && beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET_CLIENT) == null) { - for (Callback callback : findErrorHandlers(Target.UNDEFINED, index, beanClass, callbackArguments, + for (Callback callback : findErrorHandlers(Target.UNDEFINED, index, bean, beanClass, callbackArguments, transformedAnnotations, null)) { GlobalErrorHandler errorHandler = new GlobalErrorHandler(bean, callback); DotName errorTypeName = callback.argumentType(ErrorCallbackArgument::isError).name(); @@ -327,19 +351,17 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, inboundProcessingMode = webSocketClientAnnotation.value("inboundProcessingMode"); } - Callback onOpen = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN, - callbackArguments, transformedAnnotations, path); - Callback onTextMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, - WebSocketDotNames.ON_TEXT_MESSAGE, - callbackArguments, transformedAnnotations, path); - Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, + Callback onOpen = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_OPEN, callbackArguments, transformedAnnotations, path); + Callback onTextMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_TEXT_MESSAGE, callbackArguments, transformedAnnotations, path); + Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, WebSocketDotNames.ON_BINARY_MESSAGE, callbackArguments, transformedAnnotations, path); - Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, - WebSocketDotNames.ON_PONG_MESSAGE, - callbackArguments, transformedAnnotations, path, + Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_PONG_MESSAGE, callbackArguments, transformedAnnotations, path, this::validateOnPongMessage); - Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE, - callbackArguments, transformedAnnotations, path, + Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_CLOSE, callbackArguments, transformedAnnotations, path, this::validateOnClose); if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) { throw new WebSocketServerException( @@ -354,7 +376,7 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, onBinaryMessage, onPongMessage, onClose, - findErrorHandlers(target, index, beanClass, callbackArguments, transformedAnnotations, path))); + findErrorHandlers(target, index, bean, beanClass, callbackArguments, transformedAnnotations, path))); } } @@ -383,6 +405,7 @@ public void generateEndpoints(BeanArchiveIndexBuildItem index, List generatedClasses, BuildProducer generatedEndpoints, BuildProducer reflectiveClasses) { @@ -409,7 +432,8 @@ public String apply(String name) { // and delegates callback invocations to the endpoint bean String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations, index.getIndex(), classOutput, globalErrorHandlers, - endpoint.isClient() ? CLIENT_ENDPOINT_SUFFIX : SERVER_ENDPOINT_SUFFIX); + endpoint.isClient() ? CLIENT_ENDPOINT_SUFFIX : SERVER_ENDPOINT_SUFFIX, + invokerFactory); reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build()); generatedEndpoints .produce(new GeneratedEndpointBuildItem(endpoint.id, endpoint.bean.getImplClazz().name().toString(), @@ -626,9 +650,16 @@ static String getPathPrefix(IndexView index, DotName enclosingClassName) { } private void validateOnPongMessage(Callback callback) { - if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { - throw new WebSocketServerException( - "@OnPongMessage callback must return void or Uni: " + callback.asString()); + if (KotlinUtils.isKotlinMethod(callback.method)) { + if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) { + throw new WebSocketServerException( + "@OnPongMessage callback must return Unit: " + callback.asString()); + } + } else { + if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { + throw new WebSocketServerException( + "@OnPongMessage callback must return void or Uni: " + callback.asString()); + } } Type messageType = callback.argumentType(MessageCallbackArgument::isMessage); if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) { @@ -639,9 +670,16 @@ private void validateOnPongMessage(Callback callback) { } private void validateOnClose(Callback callback) { - if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { - throw new WebSocketServerException( - "@OnClose callback must return void or Uni: " + callback.asString()); + if (KotlinUtils.isKotlinMethod(callback.method)) { + if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) { + throw new WebSocketServerException( + "@OnClose callback must return Unit: " + callback.asString()); + } + } else { + if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { + throw new WebSocketServerException( + "@OnClose callback must return void or Uni: " + callback.asString()); + } } } @@ -710,7 +748,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, IndexView index, ClassOutput classOutput, GlobalErrorHandlersBuildItem globalErrorHandlers, - String endpointSuffix) { + String endpointSuffix, + InvokerFactoryBuildItem invokerFactory) { ClassInfo implClazz = endpoint.bean.getImplClazz(); String baseName; if (implClazz.enclosingClass() != null) { @@ -733,7 +772,6 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, Codecs.class, ContextSupport.class, SecuritySupport.class), constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1), constructor.getMethodParam(2), constructor.getMethodParam(3)); - constructor.returnNull(); MethodCreator inboundProcessingMode = endpointCreator.getMethodCreator("inboundProcessingMode", InboundProcessingMode.class); @@ -751,7 +789,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, // Call the business method TryBlock tryBlock = onErrorTryBlock(doOnOpen, doOnOpen.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Open", tryBlock, + beanInstance, args, invokerFactory); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); MethodCreator onOpenExecutionModel = endpointCreator.getMethodCreator("onOpenExecutionModel", @@ -759,12 +798,12 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, onOpenExecutionModel.returnValue(onOpenExecutionModel.load(callback.executionModel)); } - generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations, - index, globalErrorHandlers); - generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations, index, - globalErrorHandlers); - generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations, index, - globalErrorHandlers); + generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onBinaryMessage, argumentProviders, + transformedAnnotations, index, globalErrorHandlers, invokerFactory); + generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onTextMessage, argumentProviders, + transformedAnnotations, index, globalErrorHandlers, invokerFactory); + generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onPongMessage, argumentProviders, + transformedAnnotations, index, globalErrorHandlers, invokerFactory); if (endpoint.onClose != null) { Callback callback = endpoint.onClose; @@ -775,7 +814,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, // Call the business method TryBlock tryBlock = onErrorTryBlock(doOnClose, doOnClose.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Close", tryBlock, + beanInstance, args, invokerFactory); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); MethodCreator onCloseExecutionModel = endpointCreator.getMethodCreator("onCloseExecutionModel", @@ -783,15 +823,19 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, onCloseExecutionModel.returnValue(onCloseExecutionModel.load(callback.executionModel)); } - generateOnError(endpointCreator, endpoint, argumentProviders, transformedAnnotations, globalErrorHandlers, index); + generateOnError(endpointCreator, constructor, endpoint, transformedAnnotations, globalErrorHandlers, index, + invokerFactory); + + // we write into the constructor when generating callback invokers, so need to finish it late + constructor.returnVoid(); endpointCreator.close(); return generatedName.replace('/', '.'); } - private static void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index) { + private static void generateOnError(ClassCreator endpointCreator, MethodCreator constructor, + WebSocketEndpointBuildItem endpoint, TransformedAnnotationsBuildItem transformedAnnotations, + GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index, InvokerFactoryBuildItem invokerFactory) { Map errors = new HashMap<>(); List throwableInfos = new ArrayList<>(); @@ -841,7 +885,8 @@ private static void generateOnError(ClassCreator endpointCreator, WebSocketEndpo MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), endpointThis, funBytecode.load(throwableInfo.bean().getIdentifier())); ResultHandle[] args = callback.generateArguments(endpointThis, tryBlock, transformedAnnotations, index); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Error", tryBlock, + beanInstance, args, invokerFactory); encodeAndReturnResult(endpointThis, tryBlock, callback, globalErrorHandlers, endpoint, ret); // return doErrorExecute() @@ -891,9 +936,10 @@ record GlobalErrorHandler(BeanInfo bean, Callback callback) { } - private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback, + private static void generateOnMessage(ClassCreator endpointCreator, MethodCreator constructor, + WebSocketEndpointBuildItem endpoint, Callback callback, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers) { + IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers, InvokerFactoryBuildItem invokerFactory) { if (callback == null) { return; } @@ -924,8 +970,8 @@ private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEnd MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), tryBlock.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); // Call the business method - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, - args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, messageType, tryBlock, beanInstance, args, + invokerFactory); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); MethodCreator onMessageExecutionModel = endpointCreator.getMethodCreator("on" + messageType + "MessageExecutionModel", @@ -946,6 +992,34 @@ private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEnd } } + private static ResultHandle callBusinessMethod(ClassCreator clazz, MethodCreator constructor, Callback callback, + String messageType, BytecodeCreator bytecode, ResultHandle beanInstance, ResultHandle[] args, + InvokerFactoryBuildItem invokerFactory) { + // using CDI method invokers for Kotlin `suspend` functions to minimize custom bytecode generation + // other methods are invoked directly + if (KotlinUtils.isKotlinSuspendMethod(callback.method)) { + InvokerInfo invoker = invokerFactory.createInvoker(callback.bean, callback.method) + .withInvocationWrapper(CoroutineInvoker.class, "inNewCoroutine") + .build(); + // create a field in the endpoint class and put the created invoker into it in the `constructor` + FieldCreator invokerField = clazz.getFieldCreator("invokerFor" + messageType, Invoker.class) + .setModifiers(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL); + constructor.writeInstanceField(invokerField.getFieldDescriptor(), constructor.getThis(), + constructor.newInstance(MethodDescriptor.ofConstructor(invoker.getClassName()))); + // use the invoker to invoke the business method in the endpoint method (`bytecode`) + ResultHandle invokerHandle = bytecode.readInstanceField(invokerField.getFieldDescriptor(), bytecode.getThis()); + ResultHandle argsArray = bytecode.newArray(Object.class, args.length); + for (int i = 0; i < args.length; i++) { + bytecode.writeArrayValue(argsArray, i, args[i]); + } + return bytecode.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Invoker.class, "invoke", Object.class, Object.class, Object[].class), + invokerHandle, beanInstance, argsArray); + } else { + return bytecode.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + } + } + private static TryBlock uniFailureTryBlock(BytecodeCreator method) { TryBlock tryBlock = method.tryBlock(); CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class); @@ -1070,8 +1144,15 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // ---------------------- // === Binary message === // ---------------------- - if (callback.isReturnTypeUni()) { - Type messageType = callback.returnType().asParameterizedType().arguments().get(0); + if (callback.isReturnTypeUni() || callback.isKotlinSuspendFunction()) { + Type messageType = callback.isReturnTypeUni() + ? callback.returnType().asParameterizedType().arguments().get(0) + : KotlinUtils.getKotlinSuspendMethodResult(callback.method); + if (messageType.name().equals(KotlinDotNames.UNIT)) { + value = method.invokeInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "replaceWithVoid", Uni.class), + value); + messageType = ClassType.create(WebSocketDotNames.VOID); + } if (messageType.name().equals(WebSocketDotNames.VOID)) { // Uni return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers); @@ -1082,8 +1163,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // }); FunctionCreator fun = method.createFunction(Function.class); BytecodeCreator funBytecode = fun.getBytecode(); - ResultHandle buffer = encodeBuffer(funBytecode, - callback.returnType().asParameterizedType().arguments().get(0), + ResultHandle buffer = encodeBuffer(funBytecode, messageType, funBytecode.getMethodParam(0), endpointThis, callback); funBytecode.returnValue(funBytecode.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, @@ -1130,8 +1210,15 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // ---------------------- // === Text message === // ---------------------- - if (callback.isReturnTypeUni()) { - Type messageType = callback.returnType().asParameterizedType().arguments().get(0); + if (callback.isReturnTypeUni() || callback.isKotlinSuspendFunction()) { + Type messageType = callback.isReturnTypeUni() + ? callback.returnType().asParameterizedType().arguments().get(0) + : KotlinUtils.getKotlinSuspendMethodResult(callback.method); + if (messageType.name().equals(KotlinDotNames.UNIT)) { + value = method.invokeInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "replaceWithVoid", Uni.class), + value); + messageType = ClassType.create(WebSocketDotNames.VOID); + } if (messageType.name().equals(WebSocketDotNames.VOID)) { // Uni return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers); @@ -1142,7 +1229,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // }); FunctionCreator fun = method.createFunction(Function.class); BytecodeCreator funBytecode = fun.getBytecode(); - ResultHandle text = encodeText(funBytecode, callback.returnType().asParameterizedType().arguments().get(0), + ResultHandle text = encodeText(funBytecode, messageType, funBytecode.getMethodParam(0), endpointThis, callback); funBytecode.returnValue(funBytecode.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, @@ -1267,9 +1354,8 @@ private static void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCre } } - static List findErrorHandlers(Target target, IndexView index, ClassInfo beanClass, - CallbackArgumentsBuildItem callbackArguments, - TransformedAnnotationsBuildItem transformedAnnotations, + static List findErrorHandlers(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { List annotations = findCallbackAnnotations(index, beanClass, WebSocketDotNames.ON_ERROR); if (annotations.isEmpty()) { @@ -1285,8 +1371,9 @@ static List findErrorHandlers(Target target, IndexView index, ClassInf .anyMatch(WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION::equals)) { target = Target.CLIENT; } - Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations), - callbackArguments, transformedAnnotations, endpointPath, index); + Callback callback = new Callback(target, annotation, bean, method, + executionModel(method, transformedAnnotations), callbackArguments, transformedAnnotations, + endpointPath, index); long errorArguments = callback.arguments.stream().filter(ca -> ca instanceof ErrorCallbackArgument).count(); if (errorArguments != 1) { throw new WebSocketException( @@ -1316,16 +1403,16 @@ private static List findCallbackAnnotations(IndexView index, return annotations; } - static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath) { - return findCallback(target, index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath, - null); + static Callback findCallback(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass, + DotName annotationName, CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { + return findCallback(target, index, bean, beanClass, annotationName, callbackArguments, + transformedAnnotations, endpointPath, null); } - private static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath, + private static Callback findCallback(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass, + DotName annotationName, CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, Consumer validator) { List annotations = findCallbackAnnotations(index, beanClass, annotationName); if (annotations.isEmpty()) { @@ -1333,9 +1420,9 @@ private static Callback findCallback(Target target, IndexView index, ClassInfo b } else if (annotations.size() == 1) { AnnotationInstance annotation = annotations.get(0); MethodInfo method = annotation.target().asMethod(); - Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations), - callbackArguments, - transformedAnnotations, endpointPath, index); + Callback callback = new Callback(target, annotation, bean, method, + executionModel(method, transformedAnnotations), callbackArguments, transformedAnnotations, + endpointPath, index); long messageArguments = callback.arguments.stream().filter(ca -> ca instanceof MessageCallbackArgument).count(); if (callback.acceptsMessage()) { if (messageArguments > 1) { @@ -1370,6 +1457,14 @@ private static Callback findCallback(Target target, IndexView index, ClassInfo b } private static ExecutionModel executionModel(MethodInfo method, TransformedAnnotationsBuildItem transformedAnnotations) { + if (KotlinUtils.isKotlinSuspendMethod(method) + && (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.RUN_ON_VIRTUAL_THREAD) + || transformedAnnotations.hasAnnotation(method, WebSocketDotNames.BLOCKING) + || transformedAnnotations.hasAnnotation(method, WebSocketDotNames.NON_BLOCKING))) { + throw new WebSocketException("Kotlin `suspend` functions in WebSockets Next endpoints may not be " + + "annotated @Blocking, @NonBlocking or @RunOnVirtualThread: " + method); + } + if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.RUN_ON_VIRTUAL_THREAD)) { return ExecutionModel.VIRTUAL_THREAD; } else if (transformedAnnotations.hasAnnotation(method, WebSocketDotNames.BLOCKING)) { @@ -1382,6 +1477,10 @@ private static ExecutionModel executionModel(MethodInfo method, TransformedAnnot } static boolean hasBlockingSignature(MethodInfo method) { + if (KotlinUtils.isKotlinSuspendMethod(method)) { + return false; + } + switch (method.returnType().kind()) { case VOID: case CLASS: @@ -1414,6 +1513,8 @@ private static boolean isOnOpenWithBinaryReturnType(Callback callback) { Type returnType = callback.returnType(); if (callback.isReturnTypeUni() || callback.isReturnTypeMulti()) { returnType = callback.returnType().asParameterizedType().arguments().get(0); + } else if (callback.isKotlinSuspendFunction()) { + returnType = KotlinUtils.getKotlinSuspendMethodResult(callback.method); } return WebSocketDotNames.BUFFER.equals(returnType.name()) || (returnType.kind() == Kind.ARRAY && PrimitiveType.BYTE.equals(returnType.asArrayType().constituent())); diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java index 2f1974089db90..926c0d1b82d15 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java @@ -144,6 +144,13 @@ public Buffer sendAndAwaitReply(String message) { return messages.get(c); } + public Buffer sendAndAwaitReply(Buffer message) { + var c = messages.size(); + sendAndAwait(message); + Awaitility.await().until(() -> messages.size() > c); + return messages.get(c); + } + public boolean isClosed() { return socket.get().isClosed(); } diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt new file mode 100644 index 0000000000000..a42359791fd38 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt @@ -0,0 +1,14 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnBinaryMessage +import io.quarkus.websockets.next.WebSocket +import io.vertx.core.buffer.Buffer +import kotlinx.coroutines.delay + +@WebSocket(path = "/binary-echo") +class BinaryEcho { + @OnBinaryMessage + fun process(msg: Buffer): Buffer { + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt new file mode 100644 index 0000000000000..a2182d7607535 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt @@ -0,0 +1,15 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnBinaryMessage +import io.quarkus.websockets.next.WebSocket +import io.vertx.core.buffer.Buffer +import kotlinx.coroutines.delay + +@WebSocket(path = "/binary-echo-suspend") +class BinaryEchoSuspend { + @OnBinaryMessage + suspend fun process(msg: Buffer): Buffer { + delay(100) + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt new file mode 100644 index 0000000000000..fc60701b4991b --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt @@ -0,0 +1,12 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket + +@WebSocket(path = "/echo") +class Echo { + @OnTextMessage + fun process(msg: Message): Message { + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt new file mode 100644 index 0000000000000..92d27b2d9d85b --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt @@ -0,0 +1,14 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket +import kotlinx.coroutines.delay + +@WebSocket(path = "/echo-suspend") +class EchoSuspend { + @OnTextMessage + suspend fun process(msg: Message): Message { + delay(100) + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt new file mode 100644 index 0000000000000..9ec2cd00baf2c --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt @@ -0,0 +1,93 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.test.QuarkusUnitTest +import io.quarkus.test.common.http.TestHTTPResource +import io.quarkus.websockets.next.OnClose +import io.quarkus.websockets.next.OnOpen +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket +import io.quarkus.websockets.next.WebSocketClient +import io.quarkus.websockets.next.WebSocketConnector +import jakarta.inject.Inject +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.net.URI +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +class KotlinWebSocketClientTest { + companion object { + @RegisterExtension + val test = QuarkusUnitTest() + .withApplicationRoot { jar -> + jar.addClasses(ServerEndpoint::class.java, ClientEndpoint::class.java) + } + } + + @Inject + lateinit var connector: WebSocketConnector + + @TestHTTPResource("/") + lateinit var uri: URI + + @Test + fun test() { + val connection = connector.baseUri(uri).connectAndAwait() + connection.sendTextAndAwait("Hi!") + + assertTrue(ClientEndpoint.messagesLatch.await(5, TimeUnit.SECONDS)) + assertEquals("Hello there", ClientEndpoint.messages[0]) + assertEquals("Hi!", ClientEndpoint.messages[1]) + + connection.closeAndAwait() + assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + } + + @WebSocket(path = "/endpoint") + class ServerEndpoint { + companion object { + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnOpen + fun open(): String { + return "Hello there" + } + + @OnTextMessage + fun echo(message: String): String { + return message + } + + @OnClose + fun close() { + closedLatch.countDown() + } + } + + @WebSocketClient(path = "/endpoint") + class ClientEndpoint { + companion object { + val messages: MutableList = CopyOnWriteArrayList() + + val messagesLatch: CountDownLatch = CountDownLatch(2) + + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnTextMessage + fun onMessage(message: String) { + messages.add(message) + messagesLatch.countDown() + } + + @OnClose + fun onClose() { + closedLatch.countDown() + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSessionContextTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSessionContextTest.kt new file mode 100644 index 0000000000000..4fd8f29d1b03a --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSessionContextTest.kt @@ -0,0 +1,85 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.test.QuarkusUnitTest +import io.quarkus.test.common.http.TestHTTPResource +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket +import io.quarkus.websockets.next.test.utils.WSClient +import io.vertx.core.Vertx +import jakarta.enterprise.context.SessionScoped +import jakarta.inject.Inject +import kotlinx.coroutines.delay +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.net.URI +import java.util.UUID + +class KotlinWebSocketSessionContextTest { + companion object { + @RegisterExtension + val test = QuarkusUnitTest() + .withApplicationRoot { jar -> + jar.addClasses(MyData::class.java, Endpoint::class.java, WSClient::class.java) + } + } + + @Inject + lateinit var vertx: Vertx + + @TestHTTPResource("endpoint") + lateinit var endpoint: URI + + @Test + fun testEcho() { + WSClient.create(vertx).connect(endpoint).use { client1 -> + WSClient.create(vertx).connect(endpoint).use { client2 -> + var id1: String? = null + var id2: String? = null + for (i in 1..10) { + val client = if (i % 2 == 0) client1 else client2 + val req = "hello$i" + val resp = client.sendAndAwaitReply(req).toString().split(" ") + assertEquals(3, resp.size) + assertEquals(req, resp[0]) + assertEquals(resp[1], resp[2]) + if (i % 2 == 0) { + if (id1 == null) { + id1 = resp[1] + } + assertEquals(id1, resp[1]) + } else { + if (id2 == null) { + id2 = resp[1] + } + assertEquals(id2, resp[1]) + } + } + assertNotNull(id1) + assertNotNull(id2) + assertNotEquals(id1, id2) + } + } + } + + @SessionScoped + class MyData { + val id = UUID.randomUUID().toString() + } + + @WebSocket(path = "/endpoint") + class Endpoint { + @Inject + lateinit var data: MyData + + @OnTextMessage + suspend fun echo(message: String): String { + val id1 = data.id + delay(100) + val id2 = data.id + return "$message $id1 $id2" + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt new file mode 100644 index 0000000000000..9c4a8fdc9051a --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt @@ -0,0 +1,99 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.test.QuarkusUnitTest +import io.quarkus.test.common.http.TestHTTPResource +import io.quarkus.websockets.next.OnClose +import io.quarkus.websockets.next.OnOpen +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket +import io.quarkus.websockets.next.WebSocketClient +import io.quarkus.websockets.next.WebSocketConnector +import jakarta.inject.Inject +import kotlinx.coroutines.delay +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.net.URI +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +class KotlinWebSocketSuspendingClientTest { + companion object { + @RegisterExtension + val test = QuarkusUnitTest() + .withApplicationRoot { jar -> + jar.addClasses(ServerEndpoint::class.java, ClientEndpoint::class.java) + } + } + + @Inject + lateinit var connector: WebSocketConnector + + @TestHTTPResource("/") + lateinit var uri: URI + + @Test + fun test() { + val connection = connector.baseUri(uri).connectAndAwait() + connection.sendTextAndAwait("Hi!") + + assertTrue(ClientEndpoint.messagesLatch.await(5, TimeUnit.SECONDS)) + assertEquals("Hello there", ClientEndpoint.messages[0]) + assertEquals("Hi!", ClientEndpoint.messages[1]) + + connection.closeAndAwait() + assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + } + + @WebSocket(path = "/endpoint") + class ServerEndpoint { + companion object { + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnOpen + suspend fun open(): String { + delay(100) + return "Hello there" + } + + @OnTextMessage + suspend fun echo(message: String): String { + delay(100) + return message + } + + @OnClose + suspend fun close() { + delay(100) + closedLatch.countDown() + } + } + + @WebSocketClient(path = "/endpoint") + class ClientEndpoint { + companion object { + val messages: MutableList = CopyOnWriteArrayList() + + val messagesLatch: CountDownLatch = CountDownLatch(2) + + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnTextMessage + suspend fun onMessage(message: String) { + delay(100) + messages.add(message) + messagesLatch.countDown() + } + + @OnClose + suspend fun onClose() { + delay(100) + closedLatch.countDown() + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt new file mode 100644 index 0000000000000..12dd973ee9848 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt @@ -0,0 +1,75 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.test.QuarkusUnitTest +import io.quarkus.test.common.http.TestHTTPResource +import io.quarkus.websockets.next.test.utils.WSClient +import io.vertx.core.Vertx +import io.vertx.core.buffer.Buffer +import io.vertx.core.json.JsonObject +import jakarta.inject.Inject +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.net.URI + +class KotlinWebSocketTest { + companion object { + @RegisterExtension + val test = QuarkusUnitTest() + .withApplicationRoot { jar -> + jar.addClasses(Echo::class.java, EchoSuspend::class.java, BinaryEcho::class.java, + BinaryEchoSuspend::class.java, Message::class.java, WSClient::class.java) + } + } + + @Inject + lateinit var vertx: Vertx + + @TestHTTPResource("echo") + lateinit var echo: URI + + @TestHTTPResource("echo-suspend") + lateinit var echoSuspend: URI + + @TestHTTPResource("binary-echo") + lateinit var binaryEcho: URI + + @TestHTTPResource("binary-echo-suspend") + lateinit var binaryEchoSuspend: URI + + @Test + fun testEcho() { + doTest(echo) + } + + @Test + fun testEchoSuspend() { + doTest(echoSuspend) + } + + private fun doTest(uri: URI) { + WSClient.create(vertx).connect(uri).use { client -> + val req = JsonObject().put("msg", "hello") + val resp = client.sendAndAwaitReply(req.toString()).toJsonObject() + assertEquals(req, resp) + } + } + + @Test + fun testBinaryEcho() { + doTestBinary(binaryEcho) + } + + @Test + fun testBinaryEchoSuspend() { + doTestBinary(binaryEchoSuspend) + } + + private fun doTestBinary(uri: URI) { + WSClient.create(vertx).connect(uri).use { client -> + val req = Buffer.buffer("hello there!") + val resp = client.sendAndAwaitReply(req) + assertEquals(req, resp) + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt new file mode 100644 index 0000000000000..96563b27e7382 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt @@ -0,0 +1,5 @@ +package io.quarkus.websockets.next.test.kotlin + +import com.fasterxml.jackson.annotation.JsonCreator + +data class Message @JsonCreator constructor(var msg: String) diff --git a/extensions/websockets-next/kotlin/pom.xml b/extensions/websockets-next/kotlin/pom.xml new file mode 100644 index 0000000000000..bda3816be3a85 --- /dev/null +++ b/extensions/websockets-next/kotlin/pom.xml @@ -0,0 +1,115 @@ + + + + io.quarkus + quarkus-websockets-next-parent + 999-SNAPSHOT + + 4.0.0 + + quarkus-websockets-next-kotlin + Quarkus - WebSockets Next - Kotlin + + + + io.quarkus + quarkus-arc + + + io.quarkus + quarkus-vertx + + + org.jetbrains.kotlin + kotlin-stdlib-jdk8 + true + + + org.jetbrains.kotlinx + kotlinx-coroutines-jdk8 + true + + + io.smallrye.reactive + mutiny-kotlin + true + + + + + ${project.basedir}/src/main/kotlin + ${project.basedir}/src/test/kotlin + + + org.jetbrains.kotlin + kotlin-maven-plugin + ${kotlin.version} + + + compile + + compile + + + + test-compile + + test-compile + + + + + ${maven.compiler.target} + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${project.version} + + + + + + + default-compile + none + + + + default-testCompile + none + + + java-compile + compile + + compile + + + + java-test-compile + test-compile + + testCompile + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + false + + + + + diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt new file mode 100644 index 0000000000000..78ee39b924706 --- /dev/null +++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt @@ -0,0 +1,18 @@ +package io.quarkus.websockets.next.runtime.kotlin + +import jakarta.annotation.PreDestroy +import jakarta.inject.Singleton +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel + +@Singleton +class ApplicationCoroutineScope : CoroutineScope, AutoCloseable { + override val coroutineContext: CoroutineContext = SupervisorJob() + + @PreDestroy + override fun close() { + coroutineContext.cancel() + } +} diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt new file mode 100644 index 0000000000000..5a6b6f8e28a77 --- /dev/null +++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt @@ -0,0 +1,37 @@ +package io.quarkus.websockets.next.runtime.kotlin + +import io.quarkus.arc.Arc +import io.smallrye.mutiny.Uni +import io.smallrye.mutiny.coroutines.asUni +import io.vertx.core.Vertx +import jakarta.enterprise.invoke.Invoker +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException +import kotlin.coroutines.suspendCoroutine +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.async + +object CoroutineInvoker { + @JvmStatic + @OptIn(ExperimentalCoroutinesApi::class) + fun inNewCoroutine(instance: T, arguments: Array, invoker: Invoker): Uni { + val coroutineScope = Arc.container().instance(ApplicationCoroutineScope::class.java).get() + val dispatcher: CoroutineDispatcher = + Vertx.currentContext()?.let(::VertxDispatcher) + ?: throw IllegalStateException("No Vertx context found") + + return coroutineScope + .async(context = dispatcher) { + suspendCoroutine { continuation -> + arguments[arguments.size - 1] = continuation + try { + continuation.resume(invoker.invoke(instance, arguments)) + } catch (e: Exception) { + continuation.resumeWithException(e) + } + } + } + .asUni() + } +} diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt new file mode 100644 index 0000000000000..1377b1835fe6c --- /dev/null +++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt @@ -0,0 +1,24 @@ +package io.quarkus.websockets.next.runtime.kotlin + +import io.quarkus.arc.Arc +import io.vertx.core.Context +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineDispatcher + +class VertxDispatcher(private val vertxContext: Context) : CoroutineDispatcher() { + override fun dispatch(context: CoroutineContext, block: Runnable) { + val requestContext = Arc.container().requestContext() + vertxContext.runOnContext { + if (requestContext.isActive) { + block.run() + } else { + try { + requestContext.activate() + block.run() + } finally { + requestContext.terminate() + } + } + } + } +} diff --git a/extensions/websockets-next/pom.xml b/extensions/websockets-next/pom.xml index 5ff9318b3e765..1e149d244734f 100644 --- a/extensions/websockets-next/pom.xml +++ b/extensions/websockets-next/pom.xml @@ -16,6 +16,7 @@ deployment runtime + kotlin diff --git a/extensions/websockets-next/runtime/pom.xml b/extensions/websockets-next/runtime/pom.xml index a72d7832edd42..de16a0cfaacaa 100644 --- a/extensions/websockets-next/runtime/pom.xml +++ b/extensions/websockets-next/runtime/pom.xml @@ -34,6 +34,10 @@ io.quarkus quarkus-tls-registry + + io.quarkus + quarkus-websockets-next-kotlin + io.quarkus.security diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java index cef74e824cf2d..37ab47604ba2d 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java @@ -2,8 +2,9 @@ import org.jboss.jandex.DotName; -class KotlinDotNames { - static final DotName METADATA = DotName.createSimple("kotlin.Metadata"); +public final class KotlinDotNames { + public static final DotName METADATA = DotName.createSimple("kotlin.Metadata"); + public static final DotName UNIT = DotName.createSimple("kotlin.Unit"); - static final DotName CONTINUATION = DotName.createSimple("kotlin.coroutines.Continuation"); + public static final DotName CONTINUATION = DotName.createSimple("kotlin.coroutines.Continuation"); } diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java index 1628a83b3adef..966a7fc2c57ab 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java @@ -4,6 +4,7 @@ import org.jboss.jandex.ClassInfo; import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.MethodParameterInfo; import org.jboss.jandex.Type; public class KotlinUtils { @@ -27,6 +28,10 @@ public static boolean isKotlinSuspendMethod(MethodInfo method) { return KotlinDotNames.CONTINUATION.equals(lastParameter.name()); } + public static boolean isKotlinContinuationParameter(MethodParameterInfo parameter) { + return isKotlinSuspendMethod(parameter.method()) && KotlinDotNames.CONTINUATION.equals(parameter.type().name()); + } + public static boolean isNoninterceptableKotlinMethod(MethodInfo method) { // the Kotlin compiler generates somewhat streamlined bytecode when it determines // that a `suspend` method cannot be overridden, and that bytecode doesn't work @@ -36,4 +41,24 @@ public static boolean isNoninterceptableKotlinMethod(MethodInfo method) { return isKotlinSuspendMethod(method) && (Modifier.isFinal(method.flags()) || Modifier.isFinal(method.declaringClass().flags())); } + + public static Type getKotlinSuspendMethodResult(MethodInfo method) { + if (!isKotlinSuspendMethod(method)) { + throw new IllegalArgumentException("Not a suspend function: " + method); + } + + Type lastParameter = method.parameterType(method.parametersCount() - 1); + if (lastParameter.kind() != Type.Kind.PARAMETERIZED_TYPE) { + throw new IllegalArgumentException("Continuation parameter type not parameterized: " + lastParameter); + } + Type resultType = lastParameter.asParameterizedType().arguments().get(0); + if (resultType.kind() != Type.Kind.WILDCARD_TYPE) { + throw new IllegalArgumentException("Continuation parameter type argument not wildcard: " + resultType); + } + Type lowerBound = resultType.asWildcardType().superBound(); + if (lowerBound == null) { + throw new IllegalArgumentException("Continuation parameter type argument without lower bound: " + resultType); + } + return lowerBound; + } }