diff --git a/bom/application/pom.xml b/bom/application/pom.xml index 6a4cbeaf71ea30..8ab4c8af1ebf47 100644 --- a/bom/application/pom.xml +++ b/bom/application/pom.xml @@ -2166,6 +2166,11 @@ quarkus-websockets-next-deployment ${project.version} + + io.quarkus + quarkus-websockets-next-kotlin + ${project.version} + io.quarkus quarkus-undertow-spi diff --git a/docs/src/main/asciidoc/kotlin.adoc b/docs/src/main/asciidoc/kotlin.adoc index 84e2e61be86fd3..2ff3227ba340aa 100644 --- a/docs/src/main/asciidoc/kotlin.adoc +++ b/docs/src/main/asciidoc/kotlin.adoc @@ -504,6 +504,9 @@ The following extensions provide support for Kotlin Coroutines by allowing the u |`quarkus-vertx` |Support is provided for `@ConsumeEvent` methods +|`quarkus-websockets-next` +|Support is provided for server-side and client-side endpoint methods + |=== === Kotlin coroutines and Mutiny diff --git a/extensions/websockets-next/deployment/pom.xml b/extensions/websockets-next/deployment/pom.xml index 03928d9cbb66e2..3b62c8108947ac 100644 --- a/extensions/websockets-next/deployment/pom.xml +++ b/extensions/websockets-next/deployment/pom.xml @@ -65,10 +65,56 @@ smallrye-certificate-generator-junit5 test + + org.jetbrains.kotlin + kotlin-stdlib + test + + + org.jetbrains.kotlinx + kotlinx-coroutines-core + test + + + io.smallrye.reactive + mutiny-kotlin + test + + + org.jetbrains.kotlin + kotlin-maven-plugin + ${kotlin.version} + + + compile + + compile + + + + ${project.basedir}/src/main/kotlin + ${project.basedir}/src/main/java + + + + + test-compile + + test-compile + + + + ${project.basedir}/src/test/kotlin + ${project.basedir}/src/test/java + + + + + maven-compiler-plugin @@ -80,6 +126,32 @@ + + + + default-compile + none + + + + default-testCompile + none + + + java-compile + compile + + compile + + + + java-test-compile + test-compile + + testCompile + + + maven-surefire-plugin diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java index 575230301ae2fb..5c7df46e344b20 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java @@ -16,7 +16,10 @@ import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; import io.quarkus.arc.processor.Annotations; +import io.quarkus.arc.processor.BeanInfo; import io.quarkus.arc.processor.DotNames; +import io.quarkus.arc.processor.KotlinDotNames; +import io.quarkus.arc.processor.KotlinUtils; import io.quarkus.gizmo.BytecodeCreator; import io.quarkus.gizmo.FieldDescriptor; import io.quarkus.gizmo.ResultHandle; @@ -35,15 +38,17 @@ public class Callback { public final Target target; public final String endpointPath; public final AnnotationInstance annotation; + public final BeanInfo bean; public final MethodInfo method; public final ExecutionModel executionModel; public final MessageType messageType; public final List arguments; - public Callback(Target target, AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath, IndexView index) { + public Callback(Target target, AnnotationInstance annotation, BeanInfo bean, MethodInfo method, + ExecutionModel executionModel, CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, IndexView index) { this.target = target; + this.bean = bean; this.method = method; this.annotation = annotation; this.executionModel = executionModel; @@ -104,6 +109,15 @@ public boolean isReturnTypeMulti() { return WebSocketDotNames.MULTI.equals(returnType().name()); } + public boolean isKotlinSuspendFunction() { + return KotlinUtils.isKotlinSuspendMethod(method); + } + + public boolean isKotlinSuspendFunctionReturningUnit() { + return KotlinUtils.isKotlinSuspendMethod(method) + && KotlinUtils.getKotlinSuspendMethodResult(method).name().equals(KotlinDotNames.UNIT); + } + public boolean acceptsMessage() { return messageType != MessageType.NONE; } diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java new file mode 100644 index 00000000000000..fa6eda14b4f14c --- /dev/null +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java @@ -0,0 +1,17 @@ +package io.quarkus.websockets.next.deployment; + +import io.quarkus.arc.processor.KotlinUtils; +import io.quarkus.gizmo.ResultHandle; + +public class KotlinContinuationCallbackArgument implements CallbackArgument { + @Override + public boolean matches(ParameterContext context) { + return KotlinUtils.isKotlinContinuationParameter(context.parameter()); + } + + @Override + public ResultHandle get(InvocationBytecodeContext context) { + // the actual value is provided by the invoker + return context.bytecode().loadNull(); + } +} diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index 4ed22a99856039..0b3d23dfe7627b 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -17,18 +17,21 @@ import java.util.stream.Collectors; import jakarta.enterprise.context.SessionScoped; +import jakarta.enterprise.invoke.Invoker; import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationTransformation; import org.jboss.jandex.AnnotationValue; import org.jboss.jandex.ClassInfo; import org.jboss.jandex.ClassInfo.NestingType; +import org.jboss.jandex.ClassType; import org.jboss.jandex.DotName; import org.jboss.jandex.IndexView; import org.jboss.jandex.MethodInfo; import org.jboss.jandex.PrimitiveType; import org.jboss.jandex.Type; import org.jboss.jandex.Type.Kind; +import org.objectweb.asm.Opcodes; import io.quarkus.arc.deployment.AdditionalBeanBuildItem; import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem; @@ -39,6 +42,7 @@ import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem; import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem; import io.quarkus.arc.deployment.CustomScopeBuildItem; +import io.quarkus.arc.deployment.InvokerFactoryBuildItem; import io.quarkus.arc.deployment.SyntheticBeanBuildItem; import io.quarkus.arc.deployment.SyntheticBeansRuntimeInitBuildItem; import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem; @@ -50,7 +54,11 @@ import io.quarkus.arc.processor.BuiltinScope; import io.quarkus.arc.processor.DotNames; import io.quarkus.arc.processor.InjectionPointInfo; +import io.quarkus.arc.processor.InvokerInfo; +import io.quarkus.arc.processor.KotlinDotNames; +import io.quarkus.arc.processor.KotlinUtils; import io.quarkus.arc.processor.Types; +import io.quarkus.bootstrap.classloading.QuarkusClassLoader; import io.quarkus.deployment.Capabilities; import io.quarkus.deployment.Capability; import io.quarkus.deployment.GeneratedClassGizmoAdaptor; @@ -68,6 +76,7 @@ import io.quarkus.gizmo.CatchBlockCreator; import io.quarkus.gizmo.ClassCreator; import io.quarkus.gizmo.ClassOutput; +import io.quarkus.gizmo.FieldCreator; import io.quarkus.gizmo.FunctionCreator; import io.quarkus.gizmo.MethodCreator; import io.quarkus.gizmo.MethodDescriptor; @@ -108,6 +117,8 @@ import io.quarkus.websockets.next.runtime.WebSocketHttpServerOptionsCustomizer; import io.quarkus.websockets.next.runtime.WebSocketServerRecorder; import io.quarkus.websockets.next.runtime.WebSocketSessionContext; +import io.quarkus.websockets.next.runtime.kotlin.ApplicationCoroutineScope; +import io.quarkus.websockets.next.runtime.kotlin.CoroutineInvoker; import io.smallrye.mutiny.Multi; import io.smallrye.mutiny.Uni; import io.smallrye.mutiny.groups.UniCreate; @@ -190,6 +201,18 @@ void additionalBeans(CombinedIndexBuildItem combinedIndex, BuildProducer additionalBean) { + if (!QuarkusClassLoader.isClassPresentAtRuntime("kotlinx.coroutines.CoroutineScope")) { + return; + } + + additionalBean.produce(AdditionalBeanBuildItem.builder() + .addBeanClass(ApplicationCoroutineScope.class) + .setUnremovable() + .build()); + } + @BuildStep ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuildItem phase) { return new ContextConfiguratorBuildItem(phase.getContext() @@ -211,6 +234,7 @@ void builtinCallbackArguments(BuildProducer providers providers.produce(new CallbackArgumentBuildItem(new HandshakeRequestCallbackArgument())); providers.produce(new CallbackArgumentBuildItem(new ErrorCallbackArgument())); providers.produce(new CallbackArgumentBuildItem(new CloseReasonCallbackArgument())); + providers.produce(new CallbackArgumentBuildItem(new KotlinContinuationCallbackArgument())); } @BuildStep @@ -228,7 +252,7 @@ void collectGlobalErrorHandlers(BeanArchiveIndexBuildItem beanArchiveIndex, ClassInfo beanClass = bean.getTarget().get().asClass(); if (beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET) == null && beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET_CLIENT) == null) { - for (Callback callback : findErrorHandlers(Target.UNDEFINED, index, beanClass, callbackArguments, + for (Callback callback : findErrorHandlers(Target.UNDEFINED, index, bean, beanClass, callbackArguments, transformedAnnotations, null)) { GlobalErrorHandler errorHandler = new GlobalErrorHandler(bean, callback); DotName errorTypeName = callback.argumentType(ErrorCallbackArgument::isError).name(); @@ -327,19 +351,17 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, inboundProcessingMode = webSocketClientAnnotation.value("inboundProcessingMode"); } - Callback onOpen = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN, - callbackArguments, transformedAnnotations, path); - Callback onTextMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, - WebSocketDotNames.ON_TEXT_MESSAGE, - callbackArguments, transformedAnnotations, path); - Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, + Callback onOpen = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_OPEN, callbackArguments, transformedAnnotations, path); + Callback onTextMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_TEXT_MESSAGE, callbackArguments, transformedAnnotations, path); + Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, WebSocketDotNames.ON_BINARY_MESSAGE, callbackArguments, transformedAnnotations, path); - Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass, - WebSocketDotNames.ON_PONG_MESSAGE, - callbackArguments, transformedAnnotations, path, + Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_PONG_MESSAGE, callbackArguments, transformedAnnotations, path, this::validateOnPongMessage); - Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE, - callbackArguments, transformedAnnotations, path, + Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass, + WebSocketDotNames.ON_CLOSE, callbackArguments, transformedAnnotations, path, this::validateOnClose); if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) { throw new WebSocketServerException( @@ -354,7 +376,7 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex, onBinaryMessage, onPongMessage, onClose, - findErrorHandlers(target, index, beanClass, callbackArguments, transformedAnnotations, path))); + findErrorHandlers(target, index, bean, beanClass, callbackArguments, transformedAnnotations, path))); } } @@ -383,6 +405,7 @@ public void generateEndpoints(BeanArchiveIndexBuildItem index, List generatedClasses, BuildProducer generatedEndpoints, BuildProducer reflectiveClasses) { @@ -409,7 +432,8 @@ public String apply(String name) { // and delegates callback invocations to the endpoint bean String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations, index.getIndex(), classOutput, globalErrorHandlers, - endpoint.isClient() ? CLIENT_ENDPOINT_SUFFIX : SERVER_ENDPOINT_SUFFIX); + endpoint.isClient() ? CLIENT_ENDPOINT_SUFFIX : SERVER_ENDPOINT_SUFFIX, + invokerFactory); reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build()); generatedEndpoints .produce(new GeneratedEndpointBuildItem(endpoint.id, endpoint.bean.getImplClazz().name().toString(), @@ -626,9 +650,16 @@ static String getPathPrefix(IndexView index, DotName enclosingClassName) { } private void validateOnPongMessage(Callback callback) { - if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { - throw new WebSocketServerException( - "@OnPongMessage callback must return void or Uni: " + callback.asString()); + if (KotlinUtils.isKotlinMethod(callback.method)) { + if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) { + throw new WebSocketServerException( + "@OnPongMessage callback must return Unit: " + callback.asString()); + } + } else { + if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { + throw new WebSocketServerException( + "@OnPongMessage callback must return void or Uni: " + callback.asString()); + } } Type messageType = callback.argumentType(MessageCallbackArgument::isMessage); if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) { @@ -639,9 +670,16 @@ private void validateOnPongMessage(Callback callback) { } private void validateOnClose(Callback callback) { - if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { - throw new WebSocketServerException( - "@OnClose callback must return void or Uni: " + callback.asString()); + if (KotlinUtils.isKotlinMethod(callback.method)) { + if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) { + throw new WebSocketServerException( + "@OnClose callback must return Unit: " + callback.asString()); + } + } else { + if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) { + throw new WebSocketServerException( + "@OnClose callback must return void or Uni: " + callback.asString()); + } } } @@ -710,7 +748,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, IndexView index, ClassOutput classOutput, GlobalErrorHandlersBuildItem globalErrorHandlers, - String endpointSuffix) { + String endpointSuffix, + InvokerFactoryBuildItem invokerFactory) { ClassInfo implClazz = endpoint.bean.getImplClazz(); String baseName; if (implClazz.enclosingClass() != null) { @@ -733,7 +772,6 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, Codecs.class, ContextSupport.class, SecuritySupport.class), constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1), constructor.getMethodParam(2), constructor.getMethodParam(3)); - constructor.returnNull(); MethodCreator inboundProcessingMode = endpointCreator.getMethodCreator("inboundProcessingMode", InboundProcessingMode.class); @@ -751,7 +789,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, // Call the business method TryBlock tryBlock = onErrorTryBlock(doOnOpen, doOnOpen.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Open", tryBlock, + beanInstance, args, invokerFactory); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); MethodCreator onOpenExecutionModel = endpointCreator.getMethodCreator("onOpenExecutionModel", @@ -759,12 +798,12 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, onOpenExecutionModel.returnValue(onOpenExecutionModel.load(callback.executionModel)); } - generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations, - index, globalErrorHandlers); - generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations, index, - globalErrorHandlers); - generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations, index, - globalErrorHandlers); + generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onBinaryMessage, argumentProviders, + transformedAnnotations, index, globalErrorHandlers, invokerFactory); + generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onTextMessage, argumentProviders, + transformedAnnotations, index, globalErrorHandlers, invokerFactory); + generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onPongMessage, argumentProviders, + transformedAnnotations, index, globalErrorHandlers, invokerFactory); if (endpoint.onClose != null) { Callback callback = endpoint.onClose; @@ -775,7 +814,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, // Call the business method TryBlock tryBlock = onErrorTryBlock(doOnClose, doOnClose.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Close", tryBlock, + beanInstance, args, invokerFactory); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); MethodCreator onCloseExecutionModel = endpointCreator.getMethodCreator("onCloseExecutionModel", @@ -783,15 +823,19 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint, onCloseExecutionModel.returnValue(onCloseExecutionModel.load(callback.executionModel)); } - generateOnError(endpointCreator, endpoint, argumentProviders, transformedAnnotations, globalErrorHandlers, index); + generateOnError(endpointCreator, constructor, endpoint, transformedAnnotations, globalErrorHandlers, index, + invokerFactory); + + // we write into the constructor when generating callback invokers, so need to finish it late + constructor.returnVoid(); endpointCreator.close(); return generatedName.replace('/', '.'); } - private static void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index) { + private static void generateOnError(ClassCreator endpointCreator, MethodCreator constructor, + WebSocketEndpointBuildItem endpoint, TransformedAnnotationsBuildItem transformedAnnotations, + GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index, InvokerFactoryBuildItem invokerFactory) { Map errors = new HashMap<>(); List throwableInfos = new ArrayList<>(); @@ -841,7 +885,8 @@ private static void generateOnError(ClassCreator endpointCreator, WebSocketEndpo MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class), endpointThis, funBytecode.load(throwableInfo.bean().getIdentifier())); ResultHandle[] args = callback.generateArguments(endpointThis, tryBlock, transformedAnnotations, index); - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Error", tryBlock, + beanInstance, args, invokerFactory); encodeAndReturnResult(endpointThis, tryBlock, callback, globalErrorHandlers, endpoint, ret); // return doErrorExecute() @@ -891,9 +936,10 @@ record GlobalErrorHandler(BeanInfo bean, Callback callback) { } - private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback, + private static void generateOnMessage(ClassCreator endpointCreator, MethodCreator constructor, + WebSocketEndpointBuildItem endpoint, Callback callback, CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers) { + IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers, InvokerFactoryBuildItem invokerFactory) { if (callback == null) { return; } @@ -924,8 +970,8 @@ private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEnd MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), tryBlock.getThis()); ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index); // Call the business method - ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, - args); + ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, messageType, tryBlock, beanInstance, args, + invokerFactory); encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret); MethodCreator onMessageExecutionModel = endpointCreator.getMethodCreator("on" + messageType + "MessageExecutionModel", @@ -946,6 +992,30 @@ private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEnd } } + private static ResultHandle callBusinessMethod(ClassCreator clazz, MethodCreator constructor, Callback callback, + String messageType, BytecodeCreator bytecode, ResultHandle beanInstance, ResultHandle[] args, + InvokerFactoryBuildItem invokerFactory) { + if (KotlinUtils.isKotlinSuspendMethod(callback.method)) { + InvokerInfo invoker = invokerFactory.createInvoker(callback.bean, callback.method) + .withInvocationWrapper(CoroutineInvoker.class, "inNewCoroutine") + .build(); + FieldCreator invokerField = clazz.getFieldCreator("invokerFor" + messageType, Invoker.class) + .setModifiers(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL); + constructor.writeInstanceField(invokerField.getFieldDescriptor(), constructor.getThis(), + constructor.newInstance(MethodDescriptor.ofConstructor(invoker.getClassName()))); + ResultHandle invokerHandle = bytecode.readInstanceField(invokerField.getFieldDescriptor(), bytecode.getThis()); + ResultHandle argsArray = bytecode.newArray(Object.class, args.length); + for (int i = 0; i < args.length; i++) { + bytecode.writeArrayValue(argsArray, i, args[i]); + } + return bytecode.invokeInterfaceMethod( + MethodDescriptor.ofMethod(Invoker.class, "invoke", Object.class, Object.class, Object[].class), + invokerHandle, beanInstance, argsArray); + } else { + return bytecode.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args); + } + } + private static TryBlock uniFailureTryBlock(BytecodeCreator method) { TryBlock tryBlock = method.tryBlock(); CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class); @@ -1070,8 +1140,15 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // ---------------------- // === Binary message === // ---------------------- - if (callback.isReturnTypeUni()) { - Type messageType = callback.returnType().asParameterizedType().arguments().get(0); + if (callback.isReturnTypeUni() || callback.isKotlinSuspendFunction()) { + Type messageType = callback.isReturnTypeUni() + ? callback.returnType().asParameterizedType().arguments().get(0) + : KotlinUtils.getKotlinSuspendMethodResult(callback.method); + if (messageType.name().equals(KotlinDotNames.UNIT)) { + value = method.invokeInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "replaceWithVoid", Uni.class), + value); + messageType = ClassType.create(WebSocketDotNames.VOID); + } if (messageType.name().equals(WebSocketDotNames.VOID)) { // Uni return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers); @@ -1082,8 +1159,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // }); FunctionCreator fun = method.createFunction(Function.class); BytecodeCreator funBytecode = fun.getBytecode(); - ResultHandle buffer = encodeBuffer(funBytecode, - callback.returnType().asParameterizedType().arguments().get(0), + ResultHandle buffer = encodeBuffer(funBytecode, messageType, funBytecode.getMethodParam(0), endpointThis, callback); funBytecode.returnValue(funBytecode.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, @@ -1130,8 +1206,15 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // ---------------------- // === Text message === // ---------------------- - if (callback.isReturnTypeUni()) { - Type messageType = callback.returnType().asParameterizedType().arguments().get(0); + if (callback.isReturnTypeUni() || callback.isKotlinSuspendFunction()) { + Type messageType = callback.isReturnTypeUni() + ? callback.returnType().asParameterizedType().arguments().get(0) + : KotlinUtils.getKotlinSuspendMethodResult(callback.method); + if (messageType.name().equals(KotlinDotNames.UNIT)) { + value = method.invokeInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "replaceWithVoid", Uni.class), + value); + messageType = ClassType.create(WebSocketDotNames.VOID); + } if (messageType.name().equals(WebSocketDotNames.VOID)) { // Uni return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers); @@ -1142,7 +1225,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre // }); FunctionCreator fun = method.createFunction(Function.class); BytecodeCreator funBytecode = fun.getBytecode(); - ResultHandle text = encodeText(funBytecode, callback.returnType().asParameterizedType().arguments().get(0), + ResultHandle text = encodeText(funBytecode, messageType, funBytecode.getMethodParam(0), endpointThis, callback); funBytecode.returnValue(funBytecode.invokeVirtualMethod( MethodDescriptor.ofMethod(WebSocketEndpointBase.class, @@ -1267,9 +1350,8 @@ private static void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCre } } - static List findErrorHandlers(Target target, IndexView index, ClassInfo beanClass, - CallbackArgumentsBuildItem callbackArguments, - TransformedAnnotationsBuildItem transformedAnnotations, + static List findErrorHandlers(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass, + CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { List annotations = findCallbackAnnotations(index, beanClass, WebSocketDotNames.ON_ERROR); if (annotations.isEmpty()) { @@ -1285,8 +1367,9 @@ static List findErrorHandlers(Target target, IndexView index, ClassInf .anyMatch(WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION::equals)) { target = Target.CLIENT; } - Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations), - callbackArguments, transformedAnnotations, endpointPath, index); + Callback callback = new Callback(target, annotation, bean, method, + executionModel(method, transformedAnnotations), callbackArguments, transformedAnnotations, + endpointPath, index); long errorArguments = callback.arguments.stream().filter(ca -> ca instanceof ErrorCallbackArgument).count(); if (errorArguments != 1) { throw new WebSocketException( @@ -1316,16 +1399,16 @@ private static List findCallbackAnnotations(IndexView index, return annotations; } - static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath) { - return findCallback(target, index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath, - null); + static Callback findCallback(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass, + DotName annotationName, CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) { + return findCallback(target, index, bean, beanClass, annotationName, callbackArguments, + transformedAnnotations, endpointPath, null); } - private static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName, - CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations, - String endpointPath, + private static Callback findCallback(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass, + DotName annotationName, CallbackArgumentsBuildItem callbackArguments, + TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, Consumer validator) { List annotations = findCallbackAnnotations(index, beanClass, annotationName); if (annotations.isEmpty()) { @@ -1333,9 +1416,9 @@ private static Callback findCallback(Target target, IndexView index, ClassInfo b } else if (annotations.size() == 1) { AnnotationInstance annotation = annotations.get(0); MethodInfo method = annotation.target().asMethod(); - Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations), - callbackArguments, - transformedAnnotations, endpointPath, index); + Callback callback = new Callback(target, annotation, bean, method, + executionModel(method, transformedAnnotations), callbackArguments, transformedAnnotations, + endpointPath, index); long messageArguments = callback.arguments.stream().filter(ca -> ca instanceof MessageCallbackArgument).count(); if (callback.acceptsMessage()) { if (messageArguments > 1) { @@ -1382,6 +1465,10 @@ private static ExecutionModel executionModel(MethodInfo method, TransformedAnnot } static boolean hasBlockingSignature(MethodInfo method) { + if (KotlinUtils.isKotlinSuspendMethod(method)) { + return false; + } + switch (method.returnType().kind()) { case VOID: case CLASS: @@ -1414,6 +1501,8 @@ private static boolean isOnOpenWithBinaryReturnType(Callback callback) { Type returnType = callback.returnType(); if (callback.isReturnTypeUni() || callback.isReturnTypeMulti()) { returnType = callback.returnType().asParameterizedType().arguments().get(0); + } else if (callback.isKotlinSuspendFunction()) { + returnType = KotlinUtils.getKotlinSuspendMethodResult(callback.method); } return WebSocketDotNames.BUFFER.equals(returnType.name()) || (returnType.kind() == Kind.ARRAY && PrimitiveType.BYTE.equals(returnType.asArrayType().constituent())); diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java index 2f1974089db901..926c0d1b82d153 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java @@ -144,6 +144,13 @@ public Buffer sendAndAwaitReply(String message) { return messages.get(c); } + public Buffer sendAndAwaitReply(Buffer message) { + var c = messages.size(); + sendAndAwait(message); + Awaitility.await().until(() -> messages.size() > c); + return messages.get(c); + } + public boolean isClosed() { return socket.get().isClosed(); } diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt new file mode 100644 index 00000000000000..a42359791fd386 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt @@ -0,0 +1,14 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnBinaryMessage +import io.quarkus.websockets.next.WebSocket +import io.vertx.core.buffer.Buffer +import kotlinx.coroutines.delay + +@WebSocket(path = "/binary-echo") +class BinaryEcho { + @OnBinaryMessage + fun process(msg: Buffer): Buffer { + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt new file mode 100644 index 00000000000000..a2182d7607535e --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt @@ -0,0 +1,15 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnBinaryMessage +import io.quarkus.websockets.next.WebSocket +import io.vertx.core.buffer.Buffer +import kotlinx.coroutines.delay + +@WebSocket(path = "/binary-echo-suspend") +class BinaryEchoSuspend { + @OnBinaryMessage + suspend fun process(msg: Buffer): Buffer { + delay(100) + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt new file mode 100644 index 00000000000000..fc60701b4991be --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt @@ -0,0 +1,12 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket + +@WebSocket(path = "/echo") +class Echo { + @OnTextMessage + fun process(msg: Message): Message { + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt new file mode 100644 index 00000000000000..92d27b2d9d85bb --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt @@ -0,0 +1,14 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket +import kotlinx.coroutines.delay + +@WebSocket(path = "/echo-suspend") +class EchoSuspend { + @OnTextMessage + suspend fun process(msg: Message): Message { + delay(100) + return msg + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt new file mode 100644 index 00000000000000..9ec2cd00baf2c3 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt @@ -0,0 +1,93 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.test.QuarkusUnitTest +import io.quarkus.test.common.http.TestHTTPResource +import io.quarkus.websockets.next.OnClose +import io.quarkus.websockets.next.OnOpen +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket +import io.quarkus.websockets.next.WebSocketClient +import io.quarkus.websockets.next.WebSocketConnector +import jakarta.inject.Inject +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.net.URI +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +class KotlinWebSocketClientTest { + companion object { + @RegisterExtension + val test = QuarkusUnitTest() + .withApplicationRoot { jar -> + jar.addClasses(ServerEndpoint::class.java, ClientEndpoint::class.java) + } + } + + @Inject + lateinit var connector: WebSocketConnector + + @TestHTTPResource("/") + lateinit var uri: URI + + @Test + fun test() { + val connection = connector.baseUri(uri).connectAndAwait() + connection.sendTextAndAwait("Hi!") + + assertTrue(ClientEndpoint.messagesLatch.await(5, TimeUnit.SECONDS)) + assertEquals("Hello there", ClientEndpoint.messages[0]) + assertEquals("Hi!", ClientEndpoint.messages[1]) + + connection.closeAndAwait() + assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + } + + @WebSocket(path = "/endpoint") + class ServerEndpoint { + companion object { + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnOpen + fun open(): String { + return "Hello there" + } + + @OnTextMessage + fun echo(message: String): String { + return message + } + + @OnClose + fun close() { + closedLatch.countDown() + } + } + + @WebSocketClient(path = "/endpoint") + class ClientEndpoint { + companion object { + val messages: MutableList = CopyOnWriteArrayList() + + val messagesLatch: CountDownLatch = CountDownLatch(2) + + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnTextMessage + fun onMessage(message: String) { + messages.add(message) + messagesLatch.countDown() + } + + @OnClose + fun onClose() { + closedLatch.countDown() + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt new file mode 100644 index 00000000000000..9c4a8fdc9051a9 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt @@ -0,0 +1,99 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.test.QuarkusUnitTest +import io.quarkus.test.common.http.TestHTTPResource +import io.quarkus.websockets.next.OnClose +import io.quarkus.websockets.next.OnOpen +import io.quarkus.websockets.next.OnTextMessage +import io.quarkus.websockets.next.WebSocket +import io.quarkus.websockets.next.WebSocketClient +import io.quarkus.websockets.next.WebSocketConnector +import jakarta.inject.Inject +import kotlinx.coroutines.delay +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.net.URI +import java.util.concurrent.CopyOnWriteArrayList +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit + +class KotlinWebSocketSuspendingClientTest { + companion object { + @RegisterExtension + val test = QuarkusUnitTest() + .withApplicationRoot { jar -> + jar.addClasses(ServerEndpoint::class.java, ClientEndpoint::class.java) + } + } + + @Inject + lateinit var connector: WebSocketConnector + + @TestHTTPResource("/") + lateinit var uri: URI + + @Test + fun test() { + val connection = connector.baseUri(uri).connectAndAwait() + connection.sendTextAndAwait("Hi!") + + assertTrue(ClientEndpoint.messagesLatch.await(5, TimeUnit.SECONDS)) + assertEquals("Hello there", ClientEndpoint.messages[0]) + assertEquals("Hi!", ClientEndpoint.messages[1]) + + connection.closeAndAwait() + assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS)) + } + + @WebSocket(path = "/endpoint") + class ServerEndpoint { + companion object { + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnOpen + suspend fun open(): String { + delay(100) + return "Hello there" + } + + @OnTextMessage + suspend fun echo(message: String): String { + delay(100) + return message + } + + @OnClose + suspend fun close() { + delay(100) + closedLatch.countDown() + } + } + + @WebSocketClient(path = "/endpoint") + class ClientEndpoint { + companion object { + val messages: MutableList = CopyOnWriteArrayList() + + val messagesLatch: CountDownLatch = CountDownLatch(2) + + val closedLatch: CountDownLatch = CountDownLatch(1) + } + + @OnTextMessage + suspend fun onMessage(message: String) { + delay(100) + messages.add(message) + messagesLatch.countDown() + } + + @OnClose + suspend fun onClose() { + delay(100) + closedLatch.countDown() + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt new file mode 100644 index 00000000000000..12dd973ee9848e --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt @@ -0,0 +1,75 @@ +package io.quarkus.websockets.next.test.kotlin + +import io.quarkus.test.QuarkusUnitTest +import io.quarkus.test.common.http.TestHTTPResource +import io.quarkus.websockets.next.test.utils.WSClient +import io.vertx.core.Vertx +import io.vertx.core.buffer.Buffer +import io.vertx.core.json.JsonObject +import jakarta.inject.Inject +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.net.URI + +class KotlinWebSocketTest { + companion object { + @RegisterExtension + val test = QuarkusUnitTest() + .withApplicationRoot { jar -> + jar.addClasses(Echo::class.java, EchoSuspend::class.java, BinaryEcho::class.java, + BinaryEchoSuspend::class.java, Message::class.java, WSClient::class.java) + } + } + + @Inject + lateinit var vertx: Vertx + + @TestHTTPResource("echo") + lateinit var echo: URI + + @TestHTTPResource("echo-suspend") + lateinit var echoSuspend: URI + + @TestHTTPResource("binary-echo") + lateinit var binaryEcho: URI + + @TestHTTPResource("binary-echo-suspend") + lateinit var binaryEchoSuspend: URI + + @Test + fun testEcho() { + doTest(echo) + } + + @Test + fun testEchoSuspend() { + doTest(echoSuspend) + } + + private fun doTest(uri: URI) { + WSClient.create(vertx).connect(uri).use { client -> + val req = JsonObject().put("msg", "hello") + val resp = client.sendAndAwaitReply(req.toString()).toJsonObject() + assertEquals(req, resp) + } + } + + @Test + fun testBinaryEcho() { + doTestBinary(binaryEcho) + } + + @Test + fun testBinaryEchoSuspend() { + doTestBinary(binaryEchoSuspend) + } + + private fun doTestBinary(uri: URI) { + WSClient.create(vertx).connect(uri).use { client -> + val req = Buffer.buffer("hello there!") + val resp = client.sendAndAwaitReply(req) + assertEquals(req, resp) + } + } +} diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt new file mode 100644 index 00000000000000..96563b27e7382d --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt @@ -0,0 +1,5 @@ +package io.quarkus.websockets.next.test.kotlin + +import com.fasterxml.jackson.annotation.JsonCreator + +data class Message @JsonCreator constructor(var msg: String) diff --git a/extensions/websockets-next/kotlin/pom.xml b/extensions/websockets-next/kotlin/pom.xml new file mode 100644 index 00000000000000..bda3816be3a85e --- /dev/null +++ b/extensions/websockets-next/kotlin/pom.xml @@ -0,0 +1,115 @@ + + + + io.quarkus + quarkus-websockets-next-parent + 999-SNAPSHOT + + 4.0.0 + + quarkus-websockets-next-kotlin + Quarkus - WebSockets Next - Kotlin + + + + io.quarkus + quarkus-arc + + + io.quarkus + quarkus-vertx + + + org.jetbrains.kotlin + kotlin-stdlib-jdk8 + true + + + org.jetbrains.kotlinx + kotlinx-coroutines-jdk8 + true + + + io.smallrye.reactive + mutiny-kotlin + true + + + + + ${project.basedir}/src/main/kotlin + ${project.basedir}/src/test/kotlin + + + org.jetbrains.kotlin + kotlin-maven-plugin + ${kotlin.version} + + + compile + + compile + + + + test-compile + + test-compile + + + + + ${maven.compiler.target} + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + io.quarkus + quarkus-extension-processor + ${project.version} + + + + + + + default-compile + none + + + + default-testCompile + none + + + java-compile + compile + + compile + + + + java-test-compile + test-compile + + testCompile + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + false + + + + + diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt new file mode 100644 index 00000000000000..78ee39b9247065 --- /dev/null +++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt @@ -0,0 +1,18 @@ +package io.quarkus.websockets.next.runtime.kotlin + +import jakarta.annotation.PreDestroy +import jakarta.inject.Singleton +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel + +@Singleton +class ApplicationCoroutineScope : CoroutineScope, AutoCloseable { + override val coroutineContext: CoroutineContext = SupervisorJob() + + @PreDestroy + override fun close() { + coroutineContext.cancel() + } +} diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt new file mode 100644 index 00000000000000..5a6b6f8e28a772 --- /dev/null +++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt @@ -0,0 +1,37 @@ +package io.quarkus.websockets.next.runtime.kotlin + +import io.quarkus.arc.Arc +import io.smallrye.mutiny.Uni +import io.smallrye.mutiny.coroutines.asUni +import io.vertx.core.Vertx +import jakarta.enterprise.invoke.Invoker +import kotlin.coroutines.resume +import kotlin.coroutines.resumeWithException +import kotlin.coroutines.suspendCoroutine +import kotlinx.coroutines.CoroutineDispatcher +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.async + +object CoroutineInvoker { + @JvmStatic + @OptIn(ExperimentalCoroutinesApi::class) + fun inNewCoroutine(instance: T, arguments: Array, invoker: Invoker): Uni { + val coroutineScope = Arc.container().instance(ApplicationCoroutineScope::class.java).get() + val dispatcher: CoroutineDispatcher = + Vertx.currentContext()?.let(::VertxDispatcher) + ?: throw IllegalStateException("No Vertx context found") + + return coroutineScope + .async(context = dispatcher) { + suspendCoroutine { continuation -> + arguments[arguments.size - 1] = continuation + try { + continuation.resume(invoker.invoke(instance, arguments)) + } catch (e: Exception) { + continuation.resumeWithException(e) + } + } + } + .asUni() + } +} diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt new file mode 100644 index 00000000000000..1377b1835fe6c2 --- /dev/null +++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt @@ -0,0 +1,24 @@ +package io.quarkus.websockets.next.runtime.kotlin + +import io.quarkus.arc.Arc +import io.vertx.core.Context +import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineDispatcher + +class VertxDispatcher(private val vertxContext: Context) : CoroutineDispatcher() { + override fun dispatch(context: CoroutineContext, block: Runnable) { + val requestContext = Arc.container().requestContext() + vertxContext.runOnContext { + if (requestContext.isActive) { + block.run() + } else { + try { + requestContext.activate() + block.run() + } finally { + requestContext.terminate() + } + } + } + } +} diff --git a/extensions/websockets-next/pom.xml b/extensions/websockets-next/pom.xml index 5ff9318b3e7658..1e149d244734fa 100644 --- a/extensions/websockets-next/pom.xml +++ b/extensions/websockets-next/pom.xml @@ -16,6 +16,7 @@ deployment runtime + kotlin diff --git a/extensions/websockets-next/runtime/pom.xml b/extensions/websockets-next/runtime/pom.xml index a72d7832edd42f..de16a0cfaacaa1 100644 --- a/extensions/websockets-next/runtime/pom.xml +++ b/extensions/websockets-next/runtime/pom.xml @@ -34,6 +34,10 @@ io.quarkus quarkus-tls-registry + + io.quarkus + quarkus-websockets-next-kotlin + io.quarkus.security diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java index cef74e824cf2d9..37ab47604ba2dd 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java @@ -2,8 +2,9 @@ import org.jboss.jandex.DotName; -class KotlinDotNames { - static final DotName METADATA = DotName.createSimple("kotlin.Metadata"); +public final class KotlinDotNames { + public static final DotName METADATA = DotName.createSimple("kotlin.Metadata"); + public static final DotName UNIT = DotName.createSimple("kotlin.Unit"); - static final DotName CONTINUATION = DotName.createSimple("kotlin.coroutines.Continuation"); + public static final DotName CONTINUATION = DotName.createSimple("kotlin.coroutines.Continuation"); } diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java index 1628a83b3adefb..966a7fc2c57abf 100644 --- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java +++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java @@ -4,6 +4,7 @@ import org.jboss.jandex.ClassInfo; import org.jboss.jandex.MethodInfo; +import org.jboss.jandex.MethodParameterInfo; import org.jboss.jandex.Type; public class KotlinUtils { @@ -27,6 +28,10 @@ public static boolean isKotlinSuspendMethod(MethodInfo method) { return KotlinDotNames.CONTINUATION.equals(lastParameter.name()); } + public static boolean isKotlinContinuationParameter(MethodParameterInfo parameter) { + return isKotlinSuspendMethod(parameter.method()) && KotlinDotNames.CONTINUATION.equals(parameter.type().name()); + } + public static boolean isNoninterceptableKotlinMethod(MethodInfo method) { // the Kotlin compiler generates somewhat streamlined bytecode when it determines // that a `suspend` method cannot be overridden, and that bytecode doesn't work @@ -36,4 +41,24 @@ public static boolean isNoninterceptableKotlinMethod(MethodInfo method) { return isKotlinSuspendMethod(method) && (Modifier.isFinal(method.flags()) || Modifier.isFinal(method.declaringClass().flags())); } + + public static Type getKotlinSuspendMethodResult(MethodInfo method) { + if (!isKotlinSuspendMethod(method)) { + throw new IllegalArgumentException("Not a suspend function: " + method); + } + + Type lastParameter = method.parameterType(method.parametersCount() - 1); + if (lastParameter.kind() != Type.Kind.PARAMETERIZED_TYPE) { + throw new IllegalArgumentException("Continuation parameter type not parameterized: " + lastParameter); + } + Type resultType = lastParameter.asParameterizedType().arguments().get(0); + if (resultType.kind() != Type.Kind.WILDCARD_TYPE) { + throw new IllegalArgumentException("Continuation parameter type argument not wildcard: " + resultType); + } + Type lowerBound = resultType.asWildcardType().superBound(); + if (lowerBound == null) { + throw new IllegalArgumentException("Continuation parameter type argument without lower bound: " + resultType); + } + return lowerBound; + } }