From 7031f009934a19349dc353cb06e03d00fe29c02e Mon Sep 17 00:00:00 2001
From: jrhee17 <guins_j@guins.org>
Date: Tue, 12 Nov 2024 11:53:06 +0900
Subject: [PATCH] add thrift test case

---
 .../com/linecorp/armeria/client/Clients.java  |   2 +-
 it/xds-clients/build.gradle                   |   2 +-
 .../armeria/xds/GrpcIntegrationTest.java      |   2 +-
 .../armeria/xds/ThriftIntegrationTest.java    | 122 ++++++++++++++++++
 it/xds-clients/src/test/thrift/main.thrift    |   6 +
 .../client/thrift/ThriftClientBuilder.java    |  15 ++-
 .../armeria/client/thrift/ThriftClients.java  |  10 ++
 .../client/thrift/DefaultTHttpClient.java     |   1 -
 .../client/thrift/THttpClientFactory.java     |   2 +-
 9 files changed, 150 insertions(+), 12 deletions(-)
 create mode 100644 it/xds-clients/src/test/java/com/linecorp/armeria/xds/ThriftIntegrationTest.java
 create mode 100644 it/xds-clients/src/test/thrift/main.thrift

diff --git a/core/src/main/java/com/linecorp/armeria/client/Clients.java b/core/src/main/java/com/linecorp/armeria/client/Clients.java
index 4f0a2828a7d..79dc8deaded 100644
--- a/core/src/main/java/com/linecorp/armeria/client/Clients.java
+++ b/core/src/main/java/com/linecorp/armeria/client/Clients.java
@@ -77,7 +77,7 @@ public static <T> T newClient(URI uri, Class<T> clientType) {
      * {@code scheme} using the default {@link ClientFactory}.
      *
      * @param scheme the {@link Scheme} represented as a {@link String}
-     * @param endpointGroup the server {@link EndpointGroup}
+     * @param executionPreparation the server {@link EndpointGroup}
      * @param clientType the type of the new client
      *
      * @throws IllegalArgumentException if the specified {@code scheme} is invalid or
diff --git a/it/xds-clients/build.gradle b/it/xds-clients/build.gradle
index e6ca3094856..3df2737069a 100644
--- a/it/xds-clients/build.gradle
+++ b/it/xds-clients/build.gradle
@@ -2,5 +2,5 @@ dependencies {
     implementation project(':xds')
     implementation project(':xds-testing-internal')
     implementation project(':grpc')
-    implementation project(':thrift0.17')
+    implementation project(':thrift0.13')
 }
diff --git a/it/xds-clients/src/test/java/com/linecorp/armeria/xds/GrpcIntegrationTest.java b/it/xds-clients/src/test/java/com/linecorp/armeria/xds/GrpcIntegrationTest.java
index 2e09651621e..02da01d7b64 100644
--- a/it/xds-clients/src/test/java/com/linecorp/armeria/xds/GrpcIntegrationTest.java
+++ b/it/xds-clients/src/test/java/com/linecorp/armeria/xds/GrpcIntegrationTest.java
@@ -116,7 +116,7 @@ void simpleClient() {
         try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap);
              XdsExecutionPreparation preparation = XdsExecutionPreparation.of("listener", xdsBootstrap)) {
             TestServiceBlockingStub stub = GrpcClients.newClient(preparation,
-                                                                       TestServiceBlockingStub.class);
+                                                                 TestServiceBlockingStub.class);
             assertThat(stub.hello(HelloRequest.getDefaultInstance()).getMessage()).isEqualTo("Hello");
 
             stub = Clients.newDerivedClient(stub, ClientOptions.RESPONSE_TIMEOUT_MILLIS.newValue(10L));
diff --git a/it/xds-clients/src/test/java/com/linecorp/armeria/xds/ThriftIntegrationTest.java b/it/xds-clients/src/test/java/com/linecorp/armeria/xds/ThriftIntegrationTest.java
new file mode 100644
index 00000000000..b681d9b41e2
--- /dev/null
+++ b/it/xds-clients/src/test/java/com/linecorp/armeria/xds/ThriftIntegrationTest.java
@@ -0,0 +1,122 @@
+/*
+ * Copyright 2024 LINE Corporation
+ *
+ * LINE Corporation licenses this file to you under the Apache License,
+ * version 2.0 (the "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at:
+ *
+ *   https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ */
+
+package com.linecorp.armeria.xds;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.net.URI;
+
+import org.apache.thrift.TException;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import com.google.common.collect.ImmutableList;
+
+import com.linecorp.armeria.client.ClientOptions;
+import com.linecorp.armeria.client.Clients;
+import com.linecorp.armeria.client.thrift.ThriftClients;
+import com.linecorp.armeria.common.thrift.ThriftSerializationFormats;
+import com.linecorp.armeria.server.ServerBuilder;
+import com.linecorp.armeria.server.grpc.GrpcService;
+import com.linecorp.armeria.server.thrift.THttpService;
+import com.linecorp.armeria.testing.junit5.server.ServerExtension;
+import com.linecorp.armeria.xds.client.endpoint.XdsExecutionPreparation;
+import com.linecorp.armeria.xds.internal.XdsTestResources;
+
+import io.envoyproxy.controlplane.cache.v3.SimpleCache;
+import io.envoyproxy.controlplane.cache.v3.Snapshot;
+import io.envoyproxy.controlplane.server.V3DiscoveryServer;
+import io.envoyproxy.envoy.config.bootstrap.v3.Bootstrap;
+import io.envoyproxy.envoy.config.cluster.v3.Cluster;
+import io.envoyproxy.envoy.config.core.v3.ConfigSource;
+import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment;
+import io.envoyproxy.envoy.config.listener.v3.Listener;
+import testing.thrift.TestService;
+import testing.thrift.TestService.Iface;
+
+class ThriftIntegrationTest {
+
+    public static final String BOOTSTRAP_CLUSTER_NAME = "bootstrap-cluster";
+    private static final SimpleCache<String> cache = new SimpleCache<>(node -> "GROUP");
+
+    @RegisterExtension
+    static final ServerExtension server = new ServerExtension() {
+        @Override
+        protected void configure(ServerBuilder sb) throws Exception {
+            sb.service("/thrift", THttpService.of(new TestService.Iface() {
+
+                @Override
+                public String sayHello(String name) throws TException {
+                    return "World";
+                }
+            }));
+        }
+    };
+
+    @RegisterExtension
+    static final ServerExtension controlPlaneServer = new ServerExtension() {
+        @Override
+        protected void configure(ServerBuilder sb) {
+            final V3DiscoveryServer v3DiscoveryServer = new V3DiscoveryServer(cache);
+            sb.service(GrpcService.builder()
+                                  .addService(v3DiscoveryServer.getAggregatedDiscoveryServiceImpl())
+                                  .addService(v3DiscoveryServer.getListenerDiscoveryServiceImpl())
+                                  .addService(v3DiscoveryServer.getClusterDiscoveryServiceImpl())
+                                  .addService(v3DiscoveryServer.getRouteDiscoveryServiceImpl())
+                                  .addService(v3DiscoveryServer.getEndpointDiscoveryServiceImpl())
+                                  .build());
+            sb.tlsSelfSigned();
+            sb.http(0);
+            sb.https(0);
+        }
+    };
+
+    @BeforeEach
+    void beforeEach() {
+        final ClusterLoadAssignment loadAssignment =
+                XdsTestResources.loadAssignment("cluster", server.httpUri());
+        final Cluster httpCluster = XdsTestResources.createStaticCluster("cluster", loadAssignment);
+        final Listener httpListener = XdsTestResources.staticResourceListener();
+        cache.setSnapshot(
+                "GROUP",
+                Snapshot.create(ImmutableList.of(httpCluster), ImmutableList.of(),
+                                ImmutableList.of(httpListener), ImmutableList.of(), ImmutableList.of(), "1"));
+    }
+
+    @Test
+    void basicCase() throws Exception {
+        final ConfigSource configSource = XdsTestResources.basicConfigSource(BOOTSTRAP_CLUSTER_NAME);
+        final URI uri = controlPlaneServer.httpUri();
+        final ClusterLoadAssignment loadAssignment =
+                XdsTestResources.loadAssignment(BOOTSTRAP_CLUSTER_NAME,
+                                                uri.getHost(), uri.getPort());
+        final Cluster bootstrapCluster =
+                XdsTestResources.createStaticCluster(BOOTSTRAP_CLUSTER_NAME, loadAssignment);
+        final Bootstrap bootstrap = XdsTestResources.bootstrap(configSource, bootstrapCluster);
+        try (XdsBootstrap xdsBootstrap = XdsBootstrap.of(bootstrap);
+             XdsExecutionPreparation preparation = XdsExecutionPreparation.of("listener", xdsBootstrap)) {
+            Iface iface = ThriftClients.builder(ThriftSerializationFormats.BINARY, preparation)
+                                       .path("/thrift")
+                                       .build(Iface.class);
+            assertThat(iface.sayHello("Hello, ")).isEqualTo("World");
+
+            iface = Clients.newDerivedClient(iface, ClientOptions.RESPONSE_TIMEOUT_MILLIS.newValue(10L));
+            assertThat(iface.sayHello("Hello, ")).isEqualTo("World");
+        }
+    }
+}
diff --git a/it/xds-clients/src/test/thrift/main.thrift b/it/xds-clients/src/test/thrift/main.thrift
new file mode 100644
index 00000000000..8ae5322a8f9
--- /dev/null
+++ b/it/xds-clients/src/test/thrift/main.thrift
@@ -0,0 +1,6 @@
+namespace java testing.thrift
+
+
+service TestService {
+    string sayHello(1:string name)
+}
diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java
index 7591c0d71aa..7249d7aeebb 100644
--- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java
+++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClientBuilder.java
@@ -41,6 +41,7 @@
 import com.linecorp.armeria.client.DecoratingHttpClientFunction;
 import com.linecorp.armeria.client.DecoratingRpcClientFunction;
 import com.linecorp.armeria.client.Endpoint;
+import com.linecorp.armeria.client.ExecutionPreparation;
 import com.linecorp.armeria.client.HttpClient;
 import com.linecorp.armeria.client.ResponseTimeoutMode;
 import com.linecorp.armeria.client.RpcClient;
@@ -66,7 +67,7 @@
 public final class ThriftClientBuilder extends AbstractClientOptionsBuilder {
 
     @Nullable
-    private final EndpointGroup endpointGroup;
+    private final ExecutionPreparation executionPreparation;
 
     @Nullable
     private URI uri;
@@ -77,19 +78,19 @@ public final class ThriftClientBuilder extends AbstractClientOptionsBuilder {
     ThriftClientBuilder(URI uri) {
         requireNonNull(uri, "uri");
         checkArgument(uri.getScheme() != null, "uri must have scheme: %s", uri);
-        endpointGroup = null;
+        executionPreparation = null;
         this.uri = uri;
         scheme = Scheme.parse(uri.getScheme());
         validateOrSetSerializationFormat();
     }
 
-    ThriftClientBuilder(Scheme scheme, EndpointGroup endpointGroup) {
+    ThriftClientBuilder(Scheme scheme, ExecutionPreparation executionPreparation) {
         requireNonNull(scheme, "scheme");
-        requireNonNull(endpointGroup, "endpointGroup");
+        requireNonNull(executionPreparation, "executionPreparation");
         uri = null;
         this.scheme = scheme;
         validateOrSetSerializationFormat();
-        this.endpointGroup = endpointGroup;
+        this.executionPreparation = executionPreparation;
     }
 
     private void validateOrSetSerializationFormat() {
@@ -195,8 +196,8 @@ public <T> T build(Class<T> clientType) {
             }
             client = factory.newClient(ClientBuilderParams.of(uri, clientType, options));
         } else {
-            assert endpointGroup != null;
-            client = factory.newClient(ClientBuilderParams.of(scheme, endpointGroup,
+            assert executionPreparation != null;
+            client = factory.newClient(ClientBuilderParams.of(scheme, executionPreparation,
                                                               path, clientType, options));
         }
 
diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClients.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClients.java
index bf3c1915d90..bab9b3367f2 100644
--- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClients.java
+++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/client/thrift/ThriftClients.java
@@ -21,6 +21,7 @@
 import java.net.URI;
 
 import com.linecorp.armeria.client.ClientFactory;
+import com.linecorp.armeria.client.ExecutionPreparation;
 import com.linecorp.armeria.client.endpoint.EndpointGroup;
 import com.linecorp.armeria.common.Scheme;
 import com.linecorp.armeria.common.SerializationFormat;
@@ -220,5 +221,14 @@ public static ThriftClientBuilder builder(Scheme scheme, EndpointGroup endpointG
         return new ThriftClientBuilder(scheme, endpointGroup);
     }
 
+    /**
+     * TBU.
+     */
+    public static ThriftClientBuilder builder(SerializationFormat serializationFormat, ExecutionPreparation executionPreparation) {
+        requireNonNull(serializationFormat, "serializationFormat");
+        requireNonNull(executionPreparation, "executionPreparation");
+        return new ThriftClientBuilder(Scheme.of(serializationFormat, SessionProtocol.UNDETERMINED), executionPreparation);
+    }
+
     private ThriftClients() {}
 }
diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java
index 2e7fac95a26..6dad45f4f6a 100644
--- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java
+++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/DefaultTHttpClient.java
@@ -81,7 +81,6 @@ private RpcResponse execute0(
         final RpcRequest call = RpcRequest.of(serviceType, method, args);
         final HttpRequest httpReq = HttpRequest.of(
                 RequestHeaders.builder(HttpMethod.POST, reqTarget.path())
-                              .scheme(scheme().sessionProtocol())
                               .contentType(scheme().serializationFormat().mediaType())
                               .build());
         final RequestParams reqParams = RequestParams.of(httpReq, call, UNARY_REQUEST_OPTIONS, reqTarget);
diff --git a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/THttpClientFactory.java b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/THttpClientFactory.java
index a6cfcafa548..e0b435d944c 100644
--- a/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/THttpClientFactory.java
+++ b/thrift/thrift0.13/src/main/java/com/linecorp/armeria/internal/client/thrift/THttpClientFactory.java
@@ -81,7 +81,7 @@ public Object newClient(ClientBuilderParams params) {
         // Create a THttpClient without path.
         final ClientBuilderParams delegateParams =
                 ClientBuilderParams.of(params.scheme(),
-                                       params.endpointGroup(),
+                                       params.executionPreparation(),
                                        "/", THttpClient.class,
                                        options);