Skip to content

Commit

Permalink
Merge pull request #4588 from eclipse/jetty-10.0.x-4538-MessageReader…
Browse files Browse the repository at this point in the history
…Writer

Issue #4538 - rework of websocket message reader and writers
  • Loading branch information
lachlan-roberts authored Mar 11, 2020
2 parents b0ddba4 + fef25e7 commit b1d30fc
Show file tree
Hide file tree
Showing 16 changed files with 703 additions and 509 deletions.
16 changes: 16 additions & 0 deletions jetty-util/src/main/java/org/eclipse/jetty/util/StringUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,22 @@ public static void append2digits(StringBuilder buf, int i)
}
}

/**
* Generate a string from another string repeated n times.
*
* @param s the string to use
* @param n the number of times this string should be appended
*/
public static String stringFrom(String s, int n)
{
StringBuilder stringBuilder = new StringBuilder(s.length() * n);
for (int i = 0; i < n; i++)
{
stringBuilder.append(s);
}
return stringBuilder.toString();
}

/**
* Return a non null string.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
Expand All @@ -37,15 +36,13 @@
import javax.websocket.WebSocketContainer;

import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.MessageHandler;
import org.eclipse.jetty.websocket.core.server.Negotiation;
import org.eclipse.jetty.websocket.core.server.WebSocketNegotiator;
import org.eclipse.jetty.websocket.javax.tests.CoreServer;
import org.eclipse.jetty.websocket.javax.tests.WSEventTracker;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand All @@ -58,21 +55,14 @@ public class DecoderReaderManySmallTest
@BeforeEach
public void setUp() throws Exception
{
server = new CoreServer(new CoreServer.BaseNegotiator()
server = new CoreServer(WebSocketNegotiator.from((negotiation) ->
{
@Override
public FrameHandler negotiate(Negotiation negotiation) throws IOException
{
List<String> offeredSubProtocols = negotiation.getOfferedSubprotocols();
List<String> offeredSubProtocols = negotiation.getOfferedSubprotocols();
if (!offeredSubProtocols.isEmpty())
negotiation.setSubprotocol(offeredSubProtocols.get(0));

if (!offeredSubProtocols.isEmpty())
{
negotiation.setSubprotocol(offeredSubProtocols.get(0));
}

return new EventIdFrameHandler();
}
});
return new EventIdFrameHandler();
}));
server.start();

client = ContainerProvider.getWebSocketContainer();
Expand All @@ -86,15 +76,13 @@ public void tearDown() throws Exception
}

@Test
public void testManyIds(TestInfo testInfo) throws Exception
public void testManyIds() throws Exception
{
URI wsUri = server.getWsUri().resolve("/eventids");
EventIdSocket clientSocket = new EventIdSocket(testInfo.getTestMethod().toString());

final int from = 1000;
final int to = 2000;

try (Session clientSession = client.connectToServer(clientSocket, wsUri))
EventIdSocket clientSocket = new EventIdSocket();
try (Session clientSession = client.connectToServer(clientSocket, server.getWsUri()))
{
clientSession.getAsyncRemote().sendText("seq|" + from + "|" + to);
}
Expand Down Expand Up @@ -154,12 +142,6 @@ public static class EventIdSocket extends WSEventTracker
{
public BlockingQueue<EventId> messageQueue = new LinkedBlockingDeque<>();

public EventIdSocket(String id)
{
super(id);
}

@SuppressWarnings("unused")
@OnMessage
public void onMessage(EventId msg)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,47 +20,76 @@

import java.io.IOException;
import java.io.Reader;
import java.io.StringWriter;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.websocket.ClientEndpointConfig;
import javax.websocket.ContainerProvider;
import javax.websocket.EndpointConfig;
import javax.websocket.MessageHandler;
import javax.websocket.OnMessage;
import javax.websocket.Session;
import javax.websocket.WebSocketContainer;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;

import org.eclipse.jetty.util.BlockingArrayQueue;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.javax.common.JavaxWebSocketSession;
import org.eclipse.jetty.websocket.javax.tests.DataUtils;
import org.eclipse.jetty.websocket.javax.tests.Fuzzer;
import org.eclipse.jetty.websocket.javax.tests.LocalServer;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.eclipse.jetty.websocket.javax.tests.WSEndpointTracker;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class TextStreamTest
{
private static final Logger LOG = Log.getLogger(TextStreamTest.class);
private static final BlockingArrayQueue<QueuedTextStreamer> serverEndpoints = new BlockingArrayQueue<>();

private static LocalServer server;
private static ServerContainer container;
private final ClientEndpointConfig clientConfig = ClientEndpointConfig.Builder.create().build();
private LocalServer server;
private ServerContainer container;
private WebSocketContainer wsClient;

@BeforeAll
public static void startServer() throws Exception
@BeforeEach
public void startServer() throws Exception
{
server = new LocalServer();
server.start();
container = server.getServerContainer();
container.addEndpoint(ServerTextStreamer.class);
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedTextStreamer.class, "/test").build());
container.addEndpoint(ServerEndpointConfig.Builder.create(QueuedPartialTextStreamer.class, "/partial").build());

wsClient = ContainerProvider.getWebSocketContainer();
}

@AfterAll
public static void stopServer() throws Exception
@AfterEach
public void stopServer() throws Exception
{
server.stop();
}
Expand Down Expand Up @@ -145,6 +174,121 @@ public void testLargerThenMaxDefaultMessageBufferSize() throws Exception
}
}

@Test
public void testMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText(Integer.toString(i));
}
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is(Integer.toString(i)));
}
}

@Test
public void testFragmentedMessageOrdering() throws Exception
{
ClientTextStreamer client = new ClientTextStreamer();
Session session = wsClient.connectToServer(client, clientConfig, server.getWsUri().resolve("/test"));

final int numLoops = 20;
for (int i = 0; i < numLoops; i++)
{
session.getBasicRemote().sendText("firstFrame" + i, false);
session.getBasicRemote().sendText("|secondFrame" + i, false);
session.getBasicRemote().sendText("|finalFrame" + i, true);
}
session.close();

QueuedTextStreamer queuedTextStreamer = serverEndpoints.poll(5, TimeUnit.SECONDS);
assertNotNull(queuedTextStreamer);
for (int i = 0; i < numLoops; i++)
{
String msg = queuedTextStreamer.messages.poll(5, TimeUnit.SECONDS);
String expected = "firstFrame" + i + "|secondFrame" + i + "|finalFrame" + i;
assertThat(msg, Matchers.is(expected));
}
}

@Test
public void testMessageOrderingDoNotReadToEOF() throws Exception
{
ClientTextStreamer clientEndpoint = new ClientTextStreamer();
Session session = wsClient.connectToServer(clientEndpoint, clientConfig, server.getWsUri().resolve("/partial"));
QueuedTextStreamer serverEndpoint = Objects.requireNonNull(serverEndpoints.poll(5, TimeUnit.SECONDS));

int serverInputBufferSize = 1024;
JavaxWebSocketSession serverSession = (JavaxWebSocketSession)serverEndpoint.session;
serverSession.getCoreSession().setInputBufferSize(serverInputBufferSize);

// Write some initial data.
Writer writer = session.getBasicRemote().getSendWriter();
writer.write("first frame");
writer.flush();

// Signal to stop reading.
writer.write("|");
writer.flush();

// Lots of data after we have stopped reading and onMessage exits.
final String largePayload = StringUtil.stringFrom("x", serverInputBufferSize * 2);
writer.write(largePayload);
writer.close();

session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertTrue(serverEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertNull(clientEndpoint.error.get());
assertNull(serverEndpoint.error.get());

String msg = serverEndpoint.messages.poll(5, TimeUnit.SECONDS);
assertThat(msg, Matchers.is("first frame"));
}

public static class ClientTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
{
private final CountDownLatch latch = new CountDownLatch(1);
private final StringBuilder output = new StringBuilder();

@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
super.onOpen(session, config);
}

@Override
public void onMessage(Reader input)
{
try
{
while (true)
{
int read = input.read();
if (read < 0)
break;
output.append((char)read);
}
latch.countDown();
}
catch (IOException e)
{
throw new RuntimeException(e);
}
}
}

@ServerEndpoint("/echo")
public static class ServerTextStreamer
{
Expand All @@ -166,4 +310,59 @@ public void echo(Session session, Reader input) throws IOException
}
}
}

public static class QueuedTextStreamer extends WSEndpointTracker implements MessageHandler.Whole<Reader>
{
protected BlockingArrayQueue<String> messages = new BlockingArrayQueue<>();

@Override
public void onOpen(Session session, EndpointConfig config)
{
session.addMessageHandler(this);
super.onOpen(session, config);
serverEndpoints.add(this);
}

@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));
messages.add(IO.toString(input));
}
catch (Exception e)
{
e.printStackTrace();
}
}
}

public static class QueuedPartialTextStreamer extends QueuedTextStreamer
{
@Override
public void onMessage(Reader input)
{
try
{
Thread.sleep(Math.abs(new Random().nextLong() % 200));

// Do not read to EOF but just the first '|'.
StringWriter writer = new StringWriter();
while (true)
{
int read = input.read();
if (read < 0 || read == '|')
break;
writer.write(read);
}

messages.add(writer.toString());
}
catch (Exception e)
{
e.printStackTrace();
}
}
}
}
Loading

0 comments on commit b1d30fc

Please sign in to comment.