Skip to content

Commit

Permalink
more cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jrhee17 committed Dec 4, 2024
1 parent e9db243 commit b58d2cd
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ protected AbstractWebClientBuilder(URI uri) {
protected AbstractWebClientBuilder(SessionProtocol sessionProtocol, EndpointGroup endpointGroup,
@Nullable String path) {
this(null, null, new EndpointGroupExecutionFactory(
validateSessionProtocol(sessionProtocol),
requireNonNull(endpointGroup, "endpointGroup")), path);
sessionProtocol, requireNonNull(endpointGroup, "endpointGroup")), path);
}

/**
Expand All @@ -83,13 +82,9 @@ protected AbstractWebClientBuilder(Scheme scheme,
new EndpointGroupExecutionFactory(scheme.sessionProtocol(), endpointGroup), path);
}

/**
* Creates a new instance.
*/
AbstractWebClientBuilder(SerializationFormat serializationFormat,
RequestExecutionFactory executionFactory,
AbstractWebClientBuilder(RequestExecutionFactory executionFactory,
@Nullable String path) {
this(null, serializationFormat, executionFactory, path);
this(null, SerializationFormat.NONE, executionFactory, path);
}

/**
Expand Down Expand Up @@ -124,12 +119,6 @@ private static URI validateUri(URI uri) {
return URI.create(scheme.uriText() + uri.toString().substring(givenScheme.length()));
}

private static SessionProtocol validateSessionProtocol(SessionProtocol sessionProtocol) {
requireNonNull(sessionProtocol, "sessionProtocol");
validateScheme(sessionProtocol.uriText());
return sessionProtocol;
}

private static Scheme validateScheme(String scheme) {
final Scheme parsedScheme = Scheme.tryParse(scheme);
if (parsedScheme != null) {
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/com/linecorp/armeria/client/Clients.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,23 @@ public static ClientBuilder builder(SessionProtocol protocol, EndpointGroup endp
endpointGroup, path);
}

/**
* Returns a new {@link ClientBuilder} that builds the client that connects to the specified
* {@link RequestExecutionFactory} with the specified {@link SerializationFormat} and {@code path}.
*/
public static ClientBuilder builder(SerializationFormat serializationFormat,
RequestExecutionFactory executionFactory) {
return new ClientBuilder(serializationFormat, executionFactory, null);
}

/**
* Returns a new {@link ClientBuilder} that builds the client that connects to the specified
* {@link RequestExecutionFactory} with the specified {@link SerializationFormat} and {@code path}.
*/
public static ClientBuilder builder(SerializationFormat serializationFormat,
RequestExecutionFactory executionFactory,
String path) {
return new ClientBuilder(serializationFormat, executionFactory, path);
return new ClientBuilder(serializationFormat, executionFactory, requireNonNull(path, "path"));
}

/**
Expand Down
18 changes: 17 additions & 1 deletion core/src/main/java/com/linecorp/armeria/client/WebClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ static WebClient of(String protocol, EndpointGroup endpointGroup) {
}

/**
* Returns a new {@link WebClient} that connects to the specified {@link EndpointGroup} with
* Returns a new {@link WebClient} that connects to the specified {@link RequestExecutionFactory} with
* the specified {@code protocol} using the default {@link ClientFactory} and the default
* {@link ClientOptions}.
*
Expand Down Expand Up @@ -160,6 +160,22 @@ static WebClient of(SessionProtocol protocol, EndpointGroup endpointGroup, Strin
return builder(protocol, endpointGroup, path).build();
}

/**
* Returns a new {@link WebClient} that connects to the specified {@link RequestExecutionFactory} with
* the specified {@link SessionProtocol} and {@code path} using the default {@link ClientFactory} and
* the default {@link ClientOptions}.
*
* @param executionFactory the server {@link RequestExecutionFactory}
* @param path the path to the endpoint
*
* @throws IllegalArgumentException if the {@code protocol} is not one of the values in
* {@link SessionProtocol#httpValues()} or
* {@link SessionProtocol#httpsValues()}.
*/
static WebClient of(RequestExecutionFactory executionFactory, String path) {
return builder(executionFactory, path).build();
}

/**
* Returns a new {@link WebClientBuilder} created without a base {@link URI}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.linecorp.armeria.client.endpoint.EndpointGroup;
import com.linecorp.armeria.client.redirect.RedirectConfig;
import com.linecorp.armeria.common.RequestId;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.SuccessFunction;
import com.linecorp.armeria.common.annotation.Nullable;
Expand Down Expand Up @@ -74,7 +73,7 @@ public final class WebClientBuilder extends AbstractWebClientBuilder {
* in {@link SessionProtocol}
*/
WebClientBuilder(RequestExecutionFactory executionFactory, @Nullable String path) {
super(SerializationFormat.NONE, executionFactory, path);
super(executionFactory, path);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

package com.linecorp.armeria.internal.client;

import static com.google.common.base.Preconditions.checkArgument;
import static com.linecorp.armeria.common.SessionProtocol.httpAndHttpsValues;

import com.google.common.base.MoreObjects;

import com.linecorp.armeria.client.ClientOptions;
import com.linecorp.armeria.client.RequestExecution;
import com.linecorp.armeria.client.RequestExecutionFactory;
Expand All @@ -33,6 +38,8 @@ public final class EndpointGroupExecutionFactory implements RequestExecutionFact
private final EndpointGroup endpointGroup;

public EndpointGroupExecutionFactory(SessionProtocol sessionProtocol, EndpointGroup endpointGroup) {
checkArgument(httpAndHttpsValues().contains(sessionProtocol),
"sessionProtocol: '%s' (expected: one of '%s'", sessionProtocol, httpAndHttpsValues());
this.sessionProtocol = sessionProtocol;
this.endpointGroup = endpointGroup;
}
Expand All @@ -55,4 +62,12 @@ public EndpointGroup endpointGroup() {
public SessionProtocol sessionProtocol() {
return sessionProtocol;
}

@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("sessionProtocol", sessionProtocol)
.add("endpointGroup", endpointGroup)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import com.google.common.collect.Lists;

Expand All @@ -33,6 +37,7 @@
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.RequestTarget;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.server.ServerBuilder;
Expand All @@ -43,7 +48,10 @@
import io.netty.channel.EventLoop;
import io.netty.util.concurrent.EventExecutor;

class CustomExecutionFactoryTest {
class ExecutionFactoryTest {

private static final ClientOptionValue<Long> CUSTOM_CLIENT_OPTION =
ClientOptions.MAX_RESPONSE_LENGTH.newValue(Long.MAX_VALUE);

@RegisterExtension
static ServerExtension server = new ServerExtension() {
Expand All @@ -54,21 +62,42 @@ protected void configure(ServerBuilder sb) throws Exception {
assert channel != null;
return HttpResponse.of(channel.id().asShortText());
});
sb.service("/webClient", (ctx, req) -> HttpResponse.of("/webClient"));
sb.service("/prefix/webClient", (ctx, req) -> HttpResponse.of("/prefix/webClient"));
}
};

@RegisterExtension
static EventLoopGroupExtension eventLoopGroup = new EventLoopGroupExtension(4);

private static RequestExecutionFactory executionFactory() {
return new RequestExecutionFactory() {
@Override
public RequestExecution prepare(HttpRequest httpRequest,
@Nullable RpcRequest rpcRequest, RequestTarget requestTarget,
RequestOptions requestOptions, ClientOptions clientOptions) {
final ClientRequestContext ctx =
ClientRequestContext
.builder(httpRequest, rpcRequest, requestTarget)
.requestOptions(requestOptions)
.endpointGroup(server.httpEndpoint())
.options(clientOptions)
.sessionProtocol(SessionProtocol.HTTP)
.build();
return RequestExecution.of(ctx, server.httpEndpoint());
}
};
}

@Test
void specifyEventLoops() {
final AtomicInteger counter = new AtomicInteger();
final ArrayList<EventExecutor> eventExecutors = Lists.newArrayList(eventLoopGroup.get().iterator());
final WebClient client = WebClient.of(new RequestExecutionFactory() {
@Override
public RequestExecution prepare(HttpRequest httpRequest,
@Nullable RpcRequest rpcRequest, RequestTarget requestTarget,
RequestOptions requestOptions, ClientOptions clientOptions) {
@Nullable RpcRequest rpcRequest, RequestTarget requestTarget,
RequestOptions requestOptions, ClientOptions clientOptions) {
final ClientRequestContext ctx =
ClientRequestContext
.builder(httpRequest, rpcRequest, requestTarget)
Expand All @@ -89,4 +118,27 @@ public RequestExecution prepare(HttpRequest httpRequest,
}
assertThat(channelIds).hasSize(4);
}

private static Stream<Arguments> webClientCompatArgs() {
final RequestExecutionFactory executionFactory = executionFactory();

return Stream.of(
Arguments.of(WebClient.of(executionFactory), "/webClient"),
Arguments.of(Clients.newDerivedClient(WebClient.of(executionFactory),
CUSTOM_CLIENT_OPTION), "/webClient"),
Arguments.of(WebClient.of(executionFactory, "/prefix"), "/prefix/webClient"),
Arguments.of(Clients.builder(SerializationFormat.NONE, executionFactory)
.build(WebClient.class), "/webClient"),
Arguments.of(Clients.builder(SerializationFormat.NONE, executionFactory, "/prefix")
.build(WebClient.class), "/prefix/webClient")
);
}

@ParameterizedTest
@MethodSource("webClientCompatArgs")
void webClientCompat(WebClient webClient, String expected) {
final AggregatedHttpResponse res = webClient.blocking().get("/webClient");
assertThat(res.status().code()).isEqualTo(200);
assertThat(res.contentUtf8()).isEqualTo(expected);
}
}

0 comments on commit b58d2cd

Please sign in to comment.