Skip to content

Commit

Permalink
Fix ExchangeType detection for HeaderOverridingHttpRequest. (#5787)
Browse files Browse the repository at this point in the history
Motivation:
The `RetryingClient` uses `ExchangeType` to decide whether to use `HttpRequestDuplicator`: https://github.com/line/armeria/blob/7474525b8cf25f02be6df7c38510a8fb6a88cb1f/core/src/main/java/com/linecorp/armeria/client/retry/RetryingClient.java#L245-L246 Thus, setting the proper `ExchangeType` is important.

`ExchangeType` is currently inferred from the `HttpRequest` implementation: https://github.com/line/armeria/blob/7474525b8cf25f02be6df7c38510a8fb6a88cb1f/core/src/main/java/com/linecorp/armeria/internal/client/DefaultClientRequestContext.java#L285-L288

However, `DefaultWebClient` wraps the request, which results in incorrect `ExchangeType` detection: https://github.com/line/armeria/blob/7474525b8cf25f02be6df7c38510a8fb6a88cb1f/core/src/main/java/com/linecorp/armeria/client/DefaultWebClient.java#L113

Modifications:
- Unwrapped `HeaderOverridingHttpRequest` to detect the correct `ExchangeType`.
- Added `RequestOptions` when sending a request where applicable.

Result:
- The `ExchangeType` is now correctly detected for the default `WebClient`.
  • Loading branch information
minwoox authored Jun 27, 2024
1 parent 9cbfb28 commit 369614f
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 13 deletions.
14 changes: 8 additions & 6 deletions core/src/main/java/com/linecorp/armeria/client/WebClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.linecorp.armeria.client;

import static com.linecorp.armeria.client.DefaultWebClient.RESPONSE_STREAMING_REQUEST_OPTIONS;
import static java.util.Objects.requireNonNull;

import java.net.URI;
Expand Down Expand Up @@ -249,47 +250,48 @@ default HttpResponse execute(HttpRequest req) {
@CheckReturnValue
default HttpResponse execute(AggregatedHttpRequest aggregatedReq) {
requireNonNull(aggregatedReq, "aggregatedReq");
return execute(aggregatedReq.toHttpRequest());
return execute(aggregatedReq.toHttpRequest(), RESPONSE_STREAMING_REQUEST_OPTIONS);
}

/**
* Sends an empty HTTP request with the specified headers.
*/
@CheckReturnValue
default HttpResponse execute(RequestHeaders headers) {
return execute(HttpRequest.of(headers));
return execute(HttpRequest.of(headers), RESPONSE_STREAMING_REQUEST_OPTIONS);
}

/**
* Sends an HTTP request with the specified headers and content.
*/
@CheckReturnValue
default HttpResponse execute(RequestHeaders headers, HttpData content) {
return execute(HttpRequest.of(headers, content));
return execute(HttpRequest.of(headers, content), RESPONSE_STREAMING_REQUEST_OPTIONS);
}

/**
* Sends an HTTP request with the specified headers and content.
*/
@CheckReturnValue
default HttpResponse execute(RequestHeaders headers, byte[] content) {
return execute(HttpRequest.of(headers, HttpData.wrap(content)));
return execute(HttpRequest.of(headers, HttpData.wrap(content)), RESPONSE_STREAMING_REQUEST_OPTIONS);
}

/**
* Sends an HTTP request with the specified headers and content.
*/
@CheckReturnValue
default HttpResponse execute(RequestHeaders headers, String content) {
return execute(HttpRequest.of(headers, HttpData.ofUtf8(content)));
return execute(HttpRequest.of(headers, HttpData.ofUtf8(content)), RESPONSE_STREAMING_REQUEST_OPTIONS);
}

/**
* Sends an HTTP request with the specified headers and content.
*/
@CheckReturnValue
default HttpResponse execute(RequestHeaders headers, String content, Charset charset) {
return execute(HttpRequest.of(headers, HttpData.of(charset, content)));
return execute(HttpRequest.of(headers, HttpData.of(charset, content)),
RESPONSE_STREAMING_REQUEST_OPTIONS);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import com.linecorp.armeria.common.stream.SubscriptionOption;
import com.linecorp.armeria.internal.common.DefaultHttpRequest;
import com.linecorp.armeria.internal.common.DefaultSplitHttpRequest;
import com.linecorp.armeria.internal.common.HeaderOverridingHttpRequest;
import com.linecorp.armeria.internal.common.stream.SurroundingPublisher;
import com.linecorp.armeria.unsafe.PooledObjects;

Expand Down Expand Up @@ -478,8 +479,7 @@ default HttpRequest withHeaders(RequestHeaders newHeaders) {
// Just check the reference only to avoid heavy comparison.
return this;
}

return new HeaderOverridingHttpRequest(this, newHeaders);
return HeaderOverridingHttpRequest.of(this, newHeaders);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import com.linecorp.armeria.common.util.TimeoutMode;
import com.linecorp.armeria.common.util.UnmodifiableFuture;
import com.linecorp.armeria.internal.common.CancellationScheduler;
import com.linecorp.armeria.internal.common.HeaderOverridingHttpRequest;
import com.linecorp.armeria.internal.common.NonWrappingRequestContext;
import com.linecorp.armeria.internal.common.RequestContextExtension;
import com.linecorp.armeria.internal.common.SchemeAndAuthority;
Expand Down Expand Up @@ -285,6 +286,11 @@ private static ExchangeType guessExchangeType(RequestOptions requestOptions, @Nu
if (req instanceof FixedStreamMessage) {
return ExchangeType.RESPONSE_STREAMING;
}
if (req instanceof HeaderOverridingHttpRequest) {
if (((HeaderOverridingHttpRequest) req).delegate() instanceof FixedStreamMessage) {
return ExchangeType.RESPONSE_STREAMING;
}
}
return ExchangeType.BIDI_STREAMING;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.linecorp.armeria.common;
package com.linecorp.armeria.internal.common;

import static java.util.Objects.requireNonNull;

Expand All @@ -25,6 +25,13 @@

import com.google.common.base.MoreObjects;

import com.linecorp.armeria.common.AggregatedHttpRequest;
import com.linecorp.armeria.common.AggregationOptions;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpObject;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.stream.SubscriptionOption;

Expand All @@ -33,16 +40,30 @@
/**
* An {@link HttpRequest} that overrides the {@link RequestHeaders}.
*/
final class HeaderOverridingHttpRequest implements HttpRequest {
public final class HeaderOverridingHttpRequest implements HttpRequest {

private final HttpRequest delegate;
private final RequestHeaders headers;

public static HeaderOverridingHttpRequest of(HttpRequest delegate, RequestHeaders headers) {
requireNonNull(delegate, "delegate");
requireNonNull(headers, "headers");
if (delegate instanceof HeaderOverridingHttpRequest) {
return new HeaderOverridingHttpRequest(
((HeaderOverridingHttpRequest) delegate).delegate(), headers);
}
return new HeaderOverridingHttpRequest(delegate, headers);
}

HeaderOverridingHttpRequest(HttpRequest delegate, RequestHeaders headers) {
this.delegate = delegate;
this.headers = headers;
}

public HttpRequest delegate() {
return delegate;
}

@Override
public HttpRequest withHeaders(RequestHeaders newHeaders) {
requireNonNull(newHeaders, "newHeaders");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ void fixedMessage() {
}).isEqualTo(ExchangeType.RESPONSE_STREAMING);
}

@Test
void headerOverridingFixedMessage() {
assertExchangeType(() -> {
client.execute(HttpRequest.of(HttpMethod.POST, "/",
MediaType.PLAIN_TEXT, "foo")
.withHeaders(RequestHeaders.builder(HttpMethod.POST, "/")
.add("foo", "bar")
.build()))
.aggregate();
}).isEqualTo(ExchangeType.RESPONSE_STREAMING);
}

@Test
void fixedMessageWithCustomRequestOptions() {
assertExchangeType(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.google.common.collect.ImmutableList;

import com.linecorp.armeria.common.stream.StreamMessage;
import com.linecorp.armeria.internal.common.HeaderOverridingHttpRequest;

import io.netty.buffer.ByteBufAllocator;
import reactor.core.publisher.Flux;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
import com.linecorp.armeria.client.ClientRequestContext;
import com.linecorp.armeria.client.Clients;
import com.linecorp.armeria.client.HttpClient;
import com.linecorp.armeria.client.RequestOptions;
import com.linecorp.armeria.client.SimpleDecoratingHttpClient;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.AggregationOptions;
import com.linecorp.armeria.common.ExchangeType;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpHeaders;
Expand Down Expand Up @@ -72,12 +74,17 @@
*/
@UnstableApi
public final class UnaryGrpcClient {

private static final Logger logger = LoggerFactory.getLogger(UnaryGrpcClient.class);

private static final Set<SerializationFormat> SUPPORTED_SERIALIZATION_FORMATS =
UnaryGrpcSerializationFormats.values();

private static final RequestOptions REQUEST_OPTIONS =
RequestOptions.builder().exchangeType(ExchangeType.UNARY).build();

private final SerializationFormat serializationFormat;
private final WebClient webClient;
private static final Logger logger = LoggerFactory.getLogger(UnaryGrpcClient.class);

/**
* Constructs a {@link UnaryGrpcClient} for the given {@link WebClient}.
Expand Down Expand Up @@ -131,7 +138,7 @@ public CompletableFuture<byte[]> execute(String uri, byte[] payload) {
RequestHeaders.builder(HttpMethod.POST, uri).contentType(serializationFormat.mediaType())
.add(HttpHeaderNames.TE, HttpHeaderValues.TRAILERS.toString()).build(),
HttpData.wrap(payload));
return webClient.execute(request).aggregate(
return webClient.execute(request, REQUEST_OPTIONS).aggregate(
AggregationOptions.builder()
.usePooledObjects(PooledByteBufAllocator.DEFAULT)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import java.util.concurrent.CompletableFuture;

import com.linecorp.armeria.client.RequestOptions;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.common.ExchangeType;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.QueryParams;
import com.linecorp.armeria.common.auth.oauth2.OAuth2Request;
Expand All @@ -29,6 +31,11 @@
*/
public final class OAuth2Endpoint<T> {

private static final RequestOptions UNARY_REQUEST_OPTIONS =
RequestOptions.builder()
.exchangeType(ExchangeType.UNARY)
.build();

private final WebClient endpoint;
private final String endpointPath;
private final OAuth2ResponseHandler<T> responseHandler;
Expand All @@ -43,7 +50,7 @@ public OAuth2Endpoint(WebClient endpoint, String endpointPath,
public CompletableFuture<T> execute(OAuth2Request oAuth2Request) {
final HttpRequest request = oAuth2Request.asHttpRequest(endpointPath);
final QueryParams requestParams = oAuth2Request.bodyParams();
return endpoint.execute(request)
return endpoint.execute(request, UNARY_REQUEST_OPTIONS)
.aggregate()
.thenApply(response -> responseHandler.handle(response, requestParams));
}
Expand Down

0 comments on commit 369614f

Please sign in to comment.