Skip to content

Commit

Permalink
Make sure a WsListener supplier is called exactly once per connection…
Browse files Browse the repository at this point in the history
…. In particular, that the same WsListener is shared between the upgrade and connection phases. See issue helidon-io#8039.
  • Loading branch information
spericas committed Dec 8, 2023
1 parent 90a0cc4 commit 4911ab9
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private static DataReader reader(ArrayBlockingQueue<byte[]> queue) {

void start() {
if (serverStarted.compareAndSet(false, true)) {
WsConnection serverConnection = WsConnection.create(ctx, prologue, WritableHeaders.create(), "", serverRoute);
WsConnection serverConnection = WsConnection.create(ctx, prologue, WritableHeaders.create(), "", serverRoute.listener());

ClientWsConnection clientConnection = ClientWsConnection.create(new DirectConnect(clientReader, clientWriter),
clientListener);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2023 Oracle and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.helidon.webserver.tests.websocket;

import java.net.URI;
import java.net.http.HttpClient;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

import io.helidon.webserver.Router;
import io.helidon.webserver.WebServer;
import io.helidon.webserver.testing.junit5.ServerTest;
import io.helidon.webserver.testing.junit5.SetUpRoute;
import io.helidon.webserver.websocket.WsRouting;
import io.helidon.websocket.WsCloseCodes;
import io.helidon.websocket.WsListener;
import org.junit.jupiter.api.Test;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

/**
* Checks that a {@code WsListener} supplier is called exactly once per connection.
* In particular, that the same listener is shared between the connection upgrade
* and the connection handling phases.
*/
@ServerTest
class WebSocketSupplierTest {

private final int port;
private final HttpClient client;

private static final AtomicInteger supplierCalls = new AtomicInteger();

WebSocketSupplierTest(WebServer server) {
port = server.port();
client = HttpClient.newBuilder()
.connectTimeout(Duration.ofSeconds(5))
.build();
}

@SetUpRoute
static void router(Router.RouterBuilder<?> router) {
Supplier<WsListener> supplier = () -> {
EchoService service = new EchoService();
supplierCalls.getAndIncrement();
return service;
};
router.addRouting(WsRouting.builder().endpoint("/echo", supplier));
}

@Test
void testSingleSupplier() throws Exception {
java.net.http.WebSocket ws = client.newWebSocketBuilder()
.buildAsync(URI.create("ws://localhost:" + port + "/echo"),
new java.net.http.WebSocket.Listener() {})
.get(5, TimeUnit.SECONDS);
ws.request(10);
ws.sendText("Hello", true).get(5, TimeUnit.SECONDS);
ws.sendClose(WsCloseCodes.NORMAL_CLOSE, "normal").get(5, TimeUnit.SECONDS);

// enforce one listener per connection -- single call to supplier
assertThat(supplierCalls.get(), is(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ private WsConnection(ConnectionContext ctx,
HttpPrologue prologue,
Headers upgradeHeaders,
String wsKey,
WsRoute wsRoute) {
WsListener wsListener) {
this.ctx = ctx;
this.prologue = prologue;
this.upgradeHeaders = upgradeHeaders;
this.wsKey = wsKey;
this.listener = wsRoute.listener();
this.listener = wsListener;
this.dataReader = ctx.dataReader();
this.lastRequestTimestamp = DateTime.timestamp();
this.wsConfig = (WsConfig) ctx.listenerContext()
Expand All @@ -94,15 +94,15 @@ private WsConnection(ConnectionContext ctx,
* @param prologue prologue of this request
* @param upgradeHeaders headers for
* @param wsKey ws key
* @param wsRoute route to use
* @param wsListener a ws listener
* @return a new connection
*/
public static WsConnection create(ConnectionContext ctx,
HttpPrologue prologue,
Headers upgradeHeaders,
String wsKey,
WsRoute wsRoute) {
return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsRoute);
WsListener wsListener) {
return new WsConnection(ctx, prologue, upgradeHeaders, wsKey, wsListener);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import io.helidon.webserver.ConnectionContext;
import io.helidon.webserver.http1.spi.Http1Upgrader;
import io.helidon.webserver.spi.ServerConnection;
import io.helidon.websocket.WsListener;
import io.helidon.websocket.WsUpgradeException;

import static java.nio.charset.StandardCharsets.US_ASCII;
Expand Down Expand Up @@ -167,8 +168,9 @@ public ServerConnection upgrade(ConnectionContext ctx, HttpPrologue prologue, Wr

// invoke user-provided HTTP upgrade handler
Optional<Headers> upgradeHeaders;
WsListener wsListener = route.listener();
try {
upgradeHeaders = route.listener().onHttpUpgrade(prologue, headers);
upgradeHeaders = wsListener.onHttpUpgrade(prologue, headers);
} catch (WsUpgradeException e) {
LOGGER.log(Level.TRACE, "Websocket upgrade rejected", e);
return null;
Expand All @@ -191,7 +193,7 @@ public ServerConnection upgrade(ConnectionContext ctx, HttpPrologue prologue, Wr
LOGGER.log(Level.TRACE, "Upgraded to websocket version " + version);
}

return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, route);
return WsConnection.create(ctx, prologue, upgradeHeaders.orElse(EMPTY_HEADERS), wsKey, wsListener);
}

protected boolean anyOrigin() {
Expand Down

0 comments on commit 4911ab9

Please sign in to comment.