From 217510ece601c198580fb96e7afd207ee8ad9890 Mon Sep 17 00:00:00 2001 From: Gaole Meng Date: Tue, 30 Jan 2024 16:10:57 -0800 Subject: [PATCH] fix: split connection pool based on credential --- .../bigquery/storage/v1/StreamWriter.java | 19 +++- .../bigquery/storage/v1/StreamWriterTest.java | 106 ++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java index 35ebfb316c..558c9bf58e 100644 --- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java +++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1/StreamWriter.java @@ -21,6 +21,7 @@ import com.google.api.gax.core.ExecutorProvider; import com.google.api.gax.retrying.RetrySettings; import com.google.api.gax.rpc.TransportChannelProvider; +import com.google.auth.Credentials; import com.google.auto.value.AutoOneOf; import com.google.auto.value.AutoValue; import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.MissingValueInterpretation; @@ -46,6 +47,7 @@ import java.util.logging.Logger; import java.util.regex.Matcher; import java.util.regex.Pattern; +import javax.annotation.Nullable; /** * A BigQuery Stream Writer that can be used to write data into BigQuery Table. @@ -134,8 +136,11 @@ public static long getApiMaxRequestBytes() { abstract static class ConnectionPoolKey { abstract String location(); - public static ConnectionPoolKey create(String location) { - return new AutoValue_StreamWriter_ConnectionPoolKey(location); + @Nullable + abstract Credentials credentials(); + + public static ConnectionPoolKey create(String location, @Nullable Credentials credentials) { + return new AutoValue_StreamWriter_ConnectionPoolKey(location, credentials); } } @@ -273,6 +278,7 @@ private StreamWriter(Builder builder) throws IOException { } } this.location = location; + CredentialsProvider credentialsProvider = client.getSettings().getCredentialsProvider(); // Assume the connection in the same pool share the same client and trace id. // The first StreamWriter for a new stub will create the pool for the other // streams in the same region, meaning the per StreamWriter settings are no @@ -280,7 +286,9 @@ private StreamWriter(Builder builder) throws IOException { this.singleConnectionOrConnectionPool = SingleConnectionOrConnectionPool.ofConnectionPool( connectionPoolMap.computeIfAbsent( - ConnectionPoolKey.create(location), + ConnectionPoolKey.create( + location, + credentialsProvider != null ? credentialsProvider.getCredentials() : null), (key) -> { return new ConnectionWorkerPool( builder.maxInflightRequest, @@ -581,6 +589,11 @@ ConnectionWorkerPool getTestOnlyConnectionWorkerPool() { return connectionWorkerPool; } + @VisibleForTesting + Map getTestOnlyConnectionPoolMap() { + return connectionPoolMap; + } + // A method to clear the static connectio pool to avoid making pool visible to other tests. @VisibleForTesting static void clearConnectionPool() { diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java index 5de9045037..ce51601394 100644 --- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java +++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1/StreamWriterTest.java @@ -26,6 +26,8 @@ import com.google.api.core.ApiFutureCallback; import com.google.api.core.ApiFutures; import com.google.api.gax.batching.FlowController; +import com.google.api.gax.core.CredentialsProvider; +import com.google.api.gax.core.FixedCredentialsProvider; import com.google.api.gax.core.GoogleCredentialsProvider; import com.google.api.gax.core.InstantiatingExecutorProvider; import com.google.api.gax.core.NoCredentialsProvider; @@ -38,6 +40,7 @@ import com.google.api.gax.rpc.InvalidArgumentException; import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.rpc.UnknownException; +import com.google.auth.oauth2.UserCredentials; import com.google.cloud.bigquery.storage.test.Test.FooType; import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.MissingValueInterpretation; import com.google.cloud.bigquery.storage.v1.ConnectionWorkerPool.Settings; @@ -924,6 +927,109 @@ public void testProtoSchemaPiping_multiplexingCase() throws Exception { writer2.close(); } + @Test + public void testFixedCredentialProvider_nullProvider() throws Exception { + // Use the shared connection mode. + ConnectionWorkerPool.setOptions( + Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build()); + ProtoSchema schema1 = createProtoSchema("Schema1"); + ProtoSchema schema2 = createProtoSchema("Schema2"); + CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(null); + CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(null); + StreamWriter writer1 = + StreamWriter.newBuilder(TEST_STREAM_1, client) + .setWriterSchema(schema1) + .setLocation("US") + .setEnableConnectionPool(true) + .setMaxInflightRequests(1) + .setCredentialsProvider(credentialsProvider1) + .build(); + StreamWriter writer2 = + StreamWriter.newBuilder(TEST_STREAM_2, client) + .setWriterSchema(schema2) + .setMaxInflightRequests(1) + .setEnableConnectionPool(true) + .setCredentialsProvider(credentialsProvider2) + .setLocation("US") + .build(); + // Null credential provided belong to the same connection pool. + assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 1); + } + + @Test + public void testFixedCredentialProvider_twoCredentialsSplitPool() throws Exception { + // Use the shared connection mode. + ConnectionWorkerPool.setOptions( + Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build()); + ProtoSchema schema1 = createProtoSchema("Schema1"); + ProtoSchema schema2 = createProtoSchema("Schema2"); + UserCredentials userCredentials1 = + UserCredentials.newBuilder() + .setClientId("CLIENT_ID_1") + .setClientSecret("CLIENT_SECRET_1") + .setRefreshToken("REFRESH_TOKEN_1") + .build(); + CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(userCredentials1); + UserCredentials userCredentials2 = + UserCredentials.newBuilder() + .setClientId("CLIENT_ID_2") + .setClientSecret("CLIENT_SECRET_2") + .setRefreshToken("REFRESH_TOKEN_2") + .build(); + CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(userCredentials2); + StreamWriter writer1 = + StreamWriter.newBuilder(TEST_STREAM_1) + .setWriterSchema(schema1) + .setLocation("US") + .setEnableConnectionPool(true) + .setMaxInflightRequests(1) + .setCredentialsProvider(credentialsProvider1) + .build(); + StreamWriter writer2 = + StreamWriter.newBuilder(TEST_STREAM_2) + .setWriterSchema(schema2) + .setMaxInflightRequests(1) + .setEnableConnectionPool(true) + .setLocation("US") + .setCredentialsProvider(credentialsProvider2) + .build(); + assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 2); + } + + @Test + public void testFixedCredentialProvider_twoProviderSameCredentialSharePool() throws Exception { + // Use the shared connection mode. + ConnectionWorkerPool.setOptions( + Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build()); + ProtoSchema schema1 = createProtoSchema("Schema1"); + ProtoSchema schema2 = createProtoSchema("Schema2"); + UserCredentials userCredentials = + UserCredentials.newBuilder() + .setClientId("CLIENT_ID_1") + .setClientSecret("CLIENT_SECRET_1") + .setRefreshToken("REFRESH_TOKEN_1") + .build(); + CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(userCredentials); + CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(userCredentials); + StreamWriter writer1 = + StreamWriter.newBuilder(TEST_STREAM_1) + .setWriterSchema(schema1) + .setLocation("US") + .setEnableConnectionPool(true) + .setMaxInflightRequests(1) + .setCredentialsProvider(credentialsProvider1) + .build(); + StreamWriter writer2 = + StreamWriter.newBuilder(TEST_STREAM_2) + .setWriterSchema(schema2) + .setMaxInflightRequests(1) + .setEnableConnectionPool(true) + .setLocation("US") + .setCredentialsProvider(credentialsProvider2) + .build(); + assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 1); + } + @Test public void testDefaultValueInterpretation_multiplexingCase() throws Exception { // Use the shared connection mode.