Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug where a connection may be not reused when using RetryingClient #5290

Merged
merged 11 commits into from
Nov 8, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private static SessionProtocol desiredSessionProtocol(SessionProtocol protocol,
* e.g. {@code System.currentTimeMillis() * 1000}.
*/
public DefaultClientRequestContext(
EventLoop eventLoop, MeterRegistry meterRegistry, SessionProtocol sessionProtocol,
@Nullable EventLoop eventLoop, MeterRegistry meterRegistry, SessionProtocol sessionProtocol,
RequestId id, HttpMethod method, RequestTarget reqTarget,
ClientOptions options, @Nullable HttpRequest req, @Nullable RpcRequest rpcReq,
RequestOptions requestOptions, CancellationScheduler responseCancellationScheduler,
Expand Down Expand Up @@ -511,7 +511,6 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx,
// So we don't check the nullness of rpcRequest unlike request.
// See https://github.com/line/armeria/pull/3251 and https://github.com/line/armeria/issues/3248.

eventLoop = ctx.eventLoop().withoutContext();
options = ctx.options();
root = ctx.root();

Expand All @@ -531,6 +530,13 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx,

this.endpointGroup = endpointGroup;
updateEndpoint(endpoint);
// We don't need to acquire an EventLoop for the initial attempt because it's already acquired by
// the root context.
if (endpoint == null || ctx.endpoint() == endpoint && ctx.log.children().isEmpty()) {
eventLoop = ctx.eventLoop().withoutContext();
} else {
acquireEventLoop(endpoint);
}
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.client.retry;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.assertj.core.api.Assertions.assertThat;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.linecorp.armeria.client.BlockingWebClient;
import com.linecorp.armeria.client.ClientRequestContext;
import com.linecorp.armeria.client.ClientRequestContextCaptor;
import com.linecorp.armeria.client.Clients;
import com.linecorp.armeria.client.Endpoint;
import com.linecorp.armeria.client.WebClient;
import com.linecorp.armeria.client.endpoint.EndpointGroup;
import com.linecorp.armeria.common.AggregatedHttpResponse;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.logging.RequestLogAccess;
import com.linecorp.armeria.internal.testing.AnticipatedException;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import io.netty.channel.EventLoop;

class RetryingClientEventLoopSchedulerTest {

@RegisterExtension
static final ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
sb.http(0);
sb.http(0);
sb.http(0);
sb.service("/fail", (ctx, req) -> {
throw new AnticipatedException();
});
sb.service("/ok", (ctx, req) -> {
return HttpResponse.of(200);
});
}
};

@Test
void shouldReturnCorrectEventLoop() {
final List<Endpoint> endpoints = server.server().activePorts().values().stream()
.map(port -> Endpoint.of(port.localAddress()))
.collect(toImmutableList());
assertThat(endpoints).hasSize(3);
final Map<Endpoint, EventLoop> eventLoopMapping = new HashMap<>();

for (Endpoint endpoint : endpoints) {
// Acquire the event loops for each endpoint.
try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) {
final AggregatedHttpResponse res = WebClient.of(SessionProtocol.H2C, endpoint)
.blocking()
.get("/ok");
assertThat(res.status()).isEqualTo(HttpStatus.OK);
eventLoopMapping.put(endpoint, captor.get().eventLoop().withoutContext());
}
}

// Check that the event loops are correctly mapped for each attempt.
final EndpointGroup endpointGroup = EndpointGroup.of(endpoints);
final RetryRule retryRule = RetryRule.builder()
.onServerErrorStatus()
.thenBackoff(Backoff.withoutDelay());
final BlockingWebClient client =
WebClient.builder(SessionProtocol.H2C, endpointGroup)
// Make retries until the maxTotalAttempts is reached.
.responseTimeoutMillis(0)
.decorator(RetryingClient.newDecorator(
RetryConfig.builder(retryRule)
.maxTotalAttempts(6)
.build()))
.build()
.blocking();
try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) {
assertThat(client.get("/fail").status()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR);
final List<RequestLogAccess> children = captor.get().log().children();
assertThat(children.size()).isEqualTo(6);
for (int i = 0; i < 6; i++) {
final ClientRequestContext childCtx = (ClientRequestContext) children.get(i).context();
assertThat(childCtx.eventLoop().withoutContext())
.isSameAs(eventLoopMapping.get(childCtx.endpoint()));
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.internal.client;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import com.linecorp.armeria.client.ClientOptions;
import com.linecorp.armeria.client.ClientRequestContext;
import com.linecorp.armeria.client.Endpoint;
import com.linecorp.armeria.client.RequestOptions;
import com.linecorp.armeria.client.endpoint.DynamicEndpointGroup;
import com.linecorp.armeria.client.endpoint.EndpointSelectionStrategy;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.RequestId;
import com.linecorp.armeria.common.RequestTarget;
import com.linecorp.armeria.common.SessionProtocol;

import io.micrometer.core.instrument.simple.SimpleMeterRegistry;

class DerivedClientRequestContextClientTest {

private final Endpoint endpointA = Endpoint.of("a.com", 8080);
private final Endpoint endpointB = Endpoint.of("b.com", 8080);
private final Endpoint endpointC = Endpoint.of("c.com", 8080);
private SettableEndpointGroup group;

@BeforeEach
void setUp() {
group = new SettableEndpointGroup();
group.add(endpointA);
group.add(endpointB);
group.add(endpointC);
}

@Test
void shouldAcquireNewEventLoopForNewEndpoint() {
final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/");
final DefaultClientRequestContext parent = new DefaultClientRequestContext(
new SimpleMeterRegistry(), SessionProtocol.H2C, RequestId.random(), HttpMethod.GET,
RequestTarget.forClient("/"), ClientOptions.of(), request, null, RequestOptions.of(), 0, 0);
parent.init(group);
assertThat(parent.endpoint()).isEqualTo(endpointA);
final ClientRequestContext child =
ClientUtil.newDerivedContext(parent, request, null, false);
assertThat(child.endpoint()).isEqualTo(endpointB);
assertThat(parent.endpoint()).isNotSameAs(child.endpoint());
assertThat(parent.eventLoop().withoutContext()).isNotSameAs(child.eventLoop().withoutContext());
}

@Test
void shouldAcquireSameEventLoopForSameEndpoint() {
final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/");
final DefaultClientRequestContext parent = new DefaultClientRequestContext(
new SimpleMeterRegistry(), SessionProtocol.H2C, RequestId.random(), HttpMethod.GET,
RequestTarget.forClient("/"), ClientOptions.of(), request, null, RequestOptions.of(), 0, 0);
parent.init(group);
assertThat(parent.endpoint()).isEqualTo(endpointA);
final ClientRequestContext childA0 =
ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, true);
assertThat(childA0.endpoint()).isEqualTo(endpointA);
final ClientRequestContext childB0 =
ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, false);
assertThat(childB0.endpoint()).isEqualTo(endpointB);
final ClientRequestContext childC0 =
ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, false);
assertThat(childC0.endpoint()).isEqualTo(endpointC);

for (int i = 0; i < 3; i++) {
final ClientRequestContext childA1 =
ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, false);
assertThat(childA1.endpoint()).isEqualTo(endpointA);
assertThat(childA1.eventLoop().withoutContext()).isSameAs(childA0.eventLoop().withoutContext());
final ClientRequestContext childB1 =
ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, false);
assertThat(childB1.endpoint()).isEqualTo(endpointB);
assertThat(childB1.eventLoop().withoutContext()).isSameAs(childB0.eventLoop().withoutContext());
final ClientRequestContext childC1 =
ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, false);
assertThat(childC1.endpoint()).isEqualTo(endpointC);
assertThat(childC1.eventLoop().withoutContext()).isSameAs(childC0.eventLoop().withoutContext());
}
}

@Test
void shouldNotAcquireNewEventLoopForInitialAttempt() {
final HttpRequest request = HttpRequest.of(HttpMethod.GET, "/");
final DefaultClientRequestContext parent = new DefaultClientRequestContext(
new SimpleMeterRegistry(), SessionProtocol.H2C, RequestId.random(), HttpMethod.GET,
RequestTarget.forClient("/"), ClientOptions.of(), request, null, RequestOptions.of(), 0, 0);
parent.init(group);
assertThat(parent.endpoint()).isEqualTo(endpointA);
final ClientRequestContext child =
ClientUtil.newDerivedContext(parent, HttpRequest.of(HttpMethod.GET, "/"), null, true);
assertThat(child.endpoint()).isEqualTo(endpointA);
assertThat(parent.endpoint()).isSameAs(child.endpoint());
assertThat(parent.eventLoop().withoutContext()).isSameAs(child.eventLoop().withoutContext());
}

private static class SettableEndpointGroup extends DynamicEndpointGroup {

SettableEndpointGroup() {
super(EndpointSelectionStrategy.roundRobin());
}

void add(Endpoint endpoint) {
addEndpoint(endpoint);
}
}
}