diff --git a/consul/src/main/java/com/linecorp/armeria/client/consul/ConsulEndpointGroup.java b/consul/src/main/java/com/linecorp/armeria/client/consul/ConsulEndpointGroup.java index 76632bec3b3..75685681d38 100644 --- a/consul/src/main/java/com/linecorp/armeria/client/consul/ConsulEndpointGroup.java +++ b/consul/src/main/java/com/linecorp/armeria/client/consul/ConsulEndpointGroup.java @@ -142,10 +142,14 @@ protected void doCloseAsync(CompletableFuture future) { @Override public String toString() { - return toStringHelper() - .add("serviceName", serviceName) - .add("datacenter", datacenter) - .add("filter", filter) - .toString(); + return toString(buf -> { + buf.append(", serviceName=").append(serviceName); + if (datacenter != null) { + buf.append(", datacenter=").append(datacenter); + } + if (filter != null) { + buf.append(", filter=").append(filter); + } + }); } } diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java index 45c7dd3251b..0ece3e2a9bc 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestHandler.java @@ -26,7 +26,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.linecorp.armeria.client.HttpResponseDecoder.HttpResponseWrapper; import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpHeaderNames; @@ -59,6 +58,7 @@ abstract class AbstractHttpRequestHandler implements ChannelFutureListener { enum State { NEEDS_TO_WRITE_FIRST_HEADER, + NEEDS_DATA, NEEDS_DATA_OR_TRAILERS, DONE } @@ -71,6 +71,8 @@ enum State { private final RequestLogBuilder logBuilder; private final long timeoutMillis; private final boolean headersOnly; + private final boolean allowTrailers; + private final boolean keepAlive; // session, id and responseWrapper are assigned in tryInitialize() @Nullable @@ -86,7 +88,8 @@ enum State { AbstractHttpRequestHandler(Channel ch, ClientHttpObjectEncoder encoder, HttpResponseDecoder responseDecoder, DecodedHttpResponse originalRes, - ClientRequestContext ctx, long timeoutMillis, boolean headersOnly) { + ClientRequestContext ctx, long timeoutMillis, boolean headersOnly, + boolean allowTrailers, boolean keepAlive) { this.ch = ch; this.encoder = encoder; this.responseDecoder = responseDecoder; @@ -95,6 +98,8 @@ enum State { logBuilder = ctx.logBuilder(); this.timeoutMillis = timeoutMillis; this.headersOnly = headersOnly; + this.allowTrailers = allowTrailers; + this.keepAlive = keepAlive; } abstract void onWriteSuccess(); @@ -169,7 +174,7 @@ final boolean tryInitialize() { } this.session = session; - addResponseToDecoder(); + responseWrapper = responseDecoder.addResponse(id, originalRes, ctx, ch.eventLoop()); if (timeoutMillis > 0) { // The timer would be executed if the first message has not been sent out within the timeout. @@ -180,13 +185,6 @@ final boolean tryInitialize() { return true; } - private void addResponseToDecoder() { - final long responseTimeoutMillis = ctx.responseTimeoutMillis(); - final long maxContentLength = ctx.maxResponseLength(); - responseWrapper = responseDecoder.addResponse(id, originalRes, ctx, - ch.eventLoop(), responseTimeoutMillis, maxContentLength); - } - /** * Writes the {@link RequestHeaders} to the {@link Channel}. * The {@link RequestHeaders} is merged with {@link ClientRequestContext#additionalRequestHeaders()} @@ -199,8 +197,10 @@ final void writeHeaders(RequestHeaders headers) { assert protocol != null; if (headersOnly) { state = State.DONE; - } else { + } else if (allowTrailers) { state = State.NEEDS_DATA_OR_TRAILERS; + } else { + state = State.NEEDS_DATA; } final HttpHeaders internalHeaders; @@ -215,7 +215,7 @@ final void writeHeaders(RequestHeaders headers) { logBuilder.requestHeaders(merged); final String connectionOption = headers.get(HttpHeaderNames.CONNECTION); - if (CLOSE_STRING.equalsIgnoreCase(connectionOption)) { + if (CLOSE_STRING.equalsIgnoreCase(connectionOption) || !keepAlive) { // Make the session unhealthy so that subsequent requests do not use it. // In HTTP/2 request, the "Connection: close" is just interpreted as a signal to close the // connection by sending a GOAWAY frame that will be sent after receiving the corresponding diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestSubscriber.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestSubscriber.java new file mode 100644 index 00000000000..4ff251be998 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpRequestSubscriber.java @@ -0,0 +1,129 @@ +/* + * Copyright 2016 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.client.DecodedHttpResponse; + +import io.netty.channel.Channel; + +abstract class AbstractHttpRequestSubscriber extends AbstractHttpRequestHandler + implements Subscriber { + + private static final HttpData EMPTY_EOS = HttpData.empty().withEndOfStream(); + + static AbstractHttpRequestSubscriber of(Channel channel, ClientHttpObjectEncoder requestEncoder, + HttpResponseDecoder responseDecoder, SessionProtocol protocol, + ClientRequestContext ctx, HttpRequest req, + DecodedHttpResponse res, long writeTimeoutMillis, + boolean webSocket) { + if (webSocket) { + if (protocol.isExplicitHttp1()) { + return new WebSocketHttp1RequestSubscriber( + channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis); + } + assert protocol.isExplicitHttp2(); + return new WebSocketHttp2RequestSubscriber( + channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis); + } + return new HttpRequestSubscriber( + channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis); + } + + private final HttpRequest request; + + @Nullable + private Subscription subscription; + private boolean isSubscriptionCompleted; + + AbstractHttpRequestSubscriber(Channel ch, ClientHttpObjectEncoder encoder, + HttpResponseDecoder responseDecoder, + HttpRequest request, DecodedHttpResponse originalRes, + ClientRequestContext ctx, long timeoutMillis, boolean allowTrailers, + boolean keepAlive) { + super(ch, encoder, responseDecoder, originalRes, ctx, timeoutMillis, request.isEmpty(), allowTrailers, + keepAlive); + this.request = request; + } + + @Override + public void onSubscribe(Subscription subscription) { + assert this.subscription == null; + this.subscription = subscription; + if (state() == State.DONE) { + cancel(); + return; + } + + if (!tryInitialize()) { + return; + } + + // NB: This must be invoked at the end of this method because otherwise the callback methods in this + // class can be called before the member fields (subscription, id, responseWrapper and + // timeoutFuture) are initialized. + // It is because the successful write of the first headers will trigger subscription.request(1). + writeHeaders(mapHeaders(request.headers())); + channel().flush(); + } + + RequestHeaders mapHeaders(RequestHeaders headers) { + return headers; + } + + @Override + public void onError(Throwable cause) { + isSubscriptionCompleted = true; + failRequest(cause); + } + + @Override + public void onComplete() { + isSubscriptionCompleted = true; + + if (state() != State.DONE) { + writeData(EMPTY_EOS); + channel().flush(); + } + } + + @Override + void onWriteSuccess() { + // Request more messages regardless whether the state is DONE. It makes the producer have + // a chance to produce the last call such as 'onComplete' and 'onError' when there are + // no more messages it can produce. + if (!isSubscriptionCompleted) { + assert subscription != null; + subscription.request(1); + } + } + + @Override + void cancel() { + isSubscriptionCompleted = true; + assert subscription != null; + subscription.cancel(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractHttpResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpResponseDecoder.java new file mode 100644 index 00000000000..93a16941f0d --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractHttpResponseDecoder.java @@ -0,0 +1,162 @@ +/* + * Copyright 2016 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import java.util.Iterator; + +import com.linecorp.armeria.common.ContentTooLargeException; +import com.linecorp.armeria.common.ContentTooLargeExceptionBuilder; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.client.DecodedHttpResponse; +import com.linecorp.armeria.internal.client.HttpSession; +import com.linecorp.armeria.internal.common.InboundTrafficController; +import com.linecorp.armeria.internal.common.KeepAliveHandler; + +import io.netty.channel.Channel; +import io.netty.channel.EventLoop; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; + +abstract class AbstractHttpResponseDecoder implements HttpResponseDecoder { + + private final IntObjectMap responses = new IntObjectHashMap<>(); + private final Channel channel; + private final InboundTrafficController inboundTrafficController; + + @Nullable + private HttpSession httpSession; + + private int unfinishedResponses; + private boolean closing; + + AbstractHttpResponseDecoder(Channel channel, InboundTrafficController inboundTrafficController) { + this.channel = channel; + this.inboundTrafficController = inboundTrafficController; + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public InboundTrafficController inboundTrafficController() { + return inboundTrafficController; + } + + @Override + public HttpResponseWrapper addResponse( + int id, DecodedHttpResponse res, ClientRequestContext ctx, EventLoop eventLoop) { + final HttpResponseWrapper newRes = + new HttpResponseWrapper(res, eventLoop, ctx, + ctx.responseTimeoutMillis(), ctx.maxResponseLength()); + final HttpResponseWrapper oldRes = responses.put(id, newRes); + final KeepAliveHandler keepAliveHandler = keepAliveHandler(); + if (keepAliveHandler != null) { + keepAliveHandler.increaseNumRequests(); + } + + assert oldRes == null : "addResponse(" + id + ", " + res + ", " + ctx + "): " + oldRes; + onResponseAdded(id, eventLoop, newRes); + return newRes; + } + + abstract void onResponseAdded(int id, EventLoop eventLoop, HttpResponseWrapper responseWrapper); + + @Nullable + @Override + public HttpResponseWrapper getResponse(int id) { + return responses.get(id); + } + + @Nullable + @Override + public HttpResponseWrapper removeResponse(int id) { + if (closing) { + // `unfinishedResponses` will be removed by `failUnfinishedResponses()` + return null; + } + + final HttpResponseWrapper removed = responses.remove(id); + if (removed != null) { + unfinishedResponses--; + assert unfinishedResponses >= 0 : unfinishedResponses; + } + return removed; + } + + @Override + public boolean hasUnfinishedResponses() { + return unfinishedResponses != 0; + } + + @Override + public boolean reserveUnfinishedResponse(int maxUnfinishedResponses) { + if (unfinishedResponses >= maxUnfinishedResponses) { + return false; + } + + unfinishedResponses++; + return true; + } + + @Override + public void decrementUnfinishedResponses() { + unfinishedResponses--; + } + + @Override + public void failUnfinishedResponses(Throwable cause) { + if (closing) { + return; + } + closing = true; + + for (final Iterator iterator = responses.values().iterator(); + iterator.hasNext();) { + final HttpResponseWrapper res = iterator.next(); + // To avoid calling removeResponse by res.close(cause), remove before closing. + iterator.remove(); + unfinishedResponses--; + res.close(cause); + } + } + + @Override + public HttpSession session() { + if (httpSession != null) { + return httpSession; + } + return httpSession = HttpSession.get(channel); + } + + @Override + public boolean needsToDisconnectNow() { + return !session().isAcquirable() && !hasUnfinishedResponses(); + } + + static ContentTooLargeException contentTooLargeException(HttpResponseWrapper res, long transferred) { + final ContentTooLargeExceptionBuilder builder = + ContentTooLargeException.builder() + .maxContentLength(res.maxContentLength()) + .transferred(transferred); + if (res.contentLengthHeaderValue() >= 0) { + builder.contentLength(res.contentLengthHeaderValue()); + } + return builder.build(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java index b22f40d2b52..04ca8523611 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java @@ -16,15 +16,13 @@ package com.linecorp.armeria.client; import static com.google.common.base.Preconditions.checkArgument; +import static com.linecorp.armeria.common.SessionProtocol.httpAndHttpsValues; +import static com.linecorp.armeria.internal.client.ClientUtil.UNDEFINED_URI; import static java.util.Objects.requireNonNull; import java.net.URI; -import java.util.Set; import java.util.function.Function; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; - import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SerializationFormat; @@ -36,18 +34,6 @@ */ public abstract class AbstractWebClientBuilder extends AbstractClientOptionsBuilder { - /** - * An undefined {@link URI} to create {@link WebClient} without specifying {@link URI}. - */ - static final URI UNDEFINED_URI = URI.create("http://undefined"); - - private static final Set SUPPORTED_PROTOCOLS = - Sets.immutableEnumSet( - ImmutableList.builder() - .addAll(SessionProtocol.httpValues()) - .addAll(SessionProtocol.httpsValues()) - .build()); - @Nullable private final URI uri; @Nullable @@ -61,10 +47,7 @@ public abstract class AbstractWebClientBuilder extends AbstractClientOptionsBuil * Creates a new instance. */ protected AbstractWebClientBuilder() { - uri = UNDEFINED_URI; - scheme = null; - endpointGroup = null; - path = null; + this(UNDEFINED_URI, null, null, null); } /** @@ -74,24 +57,7 @@ protected AbstractWebClientBuilder() { * in {@link SessionProtocol} */ protected AbstractWebClientBuilder(URI uri) { - if (Clients.isUndefinedUri(uri)) { - this.uri = uri; - } else { - final String givenScheme = requireNonNull(uri, "uri").getScheme(); - final Scheme scheme = validateScheme(givenScheme); - if (scheme.uriText().equals(givenScheme)) { - // No need to replace the user-specified scheme because it's already in its normalized form. - this.uri = uri; - } else { - // Replace the user-specified scheme with the normalized one. - // e.g. http://foo.com/ -> none+http://foo.com/ - this.uri = URI.create(scheme.uriText() + - uri.toString().substring(givenScheme.length())); - } - } - scheme = null; - endpointGroup = null; - path = null; + this(validateUri(uri), null, null, null); } /** @@ -102,29 +68,65 @@ protected AbstractWebClientBuilder(URI uri) { */ protected AbstractWebClientBuilder(SessionProtocol sessionProtocol, EndpointGroup endpointGroup, @Nullable String path) { - validateScheme(requireNonNull(sessionProtocol, "sessionProtocol").uriText()); - if (path != null) { - checkArgument(path.startsWith("/"), - "path: %s (expected: an absolute path starting with '/')", path); + this(null, validateSessionProtocol(sessionProtocol), + requireNonNull(endpointGroup, "endpointGroup"), path); + } + + /** + * Creates a new instance. + */ + protected AbstractWebClientBuilder(@Nullable URI uri, @Nullable Scheme scheme, + @Nullable EndpointGroup endpointGroup, @Nullable String path) { + assert uri != null || (scheme != null && endpointGroup != null); + assert path == null || uri == null; + this.uri = uri; + this.scheme = scheme; + this.endpointGroup = endpointGroup; + this.path = validatePath(path); + } + + private static URI validateUri(URI uri) { + requireNonNull(uri, "uri"); + if (Clients.isUndefinedUri(uri)) { + return uri; + } + final String givenScheme = requireNonNull(uri, "uri").getScheme(); + final Scheme scheme = validateScheme(givenScheme); + if (scheme.uriText().equals(givenScheme)) { + // No need to replace the user-specified scheme because it's already in its normalized form. + return uri; } + // Replace the user-specified scheme with the normalized one. + // e.g. http://foo.com/ -> none+http://foo.com/ + return URI.create(scheme.uriText() + uri.toString().substring(givenScheme.length())); + } - uri = null; - scheme = Scheme.of(SerializationFormat.NONE, sessionProtocol); - this.endpointGroup = requireNonNull(endpointGroup, "endpointGroup"); - this.path = path; + private static Scheme validateSessionProtocol(SessionProtocol sessionProtocol) { + requireNonNull(sessionProtocol, "sessionProtocol"); + validateScheme(sessionProtocol.uriText()); + return Scheme.of(SerializationFormat.NONE, sessionProtocol); } private static Scheme validateScheme(String scheme) { final Scheme parsedScheme = Scheme.tryParse(scheme); if (parsedScheme != null) { if (parsedScheme.serializationFormat() == SerializationFormat.NONE && - SUPPORTED_PROTOCOLS.contains(parsedScheme.sessionProtocol())) { + httpAndHttpsValues().contains(parsedScheme.sessionProtocol())) { return parsedScheme; } } - throw new IllegalArgumentException("scheme : " + scheme + - " (expected: one of " + SUPPORTED_PROTOCOLS + ')'); + throw new IllegalArgumentException("scheme: " + scheme + + " (expected: one of " + httpAndHttpsValues() + ')'); + } + + @Nullable + private static String validatePath(@Nullable String path) { + if (path != null) { + checkArgument(path.startsWith("/"), + "path: %s (expected: an absolute path starting with '/')", path); + } + return path; } /** diff --git a/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java b/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java index 50a4be80265..bc90fc4561c 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/AggregatedHttpRequestHandler.java @@ -37,7 +37,7 @@ final class AggregatedHttpRequestHandler extends AbstractHttpRequestHandler HttpResponseDecoder responseDecoder, HttpRequest request, DecodedHttpResponse originalRes, ClientRequestContext ctx, long timeoutMillis) { - super(ch, encoder, responseDecoder, originalRes, ctx, timeoutMillis, request.isEmpty()); + super(ch, encoder, responseDecoder, originalRes, ctx, timeoutMillis, request.isEmpty(), true, true); } @Override diff --git a/core/src/main/java/com/linecorp/armeria/client/BlockingWebClientRequestPreparation.java b/core/src/main/java/com/linecorp/armeria/client/BlockingWebClientRequestPreparation.java index 302c91dbbd7..014e5deb1bc 100644 --- a/core/src/main/java/com/linecorp/armeria/client/BlockingWebClientRequestPreparation.java +++ b/core/src/main/java/com/linecorp/armeria/client/BlockingWebClientRequestPreparation.java @@ -336,6 +336,12 @@ public BlockingWebClientRequestPreparation content(MediaType contentType, HttpDa return this; } + @Override + public BlockingWebClientRequestPreparation content(Publisher content) { + delegate.content(content); + return this; + } + @Override public BlockingWebClientRequestPreparation content(MediaType contentType, Publisher content) { diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientHttp1ObjectEncoder.java b/core/src/main/java/com/linecorp/armeria/client/ClientHttp1ObjectEncoder.java index a80b7b4b9e3..3ebb1e14a09 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientHttp1ObjectEncoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientHttp1ObjectEncoder.java @@ -42,12 +42,14 @@ final class ClientHttp1ObjectEncoder extends Http1ObjectEncoder implements Clien private final Http1HeaderNaming http1HeaderNaming; private final KeepAliveHandler keepAliveHandler; + private final boolean webSocket; ClientHttp1ObjectEncoder(Channel ch, SessionProtocol protocol, Http1HeaderNaming http1HeaderNaming, - KeepAliveHandler keepAliveHandler) { + KeepAliveHandler keepAliveHandler, boolean webSocket) { super(ch, protocol); this.http1HeaderNaming = http1HeaderNaming; this.keepAliveHandler = keepAliveHandler; + this.webSocket = webSocket; } @Override @@ -71,6 +73,12 @@ private HttpObject convertHeaders(RequestHeaders headers, boolean endStream) { protocol().defaultPort())); } + if (webSocket) { + nettyHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING); + nettyHeaders.remove(HttpHeaderNames.CONTENT_LENGTH); + return req; + } + if (endStream) { nettyHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING); diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java b/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java index 84719264ce7..e422acd9c33 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientOptions.java @@ -82,6 +82,16 @@ public final class ClientOptions ClientOption.define("REQUEST_AUTO_ABORT_DELAY_MILLIS", Flags.defaultRequestAutoAbortDelayMillis()); + /** + * Whether to add an {@link HttpHeaderNames#ORIGIN} header automatically when sending + * an {@link HttpRequest} when the {@link HttpRequest#headers()} does not have it. + * + * @see The Web Origin Concept + */ + @UnstableApi + public static final ClientOption AUTO_FILL_ORIGIN_HEADER = + ClientOption.define("AUTO_FILL_ORIGIN_HEADER", false); // TODO(minwoox): Add to Flags + /** * The redirect configuration. */ @@ -306,6 +316,15 @@ public long requestAutoAbortDelayMillis() { return get(REQUEST_AUTO_ABORT_DELAY_MILLIS); } + /** + * Returns whether to add an {@link HttpHeaderNames#ORIGIN} header automatically when sending + * an {@link HttpRequest} when the {@link HttpRequest#headers()} does not have it. + */ + @UnstableApi + public boolean autoFillOriginHeader() { + return get(AUTO_FILL_ORIGIN_HEADER); + } + /** * Returns the {@link RedirectConfig}. */ diff --git a/core/src/main/java/com/linecorp/armeria/client/Clients.java b/core/src/main/java/com/linecorp/armeria/client/Clients.java index 9765583bac6..ee54d9f4ee8 100644 --- a/core/src/main/java/com/linecorp/armeria/client/Clients.java +++ b/core/src/main/java/com/linecorp/armeria/client/Clients.java @@ -15,6 +15,7 @@ */ package com.linecorp.armeria.client; +import static com.linecorp.armeria.internal.client.ClientUtil.UNDEFINED_URI; import static java.util.Objects.requireNonNull; import java.net.URI; @@ -603,7 +604,7 @@ public static ClientRequestContextCaptor newContextCaptor() { * {@code isUndefinedUri(WebClient.of().uri())} will return {@code true}. */ public static boolean isUndefinedUri(URI uri) { - return uri == AbstractWebClientBuilder.UNDEFINED_URI; + return uri == UNDEFINED_URI; } private Clients() {} diff --git a/core/src/main/java/com/linecorp/armeria/client/FutureTransformingRequestPreparation.java b/core/src/main/java/com/linecorp/armeria/client/FutureTransformingRequestPreparation.java index f4ff35d42d0..633ee0b4180 100644 --- a/core/src/main/java/com/linecorp/armeria/client/FutureTransformingRequestPreparation.java +++ b/core/src/main/java/com/linecorp/armeria/client/FutureTransformingRequestPreparation.java @@ -274,6 +274,12 @@ public FutureTransformingRequestPreparation content(MediaType contentType, Ht return this; } + @Override + public FutureTransformingRequestPreparation content(Publisher content) { + delegate.content(content); + return this; + } + @Override public FutureTransformingRequestPreparation content(MediaType contentType, Publisher content) { diff --git a/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java index 508e30c39be..c341df6e5e2 100644 --- a/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/Http1ResponseDecoder.java @@ -16,21 +16,30 @@ package com.linecorp.armeria.client; +import static com.linecorp.armeria.internal.common.KeepAliveHandlerUtil.needsKeepAliveHandler; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.collect.ImmutableList; import com.google.common.math.LongMath; import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpStatusClass; import com.linecorp.armeria.common.ProtocolViolationException; import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.metric.MoreMeters; import com.linecorp.armeria.internal.common.ArmeriaHttpUtil; import com.linecorp.armeria.internal.common.InboundTrafficController; import com.linecorp.armeria.internal.common.KeepAliveHandler; +import com.linecorp.armeria.internal.common.NoopKeepAliveHandler; import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Timer; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; @@ -41,12 +50,11 @@ import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpStatusClass; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.LastHttpContent; import io.netty.util.ReferenceCountUtil; -final class Http1ResponseDecoder extends HttpResponseDecoder implements ChannelInboundHandler { +final class Http1ResponseDecoder extends AbstractHttpResponseDecoder implements ChannelInboundHandler { private static final Logger logger = LoggerFactory.getLogger(Http1ResponseDecoder.class); @@ -60,14 +68,34 @@ private enum State { /** The response being decoded currently. */ @Nullable private HttpResponseWrapper res; - @Nullable - private KeepAliveHandler keepAliveHandler; + private final KeepAliveHandler keepAliveHandler; private int resId = 1; private int lastPingReqId = -1; private State state = State.NEED_HEADERS; - Http1ResponseDecoder(Channel channel) { + Http1ResponseDecoder(Channel channel, HttpClientFactory clientFactory, SessionProtocol protocol) { super(channel, InboundTrafficController.ofHttp1(channel)); + final long idleTimeoutMillis = clientFactory.idleTimeoutMillis(); + final long pingIntervalMillis = clientFactory.pingIntervalMillis(); + final long maxConnectionAgeMillis = clientFactory.maxConnectionAgeMillis(); + final int maxNumRequestsPerConnection = clientFactory.maxNumRequestsPerConnection(); + final boolean keepAliveOnPing = clientFactory.keepAliveOnPing(); + final boolean needsKeepAliveHandler = + needsKeepAliveHandler(idleTimeoutMillis, pingIntervalMillis, + maxConnectionAgeMillis, maxNumRequestsPerConnection); + + if (needsKeepAliveHandler) { + final Timer keepAliveTimer = + MoreMeters.newTimer(clientFactory.meterRegistry(), + "armeria.client.connections.lifespan", + ImmutableList.of(Tag.of("protocol", protocol.uriText()))); + keepAliveHandler = new Http1ClientKeepAliveHandler( + channel, this, keepAliveTimer, idleTimeoutMillis, + pingIntervalMillis, maxConnectionAgeMillis, maxNumRequestsPerConnection, + keepAliveOnPing); + } else { + keepAliveHandler = new NoopKeepAliveHandler(); + } } @Override @@ -100,7 +128,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { @Override public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - destroyKeepAliveHandler(); + keepAliveHandler.destroy(); } @Override @@ -125,7 +153,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { if (res != null) { res.close(ClosedSessionException.get()); } - destroyKeepAliveHandler(); + keepAliveHandler.destroy(); ctx.fireChannelInactive(); } @@ -141,6 +169,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception ReferenceCountUtil.release(msg); return; } + keepAliveHandler.onReadOrWrite(); try { switch (state) { @@ -168,7 +197,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception res.startResponse(); final ResponseHeaders responseHeaders = ArmeriaHttpUtil.toArmeria(nettyRes); final boolean written; - if (nettyRes.status().codeClass() == HttpStatusClass.INFORMATIONAL) { + if (responseHeaders.status().codeClass() == HttpStatusClass.INFORMATIONAL) { state = State.NEED_INFORMATIONAL_DATA; written = res.tryWrite(responseHeaders); } else { @@ -250,6 +279,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } private void failWithUnexpectedMessageType(ChannelHandlerContext ctx, Object msg, Class expected) { + final String message; try (TemporaryThreadLocals tempThreadLocals = TemporaryThreadLocals.acquire()) { final StringBuilder buf = tempThreadLocals.stringBuilder(); buf.append("unexpected message type: " + msg.getClass().getName() + @@ -260,8 +290,9 @@ private void failWithUnexpectedMessageType(ChannelHandlerContext ctx, Object msg } else { buf.append(", lastPingReqId: " + lastPingReqId + ')'); } - fail(ctx, new ProtocolViolationException(buf.toString())); + message = buf.toString(); } + fail(ctx, new ProtocolViolationException(message)); } private void fail(ChannelHandlerContext ctx, Throwable cause) { @@ -305,40 +336,19 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E } @Override - KeepAliveHandler keepAliveHandler() { + public KeepAliveHandler keepAliveHandler() { return keepAliveHandler; } - void setKeepAliveHandler(ChannelHandlerContext ctx, KeepAliveHandler keepAliveHandler) { - this.keepAliveHandler = keepAliveHandler; - if (keepAliveHandler instanceof Http1ClientKeepAliveHandler) { - maybeInitializeKeepAliveHandler(ctx); - } - } - - private void maybeInitializeKeepAliveHandler(ChannelHandlerContext ctx) { + void maybeInitializeKeepAliveHandler(ChannelHandlerContext ctx) { if (ctx.channel().isActive()) { - final KeepAliveHandler keepAliveHandler = keepAliveHandler(); - if (keepAliveHandler != null) { - keepAliveHandler.initialize(ctx); - } - } - } - - private void destroyKeepAliveHandler() { - final KeepAliveHandler keepAliveHandler = keepAliveHandler(); - if (keepAliveHandler != null) { - keepAliveHandler.destroy(); + keepAliveHandler.initialize(ctx); } } private void onPingRead(Object msg) { if (msg instanceof HttpResponse) { - final KeepAliveHandler keepAliveHandler = keepAliveHandler(); - // Ping can not be activated with NoopKeepAliveHandler. - if (keepAliveHandler instanceof Http1ClientKeepAliveHandler) { - keepAliveHandler.onPing(); - } + keepAliveHandler.onPing(); } if (msg instanceof LastHttpContent) { onPingComplete(); diff --git a/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java index 14e6c0e115f..9560d8cf89b 100644 --- a/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/Http2ResponseDecoder.java @@ -20,8 +20,6 @@ import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; import static io.netty.handler.codec.http2.Http2Exception.connectionError; -import javax.annotation.Nonnull; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,8 +52,8 @@ import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2Stream; -final class Http2ResponseDecoder extends HttpResponseDecoder implements Http2Connection.Listener, - Http2FrameListener { +final class Http2ResponseDecoder extends AbstractHttpResponseDecoder implements Http2Connection.Listener, + Http2FrameListener { private static final Logger logger = LoggerFactory.getLogger(Http2ResponseDecoder.class); @@ -345,9 +343,8 @@ public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int wind public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) {} - @Nonnull @Override - KeepAliveHandler keepAliveHandler() { + public KeepAliveHandler keepAliveHandler() { return keepAliveHandler; } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java b/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java index eb0b4683035..3556290eb2f 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java @@ -15,6 +15,8 @@ */ package com.linecorp.armeria.client; +import static com.linecorp.armeria.common.SessionProtocol.httpAndHttpsValues; + import java.lang.reflect.Array; import java.net.InetSocketAddress; import java.net.SocketAddress; @@ -26,16 +28,18 @@ import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Consumer; -import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.collect.ImmutableSet; + import com.linecorp.armeria.client.proxy.ConnectProxyConfig; import com.linecorp.armeria.client.proxy.HAProxyConfig; import com.linecorp.armeria.client.proxy.ProxyConfig; @@ -44,6 +48,7 @@ import com.linecorp.armeria.client.proxy.Socks4ProxyConfig; import com.linecorp.armeria.client.proxy.Socks5ProxyConfig; import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.logging.ClientConnectionTimingsBuilder; @@ -88,9 +93,9 @@ final class HttpChannelPool implements AsyncCloseable { private final ConnectionPoolListener listener; // Fields for creating a new connection: - private final Bootstrap[] inetBootstraps; + private final Bootstrap[][] inetBootstraps; @Nullable - private final Bootstrap[] unixBootstraps; + private final Bootstrap[][] unixBootstraps; private final int connectTimeoutMillis; private final SslContext sslCtxHttp1Or2; @@ -101,17 +106,9 @@ final class HttpChannelPool implements AsyncCloseable { ConnectionPoolListener listener) { this.clientFactory = clientFactory; this.eventLoop = eventLoop; - pool = newEnumMap( - Map.class, - unused -> new HashMap<>(), - SessionProtocol.H1, SessionProtocol.H1C, - SessionProtocol.H2, SessionProtocol.H2C); - pendingAcquisitions = newEnumMap( - Map.class, - unused -> new HashMap<>(), - SessionProtocol.HTTP, SessionProtocol.HTTPS, - SessionProtocol.H1, SessionProtocol.H1C, - SessionProtocol.H2, SessionProtocol.H2C); + pool = newEnumMap(ImmutableSet.of(SessionProtocol.H1, SessionProtocol.H1C, + SessionProtocol.H2, SessionProtocol.H2C)); + pendingAcquisitions = newEnumMap(httpAndHttpsValues()); allChannels = new IdentityHashMap<>(); this.listener = listener; this.sslCtxHttp1Only = sslCtxHttp1Only; @@ -131,26 +128,41 @@ final class HttpChannelPool implements AsyncCloseable { .get(ChannelOption.CONNECT_TIMEOUT_MILLIS); } - private Bootstrap[] newBootstrapMap(Bootstrap baseBootstrap, - HttpClientFactory clientFactory, - EventLoop eventLoop) { + private Bootstrap[][] newBootstrapMap(Bootstrap baseBootstrap, + HttpClientFactory clientFactory, + EventLoop eventLoop) { baseBootstrap.group(eventLoop); - return newEnumMap(Bootstrap.class, - desiredProtocol -> { - final SslContext sslCtx = determineSslContext(desiredProtocol); - final Bootstrap bootstrap = baseBootstrap.clone(); - bootstrap.handler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ch.pipeline().addLast( - new HttpClientPipelineConfigurator(clientFactory, desiredProtocol, sslCtx)); - } - }); - return bootstrap; - }, - SessionProtocol.HTTP, SessionProtocol.HTTPS, - SessionProtocol.H1, SessionProtocol.H1C, - SessionProtocol.H2, SessionProtocol.H2C); + final Set sessionProtocols = httpAndHttpsValues(); + final Bootstrap[][] maps = (Bootstrap[][]) Array.newInstance( + Bootstrap.class, SessionProtocol.values().length, 2); + // Attempting to access the array with an unallowed protocol will trigger NPE, + // which will help us find a bug. + for (SessionProtocol p : sessionProtocols) { + final SslContext sslCtx = determineSslContext(p); + setBootstrap(baseBootstrap.clone(), clientFactory, maps, p, sslCtx, true); + setBootstrap(baseBootstrap.clone(), clientFactory, maps, p, sslCtx, false); + } + return maps; + } + + private static void setBootstrap(Bootstrap bootstrap, HttpClientFactory clientFactory, Bootstrap[][] maps, + SessionProtocol p, SslContext sslCtx, boolean webSocket) { + bootstrap.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast( + new HttpClientPipelineConfigurator(clientFactory, webSocket, p, sslCtx)); + } + }); + maps[p.ordinal()][toIndex(webSocket)] = bootstrap; + } + + private static int toIndex(boolean webSocket) { + return webSocket ? 1 : 0; + } + + private static int toIndex(SerializationFormat serializationFormat) { + return toIndex(serializationFormat == SerializationFormat.WS); } private SslContext determineSslContext(SessionProtocol desiredProtocol) { @@ -203,22 +215,23 @@ private void configureProxy(Channel ch, ProxyConfig proxyConfig, SessionProtocol /** * Returns an array whose index signifies {@link SessionProtocol#ordinal()}. Similar to {@link EnumMap}. */ - private static T[] newEnumMap(Class elementType, - Function factory, - SessionProtocol... allowedProtocols) { + private static Map[] newEnumMap(Set allowedProtocols) { @SuppressWarnings("unchecked") - final T[] maps = (T[]) Array.newInstance(elementType, SessionProtocol.values().length); + final Map[] maps = + (Map[]) Array.newInstance(Map.class, SessionProtocol.values().length); // Attempting to access the array with an unallowed protocol will trigger NPE, // which will help us find a bug. for (SessionProtocol p : allowedProtocols) { - maps[p.ordinal()] = factory.apply(p); + maps[p.ordinal()] = new HashMap<>(); } return maps; } - private Bootstrap getBootstrap(SessionProtocol desiredProtocol, SocketAddress remoteAddress) { + // TODO(minwoox): refactor this. https://github.com/line/armeria/issues/5129 + private Bootstrap getBootstrap(SessionProtocol desiredProtocol, SocketAddress remoteAddress, + SerializationFormat serializationFormat) { if (remoteAddress instanceof InetSocketAddress) { - return inetBootstraps[desiredProtocol.ordinal()]; + return inetBootstraps[desiredProtocol.ordinal()][toIndex(serializationFormat)]; } assert remoteAddress instanceof DomainSocketAddress : remoteAddress; @@ -228,7 +241,7 @@ private Bootstrap getBootstrap(SessionProtocol desiredProtocol, SocketAddress re eventLoop.getClass().getName()); } - return unixBootstraps[desiredProtocol.ordinal()]; + return unixBootstraps[desiredProtocol.ordinal()][toIndex(serializationFormat)]; } @Nullable @@ -258,32 +271,39 @@ private void removePendingAcquisition(SessionProtocol desiredProtocol, PoolKey k * Attempts to acquire a {@link Channel} which is matched by the specified condition immediately. * * @return {@code null} is there's no match left in the pool and thus a new connection has to be - * requested via {@link #acquireLater(SessionProtocol, PoolKey, ClientConnectionTimingsBuilder)}. + * requested via {@link #acquireLater(SessionProtocol, SerializationFormat, + * PoolKey, ClientConnectionTimingsBuilder)}. */ @Nullable - PooledChannel acquireNow(SessionProtocol desiredProtocol, PoolKey key) { + @SuppressWarnings("checkstyle:FallThrough") + PooledChannel acquireNow(SessionProtocol desiredProtocol, SerializationFormat serializationFormat, + PoolKey key) { PooledChannel ch; switch (desiredProtocol) { case HTTP: - ch = acquireNowExact(key, SessionProtocol.H2C); + ch = acquireNowExact(key, SessionProtocol.H2C, serializationFormat); if (ch == null) { - ch = acquireNowExact(key, SessionProtocol.H1C); + ch = acquireNowExact(key, SessionProtocol.H1C, serializationFormat); } break; case HTTPS: - ch = acquireNowExact(key, SessionProtocol.H2); + ch = acquireNowExact(key, SessionProtocol.H2, serializationFormat); if (ch == null) { - ch = acquireNowExact(key, SessionProtocol.H1); + ch = acquireNowExact(key, SessionProtocol.H1, serializationFormat); } break; default: - ch = acquireNowExact(key, desiredProtocol); + ch = acquireNowExact(key, desiredProtocol, serializationFormat); } return ch; } @Nullable - private PooledChannel acquireNowExact(PoolKey key, SessionProtocol protocol) { + private PooledChannel acquireNowExact(PoolKey key, SessionProtocol protocol, + SerializationFormat serializationFormat) { + if (serializationFormat.requiresNewConnection(protocol)) { + return null; + } final Deque queue = getPool(protocol, key); if (queue == null) { return null; @@ -336,11 +356,13 @@ private static SessionProtocol getProtocolIfHealthy(Channel ch) { * Acquires a new {@link Channel} which is matched by the specified condition by making a connection * attempt or waiting for the current connection attempt in progress. */ - CompletableFuture acquireLater(SessionProtocol desiredProtocol, PoolKey key, + CompletableFuture acquireLater(SessionProtocol desiredProtocol, + SerializationFormat serializationFormat, + PoolKey key, ClientConnectionTimingsBuilder timingsBuilder) { final ChannelAcquisitionFuture promise = new ChannelAcquisitionFuture(); - if (!usePendingAcquisition(desiredProtocol, key, promise, timingsBuilder)) { - connect(desiredProtocol, key, promise, timingsBuilder); + if (!usePendingAcquisition(desiredProtocol, serializationFormat, key, promise, timingsBuilder)) { + connect(desiredProtocol, serializationFormat, key, promise, timingsBuilder); } return promise; } @@ -350,11 +372,13 @@ CompletableFuture acquireLater(SessionProtocol desiredProtocol, P * * @return {@code true} if succeeded to reuse the pending connection. */ - private boolean usePendingAcquisition(SessionProtocol desiredProtocol, PoolKey key, + private boolean usePendingAcquisition(SessionProtocol desiredProtocol, + SerializationFormat serializationFormat, + PoolKey key, ChannelAcquisitionFuture promise, ClientConnectionTimingsBuilder timingsBuilder) { - if (desiredProtocol == SessionProtocol.H1 || desiredProtocol == SessionProtocol.H1C) { + if (desiredProtocol.isExplicitHttp1()) { // Can't use HTTP/1 connections because they will not be available in the pool until // the request is done. return false; @@ -366,11 +390,12 @@ private boolean usePendingAcquisition(SessionProtocol desiredProtocol, PoolKey k } timingsBuilder.pendingAcquisitionStart(); - pendingAcquisition.piggyback(desiredProtocol, key, promise, timingsBuilder); + pendingAcquisition.piggyback(desiredProtocol, serializationFormat, key, promise, timingsBuilder); return true; } - private void connect(SessionProtocol desiredProtocol, PoolKey key, ChannelAcquisitionFuture promise, + private void connect(SessionProtocol desiredProtocol, SerializationFormat serializationFormat, + PoolKey key, ChannelAcquisitionFuture promise, ClientConnectionTimingsBuilder timingsBuilder) { setPendingAcquisition(desiredProtocol, key, promise); timingsBuilder.socketConnectStart(); @@ -388,7 +413,7 @@ private void connect(SessionProtocol desiredProtocol, PoolKey key, ChannelAcquis // Create a new connection. final Promise sessionPromise = eventLoop.newPromise(); - connect(remoteAddress, desiredProtocol, key, sessionPromise); + connect(remoteAddress, desiredProtocol, serializationFormat, key, sessionPromise); if (sessionPromise.isDone()) { notifyConnect(desiredProtocol, key, sessionPromise, promise, timingsBuilder); @@ -402,17 +427,18 @@ private void connect(SessionProtocol desiredProtocol, PoolKey key, ChannelAcquis /** * A low-level operation that triggers a new connection attempt. Used only by: *
    - *
  • {@link #connect(SessionProtocol, PoolKey, ChannelAcquisitionFuture, - * ClientConnectionTimingsBuilder)} - The pool has been exhausted.
  • + *
  • {@link #connect(SessionProtocol, SerializationFormat, PoolKey, ChannelAcquisitionFuture, + * ClientConnectionTimingsBuilder)} - The pool has been exhausted.
  • *
  • {@link HttpSessionHandler} - HTTP/2 upgrade has failed.
  • *
*/ void connect(SocketAddress remoteAddress, SessionProtocol desiredProtocol, + SerializationFormat serializationFormat, PoolKey poolKey, Promise sessionPromise) { final Bootstrap bootstrap; try { - bootstrap = getBootstrap(desiredProtocol, remoteAddress); + bootstrap = getBootstrap(desiredProtocol, remoteAddress, serializationFormat); } catch (Exception e) { sessionPromise.tryFailure(e); return; @@ -433,7 +459,8 @@ void connect(SocketAddress remoteAddress, SessionProtocol desiredProtocol, channel.connect(remoteAddress).addListener((ChannelFuture connectFuture) -> { if (connectFuture.isSuccess()) { - initSession(desiredProtocol, poolKey, connectFuture, sessionPromise); + initSession(desiredProtocol, serializationFormat, + poolKey, connectFuture, sessionPromise); } else { maybeHandleProxyFailure(desiredProtocol, poolKey, connectFuture.cause()); sessionPromise.tryFailure(connectFuture.cause()); @@ -469,8 +496,8 @@ void maybeHandleProxyFailure(SessionProtocol protocol, PoolKey poolKey, Throwabl } } - private void initSession(SessionProtocol desiredProtocol, PoolKey poolKey, - ChannelFuture connectFuture, Promise sessionPromise) { + private void initSession(SessionProtocol desiredProtocol, SerializationFormat serializationFormat, + PoolKey poolKey, ChannelFuture connectFuture, Promise sessionPromise) { assert connectFuture.isSuccess(); final Channel ch = connectFuture.channel(); @@ -486,10 +513,11 @@ private void initSession(SessionProtocol desiredProtocol, PoolKey poolKey, ch.pipeline().addLast( new HttpSessionHandler(this, ch, sessionPromise, timeoutFuture, - desiredProtocol, poolKey, clientFactory)); + desiredProtocol, serializationFormat, poolKey, clientFactory)); } - private void notifyConnect(SessionProtocol desiredProtocol, PoolKey key, Future future, + private void notifyConnect(SessionProtocol desiredProtocol, + PoolKey key, Future future, ChannelAcquisitionFuture promise, ClientConnectionTimingsBuilder timingsBuilder) { assert future.isDone(); @@ -777,14 +805,15 @@ private final class ChannelAcquisitionFuture extends CompletableFuture handler = - pch -> handlePiggyback(desiredProtocol, key, childPromise, timingsBuilder, pch); + pch -> handlePiggyback(desiredProtocol, serializationFormat, key, + childPromise, timingsBuilder, pch); if (pendingPiggybackHandlers == null) { // The 1st handler @@ -813,11 +842,13 @@ void piggyback(SessionProtocol desiredProtocol, PoolKey key, } // Handle immediately if complete already. - handlePiggyback(desiredProtocol, key, childPromise, timingsBuilder, + handlePiggyback(desiredProtocol, serializationFormat, key, childPromise, timingsBuilder, isCompletedExceptionally() ? null : getNow(null)); } - private void handlePiggyback(SessionProtocol desiredProtocol, PoolKey key, + private void handlePiggyback(SessionProtocol desiredProtocol, + SerializationFormat serializationFormat, + PoolKey key, ChannelAcquisitionFuture childPromise, ClientConnectionTimingsBuilder timingsBuilder, @Nullable PooledChannel pch) { @@ -829,7 +860,8 @@ private void handlePiggyback(SessionProtocol desiredProtocol, PoolKey key, final HttpSession session = HttpSession.get(pch.get()); if (session.incrementNumUnfinishedResponses()) { result = PiggybackedChannelAcquisitionResult.SUCCESS; - } else if (usePendingAcquisition(actualProtocol, key, childPromise, timingsBuilder)) { + } else if (usePendingAcquisition(actualProtocol, serializationFormat, + key, childPromise, timingsBuilder)) { result = PiggybackedChannelAcquisitionResult.PIGGYBACKED_AGAIN; } else { result = PiggybackedChannelAcquisitionResult.NEW_CONNECTION; @@ -839,7 +871,7 @@ private void handlePiggyback(SessionProtocol desiredProtocol, PoolKey key, // We use the exact protocol (H1 or H1C) instead of 'desiredProtocol' so that // we do not waste our time looking for pending acquisitions for the host // that does not support HTTP/2. - final PooledChannel ch = acquireNow(actualProtocol, key); + final PooledChannel ch = acquireNow(actualProtocol, serializationFormat, key); if (ch != null) { pch = ch; result = PiggybackedChannelAcquisitionResult.SUCCESS; @@ -858,7 +890,7 @@ private void handlePiggyback(SessionProtocol desiredProtocol, PoolKey key, break; case NEW_CONNECTION: timingsBuilder.pendingAcquisitionEnd(); - connect(desiredProtocol, key, childPromise, timingsBuilder); + connect(desiredProtocol, serializationFormat, key, childPromise, timingsBuilder); break; case PIGGYBACKED_AGAIN: // There's nothing to do because usePendingAcquisition() was called successfully above. diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java index 884c7bc705f..72f953d9d5d 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java @@ -28,6 +28,7 @@ import com.linecorp.armeria.client.proxy.ProxyType; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.logging.ClientConnectionTimings; @@ -167,25 +168,27 @@ private void acquireConnectionAndExecute0(ClientRequestContext ctx, Endpoint end HttpRequest req, DecodedHttpResponse res, ClientConnectionTimingsBuilder timingsBuilder, ProxyConfig proxyConfig) { - final SessionProtocol protocol = ctx.sessionProtocol(); final PoolKey key = new PoolKey(endpoint, proxyConfig); final HttpChannelPool pool = factory.pool(ctx.eventLoop().withoutContext()); - final PooledChannel pooledChannel = pool.acquireNow(protocol, key); + final SessionProtocol protocol = ctx.sessionProtocol(); + final SerializationFormat serializationFormat = ctx.log().partial().serializationFormat(); + final PooledChannel pooledChannel = pool.acquireNow(protocol, serializationFormat, key); if (pooledChannel != null) { logSession(ctx, pooledChannel, null); doExecute(pooledChannel, ctx, req, res); } else { - pool.acquireLater(protocol, key, timingsBuilder).handle((newPooledChannel, cause) -> { - logSession(ctx, newPooledChannel, timingsBuilder.build()); - if (cause == null) { - doExecute(newPooledChannel, ctx, req, res); - } else { - final UnprocessedRequestException wrapped = UnprocessedRequestException.of(cause); - handleEarlyRequestException(ctx, req, wrapped); - res.close(wrapped); - } - return null; - }); + pool.acquireLater(protocol, serializationFormat, key, timingsBuilder) + .handle((newPooledChannel, cause) -> { + logSession(ctx, newPooledChannel, timingsBuilder.build()); + if (cause == null) { + doExecute(newPooledChannel, ctx, req, res); + } else { + final UnprocessedRequestException wrapped = UnprocessedRequestException.of(cause); + handleEarlyRequestException(ctx, req, wrapped); + res.close(wrapped); + } + return null; + }); } } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java index 8ae62a321c7..a05eb38e9d2 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java @@ -29,6 +29,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.function.Consumer; import java.util.function.Supplier; +import java.util.stream.Stream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -78,7 +79,8 @@ final class HttpClientFactory implements ClientFactory { private static final Set SUPPORTED_SCHEMES = Arrays.stream(SessionProtocol.values()) - .map(p -> Scheme.of(SerializationFormat.NONE, p)) + .flatMap(p -> Stream.of(Scheme.of(SerializationFormat.NONE, p), + Scheme.of(SerializationFormat.WS, p))) .collect(toImmutableSet()); private final EventLoopGroup workerGroup; diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java index 5e6b8f9c988..5009687de82 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java @@ -124,6 +124,11 @@ final class HttpClientPipelineConfigurator extends ChannelDuplexHandler { */ private static final long UPGRADE_RESPONSE_MAX_LENGTH = 16384; + private static final RequestOptions REQUEST_OPTIONS_FOR_UPGRADE_REQUEST = + RequestOptions.builder() + .responseTimeoutMillis(0) + .maxResponseLength(UPGRADE_RESPONSE_MAX_LENGTH).build(); + private enum HttpPreference { HTTP1_REQUIRED, HTTP2_PREFERRED, @@ -131,6 +136,7 @@ private enum HttpPreference { } private final HttpClientFactory clientFactory; + private final boolean webSocket; @Nullable private final SslContext sslCtx; private final HttpPreference httpPreference; @@ -141,9 +147,10 @@ private enum HttpPreference { private final SessionProtocol http2; HttpClientPipelineConfigurator(HttpClientFactory clientFactory, - SessionProtocol sessionProtocol, + boolean webSocket, SessionProtocol sessionProtocol, @Nullable SslContext sslCtx) { this.clientFactory = clientFactory; + this.webSocket = webSocket; if (sessionProtocol == HTTP || sessionProtocol == HTTPS) { httpPreference = HttpPreference.HTTP2_PREFERRED; @@ -400,7 +407,10 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { void finishSuccessfully(ChannelPipeline pipeline, SessionProtocol protocol) { if (protocol == H1 || protocol == H1C) { - addBeforeSessionHandler(pipeline, new Http1ResponseDecoder(pipeline.channel())); + addBeforeSessionHandler( + pipeline, webSocket ? new WebSocketHttp1ClientChannelHandler(pipeline.channel()) + : new Http1ResponseDecoder(pipeline.channel(), + clientFactory, protocol)); } else if (protocol == H2 || protocol == H2C) { final int initialWindow = clientFactory.http2InitialConnectionWindowSize(); if (initialWindow > DEFAULT_WINDOW_SIZE) { @@ -420,7 +430,7 @@ private static void incrementLocalWindowSize(ChannelPipeline pipeline, int delta } } - private void addBeforeSessionHandler(ChannelPipeline pipeline, ChannelHandler handler) { + private static void addBeforeSessionHandler(ChannelPipeline pipeline, ChannelHandler handler) { final ChannelHandlerContext lastContext = pipeline.lastContext(); if (lastContext.handler().getClass() == HttpSessionHandler.class) { // Get the name of the HttpSessionHandler so that we can put our handlers before it. @@ -532,12 +542,11 @@ public void onComplete() {} com.linecorp.armeria.common.HttpMethod.OPTIONS, RequestTarget.forClient("*"), ClientOptions.of(), HttpRequest.of(com.linecorp.armeria.common.HttpMethod.OPTIONS, "*"), - null, RequestOptions.of(), noopResponseCancellationScheduler, + null, REQUEST_OPTIONS_FOR_UPGRADE_REQUEST, noopResponseCancellationScheduler, System.nanoTime(), SystemInfo.currentTimeMicros()); // NB: No need to set the response timeout because we have session creation timeout. - responseDecoder.addResponse(0, res, reqCtx, ctx.channel().eventLoop(), /* response timeout */ 0, - UPGRADE_RESPONSE_MAX_LENGTH); + responseDecoder.addResponse(0, res, reqCtx, ctx.channel().eventLoop()); ctx.fireChannelActive(); } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java b/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java index 89b16e54088..e4ddfb1b284 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpRequestSubscriber.java @@ -16,57 +16,22 @@ package com.linecorp.armeria.client; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; import com.linecorp.armeria.common.HttpObject; import com.linecorp.armeria.common.HttpRequest; -import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.client.DecodedHttpResponse; import com.linecorp.armeria.unsafe.PooledObjects; import io.netty.channel.Channel; -final class HttpRequestSubscriber extends AbstractHttpRequestHandler implements Subscriber { - - private static final HttpData EMPTY_EOS = HttpData.empty().withEndOfStream(); - - private final HttpRequest request; - - // subscription, id and responseWrapper are assigned in onSubscribe() - @Nullable - private Subscription subscription; - private boolean isSubscriptionCompleted; +class HttpRequestSubscriber extends AbstractHttpRequestSubscriber { HttpRequestSubscriber(Channel ch, ClientHttpObjectEncoder encoder, HttpResponseDecoder responseDecoder, HttpRequest request, DecodedHttpResponse originalRes, ClientRequestContext ctx, long timeoutMillis) { - super(ch, encoder, responseDecoder, originalRes, ctx, timeoutMillis, request.isEmpty()); - this.request = request; - } - - @Override - public void onSubscribe(Subscription subscription) { - assert this.subscription == null; - this.subscription = subscription; - if (state() == State.DONE) { - cancel(); - return; - } - - if (!tryInitialize()) { - return; - } - - // NB: This must be invoked at the end of this method because otherwise the callback methods in this - // class can be called before the member fields (subscription, id, responseWrapper and - // timeoutFuture) are initialized. - // It is because the successful write of the first headers will trigger subscription.request(1). - writeHeaders(request.headers()); - channel().flush(); + super(ch, encoder, responseDecoder, request, originalRes, ctx, timeoutMillis, true, true); } @Override @@ -74,6 +39,7 @@ public void onNext(HttpObject o) { if (!(o instanceof HttpData) && !(o instanceof HttpHeaders)) { failAndReset(new IllegalArgumentException( "published an HttpObject that's neither Http2Headers nor Http2Data: " + o)); + PooledObjects.close(o); return; } @@ -100,38 +66,4 @@ public void onNext(HttpObject o) { break; } } - - @Override - public void onError(Throwable cause) { - isSubscriptionCompleted = true; - failRequest(cause); - } - - @Override - public void onComplete() { - isSubscriptionCompleted = true; - - if (state() != State.DONE) { - writeData(EMPTY_EOS); - channel().flush(); - } - } - - @Override - void onWriteSuccess() { - // Request more messages regardless whether the state is DONE. It makes the producer have - // a chance to produce the last call such as 'onComplete' and 'onError' when there are - // no more messages it can produce. - if (!isSubscriptionCompleted) { - assert subscription != null; - subscription.request(1); - } - } - - @Override - void cancel() { - isSubscriptionCompleted = true; - assert subscription != null; - subscription.cancel(); - } } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java b/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java index decffa52eba..125340aa115 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpResponseDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2016 LINE Corporation + * Copyright 2023 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance @@ -16,425 +16,43 @@ package com.linecorp.armeria.client; -import java.util.Iterator; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; - -import org.reactivestreams.Subscriber; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import com.linecorp.armeria.common.ContentTooLargeException; -import com.linecorp.armeria.common.ContentTooLargeExceptionBuilder; -import com.linecorp.armeria.common.HttpData; -import com.linecorp.armeria.common.HttpHeaders; -import com.linecorp.armeria.common.HttpObject; -import com.linecorp.armeria.common.HttpRequest; -import com.linecorp.armeria.common.HttpStatusClass; -import com.linecorp.armeria.common.ResponseCompleteException; -import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.common.logging.RequestLogProperty; -import com.linecorp.armeria.common.stream.CancelledSubscriptionException; -import com.linecorp.armeria.common.stream.StreamWriter; -import com.linecorp.armeria.common.stream.SubscriptionOption; -import com.linecorp.armeria.common.util.Exceptions; -import com.linecorp.armeria.internal.client.ClientRequestContextExtension; import com.linecorp.armeria.internal.client.DecodedHttpResponse; import com.linecorp.armeria.internal.client.HttpSession; -import com.linecorp.armeria.internal.common.CancellationScheduler; -import com.linecorp.armeria.internal.common.CancellationScheduler.CancellationTask; import com.linecorp.armeria.internal.common.InboundTrafficController; import com.linecorp.armeria.internal.common.KeepAliveHandler; -import com.linecorp.armeria.unsafe.PooledObjects; import io.netty.channel.Channel; import io.netty.channel.EventLoop; -import io.netty.util.collection.IntObjectHashMap; -import io.netty.util.collection.IntObjectMap; -import io.netty.util.concurrent.EventExecutor; - -abstract class HttpResponseDecoder { - private static final Logger logger = LoggerFactory.getLogger(HttpResponseDecoder.class); +interface HttpResponseDecoder { - private final IntObjectMap responses = new IntObjectHashMap<>(); - private final Channel channel; - private final InboundTrafficController inboundTrafficController; - - @Nullable - private HttpSession httpSession; + Channel channel(); - private int unfinishedResponses; - private boolean closing; - - HttpResponseDecoder(Channel channel, InboundTrafficController inboundTrafficController) { - this.channel = channel; - this.inboundTrafficController = inboundTrafficController; - } - - final Channel channel() { - return channel; - } - - final InboundTrafficController inboundTrafficController() { - return inboundTrafficController; - } + InboundTrafficController inboundTrafficController(); HttpResponseWrapper addResponse( - int id, DecodedHttpResponse res, ClientRequestContext ctx, - EventLoop eventLoop, long responseTimeoutMillis, long maxContentLength) { - - final HttpResponseWrapper newRes = - new HttpResponseWrapper(res, ctx, responseTimeoutMillis, maxContentLength); - final HttpResponseWrapper oldRes = responses.put(id, newRes); - final KeepAliveHandler keepAliveHandler = keepAliveHandler(); - if (keepAliveHandler != null) { - keepAliveHandler.increaseNumRequests(); - } - - assert oldRes == null : "addResponse(" + id + ", " + res + ", " + responseTimeoutMillis + "): " + - oldRes; - onResponseAdded(id, eventLoop, newRes); - return newRes; - } - - abstract void onResponseAdded(int id, EventLoop eventLoop, HttpResponseWrapper responseWrapper); + int id, DecodedHttpResponse res, ClientRequestContext ctx, EventLoop eventLoop); @Nullable - final HttpResponseWrapper getResponse(int id) { - return responses.get(id); - } + HttpResponseWrapper getResponse(int id); @Nullable - final HttpResponseWrapper removeResponse(int id) { - if (closing) { - // `unfinishedResponses` will be removed by `failUnfinishedResponses()` - return null; - } + HttpResponseWrapper removeResponse(int id); - final HttpResponseWrapper removed = responses.remove(id); - if (removed != null) { - unfinishedResponses--; - assert unfinishedResponses >= 0 : unfinishedResponses; - } - return removed; - } + boolean hasUnfinishedResponses(); - final boolean hasUnfinishedResponses() { - return unfinishedResponses != 0; - } - - final boolean reserveUnfinishedResponse(int maxUnfinishedResponses) { - if (unfinishedResponses >= maxUnfinishedResponses) { - return false; - } - - unfinishedResponses++; - return true; - } + boolean reserveUnfinishedResponse(int maxUnfinishedResponses); - final void decrementUnfinishedResponses() { - unfinishedResponses--; - } - - final void failUnfinishedResponses(Throwable cause) { - if (closing) { - return; - } - closing = true; + void decrementUnfinishedResponses(); - for (final Iterator iterator = responses.values().iterator(); - iterator.hasNext();) { - final HttpResponseWrapper res = iterator.next(); - // To avoid calling removeResponse by res.close(cause), remove before closing. - iterator.remove(); - unfinishedResponses--; - res.close(cause); - } - } + void failUnfinishedResponses(Throwable cause); - HttpSession session() { - if (httpSession != null) { - return httpSession; - } - return httpSession = HttpSession.get(channel); - } + HttpSession session(); - @Nullable - abstract KeepAliveHandler keepAliveHandler(); + KeepAliveHandler keepAliveHandler(); - final boolean needsToDisconnectNow() { + default boolean needsToDisconnectNow() { return !session().isAcquirable() && !hasUnfinishedResponses(); } - - static final class HttpResponseWrapper implements StreamWriter { - - private final DecodedHttpResponse delegate; - private final ClientRequestContext ctx; - private final long maxContentLength; - private final long responseTimeoutMillis; - - private boolean responseStarted; - private long contentLengthHeaderValue = -1; - - private boolean done; - private boolean closed; - - HttpResponseWrapper(DecodedHttpResponse delegate, ClientRequestContext ctx, - long responseTimeoutMillis, long maxContentLength) { - this.delegate = delegate; - this.ctx = ctx; - this.maxContentLength = maxContentLength; - this.responseTimeoutMillis = responseTimeoutMillis; - } - - long maxContentLength() { - return maxContentLength; - } - - long writtenBytes() { - return delegate.writtenBytes(); - } - - long contentLengthHeaderValue() { - return contentLengthHeaderValue; - } - - @Override - public boolean isOpen() { - return delegate.isOpen(); - } - - @Override - public boolean isEmpty() { - throw new UnsupportedOperationException(); - } - - @Override - public long demand() { - throw new UnsupportedOperationException(); - } - - @Override - public CompletableFuture whenComplete() { - return delegate.whenComplete(); - } - - @Override - public void subscribe(Subscriber subscriber, EventExecutor executor, - SubscriptionOption... options) { - throw new UnsupportedOperationException(); - } - - @Override - public void abort() { - throw new UnsupportedOperationException(); - } - - @Override - public void abort(Throwable cause) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean tryWrite(HttpObject o) { - if (done) { - PooledObjects.close(o); - return false; - } - return delegate.tryWrite(o); - } - - void startResponse() { - if (responseStarted) { - return; - } - responseStarted = true; - ctx.logBuilder().startResponse(); - ctx.logBuilder().responseFirstBytesTransferred(); - initTimeout(); - } - - boolean tryWriteResponseHeaders(ResponseHeaders responseHeaders) { - assert responseHeaders.status().codeClass() != HttpStatusClass.INFORMATIONAL; - contentLengthHeaderValue = responseHeaders.contentLength(); - ctx.logBuilder().defer(RequestLogProperty.RESPONSE_HEADERS); - try { - return delegate.tryWrite(responseHeaders); - } finally { - ctx.logBuilder().responseHeaders(responseHeaders); - } - } - - boolean tryWriteData(HttpData data) { - if (done) { - PooledObjects.close(data); - return false; - } - data.touch(ctx); - ctx.logBuilder().increaseResponseLength(data); - return delegate.tryWrite(data); - } - - boolean tryWriteTrailers(HttpHeaders trailers) { - if (done) { - return false; - } - done = true; - ctx.logBuilder().defer(RequestLogProperty.RESPONSE_TRAILERS); - try { - return delegate.tryWrite(trailers); - } finally { - ctx.logBuilder().responseTrailers(trailers); - } - } - - @Override - public CompletableFuture whenConsumed() { - return delegate.whenConsumed(); - } - - void onSubscriptionCancelled(@Nullable Throwable cause) { - close(cause, true); - } - - @Override - public void close() { - close(null, false); - } - - @Override - public void close(Throwable cause) { - close(cause, false); - } - - private void close(@Nullable Throwable cause, boolean cancel) { - if (closed) { - return; - } - done = true; - closed = true; - cancelTimeoutOrLog(cause, cancel); - final HttpRequest request = ctx.request(); - assert request != null; - if (cause != null) { - request.abort(cause); - return; - } - final long requestAutoAbortDelayMillis = ctx.requestAutoAbortDelayMillis(); - if (requestAutoAbortDelayMillis == 0) { - request.abort(ResponseCompleteException.get()); - return; - } - if (requestAutoAbortDelayMillis > 0 && - requestAutoAbortDelayMillis < Long.MAX_VALUE) { - ctx.eventLoop().schedule(() -> request.abort(ResponseCompleteException.get()), - requestAutoAbortDelayMillis, TimeUnit.MILLISECONDS); - } - } - - private void closeAction(@Nullable Throwable cause) { - if (cause != null) { - delegate.close(cause); - ctx.logBuilder().endResponse(cause); - } else { - delegate.close(); - ctx.logBuilder().endResponse(); - } - } - - private void cancelAction(@Nullable Throwable cause) { - if (cause != null && !(cause instanceof CancelledSubscriptionException)) { - ctx.logBuilder().endResponse(cause); - } else { - ctx.logBuilder().endResponse(); - } - } - - private void cancelTimeoutOrLog(@Nullable Throwable cause, boolean cancel) { - CancellationScheduler responseCancellationScheduler = null; - final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); - if (ctxExtension != null) { - responseCancellationScheduler = ctxExtension.responseCancellationScheduler(); - } - - if (responseCancellationScheduler == null || !responseCancellationScheduler.isFinished()) { - if (responseCancellationScheduler != null) { - responseCancellationScheduler.clearTimeout(false); - } - // There's no timeout or the response has not been timed out. - if (cancel) { - cancelAction(cause); - } else { - closeAction(cause); - } - return; - } - if (delegate.isOpen()) { - closeAction(cause); - } - - // Response has been timed out already. - // Log only when it's not a ResponseTimeoutException. - if (cause instanceof ResponseTimeoutException) { - return; - } - - if (cause == null || !logger.isWarnEnabled() || Exceptions.isExpected(cause)) { - return; - } - - final StringBuilder logMsg = new StringBuilder("Unexpected exception while closing a request"); - final String authority = ctx.request().authority(); - if (authority != null) { - logMsg.append(" to ").append(authority); - } - - logger.warn(logMsg.append(':').toString(), cause); - } - - void initTimeout() { - final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); - if (ctxExtension != null) { - final CancellationScheduler responseCancellationScheduler = - ctxExtension.responseCancellationScheduler(); - responseCancellationScheduler.init( - ctx.eventLoop(), newCancellationTask(), - TimeUnit.MILLISECONDS.toNanos(responseTimeoutMillis), /* server */ false); - } - } - - private CancellationTask newCancellationTask() { - return new CancellationTask() { - @Override - public boolean canSchedule() { - return delegate.isOpen() && !done; - } - - @Override - public void run(Throwable cause) { - delegate.close(cause); - ctx.request().abort(cause); - ctx.logBuilder().endResponse(cause); - } - }; - } - - @Override - public String toString() { - return delegate.toString(); - } - } - - static Exception contentTooLargeException(HttpResponseWrapper res, long transferred) { - final ContentTooLargeExceptionBuilder builder = - ContentTooLargeException.builder() - .maxContentLength(res.maxContentLength()) - .transferred(transferred); - if (res.contentLengthHeaderValue() >= 0) { - builder.contentLength(res.contentLengthHeaderValue()); - } - return builder.build(); - } } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java b/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java new file mode 100644 index 00000000000..0795416df6d --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/HttpResponseWrapper.java @@ -0,0 +1,328 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static com.google.common.base.MoreObjects.toStringHelper; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.reactivestreams.Subscriber; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpHeaders; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.ResponseCompleteException; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.logging.RequestLogProperty; +import com.linecorp.armeria.common.stream.CancelledSubscriptionException; +import com.linecorp.armeria.common.stream.StreamWriter; +import com.linecorp.armeria.common.stream.SubscriptionOption; +import com.linecorp.armeria.common.util.Exceptions; +import com.linecorp.armeria.internal.client.ClientRequestContextExtension; +import com.linecorp.armeria.internal.client.DecodedHttpResponse; +import com.linecorp.armeria.internal.common.CancellationScheduler; +import com.linecorp.armeria.internal.common.CancellationScheduler.CancellationTask; +import com.linecorp.armeria.unsafe.PooledObjects; + +import io.netty.channel.EventLoop; +import io.netty.util.concurrent.EventExecutor; + +class HttpResponseWrapper implements StreamWriter { + + private static final Logger logger = LoggerFactory.getLogger(HttpResponseWrapper.class); + + private final DecodedHttpResponse delegate; + private final EventLoop eventLoop; + private final ClientRequestContext ctx; + private final long maxContentLength; + private final long responseTimeoutMillis; + + private boolean responseStarted; + private long contentLengthHeaderValue = -1; + + private boolean done; + private boolean closed; + + HttpResponseWrapper(DecodedHttpResponse delegate, EventLoop eventLoop, ClientRequestContext ctx, + long responseTimeoutMillis, long maxContentLength) { + this.delegate = delegate; + this.eventLoop = eventLoop; + this.ctx = ctx; + this.maxContentLength = maxContentLength; + this.responseTimeoutMillis = responseTimeoutMillis; + } + + DecodedHttpResponse delegate() { + return delegate; + } + + EventLoop eventLoop() { + return eventLoop; + } + + long maxContentLength() { + return maxContentLength; + } + + long writtenBytes() { + return delegate.writtenBytes(); + } + + long contentLengthHeaderValue() { + return contentLengthHeaderValue; + } + + @Override + public boolean isOpen() { + return delegate.isOpen(); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(); + } + + @Override + public long demand() { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture whenComplete() { + return delegate.whenComplete(); + } + + @Override + public void subscribe(Subscriber subscriber, EventExecutor executor, + SubscriptionOption... options) { + throw new UnsupportedOperationException(); + } + + @Override + public void abort() { + throw new UnsupportedOperationException(); + } + + @Override + public void abort(Throwable cause) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean tryWrite(HttpObject o) { + if (done) { + PooledObjects.close(o); + return false; + } + return delegate.tryWrite(o); + } + + void startResponse() { + if (responseStarted) { + return; + } + responseStarted = true; + ctx.logBuilder().startResponse(); + ctx.logBuilder().responseFirstBytesTransferred(); + initTimeout(); + } + + boolean tryWriteResponseHeaders(ResponseHeaders responseHeaders) { + contentLengthHeaderValue = responseHeaders.contentLength(); + ctx.logBuilder().defer(RequestLogProperty.RESPONSE_HEADERS); + try { + return delegate.tryWrite(responseHeaders); + } finally { + ctx.logBuilder().responseHeaders(responseHeaders); + } + } + + boolean tryWriteData(HttpData data) { + if (done) { + PooledObjects.close(data); + return false; + } + data.touch(ctx); + ctx.logBuilder().increaseResponseLength(data); + return delegate.tryWrite(data); + } + + boolean tryWriteTrailers(HttpHeaders trailers) { + if (done) { + return false; + } + done = true; + ctx.logBuilder().defer(RequestLogProperty.RESPONSE_TRAILERS); + try { + return delegate.tryWrite(trailers); + } finally { + ctx.logBuilder().responseTrailers(trailers); + } + } + + @Override + public CompletableFuture whenConsumed() { + return delegate.whenConsumed(); + } + + /** + * This method is called when the delegate is completed. + */ + void onSubscriptionCancelled(@Nullable Throwable cause) { + close(cause, true); + } + + @Override + public void close() { + close(null, false); + } + + @Override + public void close(Throwable cause) { + close(cause, false); + } + + void close(@Nullable Throwable cause, boolean cancel) { + if (closed) { + return; + } + done = true; + closed = true; + cancelTimeoutOrLog(cause, cancel); + final HttpRequest request = ctx.request(); + assert request != null; + if (cause != null) { + request.abort(cause); + return; + } + final long requestAutoAbortDelayMillis = ctx.requestAutoAbortDelayMillis(); + if (requestAutoAbortDelayMillis < 0 || requestAutoAbortDelayMillis == Long.MAX_VALUE) { + return; + } + if (requestAutoAbortDelayMillis == 0) { + request.abort(ResponseCompleteException.get()); + return; + } + ctx.eventLoop().schedule(() -> request.abort(ResponseCompleteException.get()), + requestAutoAbortDelayMillis, TimeUnit.MILLISECONDS); + } + + private void closeAction(@Nullable Throwable cause) { + if (cause != null) { + delegate.close(cause); + ctx.logBuilder().endResponse(cause); + } else { + delegate.close(); + ctx.logBuilder().endResponse(); + } + } + + private void cancelAction(@Nullable Throwable cause) { + if (cause != null && !(cause instanceof CancelledSubscriptionException)) { + ctx.logBuilder().endResponse(cause); + } else { + ctx.logBuilder().endResponse(); + } + } + + private void cancelTimeoutOrLog(@Nullable Throwable cause, boolean cancel) { + CancellationScheduler responseCancellationScheduler = null; + final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); + if (ctxExtension != null) { + responseCancellationScheduler = ctxExtension.responseCancellationScheduler(); + } + + if (responseCancellationScheduler == null || !responseCancellationScheduler.isFinished()) { + if (responseCancellationScheduler != null) { + responseCancellationScheduler.clearTimeout(false); + } + // There's no timeout or the response has not been timed out. + if (cancel) { + cancelAction(cause); + } else { + closeAction(cause); + } + return; + } + if (delegate.isOpen()) { + closeAction(cause); + } + + // Response has been timed out already. + // Log only when it's not a ResponseTimeoutException. + if (cause instanceof ResponseTimeoutException) { + return; + } + + if (cause == null || !logger.isWarnEnabled() || Exceptions.isExpected(cause)) { + return; + } + + final StringBuilder logMsg = new StringBuilder("Unexpected exception while closing a request"); + final String authority = ctx.request().authority(); + if (authority != null) { + logMsg.append(" to ").append(authority); + } + + logger.warn(logMsg.append(':').toString(), cause); + } + + void initTimeout() { + final ClientRequestContextExtension ctxExtension = ctx.as(ClientRequestContextExtension.class); + if (ctxExtension != null) { + final CancellationScheduler responseCancellationScheduler = + ctxExtension.responseCancellationScheduler(); + responseCancellationScheduler.init( + ctx.eventLoop(), newCancellationTask(), + TimeUnit.MILLISECONDS.toNanos(responseTimeoutMillis), /* server */ false); + } + } + + private CancellationTask newCancellationTask() { + return new CancellationTask() { + @Override + public boolean canSchedule() { + return delegate.isOpen() && !done; + } + + @Override + public void run(Throwable cause) { + delegate.close(cause); + ctx.request().abort(cause); + ctx.logBuilder().endResponse(cause); + } + }; + } + + @Override + public String toString() { + return toStringHelper(this).omitNullValues() + .add("ctx", ctx) + .add("eventLoop", eventLoop) + .add("responseStarted", responseStarted) + .add("maxContentLength", maxContentLength) + .add("responseTimeoutMillis", responseTimeoutMillis) + .add("contentLengthHeaderValue", contentLengthHeaderValue) + .add("delegate", delegate) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java b/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java index afed2bde1fe..59a7ee6f81a 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpSessionHandler.java @@ -19,7 +19,6 @@ import static com.linecorp.armeria.common.SessionProtocol.H1C; import static com.linecorp.armeria.common.SessionProtocol.H2; import static com.linecorp.armeria.common.SessionProtocol.H2C; -import static com.linecorp.armeria.internal.common.KeepAliveHandlerUtil.needsKeepAliveHandler; import static java.util.Objects.requireNonNull; import java.io.IOException; @@ -30,16 +29,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.common.collect.ImmutableList; - import com.linecorp.armeria.client.HttpChannelPool.PoolKey; import com.linecorp.armeria.client.proxy.ProxyType; import com.linecorp.armeria.common.AggregationOptions; import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.common.metric.MoreMeters; import com.linecorp.armeria.common.stream.CancelledSubscriptionException; import com.linecorp.armeria.common.stream.SubscriptionOption; import com.linecorp.armeria.common.util.SafeCloseable; @@ -49,11 +46,8 @@ import com.linecorp.armeria.internal.common.Http2GoAwayHandler; import com.linecorp.armeria.internal.common.InboundTrafficController; import com.linecorp.armeria.internal.common.KeepAliveHandler; -import com.linecorp.armeria.internal.common.NoopKeepAliveHandler; import com.linecorp.armeria.internal.common.RequestContextUtil; -import io.micrometer.core.instrument.Tag; -import io.micrometer.core.instrument.Timer; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.channel.Channel; @@ -82,6 +76,7 @@ final class HttpSessionHandler extends ChannelDuplexHandler implements HttpSessi private final Promise sessionPromise; private final ScheduledFuture sessionTimeoutFuture; private final SessionProtocol desiredProtocol; + private final SerializationFormat serializationFormat; private final PoolKey poolKey; private final HttpClientFactory clientFactory; @@ -124,18 +119,24 @@ final class HttpSessionHandler extends ChannelDuplexHandler implements HttpSessi HttpSessionHandler(HttpChannelPool channelPool, Channel channel, Promise sessionPromise, ScheduledFuture sessionTimeoutFuture, - SessionProtocol desiredProtocol, PoolKey poolKey, - HttpClientFactory clientFactory) { + SessionProtocol desiredProtocol, SerializationFormat serializationFormat, + PoolKey poolKey, HttpClientFactory clientFactory) { this.channelPool = requireNonNull(channelPool, "channelPool"); this.channel = requireNonNull(channel, "channel"); remoteAddress = channel.remoteAddress(); this.sessionPromise = requireNonNull(sessionPromise, "sessionPromise"); this.sessionTimeoutFuture = requireNonNull(sessionTimeoutFuture, "sessionTimeoutFuture"); this.desiredProtocol = desiredProtocol; + this.serializationFormat = serializationFormat; this.poolKey = poolKey; this.clientFactory = clientFactory; } + @Override + public SerializationFormat serializationFormat() { + return serializationFormat; + } + @Override public SessionProtocol protocol() { return protocol; @@ -195,8 +196,9 @@ public void invoke(PooledChannel pooledChannel, ClientRequestContext ctx, assert protocol != null; assert responseDecoder != null; assert requestEncoder != null; - if (!protocol.isMultiplex()) { - // When HTTP/1.1 is used: + if (!protocol.isMultiplex() && !serializationFormat.requiresNewConnection(protocol)) { + // When HTTP/1.1 is used and the serialization format does not require + // a new connection (w.g. WebSocket): // If pipelining is enabled, return as soon as the request is fully sent. // If pipelining is disabled, // return after the response is fully received and the request is fully sent. @@ -212,23 +214,26 @@ public void invoke(PooledChannel pooledChannel, ClientRequestContext ctx, }); } - if (ctx.exchangeType().isRequestStreaming()) { - final HttpRequestSubscriber reqSubscriber = new HttpRequestSubscriber( - channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis); - // A StreamMessage of a request body uses RequestContext to get the default SubscriberExecutor. - try (SafeCloseable ignored = ctx.push()) { - req.subscribe(reqSubscriber, channel.eventLoop(), SubscriptionOption.WITH_POOLED_OBJECTS); - } - } else { - final AggregatedHttpRequestHandler reqHandler = new AggregatedHttpRequestHandler( - channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis); - try (SafeCloseable ignored = ctx.push()) { + try (SafeCloseable ignored = ctx.push()) { + if (!ctx.exchangeType().isRequestStreaming()) { + final AggregatedHttpRequestHandler reqHandler = new AggregatedHttpRequestHandler( + channel, requestEncoder, responseDecoder, req, res, ctx, writeTimeoutMillis); req.aggregate(AggregationOptions.usePooledObjects(ctx.alloc(), channel.eventLoop())) .handle(reqHandler); + return; } + + final AbstractHttpRequestSubscriber subscriber = AbstractHttpRequestSubscriber.of( + channel, requestEncoder, responseDecoder, protocol, + ctx, req, res, writeTimeoutMillis, isWebSocket()); + req.subscribe(subscriber, channel.eventLoop(), SubscriptionOption.WITH_POOLED_OBJECTS); } } + private boolean isWebSocket() { + return serializationFormat == SerializationFormat.WS; + } + @Override public int incrementAndGetNumRequestsSent() { return ++numRequestsSent; @@ -345,40 +350,23 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc final SessionProtocol protocol = (SessionProtocol) evt; this.protocol = protocol; if (protocol == H1 || protocol == H1C) { - final Http1ResponseDecoder responseDecoder = ctx.pipeline().get(Http1ResponseDecoder.class); - - final long idleTimeoutMillis = clientFactory.idleTimeoutMillis(); - final long pingIntervalMillis = clientFactory.pingIntervalMillis(); - final long maxConnectionAgeMillis = clientFactory.maxConnectionAgeMillis(); - final int maxNumRequestsPerConnection = clientFactory.maxNumRequestsPerConnection(); - final boolean keepAliveOnPing = clientFactory.keepAliveOnPing(); - final boolean needsKeepAliveHandler = - needsKeepAliveHandler(idleTimeoutMillis, pingIntervalMillis, - maxConnectionAgeMillis, maxNumRequestsPerConnection); - - final KeepAliveHandler keepAliveHandler; - if (needsKeepAliveHandler) { - final Timer keepAliveTimer = - MoreMeters.newTimer(clientFactory.meterRegistry(), - "armeria.client.connections.lifespan", - ImmutableList.of(Tag.of("protocol", protocol.uriText()))); - keepAliveHandler = new Http1ClientKeepAliveHandler( - channel, responseDecoder, keepAliveTimer, idleTimeoutMillis, - pingIntervalMillis, maxConnectionAgeMillis, maxNumRequestsPerConnection, - keepAliveOnPing); + final HttpResponseDecoder responseDecoder; + if (isWebSocket()) { + responseDecoder = ctx.pipeline().get(WebSocketHttp1ClientChannelHandler.class); } else { - keepAliveHandler = new NoopKeepAliveHandler(); + responseDecoder = ctx.pipeline().get(Http1ResponseDecoder.class); } + final KeepAliveHandler keepAliveHandler = responseDecoder.keepAliveHandler(); + keepAliveHandler.initialize(ctx); final ClientHttp1ObjectEncoder requestEncoder = new ClientHttp1ObjectEncoder(channel, protocol, clientFactory.http1HeaderNaming(), - keepAliveHandler); + keepAliveHandler, + isWebSocket()); if (keepAliveHandler instanceof Http1ClientKeepAliveHandler) { ((Http1ClientKeepAliveHandler) keepAliveHandler).setEncoder(requestEncoder); } - responseDecoder.setKeepAliveHandler(ctx, keepAliveHandler); - this.requestEncoder = requestEncoder; this.responseDecoder = responseDecoder; } else if (protocol == H2 || protocol == H2C) { @@ -465,9 +453,10 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { assert responseDecoder == null || !responseDecoder.hasUnfinishedResponses(); sessionTimeoutFuture.cancel(false); if (proxyDestinationAddress != null) { - channelPool.connect(proxyDestinationAddress, retryProtocol, poolKey, sessionPromise); + channelPool.connect(proxyDestinationAddress, retryProtocol, serializationFormat, + poolKey, sessionPromise); } else { - channelPool.connect(remoteAddress, retryProtocol, poolKey, sessionPromise); + channelPool.connect(remoteAddress, retryProtocol, serializationFormat, poolKey, sessionPromise); } } else { // Fail all pending responses. diff --git a/core/src/main/java/com/linecorp/armeria/client/RestClientPreparation.java b/core/src/main/java/com/linecorp/armeria/client/RestClientPreparation.java index 5109e1632e8..86aea236e12 100644 --- a/core/src/main/java/com/linecorp/armeria/client/RestClientPreparation.java +++ b/core/src/main/java/com/linecorp/armeria/client/RestClientPreparation.java @@ -198,6 +198,12 @@ public RestClientPreparation content(MediaType contentType, HttpData content) { return this; } + @Override + public RestClientPreparation content(Publisher content) { + delegate.content(content); + return this; + } + @Override public RestClientPreparation content(MediaType contentType, Publisher content) { delegate.content(contentType, content); diff --git a/core/src/main/java/com/linecorp/armeria/client/TransformingRequestPreparation.java b/core/src/main/java/com/linecorp/armeria/client/TransformingRequestPreparation.java index 74a88494ae2..852131f6a5d 100644 --- a/core/src/main/java/com/linecorp/armeria/client/TransformingRequestPreparation.java +++ b/core/src/main/java/com/linecorp/armeria/client/TransformingRequestPreparation.java @@ -205,6 +205,12 @@ public TransformingRequestPreparation content(MediaType contentType, return this; } + @Override + public TransformingRequestPreparation content(Publisher content) { + delegate.content(content); + return this; + } + @Override public TransformingRequestPreparation content(MediaType contentType, Publisher content) { diff --git a/core/src/main/java/com/linecorp/armeria/client/WebClientRequestPreparation.java b/core/src/main/java/com/linecorp/armeria/client/WebClientRequestPreparation.java index de7a06f2e90..15cf47cfae6 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebClientRequestPreparation.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebClientRequestPreparation.java @@ -491,6 +491,11 @@ public WebClientRequestPreparation content(MediaType contentType, HttpData conte return (WebClientRequestPreparation) super.content(contentType, content); } + @Override + public WebClientRequestPreparation content(Publisher publisher) { + return (WebClientRequestPreparation) super.content(publisher); + } + @Override public WebClientRequestPreparation content(MediaType contentType, Publisher publisher) { return (WebClientRequestPreparation) super.content(contentType, publisher); diff --git a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ClientChannelHandler.java b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ClientChannelHandler.java new file mode 100644 index 00000000000..3158fcd5dd3 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ClientChannelHandler.java @@ -0,0 +1,263 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static com.linecorp.armeria.client.AbstractHttpResponseDecoder.contentTooLargeException; +import static io.netty.handler.codec.http.LastHttpContent.EMPTY_LAST_CONTENT; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.math.LongMath; + +import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.ProtocolViolationException; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.client.DecodedHttpResponse; +import com.linecorp.armeria.internal.client.HttpSession; +import com.linecorp.armeria.internal.common.ArmeriaHttpUtil; +import com.linecorp.armeria.internal.common.InboundTrafficController; +import com.linecorp.armeria.internal.common.KeepAliveHandler; +import com.linecorp.armeria.internal.common.NoopKeepAliveHandler; +import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.util.ReferenceCountUtil; + +final class WebSocketHttp1ClientChannelHandler extends ChannelDuplexHandler implements HttpResponseDecoder { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketHttp1ClientChannelHandler.class); + + private enum State { + NEEDS_HANDSHAKE_RESPONSE, + NEEDS_HANDSHAKE_RESPONSE_END, + UPGRADE_COMPLETE + } + + private final Channel channel; + private final InboundTrafficController inboundTrafficController; + @Nullable + private HttpResponseWrapper res; + private final KeepAliveHandler keepAliveHandler; + + private State state = State.NEEDS_HANDSHAKE_RESPONSE; + @Nullable + private HttpSession httpSession; + + WebSocketHttp1ClientChannelHandler(Channel channel) { + this.channel = channel; + inboundTrafficController = InboundTrafficController.ofHttp1(channel); + + // Use NoopKeepAliveHandler because + // - hasRequestsInProgress is always true for WebSocket + // - a Ping frame is not sent by the keepAliveHandler but by the upper layer. + // TODO(minwoox): Provide a dedicated KeepAliveHandler to the upper layer (e.g. WebSocketClient) + // that handles Ping frames for WebSocket. + keepAliveHandler = new NoopKeepAliveHandler(); + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public InboundTrafficController inboundTrafficController() { + return inboundTrafficController; + } + + @Override + public HttpResponseWrapper addResponse(int id, DecodedHttpResponse decodedHttpResponse, + ClientRequestContext ctx, EventLoop eventLoop) { + assert res == null; + res = new WebSocketHttp1ResponseWrapper(decodedHttpResponse, eventLoop, ctx, + ctx.responseTimeoutMillis(), ctx.maxResponseLength()); + return res; + } + + @Nullable + @Override + public HttpResponseWrapper getResponse(int unused) { + return res; + } + + @Nullable + @Override + public HttpResponseWrapper removeResponse(int unused) { + return res; + } + + @Override + public boolean hasUnfinishedResponses() { + return res != null; + } + + @Override + public boolean reserveUnfinishedResponse(int unused) { + return true; + } + + @Override + public void decrementUnfinishedResponses() {} + + @Override + public void failUnfinishedResponses(Throwable cause) { + if (res != null) { + res.close(cause); + } + } + + @Override + public HttpSession session() { + if (httpSession != null) { + return httpSession; + } + return httpSession = HttpSession.get(channel); + } + + @Override + public KeepAliveHandler keepAliveHandler() { + return keepAliveHandler; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + keepAliveHandler.destroy(); + if (res != null) { + res.close(ClosedSessionException.get()); + } + ctx.fireChannelInactive(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + try { + switch (state) { + case NEEDS_HANDSHAKE_RESPONSE: + if (!(msg instanceof HttpObject)) { + ctx.fireChannelRead(msg); + return; + } + if (!(msg instanceof HttpResponse)) { + failWithUnexpectedMessageType(ctx, msg, HttpResponse.class); + return; + } + + final HttpResponse nettyRes = (HttpResponse) msg; + final DecoderResult decoderResult = nettyRes.decoderResult(); + if (!decoderResult.isSuccess()) { + fail(ctx, new ProtocolViolationException(decoderResult.cause())); + return; + } + + if (!HttpUtil.isKeepAlive(nettyRes)) { + session().deactivate(); + } + + if (res == null && ArmeriaHttpUtil.isRequestTimeoutResponse(nettyRes)) { + ctx.close(); + return; + } + + res.startResponse(); + final ResponseHeaders responseHeaders = ArmeriaHttpUtil.toArmeria(nettyRes); + if (responseHeaders.status() == HttpStatus.SWITCHING_PROTOCOLS) { + final ChannelPipeline pipeline = ctx.pipeline(); + pipeline.remove(HttpClientCodec.class); + state = State.NEEDS_HANDSHAKE_RESPONSE_END; + } + if (!res.tryWriteResponseHeaders(responseHeaders)) { + fail(ctx, ClosedSessionException.get()); + } + break; + case NEEDS_HANDSHAKE_RESPONSE_END: + // HttpClientCodec produces this after creating the headers. We can just ignore it. + if (msg != EMPTY_LAST_CONTENT) { + failWithUnexpectedMessageType(ctx, msg, EMPTY_LAST_CONTENT.getClass()); + return; + } + state = State.UPGRADE_COMPLETE; + break; + case UPGRADE_COMPLETE: + assert msg instanceof ByteBuf; + final ByteBuf data = (ByteBuf) msg; + final int dataLength = data.readableBytes(); + if (dataLength > 0) { + final long maxContentLength = res.maxContentLength(); + final long writtenBytes = res.writtenBytes(); + if (maxContentLength > 0 && writtenBytes > maxContentLength - dataLength) { + final long transferred = LongMath.saturatedAdd(writtenBytes, dataLength); + res.close(contentTooLargeException(res, transferred)); + ctx.close(); + return; + } + if (!res.tryWriteData(HttpData.wrap(data.retain()))) { + ctx.close(); + } + } + break; + } + } finally { + ReferenceCountUtil.release(msg); + } + } + + private void failWithUnexpectedMessageType(ChannelHandlerContext ctx, Object msg, Class expected) { + final String message; + try (TemporaryThreadLocals tempThreadLocals = TemporaryThreadLocals.acquire()) { + final StringBuilder buf = tempThreadLocals.stringBuilder(); + buf.append("unexpected message type: " + msg.getClass().getName() + + " (expected: " + expected.getName() + ", channel: " + ctx.channel() + ')'); + message = buf.toString(); + } + fail(ctx, new ProtocolViolationException(message)); + } + + private void fail(ChannelHandlerContext ctx, Throwable cause) { + if (res != null) { + res.close(cause); + } else { + logger.warn("Unexpected exception:", cause); + } + + ctx.close(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof HttpContent) { + ctx.write(((HttpContent) msg).content(), promise); + return; + } + ctx.write(msg, promise); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1RequestSubscriber.java b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1RequestSubscriber.java new file mode 100644 index 00000000000..83b733a60ac --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1RequestSubscriber.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpObject; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.internal.client.DecodedHttpResponse; +import com.linecorp.armeria.unsafe.PooledObjects; + +import io.netty.channel.Channel; + +final class WebSocketHttp1RequestSubscriber extends AbstractHttpRequestSubscriber { + + WebSocketHttp1RequestSubscriber(Channel ch, ClientHttpObjectEncoder encoder, + HttpResponseDecoder responseDecoder, + HttpRequest request, DecodedHttpResponse originalRes, + ClientRequestContext ctx, long timeoutMillis) { + super(ch, encoder, responseDecoder, request, originalRes, ctx, timeoutMillis, false, false); + } + + @Override + public void onNext(HttpObject o) { + if (!(o instanceof HttpData)) { + failAndReset(new IllegalArgumentException( + "published an HttpObject that's not HttpData: " + o)); + PooledObjects.close(o); + return; + } + + switch (state()) { + case NEEDS_DATA: { + writeData((HttpData) o); + channel().flush(); + break; + } + case DONE: + // Cancel the subscription if any message comes here after the state has been changed to DONE. + cancel(); + PooledObjects.close(o); + break; + } + } +} + diff --git a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ResponseWrapper.java b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ResponseWrapper.java new file mode 100644 index 00000000000..735e640ae8c --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp1ResponseWrapper.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.client.DecodedHttpResponse; +import com.linecorp.armeria.internal.client.websocket.WebSocketClientUtil; + +import io.netty.channel.EventLoop; + +final class WebSocketHttp1ResponseWrapper extends HttpResponseWrapper { + + WebSocketHttp1ResponseWrapper(DecodedHttpResponse delegate, + EventLoop eventLoop, ClientRequestContext ctx, + long responseTimeoutMillis, long maxContentLength) { + super(delegate, eventLoop, ctx, responseTimeoutMillis, maxContentLength); + WebSocketClientUtil.setClosingResponseTask(ctx, cause -> { + super.close(cause, false); + }); + } + + @Override + void close(@Nullable Throwable cause, boolean cancel) { + if (cancel || !(cause instanceof ClosedSessionException)) { + super.close(cause, cancel); + return; + } + // Close the delegate directly so that we can give a chance to WebSocketFrameDecoder to close the + // response normally if it receives a close frame before the ClosedSessionException is raised. + delegate().close(cause); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp2RequestSubscriber.java b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp2RequestSubscriber.java new file mode 100644 index 00000000000..21798b4825c --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/WebSocketHttp2RequestSubscriber.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.internal.client.DecodedHttpResponse; + +import io.netty.channel.Channel; +import io.netty.handler.codec.http.HttpHeaderValues; + +final class WebSocketHttp2RequestSubscriber extends HttpRequestSubscriber { + + WebSocketHttp2RequestSubscriber(Channel ch, ClientHttpObjectEncoder encoder, + HttpResponseDecoder responseDecoder, + HttpRequest request, DecodedHttpResponse originalRes, + ClientRequestContext ctx, long timeoutMillis) { + super(ch, encoder, responseDecoder, request, originalRes, ctx, timeoutMillis); + } + + @Override + RequestHeaders mapHeaders(RequestHeaders headers) { + if (headers.method() == HttpMethod.CONNECT) { + return headers; + } + return headers.toBuilder() + .method(HttpMethod.CONNECT) + .removeAndThen(HttpHeaderNames.CONNECTION) + .removeAndThen(HttpHeaderNames.UPGRADE) + .removeAndThen(HttpHeaderNames.SEC_WEBSOCKET_KEY) + .set(HttpHeaderNames.PROTOCOL, HttpHeaderValues.WEBSOCKET.toString()) + .build(); + } +} + diff --git a/core/src/main/java/com/linecorp/armeria/client/endpoint/DynamicEndpointGroup.java b/core/src/main/java/com/linecorp/armeria/client/endpoint/DynamicEndpointGroup.java index 91765b810f8..e4a89d0e1ab 100644 --- a/core/src/main/java/com/linecorp/armeria/client/endpoint/DynamicEndpointGroup.java +++ b/core/src/main/java/com/linecorp/armeria/client/endpoint/DynamicEndpointGroup.java @@ -30,9 +30,8 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Lock; +import java.util.function.Consumer; -import com.google.common.base.MoreObjects; -import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -340,20 +339,36 @@ public final void close() { @Override public String toString() { - return toStringHelper().toString(); + return toString(unused -> {}); } /** - * Returns {@link ToStringHelper} that contains fields information. + * Returns the string representation of this {@link DynamicEndpointGroup}. Specify a {@link Consumer} + * to add more fields to the returned string, e.g. + *
{@code
+     * > @Override
+     * > public String toString() {
+     * >     return toString(buf -> {
+     * >         buf.append(", foo=").append(foo);
+     * >         buf.append(", bar=").append(bar);
+     * >     });
+     * > }
+     * }
+ * + * @param builderMutator the {@link Consumer} that appends the additional fields into the given + * {@link StringBuilder}. */ - protected ToStringHelper toStringHelper() { - return MoreObjects.toStringHelper(this) - .omitNullValues() - .add("selectionStrategy", selectionStrategy.getClass()) - .add("allowsEmptyEndpoints", allowEmptyEndpoints) - .add("endpoints", truncate(endpoints, 10)) - .add("numEndpoints", endpoints.size()) - .add("initialized", initialEndpointsFuture.isDone()); + @UnstableApi + protected final String toString(Consumer builderMutator) { + final StringBuilder buf = new StringBuilder(); + buf.append(getClass().getSimpleName()); + buf.append("{selectionStrategy=").append(selectionStrategy.getClass()); + buf.append(", allowsEmptyEndpoints=").append(allowEmptyEndpoints); + buf.append(", initialized=").append(initialEndpointsFuture.isDone()); + buf.append(", numEndpoints=").append(endpoints.size()); + buf.append(", endpoints=").append(truncate(endpoints, 10)); + builderMutator.accept(buf); + return buf.append('}').toString(); } private class InitialEndpointsFuture extends EventLoopCheckingFuture> { diff --git a/core/src/main/java/com/linecorp/armeria/client/endpoint/PropertiesEndpointGroup.java b/core/src/main/java/com/linecorp/armeria/client/endpoint/PropertiesEndpointGroup.java index 0bfa926ef79..40539aa67db 100644 --- a/core/src/main/java/com/linecorp/armeria/client/endpoint/PropertiesEndpointGroup.java +++ b/core/src/main/java/com/linecorp/armeria/client/endpoint/PropertiesEndpointGroup.java @@ -233,8 +233,6 @@ protected void doCloseAsync(CompletableFuture future) { @Override public String toString() { - return toStringHelper() - .add("watchRegisterKey", watchRegisterKey) - .toString(); + return toString(buf -> buf.append(", watchRegisterKey=").append(watchRegisterKey)); } } diff --git a/core/src/main/java/com/linecorp/armeria/client/endpoint/dns/DnsEndpointGroup.java b/core/src/main/java/com/linecorp/armeria/client/endpoint/dns/DnsEndpointGroup.java index 93b271cb2bb..aeee34feaf7 100644 --- a/core/src/main/java/com/linecorp/armeria/client/endpoint/dns/DnsEndpointGroup.java +++ b/core/src/main/java/com/linecorp/armeria/client/endpoint/dns/DnsEndpointGroup.java @@ -242,10 +242,10 @@ final void logDnsResolutionResult(Collection endpoints, int ttl) { @Override public String toString() { - return toStringHelper() - .add("questions", questions) - .add("logPrefix", logPrefix) - .add("attemptsSoFar", attemptsSoFar) - .toString(); + return toString(buf -> { + buf.append(", questions=").append(questions); + buf.append(", logPrefix=").append(logPrefix); + buf.append(", attemptsSoFar=").append(attemptsSoFar); + }); } } diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/DefaultWebSocketClient.java b/core/src/main/java/com/linecorp/armeria/client/websocket/DefaultWebSocketClient.java new file mode 100644 index 00000000000..e3f2eed09e6 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/DefaultWebSocketClient.java @@ -0,0 +1,261 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.websocket; + +import static com.linecorp.armeria.internal.client.ClientUtil.UNDEFINED_URI; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.generateSecWebSocketAccept; +import static java.util.Objects.requireNonNull; + +import java.net.URI; +import java.util.Base64; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ThreadLocalRandom; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Joiner; + +import com.linecorp.armeria.client.ClientOptions; +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.ClientRequestContextCaptor; +import com.linecorp.armeria.client.Clients; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpMethod; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.RequestHeadersBuilder; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.Scheme; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.SplitHttpResponse; +import com.linecorp.armeria.common.logging.RequestLogProperty; +import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.internal.common.DefaultSplitHttpResponse; +import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder; +import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper; + +import io.netty.handler.codec.http.HttpHeaderValues; + +final class DefaultWebSocketClient implements WebSocketClient { + + static final WebSocketClient DEFAULT = WebSocketClient.of(UNDEFINED_URI); + + private static final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(true); + + private final WebClient webClient; + private final int maxFramePayloadLength; + private final boolean allowMaskMismatch; + private final List subprotocols; + private final String joinedSubprotocols; + + DefaultWebSocketClient(WebClient webClient, int maxFramePayloadLength, boolean allowMaskMismatch, + List subprotocols) { + this.webClient = webClient; + this.maxFramePayloadLength = maxFramePayloadLength; + this.allowMaskMismatch = allowMaskMismatch; + this.subprotocols = subprotocols; + if (!subprotocols.isEmpty()) { + joinedSubprotocols = Joiner.on(", ").join(subprotocols); + } else { + joinedSubprotocols = ""; + } + } + + @Override + public CompletableFuture connect(String path) { + requireNonNull(path, "path"); + final RequestHeaders requestHeaders = webSocketHeaders(path); + + final CompletableFuture> outboundFuture = new CompletableFuture<>(); + final HttpRequest request = HttpRequest.of(requestHeaders, StreamMessage.of(outboundFuture)); + final HttpResponse response; + final ClientRequestContext ctx; + try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + response = webClient.execute(request); + ctx = captor.get(); + } + final SplitHttpResponse split = + new DefaultSplitHttpResponse(response, ctx.eventLoop(), responseHeaders -> { + final SessionProtocol actualSessionProtocol = actualSessionProtocol(ctx); + if (actualSessionProtocol.isExplicitHttp1()) { + return true; + } + assert actualSessionProtocol.isExplicitHttp2(); + return !responseHeaders.status().isInformational(); + }); + + final CompletableFuture result = new CompletableFuture<>(); + split.headers().handle((responseHeaders, cause) -> { + if (cause != null) { + fail(outboundFuture, response, result, cause); + return null; + } + if (!validateResponseHeaders(ctx, requestHeaders, responseHeaders, outboundFuture, + response, result)) { + return null; + } + + final WebSocketClientFrameDecoder decoder = + new WebSocketClientFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); + final WebSocketWrapper inbound = new WebSocketWrapper(split.body().decode(decoder, ctx.alloc())); + + result.complete(new WebSocketSession(ctx, responseHeaders, inbound, outboundFuture, encoder)); + return null; + }); + return result; + } + + private RequestHeaders webSocketHeaders(String path) { + final RequestHeadersBuilder builder; + if (scheme().sessionProtocol().isExplicitHttp2()) { + builder = RequestHeaders.builder(HttpMethod.CONNECT, path) + .set(HttpHeaderNames.PROTOCOL, HttpHeaderValues.WEBSOCKET.toString()); + } else { + builder = RequestHeaders.builder(HttpMethod.GET, path) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()) + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()); + final String secWebSocketKey = generateSecWebSocketKey(); + builder.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, secWebSocketKey); + } + + builder.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); + if (!subprotocols.isEmpty()) { + builder.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, joinedSubprotocols); + } + + return builder.build(); + } + + private boolean validateResponseHeaders( + ClientRequestContext ctx, RequestHeaders requestHeaders, ResponseHeaders responseHeaders, + CompletableFuture> outboundFuture, HttpResponse response, + CompletableFuture result) { + if (actualSessionProtocol(ctx).isExplicitHttp2()) { + final HttpStatus status = responseHeaders.status(); + if (status != HttpStatus.OK) { + fail(outboundFuture, response, result, new WebSocketClientHandshakeException( + "invalid status: " + status + " (expected: " + HttpStatus.OK + ')', + responseHeaders)); + return false; + } + } else { + if (!isHttp1WebSocketResponse(responseHeaders)) { + fail(outboundFuture, response, result, new WebSocketClientHandshakeException( + "invalid response headers: " + responseHeaders, responseHeaders)); + return false; + } + final String secWebSocketKey = requestHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + assert secWebSocketKey != null; + final String secWebSocketAccept = responseHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); + if (secWebSocketAccept == null) { + fail(outboundFuture, response, result, new WebSocketClientHandshakeException( + HttpHeaderNames.SEC_WEBSOCKET_ACCEPT + " is null.", responseHeaders)); + return false; + } + if (!secWebSocketAccept.equals(generateSecWebSocketAccept(secWebSocketKey))) { + fail(outboundFuture, response, result, new WebSocketClientHandshakeException( + "invalid " + HttpHeaderNames.SEC_WEBSOCKET_ACCEPT + " header: " + + secWebSocketAccept, responseHeaders)); + return false; + } + } + + if (!subprotocols.isEmpty()) { + final String responseSubprotocol = responseHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); + // null is allowed if the server does not agree to any of the client's requested + // subprotocols. + // https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.2 + + if (responseSubprotocol != null && !subprotocols.contains(responseSubprotocol)) { + fail(outboundFuture, response, result, new WebSocketClientHandshakeException( + "invalid " + HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL + " header: " + + responseSubprotocol + " (expected: one of " + subprotocols + ')', + responseHeaders)); + return false; + } + } + return true; + } + + private static SessionProtocol actualSessionProtocol(ClientRequestContext ctx) { + // This is always called after a ResponseHeaders is received which means + // RequestLogProperty.SESSION is already set. + return ctx.log().ensureAvailable(RequestLogProperty.SESSION).sessionProtocol(); + } + + private static void fail(CompletableFuture> outboundFuture, HttpResponse response, + CompletableFuture result, Throwable cause) { + outboundFuture.completeExceptionally(cause); + response.abort(cause); + result.completeExceptionally(cause); + } + + @VisibleForTesting + static String generateSecWebSocketKey() { + final byte[] bytes = new byte[16]; + ThreadLocalRandom.current().nextBytes(bytes); + return Base64.getEncoder().encodeToString(bytes); + } + + private static boolean isHttp1WebSocketResponse(ResponseHeaders responseHeaders) { + return responseHeaders.status() == HttpStatus.SWITCHING_PROTOCOLS && + HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase( + responseHeaders.get(HttpHeaderNames.UPGRADE)) && + HttpHeaderValues.UPGRADE.contentEqualsIgnoreCase( + responseHeaders.get(HttpHeaderNames.CONNECTION)); + } + + @Override + public Scheme scheme() { + return webClient.scheme(); + } + + @Override + public EndpointGroup endpointGroup() { + return webClient.endpointGroup(); + } + + @Override + public String absolutePathRef() { + return webClient.absolutePathRef(); + } + + @Override + public URI uri() { + return webClient.uri(); + } + + @Override + public Class clientType() { + return webClient.clientType(); + } + + @Override + public ClientOptions options() { + return webClient.options(); + } + + @Override + public WebClient unwrap() { + return webClient; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClient.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClient.java new file mode 100644 index 00000000000..cc11f86ded9 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClient.java @@ -0,0 +1,222 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.websocket; + +import static java.util.Objects.requireNonNull; + +import java.net.URI; +import java.util.concurrent.CompletableFuture; + +import com.linecorp.armeria.client.ClientBuilderParams; +import com.linecorp.armeria.client.ClientOptions; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.common.Scheme; +import com.linecorp.armeria.common.SerializationFormat; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.util.Unwrappable; + +/** + * A WebSocket client. + * This client has a few different default values for {@link ClientOptions} from {@link WebClient} + * because of the nature of WebSocket. See {@link WebSocketClientBuilder} for more information. + * + *

WebSocket client example: + *

{@code
+ * WebSocketClient client = WebSocketClient.of("ws://www.example.com");
+ * client.connect("/chat").thenAccept(webSocketSession -> {
+ *     // Write messages to the server.
+ *     WebSocketWriter writer = WebSocket.streaming();
+ *     webSocketSessions.setOutbound(writer);
+ *     outbound.write("Hello ");
+ *     // You can also use backpressure using whenConsumed().
+ *     outbound.whenConsumed().thenRun(() -> outbound.write("world!"));
+ *
+ *     // Read messages from the server.
+ *     Subscriber myWebSocketSubscriber = new Subscriber() {
+ *         @Override
+ *         public void onSubscribe(Subscription s) {
+ *             s.request(Long.MAX_VALUE);
+ *         }
+ *         @Override
+ *         public void onNext(WebSocketFrame webSocketFrame) {
+ *             if (webSocketFrame.type() == WebSocketFrameType.TEXT) {
+ *                 System.out.println(webSocketFrame.text());
+ *             }
+ *             ...
+ *         }
+ *         ...
+ *     };
+ *     webSocketSessions.inbound().subscribe(myWebSocketSubscriber);
+ * });
+ * }
+ * + * @see The WebSocket Protocol + */ +@UnstableApi +public interface WebSocketClient extends ClientBuilderParams, Unwrappable { + + /** + * Returns a {@link WebSocketClient} without a base URI. + */ + static WebSocketClient of() { + return DefaultWebSocketClient.DEFAULT; + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@code uri} using the + * default options. + */ + static WebSocketClient of(String uri) { + return builder(uri).build(); + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@link URI} using the + * default options. + */ + static WebSocketClient of(URI uri) { + return builder(uri).build(); + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@link EndpointGroup} with + * the specified {@code scheme} using the default {@link ClientOptions}. + */ + static WebSocketClient of(String scheme, EndpointGroup endpointGroup) { + return builder(scheme, endpointGroup).build(); + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@link EndpointGroup} with + * the specified {@link Scheme} using the default {@link ClientOptions}. + */ + static WebSocketClient of(Scheme scheme, EndpointGroup endpointGroup) { + return builder(scheme, endpointGroup).build(); + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@link EndpointGroup} with + * the specified {@link SessionProtocol} using the default {@link ClientOptions}. + */ + static WebSocketClient of(SessionProtocol protocol, EndpointGroup endpointGroup) { + return builder(protocol, endpointGroup).build(); + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@link EndpointGroup} with + * the specified {@code scheme} and {@code path} using the default {@link ClientOptions}. + */ + static WebSocketClient of(String scheme, EndpointGroup endpointGroup, String path) { + return builder(scheme, endpointGroup, path).build(); + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@link EndpointGroup} with + * the specified {@link Scheme} and {@code path} using the default {@link ClientOptions}. + */ + static WebSocketClient of(Scheme scheme, EndpointGroup endpointGroup, String path) { + return builder(scheme, endpointGroup, path).build(); + } + + /** + * Returns a new {@link WebSocketClient} that connects to the specified {@link EndpointGroup} with + * the specified {@code scheme} and {@code path} using the default {@link ClientOptions}. + */ + static WebSocketClient of(SessionProtocol protocol, EndpointGroup endpointGroup, String path) { + return builder(protocol, endpointGroup, path).build(); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified base {@code uri}. + */ + static WebSocketClientBuilder builder(String uri) { + return builder(URI.create(requireNonNull(uri, "uri"))); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified base {@link URI}. + */ + static WebSocketClientBuilder builder(URI uri) { + return new WebSocketClientBuilder(requireNonNull(uri, "uri")); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified {@code scheme} + * and the {@link EndpointGroup}. + */ + static WebSocketClientBuilder builder(String scheme, EndpointGroup endpointGroup) { + requireNonNull(scheme, "scheme"); + return builder(Scheme.parse(scheme), endpointGroup); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified {@link Scheme} + * and the {@link EndpointGroup}. + */ + static WebSocketClientBuilder builder(Scheme scheme, EndpointGroup endpointGroup) { + requireNonNull(scheme, "scheme"); + requireNonNull(endpointGroup, "endpointGroup"); + return new WebSocketClientBuilder(scheme, endpointGroup, null); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified {@link SessionProtocol} + * and the {@link EndpointGroup}. + */ + static WebSocketClientBuilder builder(SessionProtocol protocol, EndpointGroup endpointGroup) { + requireNonNull(protocol, "protocol"); + return builder(Scheme.of(SerializationFormat.WS, protocol), endpointGroup); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified {@code scheme}, + * the {@link EndpointGroup}, and the {@code path}. + */ + static WebSocketClientBuilder builder(String scheme, EndpointGroup endpointGroup, String path) { + requireNonNull(scheme, "scheme"); + return builder(Scheme.parse(scheme), endpointGroup, path); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified {@link Scheme}, + * the {@link EndpointGroup}, and the {@code path}. + */ + static WebSocketClientBuilder builder(Scheme scheme, EndpointGroup endpointGroup, String path) { + requireNonNull(scheme, "scheme"); + requireNonNull(endpointGroup, "endpointGroup"); + return new WebSocketClientBuilder(scheme, endpointGroup, path); + } + + /** + * Returns a new {@link WebSocketClientBuilder} created with the specified {@link SessionProtocol}, + * the {@link EndpointGroup}, and the {@code path}. + */ + static WebSocketClientBuilder builder(SessionProtocol protocol, EndpointGroup endpointGroup, String path) { + requireNonNull(protocol, "protocol"); + return builder(Scheme.of(SerializationFormat.WS, protocol), endpointGroup, path); + } + + /** + * Connects to the specified {@code path}. + */ + CompletableFuture connect(String path); + + @Override + WebClient unwrap(); +} diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java new file mode 100644 index 00000000000..3a32e48c58c --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilder.java @@ -0,0 +1,374 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.websocket; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.linecorp.armeria.common.SessionProtocol.httpAndHttpsValues; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS; +import static java.util.Objects.requireNonNull; + +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Map.Entry; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import com.linecorp.armeria.client.AbstractWebClientBuilder; +import com.linecorp.armeria.client.ClientFactory; +import com.linecorp.armeria.client.ClientOption; +import com.linecorp.armeria.client.ClientOptionValue; +import com.linecorp.armeria.client.ClientOptions; +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.client.Clients; +import com.linecorp.armeria.client.DecoratingHttpClientFunction; +import com.linecorp.armeria.client.DecoratingRpcClientFunction; +import com.linecorp.armeria.client.Endpoint; +import com.linecorp.armeria.client.HttpClient; +import com.linecorp.armeria.client.RpcClient; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.client.endpoint.EndpointGroup; +import com.linecorp.armeria.client.redirect.RedirectConfig; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.RequestId; +import com.linecorp.armeria.common.Scheme; +import com.linecorp.armeria.common.SerializationFormat; +import com.linecorp.armeria.common.SuccessFunction; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.auth.AuthToken; +import com.linecorp.armeria.common.auth.BasicToken; +import com.linecorp.armeria.common.auth.OAuth1aToken; +import com.linecorp.armeria.common.auth.OAuth2Token; + +/** + * Builds a {@link WebSocketClient}. + * This client has the different default options from {@link WebClient}. Here are the differences: + *
    + *
  • {@link ClientOptions#RESPONSE_TIMEOUT_MILLIS} is {@code 0}.
  • + *
  • {@link ClientOptions#MAX_RESPONSE_LENGTH} is {@code 0}.
  • + *
  • {@link ClientOptions#REQUEST_AUTO_ABORT_DELAY_MILLIS} is {@code 5000}.
  • + *
  • {@link ClientOptions#AUTO_FILL_ORIGIN_HEADER} is {@code true}.
  • + *
+ */ +@UnstableApi +public final class WebSocketClientBuilder extends AbstractWebClientBuilder { + + static final int DEFAULT_MAX_FRAME_PAYLOAD_LENGTH = 65535; // 64 * 1024 -1 + + private int maxFramePayloadLength = DEFAULT_MAX_FRAME_PAYLOAD_LENGTH; + private boolean allowMaskMismatch; + private List subprotocols = ImmutableList.of(); + + WebSocketClientBuilder(URI uri) { + super(validateUri(requireNonNull(uri, "uri")), null, null, null); + setWebSocketDefaultOption(); + } + + WebSocketClientBuilder(Scheme scheme, EndpointGroup endpointGroup, @Nullable String path) { + super(null, validateScheme(requireNonNull(scheme, "scheme")), endpointGroup, path); + setWebSocketDefaultOption(); + } + + private static URI validateUri(URI uri) { + if (Clients.isUndefinedUri(uri)) { + return uri; + } + final String givenScheme = requireNonNull(uri, "uri").getScheme(); + final Scheme scheme = validateScheme(givenScheme); + if (scheme.uriText().equals(givenScheme)) { + // No need to replace the user-specified scheme because it's already in its normalized form. + return uri; + } + // Replace the user-specified scheme with the normalized one. + // e.g. http://foo.com/ -> ws+http://foo.com/ + return URI.create(scheme.uriText() + uri.toString().substring(givenScheme.length())); + } + + private static Scheme validateScheme(String scheme) { + final Scheme parsedScheme = Scheme.tryParse(scheme); + if (parsedScheme != null) { + return validateScheme(parsedScheme); + } + + throw invalidSchemeException(scheme); + } + + private static Scheme validateScheme(Scheme scheme) { + final SerializationFormat serializationFormat = scheme.serializationFormat(); + if ((serializationFormat == SerializationFormat.WS || + serializationFormat == SerializationFormat.NONE) && + httpAndHttpsValues().contains(scheme.sessionProtocol())) { + if (serializationFormat == SerializationFormat.WS) { + return scheme; + } + return Scheme.of(SerializationFormat.WS, scheme.sessionProtocol()); + } + throw invalidSchemeException(scheme.toString()); + } + + private static IllegalArgumentException invalidSchemeException(String scheme) { + return new IllegalArgumentException( + String.format("scheme: %s (expected serialization format: %s or %s," + + " expected session protocol: one of %s)", scheme, SerializationFormat.WS, + SerializationFormat.NONE, httpAndHttpsValues())); + } + + private void setWebSocketDefaultOption() { + responseTimeoutMillis(DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS); + maxResponseLength(DEFAULT_MAX_REQUEST_RESPONSE_LENGTH); + requestAutoAbortDelayMillis(DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); + autoFillOriginHeader(true); + contextCustomizer(ctx -> ctx.logBuilder().serializationFormat(SerializationFormat.WS)); + } + + /** + * Sets the maximum length of a frame's payload. + * {@value DEFAULT_MAX_FRAME_PAYLOAD_LENGTH} is used by default. + */ + public WebSocketClientBuilder maxFramePayloadLength(int maxFramePayloadLength) { + checkArgument(maxFramePayloadLength > 0, + "maxFramePayloadLength: %s (expected: > 0)", maxFramePayloadLength); + this.maxFramePayloadLength = maxFramePayloadLength; + return this; + } + + /** + * Sets whether the decoder allows to loosen the masking requirement on received frames. + * It's not allowed by default. + */ + public WebSocketClientBuilder allowMaskMismatch(boolean allowMaskMismatch) { + this.allowMaskMismatch = allowMaskMismatch; + return this; + } + + /** + * Sets the subprotocols to use with the WebSocket Protocol. + * + * @see + * Subprotocols Using the WebSocket Protocol + */ + public WebSocketClientBuilder subprotocols(String... subprotocols) { + return subprotocols(ImmutableSet.copyOf(requireNonNull(subprotocols, "subprotocols"))); + } + + /** + * Sets the subprotocols to use with the WebSocket Protocol. + * + * @see + * Subprotocols Using the WebSocket Protocol + */ + public WebSocketClientBuilder subprotocols(Iterable subprotocols) { + this.subprotocols = ImmutableList.copyOf(requireNonNull(subprotocols, "subprotocols")); + return this; + } + + /** + * Sets whether to add an {@link HttpHeaderNames#ORIGIN} header automatically when sending + * an {@link HttpRequest} when the {@link HttpRequest#headers()} does not have it. + * It's {@code true} by default. + */ + public WebSocketClientBuilder autoFillOriginHeader(boolean autoFillOriginHeader) { + //TODO(minwoox): Promote this to AbstractClientOptionsBuilder. + option(ClientOptions.AUTO_FILL_ORIGIN_HEADER, autoFillOriginHeader); + return this; + } + + /** + * Returns a newly-created {@link WebSocketClient} based on the properties of this builder. + */ + public WebSocketClient build() { + final WebClient webClient = buildWebClient(); + return new DefaultWebSocketClient(webClient, maxFramePayloadLength, allowMaskMismatch, subprotocols); + } + + // Override the return type of the chaining methods in the superclass. + + @Deprecated + @Override + public WebSocketClientBuilder rpcDecorator(Function decorator) { + return (WebSocketClientBuilder) super.rpcDecorator(decorator); + } + + @Deprecated + @Override + public WebSocketClientBuilder rpcDecorator(DecoratingRpcClientFunction decorator) { + return (WebSocketClientBuilder) super.rpcDecorator(decorator); + } + + @Override + public WebSocketClientBuilder options(ClientOptions options) { + return (WebSocketClientBuilder) super.options(options); + } + + @Override + public WebSocketClientBuilder options(ClientOptionValue... options) { + return (WebSocketClientBuilder) super.options(options); + } + + @Override + public WebSocketClientBuilder options(Iterable> options) { + return (WebSocketClientBuilder) super.options(options); + } + + @Override + public WebSocketClientBuilder option(ClientOption option, T value) { + return (WebSocketClientBuilder) super.option(option, value); + } + + @Override + public WebSocketClientBuilder option(ClientOptionValue optionValue) { + return (WebSocketClientBuilder) super.option(optionValue); + } + + @Override + public WebSocketClientBuilder factory(ClientFactory factory) { + return (WebSocketClientBuilder) super.factory(factory); + } + + @Override + public WebSocketClientBuilder writeTimeout(Duration writeTimeout) { + return (WebSocketClientBuilder) super.writeTimeout(writeTimeout); + } + + @Override + public WebSocketClientBuilder writeTimeoutMillis(long writeTimeoutMillis) { + return (WebSocketClientBuilder) super.writeTimeoutMillis(writeTimeoutMillis); + } + + @Override + public WebSocketClientBuilder responseTimeout(Duration responseTimeout) { + return (WebSocketClientBuilder) super.responseTimeout(responseTimeout); + } + + @Override + public WebSocketClientBuilder responseTimeoutMillis(long responseTimeoutMillis) { + return (WebSocketClientBuilder) super.responseTimeoutMillis(responseTimeoutMillis); + } + + @Override + public WebSocketClientBuilder maxResponseLength(long maxResponseLength) { + return (WebSocketClientBuilder) super.maxResponseLength(maxResponseLength); + } + + @Override + public WebSocketClientBuilder requestAutoAbortDelay(Duration delay) { + return (WebSocketClientBuilder) super.requestAutoAbortDelay(delay); + } + + @Override + public WebSocketClientBuilder requestAutoAbortDelayMillis(long delayMillis) { + return (WebSocketClientBuilder) super.requestAutoAbortDelayMillis(delayMillis); + } + + @Override + public WebSocketClientBuilder requestIdGenerator(Supplier requestIdGenerator) { + return (WebSocketClientBuilder) super.requestIdGenerator(requestIdGenerator); + } + + @Override + public WebSocketClientBuilder successFunction(SuccessFunction successFunction) { + return (WebSocketClientBuilder) super.successFunction(successFunction); + } + + @Override + public WebSocketClientBuilder endpointRemapper( + Function endpointRemapper) { + return (WebSocketClientBuilder) super.endpointRemapper(endpointRemapper); + } + + @Override + public WebSocketClientBuilder decorator( + Function decorator) { + return (WebSocketClientBuilder) super.decorator(decorator); + } + + @Override + public WebSocketClientBuilder decorator(DecoratingHttpClientFunction decorator) { + return (WebSocketClientBuilder) super.decorator(decorator); + } + + @Override + public WebSocketClientBuilder clearDecorators() { + return (WebSocketClientBuilder) super.clearDecorators(); + } + + @Override + public WebSocketClientBuilder addHeader(CharSequence name, Object value) { + return (WebSocketClientBuilder) super.addHeader(name, value); + } + + @Override + public WebSocketClientBuilder addHeaders( + Iterable> headers) { + return (WebSocketClientBuilder) super.addHeaders(headers); + } + + @Override + public WebSocketClientBuilder setHeader(CharSequence name, Object value) { + return (WebSocketClientBuilder) super.setHeader(name, value); + } + + @Override + public WebSocketClientBuilder setHeaders( + Iterable> headers) { + return (WebSocketClientBuilder) super.setHeaders(headers); + } + + @Override + public WebSocketClientBuilder auth(BasicToken token) { + return (WebSocketClientBuilder) super.auth(token); + } + + @Override + public WebSocketClientBuilder auth(OAuth1aToken token) { + return (WebSocketClientBuilder) super.auth(token); + } + + @Override + public WebSocketClientBuilder auth(OAuth2Token token) { + return (WebSocketClientBuilder) super.auth(token); + } + + @Override + public WebSocketClientBuilder auth(AuthToken token) { + return (WebSocketClientBuilder) super.auth(token); + } + + @Override + public WebSocketClientBuilder followRedirects() { + return (WebSocketClientBuilder) super.followRedirects(); + } + + @Override + public WebSocketClientBuilder followRedirects(RedirectConfig redirectConfig) { + return (WebSocketClientBuilder) super.followRedirects(redirectConfig); + } + + @Override + public WebSocketClientBuilder contextCustomizer( + Consumer contextCustomizer) { + return (WebSocketClientBuilder) super.contextCustomizer(contextCustomizer); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java new file mode 100644 index 00000000000..7390953ac5f --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientFrameDecoder.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.client.websocket; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.internal.client.websocket.WebSocketClientUtil; +import com.linecorp.armeria.internal.common.websocket.WebSocketFrameDecoder; + +final class WebSocketClientFrameDecoder extends WebSocketFrameDecoder { + + private final ClientRequestContext ctx; + + WebSocketClientFrameDecoder(ClientRequestContext ctx, int maxFramePayloadLength, + boolean allowMaskMismatch) { + super(ctx, maxFramePayloadLength, allowMaskMismatch); + this.ctx = ctx; + } + + @Override + protected boolean expectMaskedFrames() { + return false; + } + + @Override + protected void onCloseFrameRead() { + // Need to close the response when HTTP/1.1 is used. + WebSocketClientUtil.closingResponse(ctx, null); + } + + @Override + protected void onProcessOnError(Throwable cause) { + // Need to close the response when HTTP/1.1 is used. + WebSocketClientUtil.closingResponse(ctx, cause); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientHandshakeException.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientHandshakeException.java new file mode 100644 index 00000000000..cba714a6426 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketClientHandshakeException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.client.websocket; + +import static java.util.Objects.requireNonNull; + +import com.linecorp.armeria.client.InvalidResponseException; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * An {@link InvalidResponseException} raised when a client received a response with invalid headers. + */ +@UnstableApi +public final class WebSocketClientHandshakeException extends InvalidResponseException { + + private static final long serialVersionUID = -8521952766254225005L; + + private final ResponseHeaders headers; + + /** + * Creates a new instance. + */ + public WebSocketClientHandshakeException(String message, ResponseHeaders headers) { + super(message); + this.headers = requireNonNull(headers, "headers"); + } + + /** + * Returns the {@link ResponseHeaders} of the handshake response. + */ + public ResponseHeaders headers() { + return headers; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketSession.java b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketSession.java new file mode 100644 index 00000000000..db2e7da5a5a --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/WebSocketSession.java @@ -0,0 +1,144 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.client.websocket; + +import static java.util.Objects.requireNonNull; + +import java.util.concurrent.CompletableFuture; + +import org.reactivestreams.Publisher; + +import com.google.common.base.MoreObjects; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.common.HttpData; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.ResponseHeaders; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.stream.PublisherBasedStreamMessage; +import com.linecorp.armeria.common.stream.StreamMessage; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder; + +/** + * A WebSocket session that is created after {@link WebSocketClient#connect(String)} succeeds. + * You can start sending {@link WebSocketFrame}s via {@link #setOutbound(Publisher)}. You can also subscribe to + * {@link #inbound()} to receive {@link WebSocketFrame}s from the server. + */ +@UnstableApi +public final class WebSocketSession { + + private final ClientRequestContext ctx; + private final ResponseHeaders responseHeaders; + @Nullable + private final String subprotocol; + private final WebSocket inbound; + private final CompletableFuture> outboundFuture; + private final WebSocketFrameEncoder encoder; + + WebSocketSession(ClientRequestContext ctx, ResponseHeaders responseHeaders, WebSocket inbound, + CompletableFuture> outboundFuture, + WebSocketFrameEncoder encoder) { + this.ctx = ctx; + this.responseHeaders = responseHeaders; + subprotocol = responseHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); + this.inbound = inbound; + this.outboundFuture = outboundFuture; + this.encoder = encoder; + } + + /** + * Returns the {@link ClientRequestContext}. + */ + public ClientRequestContext context() { + return ctx; + } + + /** + * Returns the {@link ResponseHeaders}. + */ + public ResponseHeaders responseHeaders() { + return responseHeaders; + } + + /** + * Returns the subprotocol negotiated between the client and the server. + */ + @Nullable + public String subprotocol() { + return subprotocol; + } + + /** + * Returns the {@link WebSocket} that is used to receive WebSocket frames from the server. + */ + public WebSocket inbound() { + return inbound; + } + + /** + * Returns the {@link WebSocketWriter} that is used to send WebSocket frames to the server. + * + * @throws IllegalStateException if this method or {@link #setOutbound(Publisher)} has been called already. + */ + public WebSocketWriter outbound() { + final WebSocketWriter writer = WebSocket.streaming(); + setOutbound(writer); + return writer; + } + + /** + * Sets the {@link WebSocket} that is used to send WebSocket frames to the server. + * + * @throws IllegalStateException if this method or {@link #outbound()} has been called already. + */ + public void setOutbound(Publisher outbound) { + requireNonNull(outbound, "outbound"); + if (outboundFuture.isDone()) { + if (outbound instanceof StreamMessage) { + ((StreamMessage) outbound).abort(); + } + throw new IllegalStateException("outbound() or setOutbound() has been already called."); + } + final StreamMessage streamMessage; + if (outbound instanceof StreamMessage) { + streamMessage = (StreamMessage) outbound; + } else { + streamMessage = new PublisherBasedStreamMessage<>(outbound); + } + + if (!outboundFuture.complete( + streamMessage.map(webSocketFrame -> HttpData.wrap(encoder.encode(ctx, webSocketFrame))))) { + streamMessage.abort(); + throw new IllegalStateException("outbound() or setOutbound() has been already called."); + } + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("ctx", ctx) + .add("responseHeaders", responseHeaders) + .add("subprotocol", subprotocol) + .add("inbound", inbound) + .add("outboundFuture", outboundFuture) + .add("encoder", encoder) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/websocket/package-info.java b/core/src/main/java/com/linecorp/armeria/client/websocket/package-info.java new file mode 100644 index 00000000000..848479aec63 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/websocket/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Client-side classes for the WebSocket Protocol. + */ +@NonNullByDefault +package com.linecorp.armeria.client.websocket; + +import com.linecorp.armeria.common.annotation.NonNullByDefault; diff --git a/core/src/main/java/com/linecorp/armeria/common/AbstractHttpMessageBuilder.java b/core/src/main/java/com/linecorp/armeria/common/AbstractHttpMessageBuilder.java index 0e7f8f0efcc..91e6e49e358 100644 --- a/core/src/main/java/com/linecorp/armeria/common/AbstractHttpMessageBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/common/AbstractHttpMessageBuilder.java @@ -143,6 +143,14 @@ public AbstractHttpMessageBuilder content(MediaType contentType, HttpData conten return this; } + @Override + public AbstractHttpMessageBuilder content(Publisher publisher) { + requireNonNull(publisher, "publisher"); + checkState(content == null, "content has been set already"); + this.publisher = publisher; + return this; + } + @Override public AbstractHttpMessageBuilder content(MediaType contentType, Publisher publisher) { requireNonNull(contentType, "contentType"); diff --git a/core/src/main/java/com/linecorp/armeria/common/AbstractHttpRequestBuilder.java b/core/src/main/java/com/linecorp/armeria/common/AbstractHttpRequestBuilder.java index 2fd02b9a36f..5dee216529f 100644 --- a/core/src/main/java/com/linecorp/armeria/common/AbstractHttpRequestBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/common/AbstractHttpRequestBuilder.java @@ -141,6 +141,11 @@ public AbstractHttpRequestBuilder content(MediaType contentType, HttpData conten return (AbstractHttpRequestBuilder) super.content(contentType, content); } + @Override + public AbstractHttpRequestBuilder content(Publisher content) { + return (AbstractHttpRequestBuilder) super.content(content); + } + @Override public AbstractHttpRequestBuilder content(MediaType contentType, Publisher content) { return (AbstractHttpRequestBuilder) super.content(contentType, content); @@ -288,7 +293,9 @@ private RequestHeaders requestHeaders() { } private String buildPath() { - checkState(path != null, "path must be set."); + final String headerPath = requestHeadersBuilder.get(HttpHeaderNames.PATH); + checkState(path != null || headerPath != null, "path must be set."); + final String path = firstNonNull(this.path, headerPath); if (!disablePathParams) { // Path parameter substitution is enabled. Look for : or { first. diff --git a/core/src/main/java/com/linecorp/armeria/common/DeferredHttpResponse.java b/core/src/main/java/com/linecorp/armeria/common/DeferredHttpResponse.java index 4c3b7dbbfbc..72b8d455036 100644 --- a/core/src/main/java/com/linecorp/armeria/common/DeferredHttpResponse.java +++ b/core/src/main/java/com/linecorp/armeria/common/DeferredHttpResponse.java @@ -40,7 +40,7 @@ void delegate(HttpResponse delegate) { } void delegateWhenComplete(CompletionStage stage) { - delegateWhenCompleteStage(stage); + delegateOnCompletion(stage); } @SuppressWarnings("unchecked") diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpMessageSetters.java b/core/src/main/java/com/linecorp/armeria/common/HttpMessageSetters.java index 55a0a1af8ef..5374f40c518 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpMessageSetters.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpMessageSetters.java @@ -74,6 +74,11 @@ HttpMessageSetters content(MediaType contentType, @FormatString String format, */ HttpMessageSetters content(MediaType contentType, HttpData content); + /** + * Sets the {@link Publisher} for this message. + */ + HttpMessageSetters content(Publisher content); + /** * Sets the {@link Publisher} for this message. */ diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpRequestBuilder.java b/core/src/main/java/com/linecorp/armeria/common/HttpRequestBuilder.java index 5e06d71b17d..0c42271bf8e 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpRequestBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpRequestBuilder.java @@ -127,6 +127,11 @@ public HttpRequestBuilder content(MediaType contentType, HttpData content) { return (HttpRequestBuilder) super.content(contentType, content); } + @Override + public HttpRequestBuilder content(Publisher publisher) { + return (HttpRequestBuilder) super.content(publisher); + } + @Override public HttpRequestBuilder content(MediaType contentType, Publisher publisher) { return (HttpRequestBuilder) super.content(contentType, publisher); diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpRequestSetters.java b/core/src/main/java/com/linecorp/armeria/common/HttpRequestSetters.java index 487a95bef50..fa438cab2ff 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpRequestSetters.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpRequestSetters.java @@ -82,6 +82,12 @@ HttpRequestSetters content(MediaType contentType, @FormatString String format, @Override HttpRequestSetters content(MediaType contentType, HttpData content); + /** + * Sets the {@link Publisher} for this request. + */ + @Override + HttpRequestSetters content(Publisher content); + /** * Sets the {@link Publisher} for this request. */ diff --git a/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java b/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java index 98d918906aa..4fb9982f13d 100644 --- a/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/common/HttpResponseBuilder.java @@ -181,6 +181,14 @@ public HttpResponseBuilder content(MediaType contentType, HttpData content) { return (HttpResponseBuilder) super.content(contentType, content); } + /** + * Sets the {@link Publisher} for this response. + */ + @Override + public HttpResponseBuilder content(Publisher content) { + return (HttpResponseBuilder) super.content(content); + } + /** * Sets the {@link Publisher} for this response. */ diff --git a/core/src/main/java/com/linecorp/armeria/common/Scheme.java b/core/src/main/java/com/linecorp/armeria/common/Scheme.java index 4c4a72c4237..2767fe09283 100644 --- a/core/src/main/java/com/linecorp/armeria/common/Scheme.java +++ b/core/src/main/java/com/linecorp/armeria/common/Scheme.java @@ -60,6 +60,13 @@ public final class Scheme implements Comparable { final Scheme scheme = new Scheme(f, p); schemes.put(ftxt + '+' + ptxt, scheme); schemes.put(ptxt + '+' + ftxt, scheme); + if (SerializationFormat.WS == f) { + if (SessionProtocol.HTTP == p) { + schemes.put("ws", scheme); + } else if (SessionProtocol.HTTPS == p) { + schemes.put("wss", scheme); + } + } } } diff --git a/core/src/main/java/com/linecorp/armeria/common/SerializationFormat.java b/core/src/main/java/com/linecorp/armeria/common/SerializationFormat.java index 05cf0f1da86..98e29570e22 100644 --- a/core/src/main/java/com/linecorp/armeria/common/SerializationFormat.java +++ b/core/src/main/java/com/linecorp/armeria/common/SerializationFormat.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.linecorp.armeria.common.MediaType.OCTET_STREAM; import static com.linecorp.armeria.common.MediaType.create; import static java.util.Objects.requireNonNull; @@ -39,6 +40,7 @@ import com.google.common.collect.Multimap; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; /** * Serialization format of a remote procedure call and its reply. @@ -55,6 +57,12 @@ public final class SerializationFormat implements Comparable HTTPS_VALUES = Sets.immutableEnumSet(HTTPS, H1, H2); + private static final Set HTTP_AND_HTTPS_VALUES = + Sets.immutableEnumSet(HTTPS, HTTP, H1, H1C, H2, H2C); + private static final Map uriTextToProtocols; static { @@ -117,6 +121,14 @@ public static Set httpsValues() { return HTTPS_VALUES; } + /** + * Returns an immutable {@link Set} that contains {@link #httpValues()} and {@link #httpsValues()}. + */ + @UnstableApi + public static Set httpAndHttpsValues() { + return HTTP_AND_HTTPS_VALUES; + } + private final String uriText; private final boolean useTls; private final boolean isMultiplex; diff --git a/core/src/main/java/com/linecorp/armeria/common/annotation/UnstableApi.java b/core/src/main/java/com/linecorp/armeria/common/annotation/UnstableApi.java index 5a5091952e6..432324dcc14 100644 --- a/core/src/main/java/com/linecorp/armeria/common/annotation/UnstableApi.java +++ b/core/src/main/java/com/linecorp/armeria/common/annotation/UnstableApi.java @@ -25,7 +25,7 @@ * Indicates the API of the target is not mature enough to guarantee the compatibility between releases. * Its behavior, signature or even existence might change without a prior notice at any point. */ -@Retention(RetentionPolicy.SOURCE) +@Retention(RetentionPolicy.CLASS) @Target({ ElementType.ANNOTATION_TYPE, ElementType.CONSTRUCTOR, diff --git a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java index 460f2aeb46b..67c7d672d70 100644 --- a/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java +++ b/core/src/main/java/com/linecorp/armeria/common/logging/DefaultRequestLog.java @@ -731,9 +731,21 @@ private void session0(@Nullable Channel channel, SessionProtocol sessionProtocol this.sslSession = sslSession; this.sessionProtocol = sessionProtocol; this.connectionTimings = connectionTimings; + maybeSetScheme(); updateFlags(RequestLogProperty.SESSION); } + private void maybeSetScheme() { + if (isAvailable(RequestLogProperty.SCHEME) || + serializationFormat == SerializationFormat.NONE) { + return; + } + + assert sessionProtocol != null; + scheme = Scheme.of(serializationFormat, sessionProtocol); + updateFlags(RequestLogProperty.SCHEME); + } + @Override public Channel channel() { ensureAvailable(RequestLogProperty.SESSION); diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/DeferredStreamMessage.java b/core/src/main/java/com/linecorp/armeria/common/stream/DeferredStreamMessage.java index bc94735cc8e..723539500dd 100644 --- a/core/src/main/java/com/linecorp/armeria/common/stream/DeferredStreamMessage.java +++ b/core/src/main/java/com/linecorp/armeria/common/stream/DeferredStreamMessage.java @@ -137,7 +137,7 @@ public EventExecutor defaultSubscriberExecutor() { /** * Delegates when the specified {@link CompletionStage} is complete. */ - protected final void delegateWhenCompleteStage(CompletionStage> stage) { + protected final void delegateOnCompletion(CompletionStage> stage) { requireNonNull(stage, "stage"); stage.handle((upstream, thrown) -> { if (thrown != null) { diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java index 9f1280a8449..4a5bfe1526e 100644 --- a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java +++ b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessage.java @@ -202,7 +202,7 @@ static StreamMessage of(CompletionStage> } else { final DeferredStreamMessage deferred = new DeferredStreamMessage<>(); //noinspection unchecked - deferred.delegateWhenCompleteStage((CompletionStage>) stage); + deferred.delegateOnCompletion((CompletionStage>) stage); return deferred; } } @@ -224,7 +224,7 @@ static StreamMessage of(CompletionStage deferred = new DeferredStreamMessage<>(subscriberExecutor); //noinspection unchecked - deferred.delegateWhenCompleteStage((CompletionStage>) stage); + deferred.delegateOnCompletion((CompletionStage>) stage); return deferred; } diff --git a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessageUtil.java b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessageUtil.java index bee8a461186..539f9ce2a5f 100644 --- a/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessageUtil.java +++ b/core/src/main/java/com/linecorp/armeria/common/stream/StreamMessageUtil.java @@ -57,7 +57,7 @@ static StreamMessage createStreamMessageFrom( final DeferredStreamMessage deferred = new DeferredStreamMessage<>(); //noinspection unchecked - deferred.delegateWhenCompleteStage((CompletionStage>) future); + deferred.delegateOnCompletion((CompletionStage>) future); return deferred; } diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java b/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java index 661fd471326..1c4ff86b472 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/ClientUtil.java @@ -18,6 +18,7 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static java.util.Objects.requireNonNull; +import java.net.URI; import java.util.concurrent.CompletableFuture; import java.util.function.BiFunction; import java.util.function.Function; @@ -26,6 +27,7 @@ import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.client.Endpoint; import com.linecorp.armeria.client.UnprocessedRequestException; +import com.linecorp.armeria.client.WebClient; import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.Request; @@ -43,6 +45,11 @@ public final class ClientUtil { + /** + * An undefined {@link URI} to create {@link WebClient} without specifying {@link URI}. + */ + public static final URI UNDEFINED_URI = URI.create("http://undefined"); + public static > O initContextAndExecuteWithFallback( U delegate, diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java index 8c9d406c36f..eef72c48c05 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java @@ -404,7 +404,7 @@ public void finishInitialization(boolean success) { private void updateEndpoint(@Nullable Endpoint endpoint) { this.endpoint = endpoint; - autoFillSchemeAndAuthority(); + autoFillSchemeAuthorityAndOrigin(); } private void acquireEventLoop(EndpointGroup endpointGroup) { @@ -428,7 +428,7 @@ private void failEarly(Throwable cause) { final UnprocessedRequestException wrapped = UnprocessedRequestException.of(cause); final HttpRequest req = request(); if (req != null) { - autoFillSchemeAndAuthority(); + autoFillSchemeAuthorityAndOrigin(); req.abort(wrapped); } @@ -438,7 +438,7 @@ private void failEarly(Throwable cause) { } // TODO(ikhoon): Consider moving the logic for filling authority to `HttpClientDelegate.exceute()`. - private void autoFillSchemeAndAuthority() { + private void autoFillSchemeAuthorityAndOrigin() { final String authority = authority(); if (authority != null && endpoint != null && endpoint.isIpAddrOnly()) { // The connection will be established with the IP address but `host` set to the `Endpoint` @@ -453,7 +453,16 @@ private void autoFillSchemeAndAuthority() { final HttpHeadersBuilder headersBuilder = internalRequestHeaders.toBuilder(); headersBuilder.set(HttpHeaderNames.SCHEME, getScheme(sessionProtocol())); if (endpoint != null) { - headersBuilder.set(HttpHeaderNames.AUTHORITY, endpoint.authority()); + final String endpointAuthority = endpoint.authority(); + headersBuilder.set(HttpHeaderNames.AUTHORITY, endpointAuthority); + final String origin = origin(); + if (origin != null) { + headersBuilder.set(HttpHeaderNames.ORIGIN, origin); + } else if (options().autoFillOriginHeader()) { + final String uriText = sessionProtocol().isTls() ? SessionProtocol.HTTPS.uriText() + : SessionProtocol.HTTP.uriText(); + headersBuilder.set(HttpHeaderNames.ORIGIN, uriText + "://" + endpointAuthority); + } } internalRequestHeaders = headersBuilder.build(); } @@ -576,7 +585,6 @@ public ClientRequestContext newDerivedContext(RequestId id, protocol, newHeaders.method(), reqTarget); } } - return new DefaultClientRequestContext(this, id, req, rpcReq, endpoint, endpointGroup(), sessionProtocol(), method(), requestTarget()); } @@ -698,6 +706,23 @@ public String authority() { return authority; } + @Nullable + private String origin() { + final HttpHeaders additionalRequestHeaders = this.additionalRequestHeaders; + String origin = additionalRequestHeaders.get(HttpHeaderNames.ORIGIN); + final HttpRequest request = request(); + if (origin == null && request != null) { + origin = request.headers().get(HttpHeaderNames.ORIGIN); + } + if (origin == null) { + origin = defaultRequestHeaders.get(HttpHeaderNames.ORIGIN); + } + if (origin == null) { + origin = internalRequestHeaders.get(HttpHeaderNames.ORIGIN); + } + return origin; + } + @Override public URI uri() { final String scheme = getScheme(sessionProtocol()); diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java b/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java index 768fcdf534f..29ff3f02932 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/HttpSession.java @@ -19,6 +19,7 @@ import com.linecorp.armeria.client.ClientRequestContext; import com.linecorp.armeria.common.ClosedSessionException; import com.linecorp.armeria.common.HttpRequest; +import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.common.InboundTrafficController; @@ -34,6 +35,12 @@ public interface HttpSession { int MAX_NUM_REQUESTS_SENT = 536870912; HttpSession INACTIVE = new HttpSession() { + + @Override + public SerializationFormat serializationFormat() { + return SerializationFormat.UNKNOWN; + } + @Nullable @Override public SessionProtocol protocol() { @@ -93,6 +100,13 @@ static HttpSession get(Channel ch) { return INACTIVE; } + SerializationFormat serializationFormat(); + + /** + * Returns the explicit {@link SessionProtocol} of this {@link HttpSession}. + * This is one of {@link SessionProtocol#H1}, {@link SessionProtocol#H1C}, {@link SessionProtocol#H2} and + * {@link SessionProtocol#H2C}. + */ @Nullable SessionProtocol protocol(); diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/websocket/WebSocketClientUtil.java b/core/src/main/java/com/linecorp/armeria/internal/client/websocket/WebSocketClientUtil.java new file mode 100644 index 00000000000..623b8ce5560 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/client/websocket/WebSocketClientUtil.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.client.websocket; + +import static java.util.Objects.requireNonNull; + +import java.util.function.Consumer; + +import com.linecorp.armeria.client.ClientRequestContext; +import com.linecorp.armeria.common.annotation.Nullable; + +import io.netty.util.AttributeKey; + +public final class WebSocketClientUtil { + + private static final AttributeKey> CLOSING_RESPONSE_TASK = + AttributeKey.valueOf(WebSocketClientUtil.class, "CLOSING_RESPONSE_TASK"); + + public static void setClosingResponseTask(ClientRequestContext ctx, Consumer task) { + requireNonNull(ctx, "ctx"); + requireNonNull(task, "task"); + ctx.setAttr(CLOSING_RESPONSE_TASK, task); + } + + public static void closingResponse(ClientRequestContext ctx, @Nullable Throwable cause) { + requireNonNull(ctx, "ctx"); + final Consumer task = ctx.attr(CLOSING_RESPONSE_TASK); + if (task != null) { + task.accept(cause); + } + } + + private WebSocketClientUtil() {} +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/websocket/package-info.java b/core/src/main/java/com/linecorp/armeria/internal/client/websocket/package-info.java new file mode 100644 index 00000000000..c0062c76749 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/client/websocket/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Internal client classes for + * the WebSocket Protocol. + */ +@NonNullByDefault +package com.linecorp.armeria.internal.client.websocket; + +import com.linecorp.armeria.common.annotation.NonNullByDefault; diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultSplitHttpResponse.java b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultSplitHttpResponse.java index 431f8297192..62316790cfc 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/DefaultSplitHttpResponse.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/DefaultSplitHttpResponse.java @@ -17,6 +17,7 @@ package com.linecorp.armeria.internal.common; import java.util.concurrent.CompletableFuture; +import java.util.function.Predicate; import org.reactivestreams.Subscription; @@ -38,7 +39,21 @@ public class DefaultSplitHttpResponse extends AbstractSplitHttpMessage implement private final SplitHttpResponseBodySubscriber bodySubscriber; public DefaultSplitHttpResponse(HttpResponse response, EventExecutor upstreamExecutor) { - this(response, upstreamExecutor, new SplitHttpResponseBodySubscriber(response, upstreamExecutor)); + this(response, upstreamExecutor, headers -> !headers.status().isInformational()); + } + + /** + * Creates a new {@link DefaultSplitHttpResponse} from the specified {@link HttpResponse}. + * The specified {@link Predicate} is used to determine if the {@link ResponseHeaders} is the final one. + * For example, if there are multiple informational {@link ResponseHeaders} and the {@link Predicate} + * will be {@code headers -> !headers.status().isInformational()}. + * However, if the {@link ResponseHeaders} is only one, and it can be an informational one such as a + * WebSocket response, {@link Predicate} will be {@code headers -> true}. + */ + public DefaultSplitHttpResponse(HttpResponse response, EventExecutor upstreamExecutor, + Predicate finalResponseHeadersPredicate) { + this(response, upstreamExecutor, + new SplitHttpResponseBodySubscriber(response, upstreamExecutor, finalResponseHeadersPredicate)); } private DefaultSplitHttpResponse(HttpResponse response, EventExecutor upstreamExecutor, @@ -55,9 +70,12 @@ public final CompletableFuture headers() { private static final class SplitHttpResponseBodySubscriber extends SplitHttpMessageSubscriber { private final HeadersFuture headersFuture = new HeadersFuture<>(); + private final Predicate finalResponseHeadersPredicate; - SplitHttpResponseBodySubscriber(HttpResponse response, EventExecutor upstreamExecutor) { + SplitHttpResponseBodySubscriber(HttpResponse response, EventExecutor upstreamExecutor, + Predicate finalResponseHeadersPredicate) { super(1, response, upstreamExecutor); + this.finalResponseHeadersPredicate = finalResponseHeadersPredicate; } CompletableFuture headersFuture() { @@ -68,14 +86,12 @@ CompletableFuture headersFuture() { public void onNext(HttpObject httpObject) { if (httpObject instanceof ResponseHeaders) { final ResponseHeaders headers = (ResponseHeaders) httpObject; - final HttpStatus status = headers.status(); - if (status.isInformational()) { - // Ignore informational headers + if (finalResponseHeadersPredicate.test(headers)) { + headersFuture.doComplete(headers); + } else { final Subscription upstream = upstream(); assert upstream != null; upstream.request(1); - } else { - headersFuture.doComplete(headers); } return; } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/HttpHeadersUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/HttpHeadersUtil.java index 33d989f02cf..8fade4ac277 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/HttpHeadersUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/HttpHeadersUtil.java @@ -92,6 +92,19 @@ public static RequestHeaders mergeRequestHeaders(RequestHeaders headers, headers.contains(HttpHeaderNames.USER_AGENT)) { return headers; } + if (defaultHeaders.isEmpty() && additionalHeaders.isEmpty()) { + boolean containAllInternalHeaders = true; + for (AsciiString name : internalHeaders.names()) { + if (!headers.contains(name)) { + containAllInternalHeaders = false; + break; + } + } + + if (containAllInternalHeaders) { + return headers; + } + } final RequestHeadersBuilder builder = headers.toBuilder(); diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java index d8c2f21b6e2..ad31efcb993 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameDecoder.java @@ -34,8 +34,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.linecorp.armeria.common.HttpRequestWriter; -import com.linecorp.armeria.common.Request; +import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.stream.HttpDecoder; import com.linecorp.armeria.common.stream.StreamDecoderInput; @@ -45,14 +44,12 @@ import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; import com.linecorp.armeria.common.websocket.WebSocketFrame; import com.linecorp.armeria.common.websocket.WebSocketFrameType; -import com.linecorp.armeria.internal.common.RequestContextExtension; -import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.server.websocket.WebSocketProtocolViolationException; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -public final class WebSocketFrameDecoder implements HttpDecoder { +public abstract class WebSocketFrameDecoder implements HttpDecoder { // Forked from Netty 4.1.92 https://github.com/netty/netty/blob/e8df52e442629214e0355528c00e873e213f0139/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java @@ -67,10 +64,9 @@ enum State { CORRUPT } - private final ServiceRequestContext ctx; + private final RequestContext ctx; private final int maxFramePayloadLength; private final boolean allowMaskMismatch; - private final boolean expectMaskedFrames; @Nullable private WebSocket outboundFrames; @@ -85,12 +81,10 @@ enum State { private boolean receivedClosingHandshake; private State state = State.READING_FIRST; - public WebSocketFrameDecoder(ServiceRequestContext ctx, int maxFramePayloadLength, - boolean allowMaskMismatch, boolean expectMaskedFrames) { + protected WebSocketFrameDecoder(RequestContext ctx, int maxFramePayloadLength, boolean allowMaskMismatch) { this.ctx = ctx; this.maxFramePayloadLength = maxFramePayloadLength; this.allowMaskMismatch = allowMaskMismatch; - this.expectMaskedFrames = expectMaskedFrames; } public void setOutboundWebSocket(WebSocket outboundFrames) { @@ -136,7 +130,7 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o throw protocolViolation("RSV != 0 and no extension negotiated, RSV:" + frameRsv); } - if (!allowMaskMismatch && expectMaskedFrames != frameMasked) { + if (!allowMaskMismatch && expectMaskedFrames() != frameMasked) { throw protocolViolation("received a frame that is not masked as expected"); } @@ -273,7 +267,7 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o final CloseWebSocketFrame decodedFrame = WebSocketFrame.ofPooledClose(payloadBuffer); out.add(decodedFrame); logger.trace("{} is decoded.", decodedFrame); - closeRequest(); + onCloseFrameRead(); continue; // to while loop } @@ -304,6 +298,10 @@ public void process(StreamDecoderInput in, StreamDecoderOutput o } } + protected abstract boolean expectMaskedFrames(); + + protected abstract void onCloseFrameRead(); + private void unmask(ByteBuf frame) { long longMask = mask & 0xFFFFFFFFL; longMask |= longMask << 32; @@ -366,19 +364,17 @@ private void validateCloseFrame(ByteBuf buffer) { } } - private void closeRequest() { - final RequestContextExtension ctxExtension = ctx.as(RequestContextExtension.class); - assert ctxExtension != null; - final Request request = ctxExtension.originalRequest(); - assert request instanceof HttpRequestWriter : request; - //noinspection OverlyStrongTypeCast - ((HttpRequestWriter) request).close(); - } - @Override public void processOnError(Throwable cause) { - if (outboundFrames != null) { - outboundFrames.abort(cause); + // If an exception from the inbound stream is raised after receiving a close frame, + // we should not abort the outbound stream. + if (!receivedClosingHandshake) { + if (outboundFrames != null) { + outboundFrames.abort(cause); + } } + onProcessOnError(cause); } + + protected void onProcessOnError(Throwable cause) {} } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java index dd18293c5fa..544208f7bb2 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/websocket/WebSocketUtil.java @@ -17,6 +17,11 @@ import static java.util.Objects.requireNonNull; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import com.google.common.hash.Hashing; + import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.RequestHeaders; @@ -29,8 +34,9 @@ public final class WebSocketUtil { - public static final long DEFAULT_REQUEST_TIMEOUT_MILLIS = 0; - public static final long DEFAULT_MAX_REQUEST_LENGTH = 0; + private static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + public static final long DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS = 0; + public static final long DEFAULT_MAX_REQUEST_RESPONSE_LENGTH = 0; public static final long DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS = 5000; public static boolean isHttp1WebSocketUpgradeRequest(RequestHeaders headers) { @@ -54,6 +60,17 @@ public static boolean isHttp2WebSocketUpgradeRequest(RequestHeaders headers) { HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(headers.get(HttpHeaderNames.PROTOCOL)); } + /** + * Generates Sec-WebSocket-Accept using Sec-WebSocket-Key. + * + * @see Opening Handshake + */ + public static String generateSecWebSocketAccept(String webSocketKey) { + final String acceptSeed = webSocketKey + MAGIC_GUID; + final byte[] sha1 = Hashing.sha1().hashBytes(acceptSeed.getBytes(StandardCharsets.US_ASCII)).asBytes(); + return Base64.getEncoder().encodeToString(sha1); + } + static int byteAtIndex(int mask, int index) { return (mask >> 8 * (3 - index)) & 0xFF; } diff --git a/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseSubscriber.java b/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseSubscriber.java index 344789073c2..4a8ad179f35 100644 --- a/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseSubscriber.java +++ b/core/src/main/java/com/linecorp/armeria/server/AbstractHttpResponseSubscriber.java @@ -111,6 +111,7 @@ public void onNext(HttpObject o) { req.abortResponse(new IllegalArgumentException( "published an HttpObject that's neither HttpHeaders nor HttpData: " + o + " (service: " + service() + ')'), true); + PooledObjects.close(o); return; } diff --git a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java index f564480c32a..f5de6f98b25 100644 --- a/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java +++ b/core/src/main/java/com/linecorp/armeria/server/Http1RequestDecoder.java @@ -248,7 +248,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception assert encoder instanceof ServerHttp1ObjectEncoder; ((ServerHttp1ObjectEncoder) encoder).webSocketUpgrading(); final ChannelPipeline pipeline = ctx.pipeline(); - pipeline.replace(this, null, new WebSocketSessionChannelHandler( + pipeline.replace(this, null, new WebSocketServiceChannelHandler( webSocketRequest, encoder, serviceConfig)); if (pipeline.get(HttpServerUpgradeHandler.class) != null) { pipeline.remove(HttpServerUpgradeHandler.class); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java index bfb8c8291b6..6e6b4ff3fd9 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServiceConfigBuilder.java @@ -312,7 +312,7 @@ ServiceConfig build(ServiceNaming defaultServiceNaming, } else if (!webSocket || defaultRequestTimeoutMillis != Flags.defaultRequestTimeoutMillis()) { requestTimeoutMillis = defaultRequestTimeoutMillis; } else { - requestTimeoutMillis = WebSocketUtil.DEFAULT_REQUEST_TIMEOUT_MILLIS; + requestTimeoutMillis = WebSocketUtil.DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS; } final long maxRequestLength; @@ -321,7 +321,7 @@ ServiceConfig build(ServiceNaming defaultServiceNaming, } else if (!webSocket || defaultMaxRequestLength != Flags.defaultMaxRequestLength()) { maxRequestLength = defaultMaxRequestLength; } else { - maxRequestLength = WebSocketUtil.DEFAULT_MAX_REQUEST_LENGTH; + maxRequestLength = WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH; } final long requestAutoAbortDelayMillis; diff --git a/core/src/main/java/com/linecorp/armeria/server/WebSocketSessionChannelHandler.java b/core/src/main/java/com/linecorp/armeria/server/WebSocketServiceChannelHandler.java similarity index 94% rename from core/src/main/java/com/linecorp/armeria/server/WebSocketSessionChannelHandler.java rename to core/src/main/java/com/linecorp/armeria/server/WebSocketServiceChannelHandler.java index e348d1c1b06..79bee0f88fe 100644 --- a/core/src/main/java/com/linecorp/armeria/server/WebSocketSessionChannelHandler.java +++ b/core/src/main/java/com/linecorp/armeria/server/WebSocketServiceChannelHandler.java @@ -38,15 +38,15 @@ import io.netty.handler.codec.http2.Http2Error; import io.netty.util.ReferenceCountUtil; -final class WebSocketSessionChannelHandler extends ChannelDuplexHandler { +final class WebSocketServiceChannelHandler extends ChannelDuplexHandler { - private static final Logger logger = LoggerFactory.getLogger(WebSocketSessionChannelHandler.class); + private static final Logger logger = LoggerFactory.getLogger(WebSocketServiceChannelHandler.class); private final StreamingDecodedHttpRequest req; private final ServerHttpObjectEncoder encoder; private final ServiceConfig serviceConfig; - WebSocketSessionChannelHandler(StreamingDecodedHttpRequest req, ServerHttpObjectEncoder encoder, + WebSocketServiceChannelHandler(StreamingDecodedHttpRequest req, ServerHttpObjectEncoder encoder, ServiceConfig serviceConfig) { this.req = req; this.encoder = encoder; @@ -82,6 +82,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { logger.warn("{} Unexpected msg: {}", ctx.channel(), msg); return; } + encoder.keepAliveHandler().onReadOrWrite(); try { final ByteBuf data = (ByteBuf) msg; final int dataLength = data.readableBytes(); @@ -123,7 +124,7 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) final HttpResponse response = (HttpResponse) msg; final HttpResponseStatus status = response.status(); ctx.write(msg, promise); - if (status == HttpResponseStatus.SWITCHING_PROTOCOLS) { + if (status.code() == HttpResponseStatus.SWITCHING_PROTOCOLS.code()) { ctx.pipeline().remove(HttpServerCodec.class); } return; diff --git a/core/src/main/java/com/linecorp/armeria/server/metric/PrometheusExpositionService.java b/core/src/main/java/com/linecorp/armeria/server/metric/PrometheusExpositionService.java index e6afdc782ca..311acd83841 100644 --- a/core/src/main/java/com/linecorp/armeria/server/metric/PrometheusExpositionService.java +++ b/core/src/main/java/com/linecorp/armeria/server/metric/PrometheusExpositionService.java @@ -13,7 +13,6 @@ * License for the specific language governing permissions and limitations * under the License. */ - package com.linecorp.armeria.server.metric; import static java.util.Objects.requireNonNull; @@ -46,6 +45,14 @@ */ public final class PrometheusExpositionService extends AbstractHttpService implements TransientHttpService { + /** + * Returns a new {@link PrometheusExpositionService} that exposes Prometheus metrics from + * {@link CollectorRegistry#defaultRegistry}. + */ + public static PrometheusExpositionService of() { + return of(CollectorRegistry.defaultRegistry); + } + /** * Returns a new {@link PrometheusExpositionService} that exposes Prometheus metrics from the specified * {@link CollectorRegistry}. @@ -54,6 +61,14 @@ public static PrometheusExpositionService of(CollectorRegistry collectorRegistry return new PrometheusExpositionService(collectorRegistry, Flags.transientServiceOptions()); } + /** + * Returns a new {@link PrometheusExpositionServiceBuilder} created with + * {@link CollectorRegistry#defaultRegistry}. + */ + public static PrometheusExpositionServiceBuilder builder() { + return builder(CollectorRegistry.defaultRegistry); + } + /** * Returns a new {@link PrometheusExpositionServiceBuilder} created with the specified * {@link CollectorRegistry}. diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java index fef607e1933..9c53cf2e553 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketService.java @@ -15,19 +15,17 @@ */ package com.linecorp.armeria.server.websocket; +import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.generateSecWebSocketAccept; import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.isHttp1WebSocketUpgradeRequest; import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.isHttp2WebSocketUpgradeRequest; import static com.linecorp.armeria.internal.common.websocket.WebSocketUtil.newCloseWebSocketFrame; -import java.nio.charset.StandardCharsets; -import java.util.Base64; import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.base.Splitter; -import com.google.common.hash.Hashing; import com.google.common.net.HostAndPort; import com.linecorp.armeria.common.HttpData; @@ -46,7 +44,6 @@ import com.linecorp.armeria.common.stream.StreamMessage; import com.linecorp.armeria.common.websocket.WebSocket; import com.linecorp.armeria.common.websocket.WebSocketFrame; -import com.linecorp.armeria.internal.common.websocket.WebSocketFrameDecoder; import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder; import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper; import com.linecorp.armeria.server.AbstractHttpService; @@ -68,8 +65,6 @@ public final class WebSocketService extends AbstractHttpService { private static final Logger logger = LoggerFactory.getLogger(WebSocketService.class); - private static final String WEBSOCKET_13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - private static final String SUB_PROTOCOL_WILDCARD = "*"; private static final ResponseHeaders UNSUPPORTED_WEB_SOCKET_VERSION = @@ -192,19 +187,10 @@ private void maybeAddSubprotocol(RequestHeaders headers, HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol)); } - // Generate Sec-WebSocket-Accept using Sec-WebSocket-Key. - // See https://datatracker.ietf.org/doc/html/rfc6455#section-11.3.3 - private static String generateSecWebSocketAccept(String webSocketKey) { - final String acceptSeed = webSocketKey + WEBSOCKET_13_ACCEPT_GUID; - final byte[] sha1 = Hashing.sha1().hashBytes(acceptSeed.getBytes(StandardCharsets.US_ASCII)).asBytes(); - return Base64.getEncoder().encodeToString(sha1); - } - private HttpResponse handleUpgradeRequest(ServiceRequestContext ctx, HttpRequest req, ResponseHeaders responseHeaders) { - final WebSocketFrameDecoder decoder = - new WebSocketFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch, - true); // client sends masked frames. + final WebSocketServiceFrameDecoder decoder = + new WebSocketServiceFrameDecoder(ctx, maxFramePayloadLength, allowMaskMismatch); final StreamMessage inboundFrames = req.decode(decoder, ctx.alloc()); final WebSocket outboundFrames = handler.handle(ctx, new WebSocketWrapper(inboundFrames)); decoder.setOutboundWebSocket(outboundFrames); diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java index 3256706caf6..4003a934257 100644 --- a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceBuilder.java @@ -37,9 +37,9 @@ * This service has the different default configs from a normal {@link HttpService}. Here are the differences: *
    *
  • {@link ServiceConfig#requestTimeoutMillis()} is - * {@value WebSocketUtil#DEFAULT_REQUEST_TIMEOUT_MILLIS}.
  • + * {@value WebSocketUtil#DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS}. *
  • {@link ServiceConfig#maxRequestLength()} is - * {@value WebSocketUtil#DEFAULT_MAX_REQUEST_LENGTH}.
  • + * {@value WebSocketUtil#DEFAULT_MAX_REQUEST_RESPONSE_LENGTH}. *
  • {@link ServiceConfig#requestAutoAbortDelayMillis()} is * {@value WebSocketUtil#DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS}.
  • *
diff --git a/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java new file mode 100644 index 00000000000..28bbd71aa2c --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/websocket/WebSocketServiceFrameDecoder.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.server.websocket; + +import com.linecorp.armeria.common.HttpRequestWriter; +import com.linecorp.armeria.common.Request; +import com.linecorp.armeria.internal.common.RequestContextExtension; +import com.linecorp.armeria.internal.common.websocket.WebSocketFrameDecoder; +import com.linecorp.armeria.server.ServiceRequestContext; + +final class WebSocketServiceFrameDecoder extends WebSocketFrameDecoder { + + private final ServiceRequestContext ctx; + + WebSocketServiceFrameDecoder(ServiceRequestContext ctx, int maxFramePayloadLength, + boolean allowMaskMismatch) { + super(ctx, maxFramePayloadLength, allowMaskMismatch); + this.ctx = ctx; + } + + @Override + protected boolean expectMaskedFrames() { + return true; + } + + @Override + protected void onCloseFrameRead() { + final RequestContextExtension ctxExtension = ctx.as(RequestContextExtension.class); + assert ctxExtension != null; + final Request request = ctxExtension.originalRequest(); + assert request instanceof HttpRequestWriter : request; + //noinspection OverlyStrongTypeCast + ((HttpRequestWriter) request).close(); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/DefaultWebClientTest.java b/core/src/test/java/com/linecorp/armeria/client/DefaultWebClientTest.java index fb6d397f0df..013b0dfe196 100644 --- a/core/src/test/java/com/linecorp/armeria/client/DefaultWebClientTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/DefaultWebClientTest.java @@ -15,6 +15,7 @@ */ package com.linecorp.armeria.client; +import static com.linecorp.armeria.internal.client.ClientUtil.UNDEFINED_URI; import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; @@ -43,7 +44,7 @@ void testConcatenateRequestPath() { @Test void testRequestParamsUndefinedEndPoint() { final String path = "http://127.0.0.1/helloWorld/test?q1=foo"; - final WebClient client = WebClient.of(AbstractWebClientBuilder.UNDEFINED_URI); + final WebClient client = WebClient.of(UNDEFINED_URI); try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { client.execute(HttpRequest.of(RequestHeaders.of(HttpMethod.GET, path))).aggregate(); @@ -54,7 +55,7 @@ void testRequestParamsUndefinedEndPoint() { @Test void testWithoutRequestParamsUndefinedEndPoint() { final String path = "http://127.0.0.1/helloWorld/test"; - final WebClient client = WebClient.of(AbstractWebClientBuilder.UNDEFINED_URI); + final WebClient client = WebClient.of(UNDEFINED_URI); try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { client.execute(HttpRequest.of(RequestHeaders.of(HttpMethod.GET, path))).aggregate(); @@ -118,7 +119,7 @@ void testWithQueryParams() { final QueryParams queryParams = QueryParams.builder() .add("q1", "foo") .build(); - final WebClient client = WebClient.of(AbstractWebClientBuilder.UNDEFINED_URI); + final WebClient client = WebClient.of(UNDEFINED_URI); try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { client.get(path, queryParams).aggregate(); assertThat(captor.get().request().path()).isEqualTo("/helloWorld/test?q1=foo"); diff --git a/core/src/test/java/com/linecorp/armeria/client/Http1ResponseDecoderTest.java b/core/src/test/java/com/linecorp/armeria/client/Http1ResponseDecoderTest.java index 360a9083273..c600209cc17 100644 --- a/core/src/test/java/com/linecorp/armeria/client/Http1ResponseDecoderTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/Http1ResponseDecoderTest.java @@ -20,6 +20,8 @@ import org.junit.jupiter.api.Test; +import com.linecorp.armeria.common.SessionProtocol; + import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.DefaultHttpResponse; @@ -33,8 +35,9 @@ class Http1ResponseDecoderTest { @Test void testRequestTimeoutClosesImmediately() throws Exception { final EmbeddedChannel channel = new EmbeddedChannel(); - try { - final Http1ResponseDecoder decoder = new Http1ResponseDecoder(channel); + try (HttpClientFactory httpClientFactory = new HttpClientFactory(ClientFactoryOptions.of())) { + final Http1ResponseDecoder decoder = new Http1ResponseDecoder( + channel, httpClientFactory, SessionProtocol.H1); channel.pipeline().addLast(decoder); final HttpHeaders httpHeaders = new DefaultHttpHeaders(); diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpResponseDecoderTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpResponseDecoderTest.java index b7f5e99b760..ebc99f9be09 100644 --- a/core/src/test/java/com/linecorp/armeria/client/HttpResponseDecoderTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/HttpResponseDecoderTest.java @@ -29,7 +29,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.linecorp.armeria.client.HttpResponseDecoder.HttpResponseWrapper; import com.linecorp.armeria.client.retry.Backoff; import com.linecorp.armeria.client.retry.RetryDecision; import com.linecorp.armeria.client.retry.RetryRule; @@ -59,8 +58,8 @@ protected void configure(ServerBuilder sb) { }; /** - * This test would be passed because the {@code cancelAction} method of the {@link HttpResponseWrapper} is - * invoked in the event loop of the {@link Channel}. + * This test would be passed because the {@code cancelAction} method of the + * {@link HttpResponseWrapper} is invoked in the event loop of the {@link Channel}. */ @ParameterizedTest @EnumSource(value = SessionProtocol.class, names = {"H1C", "H2C"}) diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java index 98de15c9a09..092094e43fe 100644 --- a/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/HttpResponseWrapperTest.java @@ -20,7 +20,6 @@ import org.junit.jupiter.api.Test; -import com.linecorp.armeria.client.HttpResponseDecoder.HttpResponseWrapper; import com.linecorp.armeria.common.CommonPools; import com.linecorp.armeria.common.HttpData; import com.linecorp.armeria.common.HttpHeaderNames; @@ -161,11 +160,10 @@ private static HttpResponseWrapper httpResponseWrapper(DecodedHttpResponse res) final TestHttpResponseDecoder decoder = new TestHttpResponseDecoder(channel, controller); res.init(controller); - return decoder.addResponse(1, res, cctx, cctx.eventLoop(), cctx.responseTimeoutMillis(), - cctx.maxResponseLength()); + return decoder.addResponse(1, res, cctx, cctx.eventLoop()); } - private static class TestHttpResponseDecoder extends HttpResponseDecoder { + private static class TestHttpResponseDecoder extends AbstractHttpResponseDecoder { private final KeepAliveHandler keepAliveHandler = new NoopKeepAliveHandler(); TestHttpResponseDecoder(Channel channel, InboundTrafficController inboundTrafficController) { @@ -176,7 +174,7 @@ private static class TestHttpResponseDecoder extends HttpResponseDecoder { void onResponseAdded(int id, EventLoop eventLoop, HttpResponseWrapper responseWrapper) {} @Override - KeepAliveHandler keepAliveHandler() { + public KeepAliveHandler keepAliveHandler() { return keepAliveHandler; } } diff --git a/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilderTest.java b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilderTest.java new file mode 100644 index 00000000000..df22c6d51c3 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientBuilderTest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.websocket; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import com.linecorp.armeria.client.ClientOptions; +import com.linecorp.armeria.client.Endpoint; + +class WebSocketClientBuilderTest { + + @CsvSource({ + "http, ws+http", + "https, ws+https", + "h1, ws+h1", + "h1c, ws+h1c", + "h2, ws+h2", + "h2c, ws+h2c", + "http, ws+http", + "https, ws+https", + "ws, ws+http", + "wss, ws+https", + "ws+h1, ws+h1", + "ws+h1c, ws+h1c", + "ws+h2, ws+h2", + "ws+h2c, ws+h2c", + "ws+http, ws+http", + "ws+https, ws+https", + }) + @ParameterizedTest + void uriWithWsPlusProtocol(String scheme, String convertedScheme) { + final WebSocketClient client = WebSocketClient.builder(scheme + "://google.com/").build(); + assertThat(client.uri().toString()).isEqualTo(convertedScheme + "://google.com/"); + } + + @CsvSource({ + "http, ws+http", + "https, ws+https", + "h1, ws+h1", + "h1c, ws+h1c", + "h2, ws+h2", + "h2c, ws+h2c", + "http, ws+http", + "https, ws+https", + "ws, ws+http", + "wss, ws+https", + "ws+h1, ws+h1", + "ws+h1c, ws+h1c", + "ws+h2, ws+h2", + "ws+h2c, ws+h2c", + "ws+http, ws+http", + "ws+https, ws+https", + }) + @ParameterizedTest + void endpointWithoutPath(String scheme, String convertedScheme) { + final WebSocketClient client = WebSocketClient.builder(scheme, Endpoint.of("127.0.0.1")).build(); + assertThat(client.uri().toString()).isEqualTo(convertedScheme + "://127.0.0.1/"); + } + + @CsvSource({ + "http, ws+http", + "https, ws+https", + "h1, ws+h1", + "h1c, ws+h1c", + "h2, ws+h2", + "h2c, ws+h2c", + "http, ws+http", + "https, ws+https", + "ws, ws+http", + "wss, ws+https", + "ws+h1, ws+h1", + "ws+h1c, ws+h1c", + "ws+h2, ws+h2", + "ws+h2c, ws+h2c", + "ws+http, ws+http", + "ws+https, ws+https", + }) + @ParameterizedTest + void endpointWithPath(String scheme, String convertedScheme) { + final WebSocketClient client = WebSocketClient.builder(scheme, Endpoint.of("127.0.0.1"), "/foo") + .build(); + assertThat(client.uri().toString()).isEqualTo(convertedScheme + "://127.0.0.1/foo"); + } + + @Test + void webSocketClientDefaultOptions() { + final WebSocketClient client = WebSocketClient.builder("wss://google.com/").build(); + assertThat(client.options().get(ClientOptions.RESPONSE_TIMEOUT_MILLIS)).isEqualTo(0); + assertThat(client.options().get(ClientOptions.MAX_RESPONSE_LENGTH)).isEqualTo(0); + assertThat(client.options().get(ClientOptions.REQUEST_AUTO_ABORT_DELAY_MILLIS)).isEqualTo(5000); + assertThat(client.options().get(ClientOptions.AUTO_FILL_ORIGIN_HEADER)).isTrue(); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientHandshakeTest.java b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientHandshakeTest.java new file mode 100644 index 00000000000..b45f1c23de2 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientHandshakeTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.websocket; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import com.linecorp.armeria.client.websocket.WebSocketClientTest.WebSocketServiceEchoHandler; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.SerializationFormat; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.websocket.WebSocketService; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +class WebSocketClientHandshakeTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + sb.service("/chat", WebSocketService.builder(new WebSocketServiceEchoHandler()) + .subprotocols("foo", "foo1", "foo2") + .build()); + } + }; + + @CsvSource({ + "H1C, foo2, foo1, foo2", + "H1C, bar1, bar2, ", + "H2C, foo2, foo1, foo2", + "H2C, bar1, bar2, " + }) + @ParameterizedTest + void subprotocol(SessionProtocol sessionProtocol, + String subprotocol1, String subprotocol2, @Nullable String selected) { + final WebSocketClient client = + WebSocketClient.builder(server.uri(sessionProtocol, SerializationFormat.WS)) + .subprotocols(subprotocol1, subprotocol2) + .build(); + final WebSocketSession session = client.connect("/chat").join(); + if (selected == null) { + assertThat(session.responseHeaders().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)).isNull(); + } else { + assertThat(session.responseHeaders().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)) + .isEqualTo(selected); + } + // Abort the session to close the connection. + final WebSocketWriter outbound = WebSocket.streaming(); + outbound.abort(); + session.setOutbound(outbound); + session.inbound().abort(); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientTest.java b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientTest.java new file mode 100644 index 00000000000..c98417de6a5 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/websocket/WebSocketClientTest.java @@ -0,0 +1,196 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client.websocket; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.linecorp.armeria.client.ClientFactory; +import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.SerializationFormat; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.logging.RequestLogProperty; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketFrameType; +import com.linecorp.armeria.common.websocket.WebSocketWriter; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.websocket.WebSocketService; +import com.linecorp.armeria.server.websocket.WebSocketServiceHandler; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +class WebSocketClientTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + sb.http(0) + .https(0) + .tlsSelfSigned(); + sb.route() + .get("/chat") + .connect("/chat") + .requestAutoAbortDelayMillis(5000) + .build(WebSocketService.of(new WebSocketServiceEchoHandler())); + } + }; + + @CsvSource({ + "H1, false", + "H1C, true", + "H1C, false", + "H2, false", + "H2C, true", + "H2C, false", + "HTTP, true", + "HTTP, false", + "HTTPS, false" + }) + @ParameterizedTest + void webSocketClient(SessionProtocol protocol, boolean defaultClient) throws InterruptedException { + // TODO(minwoox): Add server.webSocketClient(); + final CompletableFuture future; + if (defaultClient) { + future = WebSocketClient.of().connect(server.uri(protocol, SerializationFormat.WS) + "/chat"); + } else { + final WebSocketClient webSocketClient = + WebSocketClient.builder(server.uri(protocol, SerializationFormat.WS)) + .factory(ClientFactory.insecure()) + .build(); + future = webSocketClient.connect("/chat"); + } + final WebSocketSession webSocketSession = future.join(); + final ServiceRequestContext sctx = server.requestContextCaptor().take(); + final RequestHeaders headers = sctx.log().ensureAvailable(RequestLogProperty.REQUEST_HEADERS) + .requestHeaders(); + assertThat(headers.get(HttpHeaderNames.ORIGIN)).isEqualTo( + protocol.isHttps() ? server.httpsUri().toString() : server.httpUri().toString()); + + final WebSocketWriter outbound = webSocketSession.outbound(); + outbound.write(WebSocketFrame.ofText("hello")); + + final WebSocketInboundHandler inboundHandler = new WebSocketInboundHandler( + webSocketSession.inbound(), protocol); + + WebSocketFrame frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofText("hello")); + + frame = inboundHandler.inboundQueue().poll(1, TimeUnit.SECONDS); + assertThat(frame).isNull(); + + outbound.write(WebSocketFrame.ofText("armeria")); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofText("armeria")); + + outbound.close(WebSocketCloseStatus.NORMAL_CLOSURE); + frame = inboundHandler.inboundQueue().take(); + assertThat(frame).isEqualTo(WebSocketFrame.ofClose(WebSocketCloseStatus.NORMAL_CLOSURE)); + inboundHandler.completionFuture().join(); + await().until(outbound::isComplete); + } + + static final class WebSocketInboundHandler { + + private final ArrayBlockingQueue inboundQueue = new ArrayBlockingQueue<>(4); + private final CompletableFuture completionFuture = new CompletableFuture<>(); + + WebSocketInboundHandler(WebSocket inbound, SessionProtocol protocol) { + inbound.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + inboundQueue.add(webSocketFrame); + } + + @Override + public void onError(Throwable t) { + if (protocol.isExplicitHttp1()) { + // After receiving a close frame, ClosedSessionException can be raised for HTTP/1.1 + // before onComplete is called. + assertThat(t).isExactlyInstanceOf(ClosedSessionException.class); + } + completionFuture.complete(null); + } + + @Override + public void onComplete() { + completionFuture.complete(null); + } + }); + } + + ArrayBlockingQueue inboundQueue() { + return inboundQueue; + } + + CompletableFuture completionFuture() { + return completionFuture; + } + } + + static final class WebSocketServiceEchoHandler implements WebSocketServiceHandler { + + @Override + public WebSocket handle(ServiceRequestContext ctx, WebSocket in) { + final WebSocketWriter writer = WebSocket.streaming(); + in.subscribe(new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + if (webSocketFrame.type() != WebSocketFrameType.PING && + webSocketFrame.type() != WebSocketFrameType.PONG) { + writer.write(webSocketFrame); + } + } + + @Override + public void onError(Throwable t) { + writer.close(t); + } + + @Override + public void onComplete() { + writer.close(); + } + }); + return writer; + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java index aecc8ea7278..409e0ebe583 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/websocket/WebSocketFrameEncoderAndDecoderTest.java @@ -52,6 +52,7 @@ import com.linecorp.armeria.common.HttpRequestWriter; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpResponseWriter; +import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.RequestHeaders; import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; import com.linecorp.armeria.common.websocket.WebSocketFrame; @@ -114,7 +115,7 @@ public void testWebSocketProtocolViolation() throws InterruptedException { final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(true); final HttpRequestWriter requestWriter = HttpRequest.streaming(RequestHeaders.of(HttpMethod.GET, "/")); final WebSocketFrameDecoder decoder = - new WebSocketFrameDecoder(ctx, maxPayloadLength, false, true); + new TestWebSocketFrameDecoder(ctx, maxPayloadLength, false, true); final CompletableFuture whenComplete = new CompletableFuture<>(); requestWriter.decode(decoder, ctx.alloc()).subscribe(subscriber(whenComplete)); @@ -140,8 +141,8 @@ public void testWebSocketEncodingAndDecoding(boolean maskPayload, boolean allowM final HttpResponseWriter httpResponseWriter = HttpResponse.streaming(); final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(maskPayload); final HttpRequestWriter requestWriter = HttpRequest.streaming(RequestHeaders.of(HttpMethod.GET, "/")); - final WebSocketFrameDecoder decoder = new WebSocketFrameDecoder(ctx, 1024 * 1024, allowMaskMismatch, - maskPayload); + final WebSocketFrameDecoder decoder = new TestWebSocketFrameDecoder( + ctx, 1024 * 1024, allowMaskMismatch, maskPayload); requestWriter.decode(decoder, ctx.alloc()).subscribe(subscriber(new CompletableFuture<>())); executeTests(encoder, requestWriter); httpResponseWriter.abort(); @@ -229,4 +230,23 @@ public void onComplete() { } }; } + + private static class TestWebSocketFrameDecoder extends WebSocketFrameDecoder { + + private final boolean expectMaskedFrames; + + TestWebSocketFrameDecoder(RequestContext ctx, int maxFramePayloadLength, + boolean allowMaskMismatch, boolean expectMaskedFrames) { + super(ctx, maxFramePayloadLength, allowMaskMismatch); + this.expectMaskedFrames = expectMaskedFrames; + } + + @Override + protected boolean expectMaskedFrames() { + return expectMaskedFrames; + } + + @Override + protected void onCloseFrameRead() {} + } } diff --git a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java index 7303cfcc7b3..31862ce50c4 100644 --- a/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/websocket/WebSocketServiceConfigTest.java @@ -34,8 +34,9 @@ void webSocketServiceDefaultConfigValues() { assertThat(server.config().serviceConfigs()).hasSize(1); ServiceConfig serviceConfig = server.config().serviceConfigs().get(0); assertThat(serviceConfig.requestTimeoutMillis()).isEqualTo( - WebSocketUtil.DEFAULT_REQUEST_TIMEOUT_MILLIS); - assertThat(serviceConfig.maxRequestLength()).isEqualTo(WebSocketUtil.DEFAULT_MAX_REQUEST_LENGTH); + WebSocketUtil.DEFAULT_REQUEST_RESPONSE_TIMEOUT_MILLIS); + assertThat(serviceConfig.maxRequestLength()).isEqualTo( + WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH); assertThat(serviceConfig.requestAutoAbortDelayMillis()).isEqualTo( WebSocketUtil.DEFAULT_REQUEST_AUTO_ABORT_DELAY_MILLIS); @@ -48,7 +49,8 @@ void webSocketServiceDefaultConfigValues() { assertThat(server.config().serviceConfigs()).hasSize(1); serviceConfig = server.config().serviceConfigs().get(0); assertThat(serviceConfig.requestTimeoutMillis()).isEqualTo(2000); - assertThat(serviceConfig.maxRequestLength()).isEqualTo(WebSocketUtil.DEFAULT_MAX_REQUEST_LENGTH); + assertThat(serviceConfig.maxRequestLength()).isEqualTo( + WebSocketUtil.DEFAULT_MAX_REQUEST_RESPONSE_LENGTH); assertThat(serviceConfig.requestAutoAbortDelayMillis()).isEqualTo(1000); } } diff --git a/dependencies.toml b/dependencies.toml index cd44cd974b9..fbf0321d9ec 100644 --- a/dependencies.toml +++ b/dependencies.toml @@ -610,6 +610,9 @@ version.ref = "jetty11" [libraries.jetty11-apache-jstl] module = "org.eclipse.jetty:apache-jstl" version.ref = "jetty11-jstl" +[libraries.jetty11-http2-server] +module = "org.eclipse.jetty.http2:http2-server" +version.ref = "jetty11" [libraries.jetty11-server] module = "org.eclipse.jetty:jetty-server" version.ref = "jetty11" @@ -617,6 +620,10 @@ version.ref = "jetty11" [libraries.jetty11-webapp] module = "org.eclipse.jetty:jetty-webapp" version.ref = "jetty11" +# jetty-websocket for testing WebSocket interoperability. +[libraries.jetty11-websocket] +module = "org.eclipse.jetty.websocket:websocket-jakarta-server" +version.ref = "jetty11" [libraries.jetty93-annotations] module = "org.eclipse.jetty:jetty-annotations" diff --git a/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroup.java b/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroup.java index 216f8ce291d..d615d758da1 100644 --- a/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroup.java +++ b/eureka/src/main/java/com/linecorp/armeria/client/eureka/EurekaEndpointGroup.java @@ -419,8 +419,6 @@ private static Endpoint endpoint(InstanceInfo instanceInfo, boolean secureVip) { @Override public String toString() { - return toStringHelper() - .add("requestHeaders", requestHeaders) - .toString(); + return toString(buf -> buf.append(", requestHeaders=").append(requestHeaders)); } } diff --git a/gradle.properties b/gradle.properties index 73cf970bceb..f2d192c8ec8 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ group=com.linecorp.armeria -version=1.24.3-SNAPSHOT +version=1.25.1-SNAPSHOT projectName=Armeria projectUrl=https://armeria.dev/ projectDescription=Asynchronous HTTP/2 RPC/REST client/server library built on top of Java 8, Netty, Thrift and gRPC diff --git a/gradle/scripts/lib/java-shade.gradle b/gradle/scripts/lib/java-shade.gradle index 07b8577a893..dc371be3b49 100644 --- a/gradle/scripts/lib/java-shade.gradle +++ b/gradle/scripts/lib/java-shade.gradle @@ -137,7 +137,7 @@ configure(relocatedProjects) { if (jmodDir.isDirectory()) { jmodDir.listFiles().findAll { File f -> f.isFile() && f.name.toLowerCase(Locale.ENGLISH).endsWith(".jmod") - }.each { libraryjars it } + }.sort().each { libraryjars it } } else { libraryjars file("${System.getProperty('java.home')}/lib/rt.jar") } diff --git a/it/websocket/build.gradle b/it/websocket/build.gradle index b3c1d0ccb1a..21c8e4ab41e 100644 --- a/it/websocket/build.gradle +++ b/it/websocket/build.gradle @@ -1,3 +1,7 @@ dependencies { + testImplementation libs.jakarta.websocket + testImplementation libs.jetty11.http2.server + testImplementation libs.jetty11.websocket testImplementation libs.java.websocket + testImplementation libs.logback14 } diff --git a/it/websocket/src/test/java/com/linecorp/armeria/it/websocket/WebSocketClientItTest.java b/it/websocket/src/test/java/com/linecorp/armeria/it/websocket/WebSocketClientItTest.java new file mode 100644 index 00000000000..7cb329be1fe --- /dev/null +++ b/it/websocket/src/test/java/com/linecorp/armeria/it/websocket/WebSocketClientItTest.java @@ -0,0 +1,202 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.it.websocket; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Queue; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CountDownLatch; + +import org.eclipse.jetty.http2.server.HTTP2ServerConnectionFactory; +import org.eclipse.jetty.server.HttpConfiguration; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.websocket.jakarta.server.config.JakartaWebSocketServletContainerInitializer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.client.websocket.WebSocketClient; +import com.linecorp.armeria.client.websocket.WebSocketSession; +import com.linecorp.armeria.common.ClosedSessionException; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.websocket.CloseWebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocket; +import com.linecorp.armeria.common.websocket.WebSocketCloseStatus; +import com.linecorp.armeria.common.websocket.WebSocketFrame; +import com.linecorp.armeria.common.websocket.WebSocketWriter; + +import jakarta.websocket.CloseReason; +import jakarta.websocket.OnClose; +import jakarta.websocket.OnError; +import jakarta.websocket.OnMessage; +import jakarta.websocket.OnOpen; +import jakarta.websocket.Session; +import jakarta.websocket.server.ServerEndpoint; +import jakarta.websocket.server.ServerEndpointConfig; + +class WebSocketClientItTest { + + private static final Queue sentMessages = new ArrayBlockingQueue<>(2); + + @BeforeEach + void setUp() { + sentMessages.clear(); + } + + @CsvSource({ "h1c", "h2c" }) + @ParameterizedTest + void webSocketClientIt(String protocol) throws Exception { + final Server server = new Server(); + final ServerConnector connector = createConnector(protocol, server); + server.addConnector(connector); + setupJettyWebSocket(server); + server.start(); + + final WebSocketClient client = WebSocketClient.of( + protocol + "://127.0.0.1:" + connector.getLocalPort()); + final WebSocketSession webSocketSession = client.connect("/chat").join(); + + final WebSocketWriter writer = WebSocket.streaming(); + webSocketSession.setOutbound(writer); + writer.write("Hello, world!"); + writer.write("bye"); + writer.close(); + + final CountDownLatch latch = new CountDownLatch(1); + final List frames = new ArrayList<>(); + webSocketSession.inbound().subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(WebSocketFrame webSocketFrame) { + frames.add(webSocketFrame); + } + + @Override + public void onError(Throwable t) { + // The connection is closed by the server if HTTP/1.1 + assertThat(t).isExactlyInstanceOf(ClosedSessionException.class); + latch.countDown(); + } + + @Override + public void onComplete() { + latch.countDown(); + } + }); + latch.await(); + assertThat(frames.size()).isOne(); + final WebSocketFrame frame = frames.get(0); + assertThat(frame).isInstanceOf(CloseWebSocketFrame.class); + assertThat(((CloseWebSocketFrame) frame).status()).isSameAs(WebSocketCloseStatus.NORMAL_CLOSURE); + + assertThat(sentMessages).containsExactly("Hello, world!", "bye"); + server.stop(); + } + + @CsvSource({ + "h1c, foo2, foo1, foo2", + "h1c, bar1, bar2, ", + "h2c, foo2, foo1, foo2", + "h2c, bar1, bar2, " + }) + @ParameterizedTest + void subprotocol(String sessionProtocol, + String subprotocol1, String subprotocol2, @Nullable String selected) throws Exception { + final Server server = new Server(); + final ServerConnector connector = createConnector(sessionProtocol, server); + server.addConnector(connector); + setupJettyWebSocket(server); + server.start(); + + final WebSocketClient client = + WebSocketClient.builder(sessionProtocol + "://127.0.0.1:" + connector.getLocalPort()) + .subprotocols(subprotocol1, subprotocol2) + .build(); + final WebSocketSession session = client.connect("/chat").join(); + if (selected == null) { + assertThat(session.responseHeaders().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)).isNull(); + } else { + assertThat(session.responseHeaders().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)) + .isEqualTo(selected); + } + // Abort the session to close the connection. + final WebSocketWriter outbound = WebSocket.streaming(); + outbound.abort(); + session.setOutbound(outbound); + session.inbound().abort(); + + server.stop(); + } + + private static void setupJettyWebSocket(Server server) { + final ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath("/"); + server.setHandler(context); + + JakartaWebSocketServletContainerInitializer.configure(context, (servletContext, wsContainer) -> { + // Add WebSocket endpoint + wsContainer.addEndpoint( + ServerEndpointConfig.Builder.create(EventSocket.class, "/chat") + .subprotocols(ImmutableList.of("foo", "foo1", "foo2")) + .build()); + }); + } + + private static ServerConnector createConnector(String protocol, Server server) { + if ("h1c".equals(protocol)) { + return new ServerConnector(server); + } + return new ServerConnector(server, new HTTP2ServerConnectionFactory(new HttpConfiguration())); + } + + @ServerEndpoint("/") + public static class EventSocket { + @OnOpen + public void onWebSocketConnect(Session sess) {} + + @OnMessage + public void onWebSocketText(Session sess, String message) throws IOException { + sentMessages.add(message); + if (message.toLowerCase(Locale.US).contains("bye")) { + sess.close(new CloseReason(CloseReason.CloseCodes.NORMAL_CLOSURE, "Thanks")); + } + } + + @OnClose + public void onWebSocketClose(CloseReason reason) {} + + @OnError + public void onWebSocketError(Throwable cause) {} + } +} diff --git a/javadoc/build.gradle b/javadoc/build.gradle index 2cf1394b86c..ff579f48f28 100644 --- a/javadoc/build.gradle +++ b/javadoc/build.gradle @@ -31,8 +31,13 @@ task checkJavadoc( 'nested.classes.inherited.from.class.' ] def allowedListPrefixes = ['java.', 'javax.'] - def disallowedListPrefixes = [ 'com.linecorp.armeria.internal.' ] + - rootProject.ext.relocations.collect { it[1] + '.' } + def disallowedListPrefixes = ['com.linecorp.armeria.internal.'] + disallowedListPrefixes.addAll(rootProject.ext.relocations.collect { + def packageName = it["from"] + assert packageName != null + packageName + '.' + }) + def errors = [] reportFile.parentFile.mkdirs() diff --git a/scala/scala_2.13/src/main/scala/com/linecorp/armeria/client/scala/ScalaRestClientPreparation.scala b/scala/scala_2.13/src/main/scala/com/linecorp/armeria/client/scala/ScalaRestClientPreparation.scala index a40c44dd8f2..4f068d7e93a 100644 --- a/scala/scala_2.13/src/main/scala/com/linecorp/armeria/client/scala/ScalaRestClientPreparation.scala +++ b/scala/scala_2.13/src/main/scala/com/linecorp/armeria/client/scala/ScalaRestClientPreparation.scala @@ -190,6 +190,11 @@ final class ScalaRestClientPreparation private[scala] (delegate: RestClientPrepa this } + override def content(content: Publisher[_ <: HttpData]): ScalaRestClientPreparation = { + delegate.content(content) + this + } + override def content( contentType: MediaType, content: Publisher[_ <: HttpData]): ScalaRestClientPreparation = { diff --git a/settings.gradle b/settings.gradle index be2da1855b7..00d66d1bc91 100644 --- a/settings.gradle +++ b/settings.gradle @@ -198,7 +198,7 @@ includeWithFlags ':it:spring:webflux-security', 'java17', 'reloca includeWithFlags ':it:thrift-fullcamel', 'java', 'relocate' includeWithFlags ':it:thrift0.9.1', 'java', 'relocate' includeWithFlags ':it:trace-context-leak', 'java', 'relocate' -includeWithFlags ':it:websocket', 'java', 'relocate' +includeWithFlags ':it:websocket', 'java11', 'relocate' includeWithFlags ':jetty9.3', 'java', 'relocate' project(':jetty9.3').projectDir = file('jetty/jetty9.3') includeWithFlags ':testing-internal', 'java', 'relocate' diff --git a/site/src/pages/release-notes/1.25.0.mdx b/site/src/pages/release-notes/1.25.0.mdx new file mode 100644 index 00000000000..f82dd1217c2 --- /dev/null +++ b/site/src/pages/release-notes/1.25.0.mdx @@ -0,0 +1,162 @@ +--- +date: 2023-08-22 +--- + +## 🌟 New features + +- **GraalVM Support**: Armeria now provides [GraalVM](https://www.graalvm.org/) + [reachability metadata](https://www.graalvm.org/latest/reference-manual/native-image/metadata/) to easily build + [GraalVM](https://www.graalvm.org/) native images. #5005 +- **Micrometer Observation Support**: Support for [Micrometer Observation](https://micrometer.io/docs/observation) is added. + Refer to or for details on how to integrate with Armeria. #4659 #4980 + ```java + ObservationRegistry observationRegistry = ... + WebClient.builder() + .decorator(ObservationClient.newDecorator(observationRegistry)) + ... + Server.builder() + .decorator(ObservationService.newDecorator(observationRegistry)) + ... + ``` +- **WebSocket Client Support**: You can now send and receive data over [WebSocket](https://en.wikipedia.org/wiki/WebSocket) + using . #4972 + ```java + WebSocketClient client = WebSocketClient.of("ws://..."); + client.connect("/").thenAccept(webSocketSession -> { + WebSocketWriter writer = WebSocket.streaming(); + webSocketSessions.setOutbound(writer); + outbound.write("Hello!"); + + Subscriber subscriber = new Subscriber() { + ... + } + webSocketSessions.inbound().subscribe(subscriber); + }); + ``` +- **Implement gRPC Richer Error Model More Easily**: You can now easily use gRPC + [Richer Error Model](https://grpc.io/docs/guides/error/#richer-error-model) via . #4614 #4986 + ```java + GoogleGrpcStatusFunction statusFunction = (ctx, throwable, metadata) -> { + if (throwable instanceof MyException) { + return com.google.rpc.Status.newBuilder() + .setCode(Code.UNAUTHENTICATED.getNumber()) + .addDetails(detail(throwable)) + .build(); + } + ... + }; + Server.builder().service( + GrpcService.builder() + .exceptionMapping(statusFunction)) + ``` +- **Set HTTP Trailers Easily** You can now easily set trailers to be sent after the data stream using + or + . #3959 #4727 +- **New API for Multipart Headers**: You can now retrieve headers from a multipart request in an annotated service + using . #5106 +- **Access RequestLogProperty Values More Easily**: + has been introduced, which allows users to access a immediately if available. #4956 #4966 +- **Keep an Idle Connection Alive on PING**: The `keepAliveOnPing` option has been introduced. Enabling this option will keep + an idle connection alive when an HTTP/2 PING frame or `OPTIONS * HTTP/1.1` is received. The option can be configured + by or . #4794 #4806 +- **Create a StreamMessage from Future**: You can now easily create a from a `CompletionStage` + using )>. #4995 +- **More Shortcuts for PrometheusExpositionService**: You can now create a without + specifying the default `CollectorRegistry` explicitly. #5134 + +## 📈 Improvements + +- The number of event loops is equal to the number of cores by default when `io_uring` is used as the transport type. #5089 +- You can now customize error responses when a service for a request is not found + using . #4996 +- Redirection for a trailing slash is done correctly even if a reverse proxy rewrites the path. #4994 +- now tries to guess the correct route behind a reverse proxy. #4987 +- The `RetentionPolicy` of annotation is now `CLASS` so that + bytecode analysis tools can detect the declaration and usage of unstable APIs. #5131 + +## 🛠️ Bug fixes + +- now returns an `INTERNAL` error code if an error occurs while serializing gRPC metadata. #4625 #4686 +- now allows zero TTL for resolved DNS records. #5119 +- Armeria's DNS resolver doesn't cache a DNS whose query was timed out. #5117 +- Fixed a bug where headers could be written twice if `Content-Length` was exceeded during HTTP/2 cleartext upgrade. #5113 +- and now return + correct values when using domain sockets in abstract namespace. #5096 +- `armeria-logback12`, `armeria-logback13`, and `armeria-logback14` have been introduced for better + compatibility with [Logback](https://logback.qos.ch/). #5045 #5079 #5078 #5077 +- You can now use either an inline debug form or a modal debug form when using . #5072 +- When using Spring integrations, even if `internal-services.port` and `management.server.port` + are set to the same value internal services are bound to the port only once. #4796 #5022 +- Exceptions that occurred during a TLS handshake are properly propagated to users. #4950 +- now respects the `charset` attribute in the + `Content-Type` header if available. #4931 #4948 +- Routes with dynamic predicates are not incorrectly cached anymore. #4927 #4934 + +## 📃 Documentation + +- A new page has been added which describes how to integrate Armeria with Spring Boot. #4670 #4957 +- Documentation on how work in Armeria has been added. #4870 +- A new example on how to use [krotoDC](https://github.com/mscheong01/krotoDC) with Armeria has been added. #5092 + +## ☢️ Breaking changes + +- The `toStringHelper()` method in has been replaced + with `toString(Consumer)` to avoid exposing an internal API in the public API. #5132 + +## 🏚️ Deprecations + +- and its variants methods are deprecated. #5075 + - Use and its variants instead. + +## ⛓ Dependencies + +- gRPC-Java 1.56.0 → 1.57.2 +- GraphQL Kotlin 6.5.2 → 6.5.3 +- Guava 32.0.1-jre → 32.1.2-jre +- Jakarta Websocket 2.1.0 → 2.1.1 +- Kafka client 3.4.0 → 3.4.1 +- Kotlin 1.8.22 → 1.9.0 +- Kotlin Coroutine 1.7.1 → 1.7.3 +- Logback 1.4.7 → 1.4.11 +- Micrometer 1.11.1 → 1.11.3 +- Netty 4.1.94.Final → 4.1.96.Final +- Protobuf 3.22.3 → 3.24.0 +- Reactor 3.5.7 → 3.5.8 +- Resilience4j 2.0.2 → 2.1.0 +- Resteasy 5.0.5.Final → 5.0.7.Final +- Sangria 4.0.0 → 4.0.1 +- scala-collection-compat 2.10.0 → 2.11.0 +- Spring 6.0.9 → 6.0.11 +- Spring Boot 2.7.12 → 2.7.14, 3.1.0 → 3.1.1 +- Tomcat 10.1.10 → 10.1.12 + +## 🙇 Thank you + +