From b58d2cdae81ce3b465c18ef9fb7d1f84d4e2f0b2 Mon Sep 17 00:00:00 2001 From: jrhee17 Date: Wed, 4 Dec 2024 22:36:20 +0900 Subject: [PATCH] more cleanups --- .../client/AbstractWebClientBuilder.java | 17 +----- .../com/linecorp/armeria/client/Clients.java | 11 +++- .../linecorp/armeria/client/WebClient.java | 18 +++++- .../armeria/client/WebClientBuilder.java | 3 +- .../client/EndpointGroupExecutionFactory.java | 15 +++++ ...oryTest.java => ExecutionFactoryTest.java} | 58 ++++++++++++++++++- 6 files changed, 101 insertions(+), 21 deletions(-) rename core/src/test/java/com/linecorp/armeria/client/{CustomExecutionFactoryTest.java => ExecutionFactoryTest.java} (54%) diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java index 45b9b375984..2a2945845d6 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractWebClientBuilder.java @@ -69,8 +69,7 @@ protected AbstractWebClientBuilder(URI uri) { protected AbstractWebClientBuilder(SessionProtocol sessionProtocol, EndpointGroup endpointGroup, @Nullable String path) { this(null, null, new EndpointGroupExecutionFactory( - validateSessionProtocol(sessionProtocol), - requireNonNull(endpointGroup, "endpointGroup")), path); + sessionProtocol, requireNonNull(endpointGroup, "endpointGroup")), path); } /** @@ -83,13 +82,9 @@ protected AbstractWebClientBuilder(Scheme scheme, new EndpointGroupExecutionFactory(scheme.sessionProtocol(), endpointGroup), path); } - /** - * Creates a new instance. - */ - AbstractWebClientBuilder(SerializationFormat serializationFormat, - RequestExecutionFactory executionFactory, + AbstractWebClientBuilder(RequestExecutionFactory executionFactory, @Nullable String path) { - this(null, serializationFormat, executionFactory, path); + this(null, SerializationFormat.NONE, executionFactory, path); } /** @@ -124,12 +119,6 @@ private static URI validateUri(URI uri) { return URI.create(scheme.uriText() + uri.toString().substring(givenScheme.length())); } - private static SessionProtocol validateSessionProtocol(SessionProtocol sessionProtocol) { - requireNonNull(sessionProtocol, "sessionProtocol"); - validateScheme(sessionProtocol.uriText()); - return sessionProtocol; - } - private static Scheme validateScheme(String scheme) { final Scheme parsedScheme = Scheme.tryParse(scheme); if (parsedScheme != null) { diff --git a/core/src/main/java/com/linecorp/armeria/client/Clients.java b/core/src/main/java/com/linecorp/armeria/client/Clients.java index bf5d4792e1e..f13115a4771 100644 --- a/core/src/main/java/com/linecorp/armeria/client/Clients.java +++ b/core/src/main/java/com/linecorp/armeria/client/Clients.java @@ -231,6 +231,15 @@ public static ClientBuilder builder(SessionProtocol protocol, EndpointGroup endp endpointGroup, path); } + /** + * Returns a new {@link ClientBuilder} that builds the client that connects to the specified + * {@link RequestExecutionFactory} with the specified {@link SerializationFormat} and {@code path}. + */ + public static ClientBuilder builder(SerializationFormat serializationFormat, + RequestExecutionFactory executionFactory) { + return new ClientBuilder(serializationFormat, executionFactory, null); + } + /** * Returns a new {@link ClientBuilder} that builds the client that connects to the specified * {@link RequestExecutionFactory} with the specified {@link SerializationFormat} and {@code path}. @@ -238,7 +247,7 @@ public static ClientBuilder builder(SessionProtocol protocol, EndpointGroup endp public static ClientBuilder builder(SerializationFormat serializationFormat, RequestExecutionFactory executionFactory, String path) { - return new ClientBuilder(serializationFormat, executionFactory, path); + return new ClientBuilder(serializationFormat, executionFactory, requireNonNull(path, "path")); } /** diff --git a/core/src/main/java/com/linecorp/armeria/client/WebClient.java b/core/src/main/java/com/linecorp/armeria/client/WebClient.java index c5553ad6295..ecb6148aa2d 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebClient.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebClient.java @@ -96,7 +96,7 @@ static WebClient of(String protocol, EndpointGroup endpointGroup) { } /** - * Returns a new {@link WebClient} that connects to the specified {@link EndpointGroup} with + * Returns a new {@link WebClient} that connects to the specified {@link RequestExecutionFactory} with * the specified {@code protocol} using the default {@link ClientFactory} and the default * {@link ClientOptions}. * @@ -160,6 +160,22 @@ static WebClient of(SessionProtocol protocol, EndpointGroup endpointGroup, Strin return builder(protocol, endpointGroup, path).build(); } + /** + * Returns a new {@link WebClient} that connects to the specified {@link RequestExecutionFactory} with + * the specified {@link SessionProtocol} and {@code path} using the default {@link ClientFactory} and + * the default {@link ClientOptions}. + * + * @param executionFactory the server {@link RequestExecutionFactory} + * @param path the path to the endpoint + * + * @throws IllegalArgumentException if the {@code protocol} is not one of the values in + * {@link SessionProtocol#httpValues()} or + * {@link SessionProtocol#httpsValues()}. + */ + static WebClient of(RequestExecutionFactory executionFactory, String path) { + return builder(executionFactory, path).build(); + } + /** * Returns a new {@link WebClientBuilder} created without a base {@link URI}. */ diff --git a/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java b/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java index d5f26787dba..a6baa8f3361 100644 --- a/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/WebClientBuilder.java @@ -26,7 +26,6 @@ import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.redirect.RedirectConfig; import com.linecorp.armeria.common.RequestId; -import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.SuccessFunction; import com.linecorp.armeria.common.annotation.Nullable; @@ -74,7 +73,7 @@ public final class WebClientBuilder extends AbstractWebClientBuilder { * in {@link SessionProtocol} */ WebClientBuilder(RequestExecutionFactory executionFactory, @Nullable String path) { - super(SerializationFormat.NONE, executionFactory, path); + super(executionFactory, path); } /** diff --git a/core/src/main/java/com/linecorp/armeria/internal/client/EndpointGroupExecutionFactory.java b/core/src/main/java/com/linecorp/armeria/internal/client/EndpointGroupExecutionFactory.java index 435b841089e..ed677755b0c 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/client/EndpointGroupExecutionFactory.java +++ b/core/src/main/java/com/linecorp/armeria/internal/client/EndpointGroupExecutionFactory.java @@ -16,6 +16,11 @@ package com.linecorp.armeria.internal.client; +import static com.google.common.base.Preconditions.checkArgument; +import static com.linecorp.armeria.common.SessionProtocol.httpAndHttpsValues; + +import com.google.common.base.MoreObjects; + import com.linecorp.armeria.client.ClientOptions; import com.linecorp.armeria.client.RequestExecution; import com.linecorp.armeria.client.RequestExecutionFactory; @@ -33,6 +38,8 @@ public final class EndpointGroupExecutionFactory implements RequestExecutionFact private final EndpointGroup endpointGroup; public EndpointGroupExecutionFactory(SessionProtocol sessionProtocol, EndpointGroup endpointGroup) { + checkArgument(httpAndHttpsValues().contains(sessionProtocol), + "sessionProtocol: '%s' (expected: one of '%s'", sessionProtocol, httpAndHttpsValues()); this.sessionProtocol = sessionProtocol; this.endpointGroup = endpointGroup; } @@ -55,4 +62,12 @@ public EndpointGroup endpointGroup() { public SessionProtocol sessionProtocol() { return sessionProtocol; } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("sessionProtocol", sessionProtocol) + .add("endpointGroup", endpointGroup) + .toString(); + } } diff --git a/core/src/test/java/com/linecorp/armeria/client/CustomExecutionFactoryTest.java b/core/src/test/java/com/linecorp/armeria/client/ExecutionFactoryTest.java similarity index 54% rename from core/src/test/java/com/linecorp/armeria/client/CustomExecutionFactoryTest.java rename to core/src/test/java/com/linecorp/armeria/client/ExecutionFactoryTest.java index be64ebc2376..72238a19730 100644 --- a/core/src/test/java/com/linecorp/armeria/client/CustomExecutionFactoryTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ExecutionFactoryTest.java @@ -22,9 +22,13 @@ import java.util.HashSet; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import com.google.common.collect.Lists; @@ -33,6 +37,7 @@ import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.RequestTarget; import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.server.ServerBuilder; @@ -43,7 +48,10 @@ import io.netty.channel.EventLoop; import io.netty.util.concurrent.EventExecutor; -class CustomExecutionFactoryTest { +class ExecutionFactoryTest { + + private static final ClientOptionValue CUSTOM_CLIENT_OPTION = + ClientOptions.MAX_RESPONSE_LENGTH.newValue(Long.MAX_VALUE); @RegisterExtension static ServerExtension server = new ServerExtension() { @@ -54,12 +62,33 @@ protected void configure(ServerBuilder sb) throws Exception { assert channel != null; return HttpResponse.of(channel.id().asShortText()); }); + sb.service("/webClient", (ctx, req) -> HttpResponse.of("/webClient")); + sb.service("/prefix/webClient", (ctx, req) -> HttpResponse.of("/prefix/webClient")); } }; @RegisterExtension static EventLoopGroupExtension eventLoopGroup = new EventLoopGroupExtension(4); + private static RequestExecutionFactory executionFactory() { + return new RequestExecutionFactory() { + @Override + public RequestExecution prepare(HttpRequest httpRequest, + @Nullable RpcRequest rpcRequest, RequestTarget requestTarget, + RequestOptions requestOptions, ClientOptions clientOptions) { + final ClientRequestContext ctx = + ClientRequestContext + .builder(httpRequest, rpcRequest, requestTarget) + .requestOptions(requestOptions) + .endpointGroup(server.httpEndpoint()) + .options(clientOptions) + .sessionProtocol(SessionProtocol.HTTP) + .build(); + return RequestExecution.of(ctx, server.httpEndpoint()); + } + }; + } + @Test void specifyEventLoops() { final AtomicInteger counter = new AtomicInteger(); @@ -67,8 +96,8 @@ void specifyEventLoops() { final WebClient client = WebClient.of(new RequestExecutionFactory() { @Override public RequestExecution prepare(HttpRequest httpRequest, - @Nullable RpcRequest rpcRequest, RequestTarget requestTarget, - RequestOptions requestOptions, ClientOptions clientOptions) { + @Nullable RpcRequest rpcRequest, RequestTarget requestTarget, + RequestOptions requestOptions, ClientOptions clientOptions) { final ClientRequestContext ctx = ClientRequestContext .builder(httpRequest, rpcRequest, requestTarget) @@ -89,4 +118,27 @@ public RequestExecution prepare(HttpRequest httpRequest, } assertThat(channelIds).hasSize(4); } + + private static Stream webClientCompatArgs() { + final RequestExecutionFactory executionFactory = executionFactory(); + + return Stream.of( + Arguments.of(WebClient.of(executionFactory), "/webClient"), + Arguments.of(Clients.newDerivedClient(WebClient.of(executionFactory), + CUSTOM_CLIENT_OPTION), "/webClient"), + Arguments.of(WebClient.of(executionFactory, "/prefix"), "/prefix/webClient"), + Arguments.of(Clients.builder(SerializationFormat.NONE, executionFactory) + .build(WebClient.class), "/webClient"), + Arguments.of(Clients.builder(SerializationFormat.NONE, executionFactory, "/prefix") + .build(WebClient.class), "/prefix/webClient") + ); + } + + @ParameterizedTest + @MethodSource("webClientCompatArgs") + void webClientCompat(WebClient webClient, String expected) { + final AggregatedHttpResponse res = webClient.blocking().get("/webClient"); + assertThat(res.status().code()).isEqualTo(200); + assertThat(res.contentUtf8()).isEqualTo(expected); + } }