diff --git a/bom/application/pom.xml b/bom/application/pom.xml
index 6a4cbeaf71ea30..8ab4c8af1ebf47 100644
--- a/bom/application/pom.xml
+++ b/bom/application/pom.xml
@@ -2166,6 +2166,11 @@
quarkus-websockets-next-deployment
${project.version}
+
+ io.quarkus
+ quarkus-websockets-next-kotlin
+ ${project.version}
+
io.quarkus
quarkus-undertow-spi
diff --git a/docs/src/main/asciidoc/kotlin.adoc b/docs/src/main/asciidoc/kotlin.adoc
index 84e2e61be86fd3..2ff3227ba340aa 100644
--- a/docs/src/main/asciidoc/kotlin.adoc
+++ b/docs/src/main/asciidoc/kotlin.adoc
@@ -504,6 +504,9 @@ The following extensions provide support for Kotlin Coroutines by allowing the u
|`quarkus-vertx`
|Support is provided for `@ConsumeEvent` methods
+|`quarkus-websockets-next`
+|Support is provided for server-side and client-side endpoint methods
+
|===
=== Kotlin coroutines and Mutiny
diff --git a/extensions/websockets-next/deployment/pom.xml b/extensions/websockets-next/deployment/pom.xml
index 03928d9cbb66e2..3b62c8108947ac 100644
--- a/extensions/websockets-next/deployment/pom.xml
+++ b/extensions/websockets-next/deployment/pom.xml
@@ -65,10 +65,56 @@
smallrye-certificate-generator-junit5
test
+
+ org.jetbrains.kotlin
+ kotlin-stdlib
+ test
+
+
+ org.jetbrains.kotlinx
+ kotlinx-coroutines-core
+ test
+
+
+ io.smallrye.reactive
+ mutiny-kotlin
+ test
+
+
+ org.jetbrains.kotlin
+ kotlin-maven-plugin
+ ${kotlin.version}
+
+
+ compile
+
+ compile
+
+
+
+ ${project.basedir}/src/main/kotlin
+ ${project.basedir}/src/main/java
+
+
+
+
+ test-compile
+
+ test-compile
+
+
+
+ ${project.basedir}/src/test/kotlin
+ ${project.basedir}/src/test/java
+
+
+
+
+
maven-compiler-plugin
@@ -80,6 +126,32 @@
+
+
+
+ default-compile
+ none
+
+
+
+ default-testCompile
+ none
+
+
+ java-compile
+ compile
+
+ compile
+
+
+
+ java-test-compile
+ test-compile
+
+ testCompile
+
+
+
maven-surefire-plugin
diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java
index 575230301ae2fb..5c7df46e344b20 100644
--- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java
+++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/Callback.java
@@ -16,7 +16,10 @@
import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem;
import io.quarkus.arc.processor.Annotations;
+import io.quarkus.arc.processor.BeanInfo;
import io.quarkus.arc.processor.DotNames;
+import io.quarkus.arc.processor.KotlinDotNames;
+import io.quarkus.arc.processor.KotlinUtils;
import io.quarkus.gizmo.BytecodeCreator;
import io.quarkus.gizmo.FieldDescriptor;
import io.quarkus.gizmo.ResultHandle;
@@ -35,15 +38,17 @@ public class Callback {
public final Target target;
public final String endpointPath;
public final AnnotationInstance annotation;
+ public final BeanInfo bean;
public final MethodInfo method;
public final ExecutionModel executionModel;
public final MessageType messageType;
public final List arguments;
- public Callback(Target target, AnnotationInstance annotation, MethodInfo method, ExecutionModel executionModel,
- CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
- String endpointPath, IndexView index) {
+ public Callback(Target target, AnnotationInstance annotation, BeanInfo bean, MethodInfo method,
+ ExecutionModel executionModel, CallbackArgumentsBuildItem callbackArguments,
+ TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath, IndexView index) {
this.target = target;
+ this.bean = bean;
this.method = method;
this.annotation = annotation;
this.executionModel = executionModel;
@@ -104,6 +109,15 @@ public boolean isReturnTypeMulti() {
return WebSocketDotNames.MULTI.equals(returnType().name());
}
+ public boolean isKotlinSuspendFunction() {
+ return KotlinUtils.isKotlinSuspendMethod(method);
+ }
+
+ public boolean isKotlinSuspendFunctionReturningUnit() {
+ return KotlinUtils.isKotlinSuspendMethod(method)
+ && KotlinUtils.getKotlinSuspendMethodResult(method).name().equals(KotlinDotNames.UNIT);
+ }
+
public boolean acceptsMessage() {
return messageType != MessageType.NONE;
}
diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java
new file mode 100644
index 00000000000000..fa6eda14b4f14c
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/KotlinContinuationCallbackArgument.java
@@ -0,0 +1,17 @@
+package io.quarkus.websockets.next.deployment;
+
+import io.quarkus.arc.processor.KotlinUtils;
+import io.quarkus.gizmo.ResultHandle;
+
+public class KotlinContinuationCallbackArgument implements CallbackArgument {
+ @Override
+ public boolean matches(ParameterContext context) {
+ return KotlinUtils.isKotlinContinuationParameter(context.parameter());
+ }
+
+ @Override
+ public ResultHandle get(InvocationBytecodeContext context) {
+ // the actual value is provided by the invoker
+ return context.bytecode().loadNull();
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java
index 4ed22a99856039..0b3d23dfe7627b 100644
--- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java
+++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java
@@ -17,18 +17,21 @@
import java.util.stream.Collectors;
import jakarta.enterprise.context.SessionScoped;
+import jakarta.enterprise.invoke.Invoker;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTransformation;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.ClassInfo.NestingType;
+import org.jboss.jandex.ClassType;
import org.jboss.jandex.DotName;
import org.jboss.jandex.IndexView;
import org.jboss.jandex.MethodInfo;
import org.jboss.jandex.PrimitiveType;
import org.jboss.jandex.Type;
import org.jboss.jandex.Type.Kind;
+import org.objectweb.asm.Opcodes;
import io.quarkus.arc.deployment.AdditionalBeanBuildItem;
import io.quarkus.arc.deployment.AnnotationsTransformerBuildItem;
@@ -39,6 +42,7 @@
import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem;
import io.quarkus.arc.deployment.ContextRegistrationPhaseBuildItem.ContextConfiguratorBuildItem;
import io.quarkus.arc.deployment.CustomScopeBuildItem;
+import io.quarkus.arc.deployment.InvokerFactoryBuildItem;
import io.quarkus.arc.deployment.SyntheticBeanBuildItem;
import io.quarkus.arc.deployment.SyntheticBeansRuntimeInitBuildItem;
import io.quarkus.arc.deployment.TransformedAnnotationsBuildItem;
@@ -50,7 +54,11 @@
import io.quarkus.arc.processor.BuiltinScope;
import io.quarkus.arc.processor.DotNames;
import io.quarkus.arc.processor.InjectionPointInfo;
+import io.quarkus.arc.processor.InvokerInfo;
+import io.quarkus.arc.processor.KotlinDotNames;
+import io.quarkus.arc.processor.KotlinUtils;
import io.quarkus.arc.processor.Types;
+import io.quarkus.bootstrap.classloading.QuarkusClassLoader;
import io.quarkus.deployment.Capabilities;
import io.quarkus.deployment.Capability;
import io.quarkus.deployment.GeneratedClassGizmoAdaptor;
@@ -68,6 +76,7 @@
import io.quarkus.gizmo.CatchBlockCreator;
import io.quarkus.gizmo.ClassCreator;
import io.quarkus.gizmo.ClassOutput;
+import io.quarkus.gizmo.FieldCreator;
import io.quarkus.gizmo.FunctionCreator;
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
@@ -108,6 +117,8 @@
import io.quarkus.websockets.next.runtime.WebSocketHttpServerOptionsCustomizer;
import io.quarkus.websockets.next.runtime.WebSocketServerRecorder;
import io.quarkus.websockets.next.runtime.WebSocketSessionContext;
+import io.quarkus.websockets.next.runtime.kotlin.ApplicationCoroutineScope;
+import io.quarkus.websockets.next.runtime.kotlin.CoroutineInvoker;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.groups.UniCreate;
@@ -190,6 +201,18 @@ void additionalBeans(CombinedIndexBuildItem combinedIndex, BuildProducer additionalBean) {
+ if (!QuarkusClassLoader.isClassPresentAtRuntime("kotlinx.coroutines.CoroutineScope")) {
+ return;
+ }
+
+ additionalBean.produce(AdditionalBeanBuildItem.builder()
+ .addBeanClass(ApplicationCoroutineScope.class)
+ .setUnremovable()
+ .build());
+ }
+
@BuildStep
ContextConfiguratorBuildItem registerSessionContext(ContextRegistrationPhaseBuildItem phase) {
return new ContextConfiguratorBuildItem(phase.getContext()
@@ -211,6 +234,7 @@ void builtinCallbackArguments(BuildProducer providers
providers.produce(new CallbackArgumentBuildItem(new HandshakeRequestCallbackArgument()));
providers.produce(new CallbackArgumentBuildItem(new ErrorCallbackArgument()));
providers.produce(new CallbackArgumentBuildItem(new CloseReasonCallbackArgument()));
+ providers.produce(new CallbackArgumentBuildItem(new KotlinContinuationCallbackArgument()));
}
@BuildStep
@@ -228,7 +252,7 @@ void collectGlobalErrorHandlers(BeanArchiveIndexBuildItem beanArchiveIndex,
ClassInfo beanClass = bean.getTarget().get().asClass();
if (beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET) == null
&& beanClass.declaredAnnotation(WebSocketDotNames.WEB_SOCKET_CLIENT) == null) {
- for (Callback callback : findErrorHandlers(Target.UNDEFINED, index, beanClass, callbackArguments,
+ for (Callback callback : findErrorHandlers(Target.UNDEFINED, index, bean, beanClass, callbackArguments,
transformedAnnotations, null)) {
GlobalErrorHandler errorHandler = new GlobalErrorHandler(bean, callback);
DotName errorTypeName = callback.argumentType(ErrorCallbackArgument::isError).name();
@@ -327,19 +351,17 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
inboundProcessingMode = webSocketClientAnnotation.value("inboundProcessingMode");
}
- Callback onOpen = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN,
- callbackArguments, transformedAnnotations, path);
- Callback onTextMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass,
- WebSocketDotNames.ON_TEXT_MESSAGE,
- callbackArguments, transformedAnnotations, path);
- Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass,
+ Callback onOpen = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
+ WebSocketDotNames.ON_OPEN, callbackArguments, transformedAnnotations, path);
+ Callback onTextMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
+ WebSocketDotNames.ON_TEXT_MESSAGE, callbackArguments, transformedAnnotations, path);
+ Callback onBinaryMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
WebSocketDotNames.ON_BINARY_MESSAGE, callbackArguments, transformedAnnotations, path);
- Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), beanClass,
- WebSocketDotNames.ON_PONG_MESSAGE,
- callbackArguments, transformedAnnotations, path,
+ Callback onPongMessage = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
+ WebSocketDotNames.ON_PONG_MESSAGE, callbackArguments, transformedAnnotations, path,
this::validateOnPongMessage);
- Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_CLOSE,
- callbackArguments, transformedAnnotations, path,
+ Callback onClose = findCallback(target, beanArchiveIndex.getIndex(), bean, beanClass,
+ WebSocketDotNames.ON_CLOSE, callbackArguments, transformedAnnotations, path,
this::validateOnClose);
if (onOpen == null && onTextMessage == null && onBinaryMessage == null && onPongMessage == null) {
throw new WebSocketServerException(
@@ -354,7 +376,7 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
onBinaryMessage,
onPongMessage,
onClose,
- findErrorHandlers(target, index, beanClass, callbackArguments, transformedAnnotations, path)));
+ findErrorHandlers(target, index, bean, beanClass, callbackArguments, transformedAnnotations, path)));
}
}
@@ -383,6 +405,7 @@ public void generateEndpoints(BeanArchiveIndexBuildItem index, List generatedClasses,
BuildProducer generatedEndpoints,
BuildProducer reflectiveClasses) {
@@ -409,7 +432,8 @@ public String apply(String name) {
// and delegates callback invocations to the endpoint bean
String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations,
index.getIndex(), classOutput, globalErrorHandlers,
- endpoint.isClient() ? CLIENT_ENDPOINT_SUFFIX : SERVER_ENDPOINT_SUFFIX);
+ endpoint.isClient() ? CLIENT_ENDPOINT_SUFFIX : SERVER_ENDPOINT_SUFFIX,
+ invokerFactory);
reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build());
generatedEndpoints
.produce(new GeneratedEndpointBuildItem(endpoint.id, endpoint.bean.getImplClazz().name().toString(),
@@ -626,9 +650,16 @@ static String getPathPrefix(IndexView index, DotName enclosingClassName) {
}
private void validateOnPongMessage(Callback callback) {
- if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) {
- throw new WebSocketServerException(
- "@OnPongMessage callback must return void or Uni: " + callback.asString());
+ if (KotlinUtils.isKotlinMethod(callback.method)) {
+ if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) {
+ throw new WebSocketServerException(
+ "@OnPongMessage callback must return Unit: " + callback.asString());
+ }
+ } else {
+ if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) {
+ throw new WebSocketServerException(
+ "@OnPongMessage callback must return void or Uni: " + callback.asString());
+ }
}
Type messageType = callback.argumentType(MessageCallbackArgument::isMessage);
if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) {
@@ -639,9 +670,16 @@ private void validateOnPongMessage(Callback callback) {
}
private void validateOnClose(Callback callback) {
- if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) {
- throw new WebSocketServerException(
- "@OnClose callback must return void or Uni: " + callback.asString());
+ if (KotlinUtils.isKotlinMethod(callback.method)) {
+ if (!callback.isReturnTypeVoid() && !callback.isKotlinSuspendFunctionReturningUnit()) {
+ throw new WebSocketServerException(
+ "@OnClose callback must return Unit: " + callback.asString());
+ }
+ } else {
+ if (callback.returnType().kind() != Kind.VOID && !WebSocketProcessor.isUniVoid(callback.returnType())) {
+ throw new WebSocketServerException(
+ "@OnClose callback must return void or Uni: " + callback.asString());
+ }
}
}
@@ -710,7 +748,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
IndexView index,
ClassOutput classOutput,
GlobalErrorHandlersBuildItem globalErrorHandlers,
- String endpointSuffix) {
+ String endpointSuffix,
+ InvokerFactoryBuildItem invokerFactory) {
ClassInfo implClazz = endpoint.bean.getImplClazz();
String baseName;
if (implClazz.enclosingClass() != null) {
@@ -733,7 +772,6 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
Codecs.class, ContextSupport.class, SecuritySupport.class),
constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1),
constructor.getMethodParam(2), constructor.getMethodParam(3));
- constructor.returnNull();
MethodCreator inboundProcessingMode = endpointCreator.getMethodCreator("inboundProcessingMode",
InboundProcessingMode.class);
@@ -751,7 +789,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
// Call the business method
TryBlock tryBlock = onErrorTryBlock(doOnOpen, doOnOpen.getThis());
ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
- ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args);
+ ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Open", tryBlock,
+ beanInstance, args, invokerFactory);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret);
MethodCreator onOpenExecutionModel = endpointCreator.getMethodCreator("onOpenExecutionModel",
@@ -759,12 +798,12 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
onOpenExecutionModel.returnValue(onOpenExecutionModel.load(callback.executionModel));
}
- generateOnMessage(endpointCreator, endpoint, endpoint.onBinaryMessage, argumentProviders, transformedAnnotations,
- index, globalErrorHandlers);
- generateOnMessage(endpointCreator, endpoint, endpoint.onTextMessage, argumentProviders, transformedAnnotations, index,
- globalErrorHandlers);
- generateOnMessage(endpointCreator, endpoint, endpoint.onPongMessage, argumentProviders, transformedAnnotations, index,
- globalErrorHandlers);
+ generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onBinaryMessage, argumentProviders,
+ transformedAnnotations, index, globalErrorHandlers, invokerFactory);
+ generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onTextMessage, argumentProviders,
+ transformedAnnotations, index, globalErrorHandlers, invokerFactory);
+ generateOnMessage(endpointCreator, constructor, endpoint, endpoint.onPongMessage, argumentProviders,
+ transformedAnnotations, index, globalErrorHandlers, invokerFactory);
if (endpoint.onClose != null) {
Callback callback = endpoint.onClose;
@@ -775,7 +814,8 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
// Call the business method
TryBlock tryBlock = onErrorTryBlock(doOnClose, doOnClose.getThis());
ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
- ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args);
+ ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Close", tryBlock,
+ beanInstance, args, invokerFactory);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret);
MethodCreator onCloseExecutionModel = endpointCreator.getMethodCreator("onCloseExecutionModel",
@@ -783,15 +823,19 @@ static String generateEndpoint(WebSocketEndpointBuildItem endpoint,
onCloseExecutionModel.returnValue(onCloseExecutionModel.load(callback.executionModel));
}
- generateOnError(endpointCreator, endpoint, argumentProviders, transformedAnnotations, globalErrorHandlers, index);
+ generateOnError(endpointCreator, constructor, endpoint, transformedAnnotations, globalErrorHandlers, index,
+ invokerFactory);
+
+ // we write into the constructor when generating callback invokers, so need to finish it late
+ constructor.returnVoid();
endpointCreator.close();
return generatedName.replace('/', '.');
}
- private static void generateOnError(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint,
- CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
- GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index) {
+ private static void generateOnError(ClassCreator endpointCreator, MethodCreator constructor,
+ WebSocketEndpointBuildItem endpoint, TransformedAnnotationsBuildItem transformedAnnotations,
+ GlobalErrorHandlersBuildItem globalErrorHandlers, IndexView index, InvokerFactoryBuildItem invokerFactory) {
Map errors = new HashMap<>();
List throwableInfos = new ArrayList<>();
@@ -841,7 +885,8 @@ private static void generateOnError(ClassCreator endpointCreator, WebSocketEndpo
MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class, String.class),
endpointThis, funBytecode.load(throwableInfo.bean().getIdentifier()));
ResultHandle[] args = callback.generateArguments(endpointThis, tryBlock, transformedAnnotations, index);
- ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args);
+ ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, "Error", tryBlock,
+ beanInstance, args, invokerFactory);
encodeAndReturnResult(endpointThis, tryBlock, callback, globalErrorHandlers, endpoint, ret);
// return doErrorExecute()
@@ -891,9 +936,10 @@ record GlobalErrorHandler(BeanInfo bean, Callback callback) {
}
- private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEndpointBuildItem endpoint, Callback callback,
+ private static void generateOnMessage(ClassCreator endpointCreator, MethodCreator constructor,
+ WebSocketEndpointBuildItem endpoint, Callback callback,
CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
- IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers) {
+ IndexView index, GlobalErrorHandlersBuildItem globalErrorHandlers, InvokerFactoryBuildItem invokerFactory) {
if (callback == null) {
return;
}
@@ -924,8 +970,8 @@ private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEnd
MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "beanInstance", Object.class), tryBlock.getThis());
ResultHandle[] args = callback.generateArguments(tryBlock.getThis(), tryBlock, transformedAnnotations, index);
// Call the business method
- ResultHandle ret = tryBlock.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance,
- args);
+ ResultHandle ret = callBusinessMethod(endpointCreator, constructor, callback, messageType, tryBlock, beanInstance, args,
+ invokerFactory);
encodeAndReturnResult(tryBlock.getThis(), tryBlock, callback, globalErrorHandlers, endpoint, ret);
MethodCreator onMessageExecutionModel = endpointCreator.getMethodCreator("on" + messageType + "MessageExecutionModel",
@@ -946,6 +992,30 @@ private static void generateOnMessage(ClassCreator endpointCreator, WebSocketEnd
}
}
+ private static ResultHandle callBusinessMethod(ClassCreator clazz, MethodCreator constructor, Callback callback,
+ String messageType, BytecodeCreator bytecode, ResultHandle beanInstance, ResultHandle[] args,
+ InvokerFactoryBuildItem invokerFactory) {
+ if (KotlinUtils.isKotlinSuspendMethod(callback.method)) {
+ InvokerInfo invoker = invokerFactory.createInvoker(callback.bean, callback.method)
+ .withInvocationWrapper(CoroutineInvoker.class, "inNewCoroutine")
+ .build();
+ FieldCreator invokerField = clazz.getFieldCreator("invokerFor" + messageType, Invoker.class)
+ .setModifiers(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL);
+ constructor.writeInstanceField(invokerField.getFieldDescriptor(), constructor.getThis(),
+ constructor.newInstance(MethodDescriptor.ofConstructor(invoker.getClassName())));
+ ResultHandle invokerHandle = bytecode.readInstanceField(invokerField.getFieldDescriptor(), bytecode.getThis());
+ ResultHandle argsArray = bytecode.newArray(Object.class, args.length);
+ for (int i = 0; i < args.length; i++) {
+ bytecode.writeArrayValue(argsArray, i, args[i]);
+ }
+ return bytecode.invokeInterfaceMethod(
+ MethodDescriptor.ofMethod(Invoker.class, "invoke", Object.class, Object.class, Object[].class),
+ invokerHandle, beanInstance, argsArray);
+ } else {
+ return bytecode.invokeVirtualMethod(MethodDescriptor.of(callback.method), beanInstance, args);
+ }
+ }
+
private static TryBlock uniFailureTryBlock(BytecodeCreator method) {
TryBlock tryBlock = method.tryBlock();
CatchBlockCreator catchBlock = tryBlock.addCatch(Throwable.class);
@@ -1070,8 +1140,15 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre
// ----------------------
// === Binary message ===
// ----------------------
- if (callback.isReturnTypeUni()) {
- Type messageType = callback.returnType().asParameterizedType().arguments().get(0);
+ if (callback.isReturnTypeUni() || callback.isKotlinSuspendFunction()) {
+ Type messageType = callback.isReturnTypeUni()
+ ? callback.returnType().asParameterizedType().arguments().get(0)
+ : KotlinUtils.getKotlinSuspendMethodResult(callback.method);
+ if (messageType.name().equals(KotlinDotNames.UNIT)) {
+ value = method.invokeInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "replaceWithVoid", Uni.class),
+ value);
+ messageType = ClassType.create(WebSocketDotNames.VOID);
+ }
if (messageType.name().equals(WebSocketDotNames.VOID)) {
// Uni
return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers);
@@ -1082,8 +1159,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre
// });
FunctionCreator fun = method.createFunction(Function.class);
BytecodeCreator funBytecode = fun.getBytecode();
- ResultHandle buffer = encodeBuffer(funBytecode,
- callback.returnType().asParameterizedType().arguments().get(0),
+ ResultHandle buffer = encodeBuffer(funBytecode, messageType,
funBytecode.getMethodParam(0), endpointThis, callback);
funBytecode.returnValue(funBytecode.invokeVirtualMethod(
MethodDescriptor.ofMethod(WebSocketEndpointBase.class,
@@ -1130,8 +1206,15 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre
// ----------------------
// === Text message ===
// ----------------------
- if (callback.isReturnTypeUni()) {
- Type messageType = callback.returnType().asParameterizedType().arguments().get(0);
+ if (callback.isReturnTypeUni() || callback.isKotlinSuspendFunction()) {
+ Type messageType = callback.isReturnTypeUni()
+ ? callback.returnType().asParameterizedType().arguments().get(0)
+ : KotlinUtils.getKotlinSuspendMethodResult(callback.method);
+ if (messageType.name().equals(KotlinDotNames.UNIT)) {
+ value = method.invokeInterfaceMethod(MethodDescriptor.ofMethod(Uni.class, "replaceWithVoid", Uni.class),
+ value);
+ messageType = ClassType.create(WebSocketDotNames.VOID);
+ }
if (messageType.name().equals(WebSocketDotNames.VOID)) {
// Uni
return uniOnFailureDoOnError(endpointThis, method, callback, value, endpoint, globalErrorHandlers);
@@ -1142,7 +1225,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre
// });
FunctionCreator fun = method.createFunction(Function.class);
BytecodeCreator funBytecode = fun.getBytecode();
- ResultHandle text = encodeText(funBytecode, callback.returnType().asParameterizedType().arguments().get(0),
+ ResultHandle text = encodeText(funBytecode, messageType,
funBytecode.getMethodParam(0), endpointThis, callback);
funBytecode.returnValue(funBytecode.invokeVirtualMethod(
MethodDescriptor.ofMethod(WebSocketEndpointBase.class,
@@ -1267,9 +1350,8 @@ private static void encodeAndReturnResult(ResultHandle endpointThis, BytecodeCre
}
}
- static List findErrorHandlers(Target target, IndexView index, ClassInfo beanClass,
- CallbackArgumentsBuildItem callbackArguments,
- TransformedAnnotationsBuildItem transformedAnnotations,
+ static List findErrorHandlers(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass,
+ CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
String endpointPath) {
List annotations = findCallbackAnnotations(index, beanClass, WebSocketDotNames.ON_ERROR);
if (annotations.isEmpty()) {
@@ -1285,8 +1367,9 @@ static List findErrorHandlers(Target target, IndexView index, ClassInf
.anyMatch(WebSocketDotNames.WEB_SOCKET_CLIENT_CONNECTION::equals)) {
target = Target.CLIENT;
}
- Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations),
- callbackArguments, transformedAnnotations, endpointPath, index);
+ Callback callback = new Callback(target, annotation, bean, method,
+ executionModel(method, transformedAnnotations), callbackArguments, transformedAnnotations,
+ endpointPath, index);
long errorArguments = callback.arguments.stream().filter(ca -> ca instanceof ErrorCallbackArgument).count();
if (errorArguments != 1) {
throw new WebSocketException(
@@ -1316,16 +1399,16 @@ private static List findCallbackAnnotations(IndexView index,
return annotations;
}
- static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName,
- CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
- String endpointPath) {
- return findCallback(target, index, beanClass, annotationName, callbackArguments, transformedAnnotations, endpointPath,
- null);
+ static Callback findCallback(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass,
+ DotName annotationName, CallbackArgumentsBuildItem callbackArguments,
+ TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath) {
+ return findCallback(target, index, bean, beanClass, annotationName, callbackArguments,
+ transformedAnnotations, endpointPath, null);
}
- private static Callback findCallback(Target target, IndexView index, ClassInfo beanClass, DotName annotationName,
- CallbackArgumentsBuildItem callbackArguments, TransformedAnnotationsBuildItem transformedAnnotations,
- String endpointPath,
+ private static Callback findCallback(Target target, IndexView index, BeanInfo bean, ClassInfo beanClass,
+ DotName annotationName, CallbackArgumentsBuildItem callbackArguments,
+ TransformedAnnotationsBuildItem transformedAnnotations, String endpointPath,
Consumer validator) {
List annotations = findCallbackAnnotations(index, beanClass, annotationName);
if (annotations.isEmpty()) {
@@ -1333,9 +1416,9 @@ private static Callback findCallback(Target target, IndexView index, ClassInfo b
} else if (annotations.size() == 1) {
AnnotationInstance annotation = annotations.get(0);
MethodInfo method = annotation.target().asMethod();
- Callback callback = new Callback(target, annotation, method, executionModel(method, transformedAnnotations),
- callbackArguments,
- transformedAnnotations, endpointPath, index);
+ Callback callback = new Callback(target, annotation, bean, method,
+ executionModel(method, transformedAnnotations), callbackArguments, transformedAnnotations,
+ endpointPath, index);
long messageArguments = callback.arguments.stream().filter(ca -> ca instanceof MessageCallbackArgument).count();
if (callback.acceptsMessage()) {
if (messageArguments > 1) {
@@ -1382,6 +1465,10 @@ private static ExecutionModel executionModel(MethodInfo method, TransformedAnnot
}
static boolean hasBlockingSignature(MethodInfo method) {
+ if (KotlinUtils.isKotlinSuspendMethod(method)) {
+ return false;
+ }
+
switch (method.returnType().kind()) {
case VOID:
case CLASS:
@@ -1414,6 +1501,8 @@ private static boolean isOnOpenWithBinaryReturnType(Callback callback) {
Type returnType = callback.returnType();
if (callback.isReturnTypeUni() || callback.isReturnTypeMulti()) {
returnType = callback.returnType().asParameterizedType().arguments().get(0);
+ } else if (callback.isKotlinSuspendFunction()) {
+ returnType = KotlinUtils.getKotlinSuspendMethodResult(callback.method);
}
return WebSocketDotNames.BUFFER.equals(returnType.name())
|| (returnType.kind() == Kind.ARRAY && PrimitiveType.BYTE.equals(returnType.asArrayType().constituent()));
diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java
index 2f1974089db901..926c0d1b82d153 100644
--- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java
+++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java
@@ -144,6 +144,13 @@ public Buffer sendAndAwaitReply(String message) {
return messages.get(c);
}
+ public Buffer sendAndAwaitReply(Buffer message) {
+ var c = messages.size();
+ sendAndAwait(message);
+ Awaitility.await().until(() -> messages.size() > c);
+ return messages.get(c);
+ }
+
public boolean isClosed() {
return socket.get().isClosed();
}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt
new file mode 100644
index 00000000000000..a42359791fd386
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEcho.kt
@@ -0,0 +1,14 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import io.quarkus.websockets.next.OnBinaryMessage
+import io.quarkus.websockets.next.WebSocket
+import io.vertx.core.buffer.Buffer
+import kotlinx.coroutines.delay
+
+@WebSocket(path = "/binary-echo")
+class BinaryEcho {
+ @OnBinaryMessage
+ fun process(msg: Buffer): Buffer {
+ return msg
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt
new file mode 100644
index 00000000000000..a2182d7607535e
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/BinaryEchoSuspend.kt
@@ -0,0 +1,15 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import io.quarkus.websockets.next.OnBinaryMessage
+import io.quarkus.websockets.next.WebSocket
+import io.vertx.core.buffer.Buffer
+import kotlinx.coroutines.delay
+
+@WebSocket(path = "/binary-echo-suspend")
+class BinaryEchoSuspend {
+ @OnBinaryMessage
+ suspend fun process(msg: Buffer): Buffer {
+ delay(100)
+ return msg
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt
new file mode 100644
index 00000000000000..fc60701b4991be
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Echo.kt
@@ -0,0 +1,12 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import io.quarkus.websockets.next.OnTextMessage
+import io.quarkus.websockets.next.WebSocket
+
+@WebSocket(path = "/echo")
+class Echo {
+ @OnTextMessage
+ fun process(msg: Message): Message {
+ return msg
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt
new file mode 100644
index 00000000000000..92d27b2d9d85bb
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/EchoSuspend.kt
@@ -0,0 +1,14 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import io.quarkus.websockets.next.OnTextMessage
+import io.quarkus.websockets.next.WebSocket
+import kotlinx.coroutines.delay
+
+@WebSocket(path = "/echo-suspend")
+class EchoSuspend {
+ @OnTextMessage
+ suspend fun process(msg: Message): Message {
+ delay(100)
+ return msg
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt
new file mode 100644
index 00000000000000..9ec2cd00baf2c3
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketClientTest.kt
@@ -0,0 +1,93 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import io.quarkus.test.QuarkusUnitTest
+import io.quarkus.test.common.http.TestHTTPResource
+import io.quarkus.websockets.next.OnClose
+import io.quarkus.websockets.next.OnOpen
+import io.quarkus.websockets.next.OnTextMessage
+import io.quarkus.websockets.next.WebSocket
+import io.quarkus.websockets.next.WebSocketClient
+import io.quarkus.websockets.next.WebSocketConnector
+import jakarta.inject.Inject
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Assertions.assertTrue
+import org.junit.jupiter.api.Test
+import org.junit.jupiter.api.extension.RegisterExtension
+import java.net.URI
+import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+
+class KotlinWebSocketClientTest {
+ companion object {
+ @RegisterExtension
+ val test = QuarkusUnitTest()
+ .withApplicationRoot { jar ->
+ jar.addClasses(ServerEndpoint::class.java, ClientEndpoint::class.java)
+ }
+ }
+
+ @Inject
+ lateinit var connector: WebSocketConnector
+
+ @TestHTTPResource("/")
+ lateinit var uri: URI
+
+ @Test
+ fun test() {
+ val connection = connector.baseUri(uri).connectAndAwait()
+ connection.sendTextAndAwait("Hi!")
+
+ assertTrue(ClientEndpoint.messagesLatch.await(5, TimeUnit.SECONDS))
+ assertEquals("Hello there", ClientEndpoint.messages[0])
+ assertEquals("Hi!", ClientEndpoint.messages[1])
+
+ connection.closeAndAwait()
+ assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS))
+ assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS))
+ }
+
+ @WebSocket(path = "/endpoint")
+ class ServerEndpoint {
+ companion object {
+ val closedLatch: CountDownLatch = CountDownLatch(1)
+ }
+
+ @OnOpen
+ fun open(): String {
+ return "Hello there"
+ }
+
+ @OnTextMessage
+ fun echo(message: String): String {
+ return message
+ }
+
+ @OnClose
+ fun close() {
+ closedLatch.countDown()
+ }
+ }
+
+ @WebSocketClient(path = "/endpoint")
+ class ClientEndpoint {
+ companion object {
+ val messages: MutableList = CopyOnWriteArrayList()
+
+ val messagesLatch: CountDownLatch = CountDownLatch(2)
+
+ val closedLatch: CountDownLatch = CountDownLatch(1)
+ }
+
+ @OnTextMessage
+ fun onMessage(message: String) {
+ messages.add(message)
+ messagesLatch.countDown()
+ }
+
+ @OnClose
+ fun onClose() {
+ closedLatch.countDown()
+ }
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt
new file mode 100644
index 00000000000000..9c4a8fdc9051a9
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketSuspendingClientTest.kt
@@ -0,0 +1,99 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import io.quarkus.test.QuarkusUnitTest
+import io.quarkus.test.common.http.TestHTTPResource
+import io.quarkus.websockets.next.OnClose
+import io.quarkus.websockets.next.OnOpen
+import io.quarkus.websockets.next.OnTextMessage
+import io.quarkus.websockets.next.WebSocket
+import io.quarkus.websockets.next.WebSocketClient
+import io.quarkus.websockets.next.WebSocketConnector
+import jakarta.inject.Inject
+import kotlinx.coroutines.delay
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Assertions.assertTrue
+import org.junit.jupiter.api.Test
+import org.junit.jupiter.api.extension.RegisterExtension
+import java.net.URI
+import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.TimeUnit
+
+class KotlinWebSocketSuspendingClientTest {
+ companion object {
+ @RegisterExtension
+ val test = QuarkusUnitTest()
+ .withApplicationRoot { jar ->
+ jar.addClasses(ServerEndpoint::class.java, ClientEndpoint::class.java)
+ }
+ }
+
+ @Inject
+ lateinit var connector: WebSocketConnector
+
+ @TestHTTPResource("/")
+ lateinit var uri: URI
+
+ @Test
+ fun test() {
+ val connection = connector.baseUri(uri).connectAndAwait()
+ connection.sendTextAndAwait("Hi!")
+
+ assertTrue(ClientEndpoint.messagesLatch.await(5, TimeUnit.SECONDS))
+ assertEquals("Hello there", ClientEndpoint.messages[0])
+ assertEquals("Hi!", ClientEndpoint.messages[1])
+
+ connection.closeAndAwait()
+ assertTrue(ClientEndpoint.closedLatch.await(5, TimeUnit.SECONDS))
+ assertTrue(ServerEndpoint.closedLatch.await(5, TimeUnit.SECONDS))
+ }
+
+ @WebSocket(path = "/endpoint")
+ class ServerEndpoint {
+ companion object {
+ val closedLatch: CountDownLatch = CountDownLatch(1)
+ }
+
+ @OnOpen
+ suspend fun open(): String {
+ delay(100)
+ return "Hello there"
+ }
+
+ @OnTextMessage
+ suspend fun echo(message: String): String {
+ delay(100)
+ return message
+ }
+
+ @OnClose
+ suspend fun close() {
+ delay(100)
+ closedLatch.countDown()
+ }
+ }
+
+ @WebSocketClient(path = "/endpoint")
+ class ClientEndpoint {
+ companion object {
+ val messages: MutableList = CopyOnWriteArrayList()
+
+ val messagesLatch: CountDownLatch = CountDownLatch(2)
+
+ val closedLatch: CountDownLatch = CountDownLatch(1)
+ }
+
+ @OnTextMessage
+ suspend fun onMessage(message: String) {
+ delay(100)
+ messages.add(message)
+ messagesLatch.countDown()
+ }
+
+ @OnClose
+ suspend fun onClose() {
+ delay(100)
+ closedLatch.countDown()
+ }
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt
new file mode 100644
index 00000000000000..12dd973ee9848e
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/KotlinWebSocketTest.kt
@@ -0,0 +1,75 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import io.quarkus.test.QuarkusUnitTest
+import io.quarkus.test.common.http.TestHTTPResource
+import io.quarkus.websockets.next.test.utils.WSClient
+import io.vertx.core.Vertx
+import io.vertx.core.buffer.Buffer
+import io.vertx.core.json.JsonObject
+import jakarta.inject.Inject
+import org.junit.jupiter.api.Assertions.assertEquals
+import org.junit.jupiter.api.Test
+import org.junit.jupiter.api.extension.RegisterExtension
+import java.net.URI
+
+class KotlinWebSocketTest {
+ companion object {
+ @RegisterExtension
+ val test = QuarkusUnitTest()
+ .withApplicationRoot { jar ->
+ jar.addClasses(Echo::class.java, EchoSuspend::class.java, BinaryEcho::class.java,
+ BinaryEchoSuspend::class.java, Message::class.java, WSClient::class.java)
+ }
+ }
+
+ @Inject
+ lateinit var vertx: Vertx
+
+ @TestHTTPResource("echo")
+ lateinit var echo: URI
+
+ @TestHTTPResource("echo-suspend")
+ lateinit var echoSuspend: URI
+
+ @TestHTTPResource("binary-echo")
+ lateinit var binaryEcho: URI
+
+ @TestHTTPResource("binary-echo-suspend")
+ lateinit var binaryEchoSuspend: URI
+
+ @Test
+ fun testEcho() {
+ doTest(echo)
+ }
+
+ @Test
+ fun testEchoSuspend() {
+ doTest(echoSuspend)
+ }
+
+ private fun doTest(uri: URI) {
+ WSClient.create(vertx).connect(uri).use { client ->
+ val req = JsonObject().put("msg", "hello")
+ val resp = client.sendAndAwaitReply(req.toString()).toJsonObject()
+ assertEquals(req, resp)
+ }
+ }
+
+ @Test
+ fun testBinaryEcho() {
+ doTestBinary(binaryEcho)
+ }
+
+ @Test
+ fun testBinaryEchoSuspend() {
+ doTestBinary(binaryEchoSuspend)
+ }
+
+ private fun doTestBinary(uri: URI) {
+ WSClient.create(vertx).connect(uri).use { client ->
+ val req = Buffer.buffer("hello there!")
+ val resp = client.sendAndAwaitReply(req)
+ assertEquals(req, resp)
+ }
+ }
+}
diff --git a/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt
new file mode 100644
index 00000000000000..96563b27e7382d
--- /dev/null
+++ b/extensions/websockets-next/deployment/src/test/kotlin/io/quarkus/websockets/next/test/kotlin/Message.kt
@@ -0,0 +1,5 @@
+package io.quarkus.websockets.next.test.kotlin
+
+import com.fasterxml.jackson.annotation.JsonCreator
+
+data class Message @JsonCreator constructor(var msg: String)
diff --git a/extensions/websockets-next/kotlin/pom.xml b/extensions/websockets-next/kotlin/pom.xml
new file mode 100644
index 00000000000000..bda3816be3a85e
--- /dev/null
+++ b/extensions/websockets-next/kotlin/pom.xml
@@ -0,0 +1,115 @@
+
+
+
+ io.quarkus
+ quarkus-websockets-next-parent
+ 999-SNAPSHOT
+
+ 4.0.0
+
+ quarkus-websockets-next-kotlin
+ Quarkus - WebSockets Next - Kotlin
+
+
+
+ io.quarkus
+ quarkus-arc
+
+
+ io.quarkus
+ quarkus-vertx
+
+
+ org.jetbrains.kotlin
+ kotlin-stdlib-jdk8
+ true
+
+
+ org.jetbrains.kotlinx
+ kotlinx-coroutines-jdk8
+ true
+
+
+ io.smallrye.reactive
+ mutiny-kotlin
+ true
+
+
+
+
+ ${project.basedir}/src/main/kotlin
+ ${project.basedir}/src/test/kotlin
+
+
+ org.jetbrains.kotlin
+ kotlin-maven-plugin
+ ${kotlin.version}
+
+
+ compile
+
+ compile
+
+
+
+ test-compile
+
+ test-compile
+
+
+
+
+ ${maven.compiler.target}
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+
+
+
+ io.quarkus
+ quarkus-extension-processor
+ ${project.version}
+
+
+
+
+
+
+ default-compile
+ none
+
+
+
+ default-testCompile
+ none
+
+
+ java-compile
+ compile
+
+ compile
+
+
+
+ java-test-compile
+ test-compile
+
+ testCompile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+
+ false
+
+
+
+
+
diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt
new file mode 100644
index 00000000000000..78ee39b9247065
--- /dev/null
+++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/ApplicationCoroutineScope.kt
@@ -0,0 +1,18 @@
+package io.quarkus.websockets.next.runtime.kotlin
+
+import jakarta.annotation.PreDestroy
+import jakarta.inject.Singleton
+import kotlin.coroutines.CoroutineContext
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.SupervisorJob
+import kotlinx.coroutines.cancel
+
+@Singleton
+class ApplicationCoroutineScope : CoroutineScope, AutoCloseable {
+ override val coroutineContext: CoroutineContext = SupervisorJob()
+
+ @PreDestroy
+ override fun close() {
+ coroutineContext.cancel()
+ }
+}
diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt
new file mode 100644
index 00000000000000..5a6b6f8e28a772
--- /dev/null
+++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/CoroutineInvoker.kt
@@ -0,0 +1,37 @@
+package io.quarkus.websockets.next.runtime.kotlin
+
+import io.quarkus.arc.Arc
+import io.smallrye.mutiny.Uni
+import io.smallrye.mutiny.coroutines.asUni
+import io.vertx.core.Vertx
+import jakarta.enterprise.invoke.Invoker
+import kotlin.coroutines.resume
+import kotlin.coroutines.resumeWithException
+import kotlin.coroutines.suspendCoroutine
+import kotlinx.coroutines.CoroutineDispatcher
+import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.async
+
+object CoroutineInvoker {
+ @JvmStatic
+ @OptIn(ExperimentalCoroutinesApi::class)
+ fun inNewCoroutine(instance: T, arguments: Array, invoker: Invoker): Uni {
+ val coroutineScope = Arc.container().instance(ApplicationCoroutineScope::class.java).get()
+ val dispatcher: CoroutineDispatcher =
+ Vertx.currentContext()?.let(::VertxDispatcher)
+ ?: throw IllegalStateException("No Vertx context found")
+
+ return coroutineScope
+ .async(context = dispatcher) {
+ suspendCoroutine { continuation ->
+ arguments[arguments.size - 1] = continuation
+ try {
+ continuation.resume(invoker.invoke(instance, arguments))
+ } catch (e: Exception) {
+ continuation.resumeWithException(e)
+ }
+ }
+ }
+ .asUni()
+ }
+}
diff --git a/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt
new file mode 100644
index 00000000000000..1377b1835fe6c2
--- /dev/null
+++ b/extensions/websockets-next/kotlin/src/main/kotlin/io/quarkus/websockets/next/runtime/kotlin/VertxDispatcher.kt
@@ -0,0 +1,24 @@
+package io.quarkus.websockets.next.runtime.kotlin
+
+import io.quarkus.arc.Arc
+import io.vertx.core.Context
+import kotlin.coroutines.CoroutineContext
+import kotlinx.coroutines.CoroutineDispatcher
+
+class VertxDispatcher(private val vertxContext: Context) : CoroutineDispatcher() {
+ override fun dispatch(context: CoroutineContext, block: Runnable) {
+ val requestContext = Arc.container().requestContext()
+ vertxContext.runOnContext {
+ if (requestContext.isActive) {
+ block.run()
+ } else {
+ try {
+ requestContext.activate()
+ block.run()
+ } finally {
+ requestContext.terminate()
+ }
+ }
+ }
+ }
+}
diff --git a/extensions/websockets-next/pom.xml b/extensions/websockets-next/pom.xml
index 5ff9318b3e7658..1e149d244734fa 100644
--- a/extensions/websockets-next/pom.xml
+++ b/extensions/websockets-next/pom.xml
@@ -16,6 +16,7 @@
deployment
runtime
+ kotlin
diff --git a/extensions/websockets-next/runtime/pom.xml b/extensions/websockets-next/runtime/pom.xml
index a72d7832edd42f..de16a0cfaacaa1 100644
--- a/extensions/websockets-next/runtime/pom.xml
+++ b/extensions/websockets-next/runtime/pom.xml
@@ -34,6 +34,10 @@
io.quarkus
quarkus-tls-registry
+
+ io.quarkus
+ quarkus-websockets-next-kotlin
+
io.quarkus.security
diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java
index cef74e824cf2d9..37ab47604ba2dd 100644
--- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java
+++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinDotNames.java
@@ -2,8 +2,9 @@
import org.jboss.jandex.DotName;
-class KotlinDotNames {
- static final DotName METADATA = DotName.createSimple("kotlin.Metadata");
+public final class KotlinDotNames {
+ public static final DotName METADATA = DotName.createSimple("kotlin.Metadata");
+ public static final DotName UNIT = DotName.createSimple("kotlin.Unit");
- static final DotName CONTINUATION = DotName.createSimple("kotlin.coroutines.Continuation");
+ public static final DotName CONTINUATION = DotName.createSimple("kotlin.coroutines.Continuation");
}
diff --git a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java
index 1628a83b3adefb..966a7fc2c57abf 100644
--- a/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java
+++ b/independent-projects/arc/processor/src/main/java/io/quarkus/arc/processor/KotlinUtils.java
@@ -4,6 +4,7 @@
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.MethodInfo;
+import org.jboss.jandex.MethodParameterInfo;
import org.jboss.jandex.Type;
public class KotlinUtils {
@@ -27,6 +28,10 @@ public static boolean isKotlinSuspendMethod(MethodInfo method) {
return KotlinDotNames.CONTINUATION.equals(lastParameter.name());
}
+ public static boolean isKotlinContinuationParameter(MethodParameterInfo parameter) {
+ return isKotlinSuspendMethod(parameter.method()) && KotlinDotNames.CONTINUATION.equals(parameter.type().name());
+ }
+
public static boolean isNoninterceptableKotlinMethod(MethodInfo method) {
// the Kotlin compiler generates somewhat streamlined bytecode when it determines
// that a `suspend` method cannot be overridden, and that bytecode doesn't work
@@ -36,4 +41,24 @@ public static boolean isNoninterceptableKotlinMethod(MethodInfo method) {
return isKotlinSuspendMethod(method)
&& (Modifier.isFinal(method.flags()) || Modifier.isFinal(method.declaringClass().flags()));
}
+
+ public static Type getKotlinSuspendMethodResult(MethodInfo method) {
+ if (!isKotlinSuspendMethod(method)) {
+ throw new IllegalArgumentException("Not a suspend function: " + method);
+ }
+
+ Type lastParameter = method.parameterType(method.parametersCount() - 1);
+ if (lastParameter.kind() != Type.Kind.PARAMETERIZED_TYPE) {
+ throw new IllegalArgumentException("Continuation parameter type not parameterized: " + lastParameter);
+ }
+ Type resultType = lastParameter.asParameterizedType().arguments().get(0);
+ if (resultType.kind() != Type.Kind.WILDCARD_TYPE) {
+ throw new IllegalArgumentException("Continuation parameter type argument not wildcard: " + resultType);
+ }
+ Type lowerBound = resultType.asWildcardType().superBound();
+ if (lowerBound == null) {
+ throw new IllegalArgumentException("Continuation parameter type argument without lower bound: " + resultType);
+ }
+ return lowerBound;
+ }
}