Skip to content

Commit

Permalink
Merge pull request #40207 from mkouba/issue-39862
Browse files Browse the repository at this point in the history
WebSockets Next: send ping message from the server automatically
  • Loading branch information
cescoffier authored Apr 24, 2024
2 parents cfcd5a3 + 8357938 commit 323774b
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback;
import io.quarkus.websockets.next.deployment.WebSocketEndpointBuildItem.Callback.MessageType;
import io.quarkus.websockets.next.runtime.Codecs;
Expand Down Expand Up @@ -383,7 +383,7 @@ private void validateOnPongMessage(Callback callback) {
"@OnPongMessage callback must return void or Uni<Void>: " + callbackToString(callback.method));
}
Type messageType = callback.argumentType(MessageCallbackArgument::isMessage);
if (!messageType.name().equals(WebSocketDotNames.BUFFER)) {
if (messageType == null || !messageType.name().equals(WebSocketDotNames.BUFFER)) {
throw new WebSocketServerException(
"@OnPongMessage callback must accept exactly one message parameter of type io.vertx.core.buffer.Buffer: "
+ callbackToString(callback.method));
Expand Down Expand Up @@ -478,10 +478,10 @@ private String generateEndpoint(WebSocketEndpointBuildItem endpoint,
.build();

MethodCreator constructor = endpointCreator.getConstructorCreator(WebSocketConnection.class,
Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class);
Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class);
constructor.invokeSpecialMethod(
MethodDescriptor.ofConstructor(WebSocketEndpointBase.class, WebSocketConnection.class,
Codecs.class, WebSocketsRuntimeConfig.class, ContextSupport.class),
Codecs.class, WebSocketsServerRuntimeConfig.class, ContextSupport.class),
constructor.getThis(), constructor.getMethodParam(0), constructor.getMethodParam(1),
constructor.getMethodParam(2), constructor.getMethodParam(3));
constructor.returnNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class MaxMessageSizeTest {
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Echo.class, WSClient.class);
}).overrideConfigKey("quarkus.websockets-next.max-message-size", "10");
}).overrideConfigKey("quarkus.websockets-next.server.max-message-size", "10");

@Inject
Vertx vertx;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package io.quarkus.websockets.next.test.pingpong;

import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnPongMessage;
import io.quarkus.websockets.next.WebSocket;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.WebSocketClient;

public class AutoPingIntervalTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class);
}).overrideConfigKey("quarkus.websockets-next.server.auto-ping-interval", "200ms");

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@Test
public void testPingPong() throws InterruptedException, ExecutionException {
WebSocketClient client = vertx.createWebSocketClient();
try {
CountDownLatch connectedLatch = new CountDownLatch(1);
client
.connect(endUri.getPort(), endUri.getHost(), endUri.getPath())
.onComplete(r -> {
if (r.succeeded()) {
connectedLatch.countDown();
} else {
throw new IllegalStateException(r.cause());
}
});
assertTrue(connectedLatch.await(5, TimeUnit.SECONDS));
// The pong message should be sent by the client automatically and should be identical to the ping message
assertTrue(Endpoint.PONG.await(5, TimeUnit.SECONDS));
} finally {
client.close().toCompletionStage().toCompletableFuture().get();
}
}

@WebSocket(path = "/end")
public static class Endpoint {

static final CountDownLatch PONG = new CountDownLatch(3);

@OnOpen
public String open() {
return "ok";
}

@OnPongMessage
void pong(Buffer data) {
PONG.countDown();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class SubprotocolSelectedTest {
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
}).overrideConfigKey("quarkus.websockets-next.supported-subprotocols", "oak,larch");
}).overrideConfigKey("quarkus.websockets-next.server.supported-subprotocols", "oak,larch");

@Inject
Vertx vertx;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.quarkus.websockets.next;

import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.OptionalInt;
Expand All @@ -10,9 +11,9 @@
import io.smallrye.config.WithDefault;
import io.vertx.core.http.HttpServerOptions;

@ConfigMapping(prefix = "quarkus.websockets-next")
@ConfigMapping(prefix = "quarkus.websockets-next.server")
@ConfigRoot(phase = ConfigPhase.RUN_TIME)
public interface WebSocketsRuntimeConfig {
public interface WebSocketsServerRuntimeConfig {

/**
* See <a href="https://datatracker.ietf.org/doc/html/rfc6455#page-12">The WebSocket Protocol</a>
Expand All @@ -39,4 +40,11 @@ public interface WebSocketsRuntimeConfig {
*/
OptionalInt maxMessageSize();

/**
* The interval after which, when set, the server sends a ping message to a connected client automatically.
* <p>
* Ping messages are not sent automatically by default.
*/
Optional<Duration> autoPingInterval();

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ void add(String endpoint, WebSocketConnection connection) {
if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) {
if (!listeners.isEmpty()) {
for (ConnectionListener listener : listeners) {
listener.connectionAdded(endpoint, connection);
try {
listener.connectionAdded(endpoint, connection);
} catch (Exception e) {
LOG.warnf("Unable to call listener#connectionAdded() on [%s]: %s", listener.getClass(),
e.toString());
}
}
}
}
Expand All @@ -53,7 +58,12 @@ void remove(String endpoint, WebSocketConnection connection) {
if (connections.remove(connection)) {
if (!listeners.isEmpty()) {
for (ConnectionListener listener : listeners) {
listener.connectionRemoved(endpoint, connection.id());
try {
listener.connectionRemoved(endpoint, connection.id());
} catch (Exception e) {
LOG.warnf("Unable to call listener#connectionRemoved() on [%s]: %s", listener.getClass(),
e.toString());
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.jboss.logging.Logger;

import io.quarkus.vertx.core.runtime.VertxBufferImpl;
import io.quarkus.websockets.next.WebSocketConnection;
import io.smallrye.mutiny.Uni;
Expand All @@ -26,6 +28,8 @@

class WebSocketConnectionImpl implements WebSocketConnection {

private static final Logger LOG = Logger.getLogger(WebSocketConnectionImpl.class);

private final String generatedEndpointClass;

private final String endpointId;
Expand Down Expand Up @@ -106,6 +110,14 @@ public Uni<Void> sendPing(Buffer data) {
return UniHelper.toUni(webSocket.writePing(data));
}

void sendAutoPing() {
webSocket.writePing(Buffer.buffer("ping")).onComplete(r -> {
if (r.failed()) {
LOG.warnf("Unable to send auto-ping for %s: %s", this, r.cause().toString());
}
});
}

@Override
public Uni<Void> sendPong(Buffer data) {
return UniHelper.toUni(webSocket.writePong(data));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io.quarkus.virtual.threads.VirtualThreadsRecorder;
import io.quarkus.websockets.next.WebSocket.ExecutionMode;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.quarkus.websockets.next.runtime.ConcurrencyLimiter.PromiseComplete;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.Uni;
Expand All @@ -41,7 +41,7 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint {
private final ConcurrencyLimiter limiter;

@SuppressWarnings("unused")
private final WebSocketsRuntimeConfig config;
private final WebSocketsServerRuntimeConfig config;

private final ArcContainer container;

Expand All @@ -51,7 +51,7 @@ public abstract class WebSocketEndpointBase implements WebSocketEndpoint {
private final Object beanInstance;

public WebSocketEndpointBase(WebSocketConnection connection, Codecs codecs,
WebSocketsRuntimeConfig config, ContextSupport contextSupport) {
WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) {
this.connection = connection;
this.codecs = codecs;
this.limiter = executionMode() == ExecutionMode.SERIAL ? new ConcurrencyLimiter(connection) : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import jakarta.inject.Inject;

import io.quarkus.vertx.http.HttpServerOptionsCustomizer;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.vertx.core.http.HttpServerOptions;

@Dependent
public class WebSocketHttpServerOptionsCustomizer implements HttpServerOptionsCustomizer {

@Inject
WebSocketsRuntimeConfig config;
WebSocketsServerRuntimeConfig config;

@Override
public void customizeHttpServer(HttpServerOptions options) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io.quarkus.vertx.core.runtime.VertxCoreRecorder;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.WebSocketServerException;
import io.quarkus.websockets.next.WebSocketsRuntimeConfig;
import io.quarkus.websockets.next.WebSocketsServerRuntimeConfig;
import io.quarkus.websockets.next.runtime.WebSocketSessionContext.SessionContextState;
import io.smallrye.common.vertx.VertxContext;
import io.smallrye.mutiny.Multi;
Expand All @@ -34,9 +34,9 @@ public class WebSocketServerRecorder {

static final String WEB_SOCKET_CONN_KEY = WebSocketConnection.class.getName();

private final WebSocketsRuntimeConfig config;
private final WebSocketsServerRuntimeConfig config;

public WebSocketServerRecorder(WebSocketsRuntimeConfig config) {
public WebSocketServerRecorder(WebSocketsServerRuntimeConfig config) {
this.config = config;
}

Expand Down Expand Up @@ -67,12 +67,13 @@ public Handler<RoutingContext> createEndpointHandler(String generatedEndpointCla
public void handle(RoutingContext ctx) {
Future<ServerWebSocket> future = ctx.request().toWebSocket();
future.onSuccess(ws -> {
Context context = VertxCoreRecorder.getVertx().get().getOrCreateContext();
Vertx vertx = VertxCoreRecorder.getVertx().get();
Context context = vertx.getOrCreateContext();

WebSocketConnection connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointId, ws,
WebSocketConnectionImpl connection = new WebSocketConnectionImpl(generatedEndpointClass, endpointId, ws,
connectionManager, codecs, ctx);
connectionManager.add(generatedEndpointClass, connection);
LOG.debugf("Connnected: %s", connection);
LOG.debugf("Connection created: %s", connection);

// Initialize and capture the session context state that will be activated
// during message processing
Expand Down Expand Up @@ -216,6 +217,18 @@ public void handle(Void event) {
});
});

Long timerId;
if (config.autoPingInterval().isPresent()) {
timerId = vertx.setPeriodic(config.autoPingInterval().get().toMillis(), new Handler<Long>() {
@Override
public void handle(Long timerId) {
connection.sendAutoPing();
}
});
} else {
timerId = null;
}

ws.closeHandler(new Handler<Void>() {
@Override
public void handle(Void event) {
Expand All @@ -229,6 +242,9 @@ public void handle(Void event) {
LOG.errorf(r.cause(), "Unable to complete @OnClose callback: %s", connection);
}
connectionManager.remove(generatedEndpointClass, connection);
if (timerId != null) {
vertx.cancelTimer(timerId);
}
});
}
});
Expand All @@ -249,6 +265,7 @@ public void handle(Void event) {
});
}
});

});
}
};
Expand Down Expand Up @@ -307,7 +324,7 @@ public void handle(Void event) {
}

private WebSocketEndpoint createEndpoint(String endpointClassName, Context context, WebSocketConnection connection,
Codecs codecs, WebSocketsRuntimeConfig config, ContextSupport contextSupport) {
Codecs codecs, WebSocketsServerRuntimeConfig config, ContextSupport contextSupport) {
try {
ClassLoader cl = Thread.currentThread().getContextClassLoader();
if (cl == null) {
Expand All @@ -318,7 +335,7 @@ private WebSocketEndpoint createEndpoint(String endpointClassName, Context conte
.loadClass(endpointClassName);
WebSocketEndpoint endpoint = (WebSocketEndpoint) endpointClazz
.getDeclaredConstructor(WebSocketConnection.class, Codecs.class,
WebSocketsRuntimeConfig.class, ContextSupport.class)
WebSocketsServerRuntimeConfig.class, ContextSupport.class)
.newInstance(connection, codecs, config, contextSupport);
return endpoint;
} catch (Exception e) {
Expand Down

0 comments on commit 323774b

Please sign in to comment.