Skip to content

Commit

Permalink
Add WebSocketService (#3904)
Browse files Browse the repository at this point in the history
Motivation:
It would be nice if we can provide a service that supports [The
WebSocket Protocol](https://datatracker.ietf.org/doc/html/rfc6455)

Modifications:
- Add `WebSocketService` and `WebSocketHandler` to implement the
WebSocket service using back pressure.
- Add `WebSocketFrame` that represents the web socket frames.
- Add `RequestLogBuilder.responseCause(cause)` to set the cause while
sending a normal response.
- Forked `HttpServerCodec` from netty to support WebSocket upgrade.
- Forked WebSocker encoder and decoder from Netty to use it without
channels.

Result:
- You can now use `WebSocketService` to send WebSocket messages.
  ```
  ServerBuilder sb = ...
  WebSocketHandler backpressureHandler = (ctx, messages) -> { 
      WebSocketWriter webSocketWriter = WebSocket.streaming();
      // Write frames using back pressure.
      return webSocketWriter;
  };
  sb.service("/chat", WebSocketService.of(backpressureHandler));
  ```

Todos:
- Provide another abstract class for the `WebSocketHandler` so that a
user can implement WebSocket easily without considering back pressure.
(next PR)
- Support WebSocketClient. (next PR)
- Support extensions. (next PR)
  • Loading branch information
minwoox authored May 18, 2023
1 parent e15c7b7 commit b78d951
Show file tree
Hide file tree
Showing 76 changed files with 5,432 additions and 714 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,25 @@ public class RoutersBenchmark {
new ServiceConfig(route1, route1,
SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0,
false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(),
SuccessFunction.always(), multipartUploadsLocation, ImmutableList.of(),
SuccessFunction.always(), 0, multipartUploadsLocation, ImmutableList.of(),
HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler),
new ServiceConfig(route2, route2,
SERVICE, defaultLogName, defaultServiceName, defaultServiceNaming, 0, 0,
false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(),
SuccessFunction.always(), multipartUploadsLocation, ImmutableList.of(),
SuccessFunction.always(), 0, multipartUploadsLocation, ImmutableList.of(),
HttpHeaders.of(), ctx -> RequestId.random(), serviceErrorHandler));
FALLBACK_SERVICE = new ServiceConfig(Route.ofCatchAll(), Route.ofCatchAll(), SERVICE,
defaultLogName, defaultServiceName,
defaultServiceNaming, 0, 0, false, AccessLogWriter.disabled(),
CommonPools.blockingTaskExecutor(),
SuccessFunction.always(), multipartUploadsLocation,
SuccessFunction.always(), 0, multipartUploadsLocation,
ImmutableList.of(), HttpHeaders.of(), ctx -> RequestId.random(),
serviceErrorHandler);
HOST = new VirtualHost(
"localhost", "localhost", 0, null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED,
unused -> NOPLogger.NOP_LOGGER, defaultServiceNaming, 0, 0, false,
AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), multipartUploadsLocation,
AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), 0,
multipartUploadsLocation,
ImmutableList.of(),
ctx -> RequestId.random());
ROUTER = Routers.ofVirtualHost(HOST, SERVICES, RejectedRouteHandler.DISABLED);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ public ClientFactoryBuilder proxyConfig(ProxyConfigSelector proxyConfigSelector)
/**
* Sets the {@link Http1HeaderNaming} which converts a lower-cased HTTP/2 header name into
* another HTTP/1 header name. This is useful when communicating with a legacy system that only supports
* case sensitive HTTP/1 headers.
* case-sensitive HTTP/1 headers.
*/
public ClientFactoryBuilder http1HeaderNaming(Http1HeaderNaming http1HeaderNaming) {
requireNonNull(http1HeaderNaming, "http1HeaderNaming");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ public ProxyConfigSelector proxyConfigSelector() {
/**
* Returns the {@link Http1HeaderNaming} which converts a lower-cased HTTP/2 header name into
* another header name. This is useful when communicating with a legacy system that only supports
* case sensitive HTTP/1 headers.
* case-sensitive HTTP/1 headers.
*/
public Http1HeaderNaming http1HeaderNaming() {
return get(HTTP1_HEADER_NAMING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ private HttpObject convertHeaders(RequestHeaders headers, boolean endStream) {
return req;
}

@Override
protected ChannelFuture write(HttpObject obj, ChannelPromise promise) {
return channel().write(obj, promise);
}

@Override
protected void convertTrailers(HttpHeaders inputHeaders,
io.netty.handler.codec.http.HttpHeaders outputHeaders) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,31 @@
package com.linecorp.armeria.common;

import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.websocket.WebSocketFrame;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;

/**
* Specifies the way a {@link ByteBuf} is retrieved from an {@link HttpData}.
* Specifies the way a {@link ByteBuf} is retrieved from a {@link Bytes}, such as {@link HttpData} or
* {@link WebSocketFrame}.
*/
@UnstableApi
public enum ByteBufAccessMode {
/**
* Gets the duplicate (or slice) of the underlying {@link ByteBuf}. This mode is useful when you access
* the {@link ByteBuf} within the life cycle of the {@link HttpData}:
* the {@link ByteBuf} within the life cycle of the {@link Bytes}, such as {@link HttpData} or
* {@link WebSocketFrame}:
* <pre>{@code
* try (HttpData content = ...) {
* ByteBuf buf = content.byteBuf(ByteBufAccessMode.DUPLICATE);
* // Read something from 'buf' here.
* }
* // WebSocket frame.
* try (WebSocketFrame frame = ...) {
* ByteBuf buf = frame.byteBuf(ByteBufAccessMode.DUPLICATE);
* // Read something from 'buf' here.
* }
* }</pre>
*
* @see ByteBuf#duplicate()
Expand All @@ -41,11 +49,16 @@ public enum ByteBufAccessMode {
DUPLICATE,
/**
* Gets the retained duplicate (or slice) of the underlying {@link ByteBuf}. This mode is useful when
* you access the {@link ByteBuf} beyond the life cycle of the {@link HttpData}, such as creating
* another {@link HttpData} that shares the {@link ByteBuf}'s memory region:
* you access the {@link ByteBuf} beyond the life cycle of the {@link Bytes}, such as {@link HttpData} or
* {@link WebSocketFrame}, by creating another {@link Bytes} that shares the {@link ByteBuf}'s
* memory region:
* <pre>{@code
* HttpData data1 = HttpData.wrap(byteBuf);
* HttpData data2 = HttpData.wrap(data1.byteBuf(ByteBufAccessMode.RETAINED_DUPLICATE));
*
* WebSocketFrame binaryFrame1 = WebSocketFrame.ofPooledBinary(byteBuf);
* WebSocketFrame binaryFrame2 = WebSocketFrame.ofPooledBinary(
* binaryFrame1.byteBuf(ByteBufAccessMode.RETAINED_DUPLICATE));
* }</pre>
*
* @see ByteBuf#retainedDuplicate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
/**
* A {@link RuntimeException} raised when a remote peer violated the current {@link SessionProtocol}.
*/
public final class ProtocolViolationException extends RuntimeException {
public class ProtocolViolationException extends RuntimeException {

private static final long serialVersionUID = 4674394621849790490L;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ public boolean isHttps() {
return HTTPS_VALUES.contains(this);
}

/**
* Returns {@code true} if this {@link SessionProtocol} is {@link #H1} or {@link #H1C}.
* Note that this method returns {@code false} for {@link #HTTP} and {@link #HTTPS}.
*/
public boolean isExplicitHttp1() {
return this == H1 || this == H1C;
}

/**
* Returns {@code true} if this {@link SessionProtocol} is {@link #H2} or {@link #H2C}.
* Note that this method returns {@code false} for {@link #HTTP} and {@link #HTTPS}.
*/
public boolean isExplicitHttp2() {
return this == H2 || this == H2C;
}

/**
* Returns {@code true} if and only if this protocol uses TLS as its transport-level security layer.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,17 @@ public Throwable responseCause() {
return responseCause;
}

@Override
public void responseCause(Throwable cause) {
if (isAvailable(RequestLogProperty.RESPONSE_CAUSE)) {
return;
}

requireNonNull(cause, "cause");
responseCause = cause;
updateFlags(RequestLogProperty.RESPONSE_CAUSE);
}

@Override
public long responseLength() {
ensureAvailable(RequestLogProperty.RESPONSE_LENGTH);
Expand Down Expand Up @@ -1281,9 +1292,9 @@ public void responseContent(@Nullable Object responseContent, @Nullable Object r
if (!rpcResponse.isDone()) {
throw new IllegalArgumentException("responseContent must be complete: " + responseContent);
}
if (rpcResponse.cause() != null) {
responseCause = rpcResponse.cause();
updateFlags(RequestLogProperty.RESPONSE_CAUSE);
final Throwable cause = rpcResponse.cause();
if (cause != null) {
responseCause(cause);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,14 @@ void session(@Nullable Channel channel, SessionProtocol sessionProtocol, @Nullab
*/
void responseTrailers(HttpHeaders responseTrailers);

/**
* Sets the {@link RequestLog#responseCause()} without completing the response log.
* This method may be useful if you want to send additional data even after an exception is raised.
* If you want to end the response log right away when an exception is raised,
* please use {@link #endResponse(Throwable)}.
*/
void responseCause(Throwable cause);

/**
* Finishes the collection of the {@link Response} information. If a {@link Throwable} cause has been set
* with {@link #responseContent(Object, Object)}, it will be treated as the {@code responseCause} for this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ private void handleCloseEvent(SubscriptionImpl subscription, CloseEvent o) {
}

@Override
public final void close() {
public void close() {
if (setState(State.OPEN, State.CLOSED)) {
addObjectOrEvent(SUCCESSFUL_CLOSE);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2022 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.common.websocket;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;

import com.google.common.base.MoreObjects;

import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.common.ByteArrayBytes;

class ByteArrayWebSocketFrame extends ByteArrayBytes implements WebSocketFrame {

private static final byte[] EMPTY_BYTES = {};

static final WebSocketFrame EMPTY_PING = new ByteArrayWebSocketFrame(EMPTY_BYTES, WebSocketFrameType.PING);
static final WebSocketFrame EMPTY_PONG = new ByteArrayWebSocketFrame(EMPTY_BYTES, WebSocketFrameType.PONG);

private final WebSocketFrameType type;
private final boolean finalFragment;

@Nullable
private String text;

ByteArrayWebSocketFrame(byte[] array, WebSocketFrameType type) {
this(array, type, true);
}

ByteArrayWebSocketFrame(byte[] array, WebSocketFrameType type, boolean finalFragment) {
this(array, type, finalFragment, null);
}

ByteArrayWebSocketFrame(byte[] array, WebSocketFrameType type,
boolean finalFragment, @Nullable String text) {
super(array);
this.type = type;
this.finalFragment = finalFragment;
this.text = text;
}

@Override
public WebSocketFrameType type() {
return type;
}

@Override
public boolean isFinalFragment() {
return finalFragment;
}

@Override
public String text() {
if (text != null) {
return text;
}
return text = toString(StandardCharsets.UTF_8);
}

@Override
public int hashCode() {
return (super.hashCode() * 31 + type.hashCode()) * 31 + Boolean.hashCode(finalFragment);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof WebSocketFrame)) {
return false;
}

final WebSocketFrame that = (WebSocketFrame) o;
if (length() != that.length()) {
return false;
}

return type == that.type() &&
finalFragment == that.isFinalFragment() &&
Arrays.equals(array(), that.array());
}

@Override
public String toString() {
return toString(super.toString());
}

private String toString(String bytes) {
return MoreObjects.toStringHelper(this).omitNullValues()
.add("type", type)
.add("finalFragment", finalFragment)
.add("bytes", bytes)
.toString();
}
}
Loading

0 comments on commit b78d951

Please sign in to comment.