Skip to content

Commit

Permalink
fix: split connection pool based on credential
Browse files Browse the repository at this point in the history
  • Loading branch information
GaoleMeng committed Feb 1, 2024
1 parent 4498247 commit 8d661a9
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.auth.Credentials;
import com.google.auto.value.AutoOneOf;
import com.google.auto.value.AutoValue;
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.MissingValueInterpretation;
Expand All @@ -46,6 +47,7 @@
import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.Nullable;

/**
* A BigQuery Stream Writer that can be used to write data into BigQuery Table.
Expand Down Expand Up @@ -134,8 +136,11 @@ public static long getApiMaxRequestBytes() {
abstract static class ConnectionPoolKey {
abstract String location();

public static ConnectionPoolKey create(String location) {
return new AutoValue_StreamWriter_ConnectionPoolKey(location);
@Nullable
abstract Credentials credentials();

public static ConnectionPoolKey create(String location, @Nullable Credentials credentials) {
return new AutoValue_StreamWriter_ConnectionPoolKey(location, credentials);
}
}

Expand Down Expand Up @@ -273,14 +278,17 @@ private StreamWriter(Builder builder) throws IOException {
}
}
this.location = location;
CredentialsProvider credentialsProvider = client.getSettings().getCredentialsProvider();
// Assume the connection in the same pool share the same client and trace id.
// The first StreamWriter for a new stub will create the pool for the other
// streams in the same region, meaning the per StreamWriter settings are no
// longer working unless all streams share the same set of settings
this.singleConnectionOrConnectionPool =
SingleConnectionOrConnectionPool.ofConnectionPool(
connectionPoolMap.computeIfAbsent(
ConnectionPoolKey.create(location),
ConnectionPoolKey.create(
location,
credentialsProvider != null ? credentialsProvider.getCredentials() : null),
(key) -> {
return new ConnectionWorkerPool(
builder.maxInflightRequest,
Expand Down Expand Up @@ -581,6 +589,11 @@ ConnectionWorkerPool getTestOnlyConnectionWorkerPool() {
return connectionWorkerPool;
}

@VisibleForTesting
Map<ConnectionPoolKey, ConnectionWorkerPool> getTestOnlyConnectionPoolMap() {
return connectionPoolMap;
}

// A method to clear the static connectio pool to avoid making pool visible to other tests.
@VisibleForTesting
static void clearConnectionPool() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.google.api.core.ApiFutureCallback;
import com.google.api.core.ApiFutures;
import com.google.api.gax.batching.FlowController;
import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.FixedCredentialsProvider;
import com.google.api.gax.core.GoogleCredentialsProvider;
import com.google.api.gax.core.InstantiatingExecutorProvider;
import com.google.api.gax.core.NoCredentialsProvider;
Expand All @@ -38,6 +40,7 @@
import com.google.api.gax.rpc.InvalidArgumentException;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.api.gax.rpc.UnknownException;
import com.google.auth.oauth2.UserCredentials;
import com.google.cloud.bigquery.storage.test.Test.FooType;
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.MissingValueInterpretation;
import com.google.cloud.bigquery.storage.v1.ConnectionWorkerPool.Settings;
Expand Down Expand Up @@ -924,6 +927,109 @@ public void testProtoSchemaPiping_multiplexingCase() throws Exception {
writer2.close();
}

@Test
public void testFixedCredentialProvider_nullProvider() throws Exception {
// Use the shared connection mode.
ConnectionWorkerPool.setOptions(
Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build());
ProtoSchema schema1 = createProtoSchema("Schema1");
ProtoSchema schema2 = createProtoSchema("Schema2");
CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(null);
CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(null);
StreamWriter writer1 =
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setWriterSchema(schema1)
.setLocation("US")
.setEnableConnectionPool(true)
.setMaxInflightRequests(1)
.setCredentialsProvider(credentialsProvider1)
.build();
StreamWriter writer2 =
StreamWriter.newBuilder(TEST_STREAM_2, client)
.setWriterSchema(schema2)
.setMaxInflightRequests(1)
.setEnableConnectionPool(true)
.setCredentialsProvider(credentialsProvider2)
.setLocation("US")
.build();
// Null credential provided belong to the same connection pool.
assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 1);
}

@Test
public void testFixedCredentialProvider_twoCredentialsSplitPool() throws Exception {
// Use the shared connection mode.
ConnectionWorkerPool.setOptions(
Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build());
ProtoSchema schema1 = createProtoSchema("Schema1");
ProtoSchema schema2 = createProtoSchema("Schema2");
UserCredentials userCredentials1 =
UserCredentials.newBuilder()
.setClientId("CLIENT_ID_1")
.setClientSecret("CLIENT_SECRET_1")
.setRefreshToken("REFRESH_TOKEN_1")
.build();
CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(userCredentials1);
UserCredentials userCredentials2 =
UserCredentials.newBuilder()
.setClientId("CLIENT_ID_2")
.setClientSecret("CLIENT_SECRET_2")
.setRefreshToken("REFRESH_TOKEN_2")
.build();
CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(userCredentials2);
StreamWriter writer1 =
StreamWriter.newBuilder(TEST_STREAM_1)
.setWriterSchema(schema1)
.setLocation("US")
.setEnableConnectionPool(true)
.setMaxInflightRequests(1)
.setCredentialsProvider(credentialsProvider1)
.build();
StreamWriter writer2 =
StreamWriter.newBuilder(TEST_STREAM_2)
.setWriterSchema(schema2)
.setMaxInflightRequests(1)
.setEnableConnectionPool(true)
.setLocation("US")
.setCredentialsProvider(credentialsProvider2)
.build();
assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 2);
}

@Test
public void testFixedCredentialProvider_twoProviderSameCredentialSharePool() throws Exception {
// Use the shared connection mode.
ConnectionWorkerPool.setOptions(
Settings.builder().setMinConnectionsPerRegion(1).setMaxConnectionsPerRegion(1).build());
ProtoSchema schema1 = createProtoSchema("Schema1");
ProtoSchema schema2 = createProtoSchema("Schema2");
UserCredentials userCredentials =
UserCredentials.newBuilder()
.setClientId("CLIENT_ID_1")
.setClientSecret("CLIENT_SECRET_1")
.setRefreshToken("REFRESH_TOKEN_1")
.build();
CredentialsProvider credentialsProvider1 = FixedCredentialsProvider.create(userCredentials);
CredentialsProvider credentialsProvider2 = FixedCredentialsProvider.create(userCredentials);
StreamWriter writer1 =
StreamWriter.newBuilder(TEST_STREAM_1)
.setWriterSchema(schema1)
.setLocation("US")
.setEnableConnectionPool(true)
.setMaxInflightRequests(1)
.setCredentialsProvider(credentialsProvider1)
.build();
StreamWriter writer2 =
StreamWriter.newBuilder(TEST_STREAM_2)
.setWriterSchema(schema2)
.setMaxInflightRequests(1)
.setEnableConnectionPool(true)
.setLocation("US")
.setCredentialsProvider(credentialsProvider2)
.build();
assertEquals(writer1.getTestOnlyConnectionPoolMap().size(), 1);
}

@Test
public void testDefaultValueInterpretation_multiplexingCase() throws Exception {
// Use the shared connection mode.
Expand Down

0 comments on commit 8d661a9

Please sign in to comment.