diff --git a/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketSupplierTest.java b/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketSupplierTest.java new file mode 100644 index 00000000000..f9ee326cb01 --- /dev/null +++ b/webserver/tests/websocket/src/test/java/io/helidon/webserver/tests/websocket/WebSocketSupplierTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2023 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.webserver.tests.websocket; + +import java.net.URI; +import java.net.http.HttpClient; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import io.helidon.webserver.Router; +import io.helidon.webserver.WebServer; +import io.helidon.webserver.testing.junit5.ServerTest; +import io.helidon.webserver.testing.junit5.SetUpRoute; +import io.helidon.webserver.websocket.WsRouting; +import io.helidon.websocket.WsCloseCodes; +import io.helidon.websocket.WsListener; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * Checks that a {@code WsListener} supplier is called exactly once per connection. + * In particular, that the same listener is shared between the connection upgrade + * and the connection handling phases. + */ +@ServerTest +class WebSocketSupplierTest { + + private final int port; + private final HttpClient client; + + private static final AtomicInteger supplierCalls = new AtomicInteger(); + + WebSocketSupplierTest(WebServer server) { + port = server.port(); + client = HttpClient.newBuilder() + .connectTimeout(Duration.ofSeconds(5)) + .build(); + } + + @SetUpRoute + static void router(Router.RouterBuilder router) { + Supplier supplier = () -> { + EchoService service = new EchoService(); + supplierCalls.getAndIncrement(); + return service; + }; + router.addRouting(WsRouting.builder().endpoint("/echo", supplier)); + } + + @Test + void testSingleSupplier() throws Exception { + java.net.http.WebSocket ws = client.newWebSocketBuilder() + .buildAsync(URI.create("ws://localhost:" + port + "/echo"), + new java.net.http.WebSocket.Listener() {}) + .get(5, TimeUnit.SECONDS); + ws.request(10); + ws.sendText("Hello", true).get(5, TimeUnit.SECONDS); + ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS); + + // enforce one listener per connection -- single call to supplier + assertThat(supplierCalls.get(), is(1)); + } +} diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java index 32d3f442d88..198d2ac4d0a 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsConnection.java @@ -70,12 +70,12 @@ private WsConnection(ConnectionContext ctx, HttpPrologue prologue, Headers upgradeHeaders, String wsKey, - WsRoute wsRoute) { + WsListener wsListener) { this.ctx = ctx; this.prologue = prologue; this.upgradeHeaders = upgradeHeaders; this.wsKey = wsKey; - this.listener = wsRoute.listener(); + this.listener = wsListener; this.dataReader = ctx.dataReader(); this.lastRequestTimestamp = DateTime.timestamp(); this.wsConfig = (WsConfig) ctx.listenerContext() @@ -94,15 +94,15 @@ private WsConnection(ConnectionContext ctx, * @param prologue prologue of this request * @param upgradeHeaders headers for * @param wsKey ws key - * @param wsRoute route to use + * @param wsListener a ws listener * @return a new connection */ public static WsConnection create(ConnectionContext ctx, HttpPrologue prologue, Headers upgradeHeaders, String wsKey, - WsRoute wsRoute) { - return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsRoute); + WsListener wsListener) { + return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsListener); } @Override diff --git a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java index 476e351ecfc..71f21ca49a0 100644 --- a/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java +++ b/webserver/websocket/src/main/java/io/helidon/webserver/websocket/WsUpgrader.java @@ -38,6 +38,7 @@ import io.helidon.webserver.ConnectionContext; import io.helidon.webserver.http1.spi.Http1Upgrader; import io.helidon.webserver.spi.ServerConnection; +import io.helidon.websocket.WsListener; import io.helidon.websocket.WsUpgradeException; import static java.nio.charset.StandardCharsets.US_ASCII; @@ -167,8 +168,9 @@ public ServerConnection upgrade(ConnectionContext ctx, HttpPrologue prologue, Wr // invoke user-provided HTTP upgrade handler Optional upgradeHeaders; + WsListener wsListener = route.listener(); try { - upgradeHeaders = route.listener().onHttpUpgrade(prologue, headers); + upgradeHeaders = wsListener.onHttpUpgrade(prologue, headers); } catch (WsUpgradeException e) { LOGGER.log(Level.TRACE, "Websocket upgrade rejected", e); return null; @@ -191,7 +193,7 @@ public ServerConnection upgrade(ConnectionContext ctx, HttpPrologue prologue, Wr LOGGER.log(Level.TRACE, "Upgraded to websocket version " + version); } - return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, route); + return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, wsListener); } protected boolean anyOrigin() {