diff --git a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java index c367691098caa..4ed22a9985603 100644 --- a/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java +++ b/extensions/websockets-next/deployment/src/main/java/io/quarkus/websockets/next/deployment/WebSocketProcessor.java @@ -1065,7 +1065,8 @@ private static ResultHandle uniOnFailureDoOnError(ResultHandle endpointThis, Byt private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCreator method, Callback callback, GlobalErrorHandlersBuildItem globalErrorHandlers, WebSocketEndpointBuildItem endpoint, ResultHandle value) { - if (callback.acceptsBinaryMessage()) { + if (callback.acceptsBinaryMessage() + || isOnOpenWithBinaryReturnType(callback)) { // ---------------------- // === Binary message === // ---------------------- @@ -1119,7 +1120,7 @@ private static ResultHandle encodeMessage(ResultHandle endpointThis, BytecodeCre value, fun.getInstance()); } else { - // return sendBinary(buffer,broadcast); + // return sendBinary(encodeBuffer(b),broadcast); ResultHandle buffer = encodeBuffer(method, callback.returnType(), value, endpointThis, callback); return method.invokeVirtualMethod(MethodDescriptor.ofMethod(WebSocketEndpointBase.class, "sendBinary", Uni.class, Buffer.class, boolean.class), endpointThis, buffer, @@ -1407,4 +1408,16 @@ static boolean isByteArray(Type type) { static String methodToString(MethodInfo method) { return method.declaringClass().name() + "#" + method.name() + "()"; } + + private static boolean isOnOpenWithBinaryReturnType(Callback callback) { + if (callback.isOnOpen()) { + Type returnType = callback.returnType(); + if (callback.isReturnTypeUni() || callback.isReturnTypeMulti()) { + returnType = callback.returnType().asParameterizedType().arguments().get(0); + } + return WebSocketDotNames.BUFFER.equals(returnType.name()) + || (returnType.kind() == Kind.ARRAY && PrimitiveType.BYTE.equals(returnType.asArrayType().constituent())); + } + return false; + } } diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/onopenreturntypes/OnOpenReturnTypesTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/onopenreturntypes/OnOpenReturnTypesTest.java new file mode 100644 index 0000000000000..aa5ab58469ab9 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/onopenreturntypes/OnOpenReturnTypesTest.java @@ -0,0 +1,71 @@ +package io.quarkus.websockets.next.test.onopenreturntypes; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.quarkus.websockets.next.test.utils.WSClient.ReceiverMode; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class OnOpenReturnTypesTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(EndpointText.class, EndpointBinary.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("end-text") + URI endText; + + @TestHTTPResource("end-binary") + URI endBinary; + + @Test + void testReturnTypes() throws Exception { + try (WSClient textClient = WSClient.create(vertx, ReceiverMode.TEXT).connect(endText)) { + textClient.waitForMessages(1); + assertEquals("/end-text", textClient.getMessages().get(0).toString()); + } + try (WSClient binaryClient = WSClient.create(vertx, ReceiverMode.BINARY).connect(endBinary)) { + binaryClient.waitForMessages(1); + assertEquals("/end-binary", binaryClient.getMessages().get(0).toString()); + } + } + + @WebSocket(path = "/end-text") + public static class EndpointText { + + @OnOpen + String open(WebSocketConnection connection) { + return connection.handshakeRequest().path(); + } + + } + + @WebSocket(path = "/end-binary") + public static class EndpointBinary { + + @OnOpen + Buffer open(WebSocketConnection connection) { + return Buffer.buffer(connection.handshakeRequest().path()); + } + + } + +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java index 955eb9c1b315c..2f1974089db90 100644 --- a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/utils/WSClient.java @@ -17,16 +17,26 @@ public class WSClient implements AutoCloseable { + public static WSClient create(Vertx vertx) { + return new WSClient(vertx); + } + + public static WSClient create(Vertx vertx, ReceiverMode mode) { + return new WSClient(vertx, mode); + } + private final WebSocketClient client; private AtomicReference socket = new AtomicReference<>(); private List messages = new CopyOnWriteArrayList<>(); + private final ReceiverMode mode; - public WSClient(Vertx vertx) { + public WSClient(Vertx vertx, ReceiverMode mode) { this.client = vertx.createWebSocketClient(); + this.mode = mode; } - public static WSClient create(Vertx vertx) { - return new WSClient(vertx); + public WSClient(Vertx vertx) { + this(vertx, ReceiverMode.ALL); } public static URI toWS(URI uri, String path) { @@ -52,7 +62,19 @@ public WSClient connect(WebSocketConnectOptions options, URI url) { uri.append("?").append(url.getQuery()); } ClientWebSocket webSocket = client.webSocket(); - webSocket.handler(b -> messages.add(b)); + switch (mode) { + case ALL: + webSocket.handler(b -> messages.add(b)); + break; + case BINARY: + webSocket.binaryMessageHandler(b -> messages.add(b)); + break; + case TEXT: + webSocket.textMessageHandler(b -> messages.add(Buffer.buffer(b))); + break; + default: + throw new IllegalStateException(); + } await(webSocket.connect(options.setPort(url.getPort()).setHost(url.getHost()).setURI(uri.toString()))); var prev = socket.getAndSet(webSocket); if (prev != null) { @@ -135,4 +157,10 @@ public void close() { disconnect(); } + public enum ReceiverMode { + BINARY, + TEXT, + ALL + } + }