diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java index f5e27a5935d..39f3acf156d 100644 --- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/THttpServiceBuilder.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Executors; import java.util.function.BiFunction; import java.util.function.Function; @@ -84,6 +85,7 @@ public final class THttpServiceBuilder { // -1 means to use the default request length of the Server. private int maxRequestStringLength = -1; private int maxRequestContainerLength = -1; + private boolean useBlockingTaskExecutor; THttpServiceBuilder() {} @@ -190,6 +192,17 @@ public THttpServiceBuilder maxRequestContainerLength(int maxRequestContainerLeng return this; } + /** + * Sets whether the service executes service methods using the blocking executor. By default, service + * methods are executed directly on the event loop for implementing fully asynchronous services. If your + * service uses blocking logic, you should either execute such logic in a separate thread using something + * like {@link Executors#newCachedThreadPool()} or enable this setting. + */ + public THttpServiceBuilder useBlockingTaskExecutor(boolean useBlockingTaskExecutor) { + this.useBlockingTaskExecutor = useBlockingTaskExecutor; + return this; + } + /** * Sets the {@link BiFunction} that returns an {@link RpcResponse} using the given {@link Throwable} * and {@link ServiceRequestContext}. @@ -225,10 +238,11 @@ private RpcService decorate(RpcService service) { * Builds a new instance of {@link THttpService}. */ public THttpService build() { - @SuppressWarnings("UnstableApiUsage") final Map> implementations = Multimaps.asMap(implementationsBuilder.build()); - - final ThriftCallService tcs = ThriftCallService.of(implementations); + final ThriftCallService tcs = new ThriftCallServiceBuilder() + .addServices(implementations) + .useBlockingTaskExecutor(useBlockingTaskExecutor) + .build(); return build0(tcs); } @@ -244,7 +258,9 @@ private THttpService build0(RpcService tcs) { builder.add(defaultSerializationFormat); builder.addAll(otherSerializationFormats); - return new THttpService(decorate(tcs), defaultSerializationFormat, builder.build(), - maxRequestStringLength, maxRequestContainerLength, exceptionHandler); + return new THttpService( + decorate(tcs), defaultSerializationFormat, builder.build(), + maxRequestStringLength, maxRequestContainerLength, exceptionHandler + ); } } diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java index b3cd8d1c1dc..76241c09f17 100644 --- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallService.java @@ -16,7 +16,7 @@ package com.linecorp.armeria.server.thrift; -import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; import java.util.List; @@ -31,14 +31,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; - import com.linecorp.armeria.common.CompletableRpcResponse; import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.RpcRequest; import com.linecorp.armeria.common.RpcResponse; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.internal.common.thrift.ThriftFunction; import com.linecorp.armeria.server.RpcService; import com.linecorp.armeria.server.ServiceRequestContext; @@ -70,7 +68,7 @@ public void onError(Exception e) { */ public static ThriftCallService of(Object implementation) { requireNonNull(implementation, "implementation"); - return new ThriftCallService(ImmutableMap.of("", ImmutableList.of(implementation))); + return builder().addService(implementation).build(); } /** @@ -82,19 +80,27 @@ public static ThriftCallService of(Object implementation) { */ public static ThriftCallService of(Map> implementations) { requireNonNull(implementations, "implementations"); - return new ThriftCallService(implementations); + checkArgument(!implementations.isEmpty(), "implementations is empty"); + + return builder().addServices(implementations).build(); + } + + /** + * Creates a new instance of {@link ThriftCallServiceBuilder} which can build + * an instance of {@link ThriftCallService} fluently. + */ + @UnstableApi + public static ThriftCallServiceBuilder builder() { + return new ThriftCallServiceBuilder(); } private final Map entries; - private ThriftCallService(Map> implementations) { - requireNonNull(implementations, "implementations"); - if (implementations.isEmpty()) { - throw new IllegalArgumentException("empty implementations"); - } + private final boolean useBlockingTaskExecutor; - entries = implementations.entrySet().stream().collect( - toImmutableMap(Map.Entry::getKey, ThriftServiceEntry::new)); + ThriftCallService(Map entries, boolean useBlockingTaskExecutor) { + this.entries = entries; + this.useBlockingTaskExecutor = useBlockingTaskExecutor; } /** @@ -140,14 +146,24 @@ public RpcResponse serve(ServiceRequestContext ctx, RpcRequest call) throws Exce TApplicationException.UNKNOWN_METHOD, "unknown method: " + call.method())); } - private static void invoke( + private void invoke( ServiceRequestContext ctx, Object impl, ThriftFunction func, List args, CompletableRpcResponse reply) { try { final TBase tArgs = func.newArgs(args); if (func.isAsync()) { - invokeAsynchronously(impl, func, tArgs, reply); + if (useBlockingTaskExecutor) { + ctx.blockingTaskExecutor().execute(() -> { + try { + invokeAsynchronously(impl, func, tArgs, reply); + } catch (Throwable t) { + reply.completeExceptionally(t); + } + }); + } else { + invokeAsynchronously(impl, func, tArgs, reply); + } } else { invokeSynchronously(ctx, impl, func, tArgs, reply); } diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilder.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilder.java new file mode 100644 index 00000000000..ff2df8be448 --- /dev/null +++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilder.java @@ -0,0 +1,153 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.thrift; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +import java.util.Map; +import java.util.concurrent.Executors; + +import com.google.common.collect.ImmutableListMultimap; + +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * A fluent builder to build an instance of {@link ThriftCallService}. + * + *

Example

+ *
{@code
+ * ThriftCallService service = ThriftCallService
+ *                 .builder()
+ *                 .addService(defaultServiceImpl) // Adds a service
+ *                 .addService("foo", fooServiceImpl) // Adds a service with a key
+ *                 .addService("foobar", fooServiceImpl)  // Adds multiple services to the same key
+ *                 .addService("foobar", barServiceImpl)
+ *                  // Adds multiple services at once
+ *                 .addServices("foobarOnce", fooServiceImpl, barServiceImpl)
+ *                  // Adds multiple services by list
+ *                 .addServices("foobarList", ImmutableList.of(fooServiceImpl, barServiceImpl))
+ *                  // Adds multiple services by map
+ *                 .addServices(ImmutableMap.of("fooMap", fooServiceImpl, "barMap", barServiceImpl))
+ *                  // Adds multiple services by map
+ *                 .addServices(ImmutableMap.of("fooIterableMap",
+ *                                              ImmutableList.of(fooServiceImpl, barServiceImpl)))
+ *                 .build();
+ * }
+ * + * @see ThriftCallService + */ +@UnstableApi +public final class ThriftCallServiceBuilder { + private final ImmutableListMultimap.Builder servicesBuilder = + ImmutableListMultimap.builder(); + + private boolean useBlockingTaskExecutor; + + ThriftCallServiceBuilder() {} + + /** + * Adds a service for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addService(Object service) { + requireNonNull(service, "service"); + servicesBuilder.put("", service); + return this; + } + + /** + * Adds a service with a key for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addService(String key, Object service) { + requireNonNull(key, "key"); + requireNonNull(service, "service"); + servicesBuilder.put(key, service); + return this; + } + + /** + * Adds a service for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(Object... services) { + requireNonNull(services, "services"); + checkArgument(services.length != 0, "service should not be empty"); + servicesBuilder.putAll("", services); + return this; + } + + /** + * Adds a service with a key for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(String key, Object... services) { + requireNonNull(key, "key"); + requireNonNull(services, "service"); + checkArgument(services.length != 0, "service should not be empty"); + servicesBuilder.putAll(key, services); + return this; + } + + /** + * Adds services with key by iterable for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(String key, Iterable services) { + requireNonNull(key, "key"); + requireNonNull(services, "services"); + checkArgument(services.iterator().hasNext(), "service should not be empty"); + servicesBuilder.putAll(key, services); + return this; + } + + /** + * Adds multiple services by map for {@link ThriftServiceEntry}. + */ + public ThriftCallServiceBuilder addServices(Map services) { + requireNonNull(services, "services"); + checkArgument(!services.isEmpty(), "service should not be empty"); + + services.forEach((k, v) -> { + if (v instanceof Iterable) { + servicesBuilder.putAll(k, (Iterable) v); + } else { + servicesBuilder.put(k, v); + } + }); + return this; + } + + /** + * Sets whether the service executes service methods using the blocking executor. By default, service + * methods are executed directly on the event loop for implementing fully asynchronous services. If your + * service uses blocking logic, you should either execute such logic in a separate thread using something + * like {@link Executors#newCachedThreadPool()} or enable this setting. + */ + public ThriftCallServiceBuilder useBlockingTaskExecutor(boolean useBlockingTaskExecutor) { + this.useBlockingTaskExecutor = useBlockingTaskExecutor; + return this; + } + + /** + * Builds a new instance of {@link ThriftCallService}. + */ + public ThriftCallService build() { + return new ThriftCallService( + servicesBuilder.build().asMap().entrySet().stream().collect( + toImmutableMap(Map.Entry::getKey, ThriftServiceEntry::new)), + useBlockingTaskExecutor + ); + } +} diff --git a/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/THttpServiceBlockingTest.java b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/THttpServiceBlockingTest.java new file mode 100644 index 00000000000..5c53eca17f7 --- /dev/null +++ b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/THttpServiceBlockingTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.thrift; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.thrift.TException; +import org.apache.thrift.async.AsyncMethodCallback; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.client.thrift.ThriftClients; +import com.linecorp.armeria.common.util.ThreadFactories; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import testing.thrift.main.HelloService; + +class THttpServiceBlockingTest { + private static final AtomicReference currentThreadName = new AtomicReference<>(""); + + private static final String BLOCKING_EXECUTOR_PREFIX = "blocking-test"; + private static final ScheduledExecutorService executor = + new ScheduledThreadPoolExecutor(1, + ThreadFactories.newThreadFactory(BLOCKING_EXECUTOR_PREFIX, true)); + + @BeforeEach + void clearDetector() { + currentThreadName.set(""); + } + + @AfterAll + public static void shutdownExecutor() { + executor.shutdown(); + } + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) throws Exception { + + sb.service("/", + THttpService.builder().addService(new HelloServiceAsyncImpl()).build()); + sb.service("/blocking", THttpService.builder() + .useBlockingTaskExecutor(true) + .addService(new HelloServiceAsyncImpl()) + .build()); + sb.service("/blocking-iface", + THttpService.builder().addService(new HelloServiceImpl()).build()); + + sb.blockingTaskExecutor(executor, true); + } + }; + + @Test + void nonBlocking() throws Exception { + final HelloService.Iface client = ThriftClients.newClient(server.httpUri(), HelloService.Iface.class); + + final String message = "nonBlockingTest"; + final String response = client.hello(message); + + assertThat(response).isEqualTo(message); + assertThat(currentThreadName.get().startsWith(BLOCKING_EXECUTOR_PREFIX)).isFalse(); + } + + @Test + void blocking() throws Exception { + final HelloService.Iface client = + ThriftClients.builder(server.httpUri()) + .path("/blocking") + .build(HelloService.Iface.class); + final String message = "blockingTest"; + final String response = client.hello(message); + + assertThat(response).isEqualTo(message); + assertThat(currentThreadName.get().startsWith(BLOCKING_EXECUTOR_PREFIX)).isTrue(); + } + + @Test + void blockingIface() throws Exception { + final HelloService.Iface client = + ThriftClients.builder(server.httpUri()) + .path("/blocking-iface") + .build(HelloService.Iface.class); + final String message = "blockingTest"; + final String response = client.hello(message); + + assertThat(response).isEqualTo(message); + assertThat(currentThreadName.get().startsWith(BLOCKING_EXECUTOR_PREFIX)).isTrue(); + } + + static class HelloServiceAsyncImpl implements HelloService.AsyncIface { + @Override + public void hello(String name, AsyncMethodCallback resultHandler) throws TException { + currentThreadName.set(Thread.currentThread().getName()); + resultHandler.onComplete(name); + } + } + + static class HelloServiceImpl implements HelloService.Iface { + @Override + public String hello(String name) throws TException { + currentThreadName.set(Thread.currentThread().getName()); + return name; + } + } +} diff --git a/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilderTest.java b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilderTest.java new file mode 100644 index 00000000000..419629ebfe4 --- /dev/null +++ b/thrift/thrift0.13/src/test/java/com/linecorp/armeria/server/thrift/ThriftCallServiceBuilderTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server.thrift; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import testing.thrift.main.FooService.AsyncIface; + +/** + * Test for {@link ThriftCallServiceBuilder}. + */ +class ThriftCallServiceBuilderTest { + @Test + void nullAndEmptyCases() { + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addService(null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addService("", null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", (Object) null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", null, null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", null, null, null) + ); + assertThrows(NullPointerException.class, () -> + ThriftCallService.builder().addServices("", (Iterable) null) + ); + assertThrows(IllegalArgumentException.class, () -> + ThriftCallService.builder().addServices("", new ArrayList<>()) + ); + } + + @Test + void testBuilder() { + final AsyncIface defaultServiceImpl = mock(AsyncIface.class); + final AsyncIface fooServiceImpl = mock(AsyncIface.class); + final AsyncIface barServiceImpl = mock(AsyncIface.class); + final ThriftCallService service = ThriftCallService + .builder() + .addService(defaultServiceImpl) + .addService("foo", fooServiceImpl) + .addService("foobar", fooServiceImpl) + .addService("foobar", barServiceImpl) + .addServices("foobarOnce", fooServiceImpl, barServiceImpl) + .addServices("foobarList", ImmutableList.of(fooServiceImpl, barServiceImpl)) + .addServices(ImmutableMap.of("fooMap", fooServiceImpl, "barMap", barServiceImpl)) + .addServices(ImmutableMap.of("fooIterableMap", + ImmutableList.of(fooServiceImpl, barServiceImpl))) + .build(); + final Map> actualEntries = + service.entries().entrySet().stream() + .collect(ImmutableMap.toImmutableMap( + Map.Entry::getKey, + e -> ImmutableList.copyOf(e.getValue().implementations))); + + final Map> expectedEntries = ImmutableMap.of( + "", ImmutableList.of(defaultServiceImpl), + "foo", ImmutableList.of(fooServiceImpl), + "foobar", ImmutableList.of(fooServiceImpl, barServiceImpl), + "foobarOnce", ImmutableList.of(fooServiceImpl, barServiceImpl), + "foobarList", ImmutableList.of(fooServiceImpl, barServiceImpl), + "fooMap", ImmutableList.of(fooServiceImpl), + "barMap", ImmutableList.of(barServiceImpl), + "fooIterableMap", ImmutableList.of(fooServiceImpl, barServiceImpl)); + + assertThat(actualEntries).isEqualTo(expectedEntries); + } +}