Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into feature/client-contex…
Browse files Browse the repository at this point in the history
…t-cancel
  • Loading branch information
jrhee17 committed Aug 2, 2024
2 parents 06012a1 + 4c6f0ff commit 88c4c18
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.linecorp.armeria.common.RequestId;
import com.linecorp.armeria.common.Response;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.TimeoutException;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.logging.RequestLog;
Expand Down Expand Up @@ -514,6 +515,21 @@ default void timeoutNow() {
cancel(ResponseTimeoutException.get());
}

/**
* Returns whether this {@link ClientRequestContext} has been timed-out, that is the cancellation cause
* is an instance of {@link TimeoutException} or
* {@link UnprocessedRequestException} and wrapped cause is {@link TimeoutException}.
*/
@Override
default boolean isTimedOut() {
if (RequestContext.super.isTimedOut()) {
return true;
}
final Throwable cause = cancellationCause();
return cause instanceof TimeoutException ||
cause instanceof UnprocessedRequestException && cause.getCause() instanceof TimeoutException;
}

/**
* Returns the maximum length of the received {@link Response}.
* This value is initially set from {@link ClientOptions#MAX_RESPONSE_LENGTH}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public final class SessionProtocolNegotiationException extends RuntimeException
* Creates a new instance with the specified expected {@link SessionProtocol}.
*/
public SessionProtocolNegotiationException(SessionProtocol expected, @Nullable String reason) {
super("expected: " + requireNonNull(expected, "expected") + ", reason: " + reason);
super(appendReason("expected: " + requireNonNull(expected, "expected"), reason));
this.expected = expected;
actual = null;
}
Expand All @@ -48,8 +48,8 @@ public SessionProtocolNegotiationException(SessionProtocol expected, @Nullable S
public SessionProtocolNegotiationException(SessionProtocol expected,
@Nullable SessionProtocol actual, @Nullable String reason) {

super("expected: " + requireNonNull(expected, "expected") +
", actual: " + requireNonNull(actual, "actual") + ", reason: " + reason);
super(appendReason("expected: " + requireNonNull(expected, "expected") +
", actual: " + actual, reason));
this.expected = expected;
this.actual = actual;
}
Expand Down Expand Up @@ -78,4 +78,11 @@ public Throwable fillInStackTrace() {
}
return this;
}

private static String appendReason(String message, @Nullable String reason) {
if (reason == null) {
return message;
}
return message + ", reason: " + reason;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.util.function.Function;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
import org.junit.jupiter.params.provider.ValueSource;

import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.TimeoutException;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.server.ServiceRequestContext;
Expand Down Expand Up @@ -276,6 +282,24 @@ void updateRequestWithInvalidPath(String path) {
.hasMessageContaining("invalid path");
}

@ParameterizedTest
@ArgumentsSource(TimedOutExceptionProvider.class)
void isTimedOut_true(Throwable cause) {
final ClientRequestContext cctx = clientRequestContext();
cctx.cancel(cause);
cctx.whenResponseCancelled().join();
assertThat(cctx.isTimedOut()).isTrue();
}

@ParameterizedTest
@ArgumentsSource(NotTimedOutExceptionProvider.class)
void isTimedOut_false(Throwable cause) {
final ClientRequestContext cctx = clientRequestContext();
cctx.cancel(cause);
cctx.whenResponseCancelled().join();
assertThat(cctx.isTimedOut()).isFalse();
}

private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) {
final RequestContext current = RequestContext.currentOrNull();
if (current == null) {
Expand All @@ -292,4 +316,25 @@ private static ServiceRequestContext serviceRequestContext() {
private static ClientRequestContext clientRequestContext() {
return ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/"));
}

private static class TimedOutExceptionProvider implements ArgumentsProvider {

@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(new TimeoutException(),
ResponseTimeoutException.get(),
UnprocessedRequestException.of(ResponseTimeoutException.get()))
.map(Arguments::of);
}
}

private static class NotTimedOutExceptionProvider implements ArgumentsProvider {

@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(new RuntimeException(),
UnprocessedRequestException.of(new RuntimeException()))
.map(Arguments::of);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,39 @@

package com.linecorp.armeria.client;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.awaitility.Awaitility.await;

import java.util.concurrent.CompletionException;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.google.common.collect.ImmutableMap;

import com.linecorp.armeria.client.endpoint.dns.TestDnsServer;
import com.linecorp.armeria.common.CommonPools;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.metric.PrometheusMeterRegistries;
import com.linecorp.armeria.server.AbstractHttpService;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.dns.DatagramDnsQuery;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.resolver.dns.DnsServerAddressStreamProvider;
import io.netty.resolver.dns.DnsServerAddresses;
import io.netty.util.ReferenceCountUtil;

class HttpClientFactoryTest {
@RegisterExtension
public static final ServerExtension server = new ServerExtension() {
Expand Down Expand Up @@ -103,4 +121,61 @@ protected HttpResponse doGet(ServiceRequestContext ctx,
});
}
}

@Test
void execute_dnsTimeout_clientRequestContext_isTimedOut() {
try (TestDnsServer dnsServer = new TestDnsServer(ImmutableMap.of(), new AlwaysTimeoutHandler())) {
try (RefreshingAddressResolverGroup group = dnsTimeoutBuilder(dnsServer)
.build(CommonPools.workerGroup().next())) {
final ClientFactory clientFactory = ClientFactory
.builder()
.addressResolverGroupFactory(eventExecutors -> group)
.build();
final Endpoint endpoint = Endpoint
.of("test")
.withIpAddr(null); // to invoke dns resolve address
final WebClient client = WebClient
.builder(endpoint.toUri(SessionProtocol.H1C))
.factory(clientFactory)
.build();

try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) {
assertThatThrownBy(() -> client.get("/").aggregate().join())
.isInstanceOf(CompletionException.class)
.hasCauseInstanceOf(UnprocessedRequestException.class)
.hasRootCauseInstanceOf(DnsTimeoutException.class);
captor.get().whenResponseCancelled().join();
assertThat(captor.get().isTimedOut()).isTrue();
}

clientFactory.close();
endpoint.close();
}
}
}

private static class AlwaysTimeoutHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof DatagramDnsQuery) {
// Just release the msg and return so that the client request is timed out.
ReferenceCountUtil.safeRelease(msg);
return;
}
super.channelRead(ctx, msg);
}
}

private static DnsResolverGroupBuilder dnsTimeoutBuilder(TestDnsServer... servers) {
final DnsServerAddressStreamProvider dnsServerAddressStreamProvider =
hostname -> DnsServerAddresses.sequential(
Stream.of(servers).map(TestDnsServer::addr).collect(toImmutableList())).stream();
final DnsResolverGroupBuilder builder = new DnsResolverGroupBuilder()
.serverAddressStreamProvider(dnsServerAddressStreamProvider)
.meterRegistry(PrometheusMeterRegistries.newRegistry())
.resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY)
.traceEnabled(false)
.queryTimeoutMillis(1); // dns timeout
return builder;
}
}

0 comments on commit 88c4c18

Please sign in to comment.