Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add header back to the client #2016

Merged
merged 12 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.api.core.ApiFuture;
import com.google.api.core.SettableApiFuture;
import com.google.api.gax.batching.FlowController;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.auto.value.AutoValue;
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.ProtoData;
import com.google.cloud.bigquery.storage.v1.Exceptions.AppendSerializtionError;
Expand Down Expand Up @@ -77,6 +78,11 @@ class ConnectionWorker implements AutoCloseable {
*/
private String streamName;

/*
* The location of this connection.
*/
private String location = null;

/*
* The proto schema of rows to write. This schema can change during multiplexing.
*/
Expand Down Expand Up @@ -211,6 +217,7 @@ public static long getApiMaxRequestBytes() {

public ConnectionWorker(
String streamName,
String location,
ProtoSchema writerSchema,
long maxInflightRequests,
long maxInflightBytes,
Expand All @@ -223,6 +230,9 @@ public ConnectionWorker(
this.hasMessageInWaitingQueue = lock.newCondition();
this.inflightReduced = lock.newCondition();
this.streamName = streamName;
if (location != null && !location.isEmpty()) {
this.location = location;
}
this.maxRetryDuration = maxRetryDuration;
if (writerSchema == null) {
throw new StatusRuntimeException(
Expand All @@ -236,6 +246,18 @@ public ConnectionWorker(
this.waitingRequestQueue = new LinkedList<AppendRequestAndResponse>();
this.inflightRequestQueue = new LinkedList<AppendRequestAndResponse>();
// Always recreate a client for connection worker.
HashMap<String, String> newHeaders = new HashMap<>();
newHeaders.putAll(clientSettings.toBuilder().getHeaderProvider().getHeaders());
if (this.location == null) {
newHeaders.put("x-goog-request-params", "write_stream=" + this.streamName);
} else {
newHeaders.put("x-goog-request-params", "write_location=" + this.location);
}
BigQueryWriteSettings stubSettings =
clientSettings
.toBuilder()
.setHeaderProvider(FixedHeaderProvider.create(newHeaders))
.build();
this.client = BigQueryWriteClient.create(clientSettings);

this.appendThread =
Expand Down Expand Up @@ -297,6 +319,24 @@ public void run(Throwable finalStatus) {

/** Schedules the writing of rows at given offset. */
ApiFuture<AppendRowsResponse> append(StreamWriter streamWriter, ProtoRows rows, long offset) {
if (this.location != null && this.location != streamWriter.getLocation()) {
throw new StatusRuntimeException(
Status.fromCode(Code.INVALID_ARGUMENT)
.withDescription(
"StreamWriter with location "
+ streamWriter.getLocation()
+ " is scheduled to use a connection with location "
+ this.location));
} else if (this.location == null && streamWriter.getStreamName() != this.streamName) {
// Location is null implies this is non-multiplexed connection.
throw new StatusRuntimeException(
Status.fromCode(Code.INVALID_ARGUMENT)
.withDescription(
"StreamWriter with stream name "
+ streamWriter.getStreamName()
+ " is scheduled to use a connection with stream name "
+ this.streamName));
}
Preconditions.checkNotNull(streamWriter);
AppendRowsRequest.Builder requestBuilder = AppendRowsRequest.newBuilder();
requestBuilder.setProtoRows(
Expand All @@ -322,6 +362,10 @@ Boolean isUserClosed() {
}
}

String getWriteLocation() {
return this.location;
}

private ApiFuture<AppendRowsResponse> appendInternal(
StreamWriter streamWriter, AppendRowsRequest message) {
AppendRequestAndResponse requestWrapper = new AppendRequestAndResponse(message, streamWriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ private ConnectionWorker createOrReuseConnectionWorker(
String streamReference = streamWriter.getStreamName();
if (connectionWorkerPool.size() < currentMaxConnectionCount) {
// Always create a new connection if we haven't reached current maximum.
return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema());
return createConnectionWorker(
streamWriter.getStreamName(), streamWriter.getLocation(), streamWriter.getProtoSchema());
} else {
ConnectionWorker existingBestConnection =
pickBestLoadConnection(
Expand All @@ -304,7 +305,10 @@ private ConnectionWorker createOrReuseConnectionWorker(
if (currentMaxConnectionCount > settings.maxConnectionsPerRegion()) {
currentMaxConnectionCount = settings.maxConnectionsPerRegion();
}
return createConnectionWorker(streamWriter.getStreamName(), streamWriter.getProtoSchema());
return createConnectionWorker(
streamWriter.getStreamName(),
streamWriter.getLocation(),
streamWriter.getProtoSchema());
} else {
// Stick to the original connection if all the connections are overwhelmed.
if (existingConnectionWorker != null) {
Expand Down Expand Up @@ -359,15 +363,16 @@ static ConnectionWorker pickBestLoadConnection(
* a single stream reference. This is because createConnectionWorker(...) is called via
* computeIfAbsent(...) which is at most once per key.
*/
private ConnectionWorker createConnectionWorker(String streamName, ProtoSchema writeSchema)
throws IOException {
private ConnectionWorker createConnectionWorker(
String streamName, String location, ProtoSchema writeSchema) throws IOException {
if (enableTesting) {
// Though atomic integer is super lightweight, add extra if check in case adding future logic.
testValueCreateConnectionCount.getAndIncrement();
}
ConnectionWorker connectionWorker =
new ConnectionWorker(
streamName,
location,
writeSchema,
maxInflightRequests,
maxInflightBytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ private StreamWriter(Builder builder) throws IOException {
SingleConnectionOrConnectionPool.ofSingleConnection(
new ConnectionWorker(
builder.streamName,
builder.location,
builder.writerSchema,
builder.maxInflightRequest,
builder.maxInflightBytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ private StreamWriter getTestStreamWriter(String streamName) throws IOException {
return StreamWriter.newBuilder(streamName)
.setWriterSchema(createProtoSchema())
.setTraceId(TEST_TRACE_ID)
.setLocation("us")
.setCredentialsProvider(NoCredentialsProvider.create())
.setChannelProvider(serviceHelper.createChannelProvider())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.logging.Logger;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class ConnectionWorkerTest {
private static final Logger log = Logger.getLogger(StreamWriter.class.getName());
private static final String TEST_STREAM_1 = "projects/p1/datasets/d1/tables/t1/streams/s1";
private static final String TEST_STREAM_2 = "projects/p2/datasets/d2/tables/t2/streams/s2";
private static final String TEST_TRACE_ID = "DATAFLOW:job_id";
Expand Down Expand Up @@ -84,10 +86,12 @@ public void testMultiplexedAppendSuccess() throws Exception {
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setWriterSchema(createProtoSchema("foo"))
.setLocation("us")
.build();
StreamWriter sw2 =
StreamWriter.newBuilder(TEST_STREAM_2, client)
.setWriterSchema(createProtoSchema("complicate"))
.setLocation("us")
.build();
// We do a pattern of:
// send to stream1, string1
Expand Down Expand Up @@ -205,11 +209,20 @@ public void testAppendInSameStream_switchSchema() throws Exception {
// send to stream1, schema1
// ...
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema1)
.build();
StreamWriter sw2 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema2).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema2)
.build();
StreamWriter sw3 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema3).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema3)
.build();
for (long i = 0; i < appendCount; i++) {
switch ((int) i % 4) {
case 0:
Expand Down Expand Up @@ -305,10 +318,14 @@ public void testAppendInSameStream_switchSchema() throws Exception {
public void testAppendButInflightQueueFull() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema1)
.build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_1,
"us",
createProtoSchema("foo"),
6,
100000,
Expand Down Expand Up @@ -356,10 +373,14 @@ public void testAppendButInflightQueueFull() throws Exception {
public void testThrowExceptionWhileWithinAppendLoop() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setLocation("us")
.setWriterSchema(schema1)
.build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_1,
"us",
createProtoSchema("foo"),
100000,
100000,
Expand Down Expand Up @@ -411,6 +432,69 @@ public void testThrowExceptionWhileWithinAppendLoop() throws Exception {
assertThat(ex.getCause()).hasMessageThat().contains("Any exception can happen.");
}

@Test
public void testLocationMismatch() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client)
.setWriterSchema(schema1)
.setLocation("eu")
.build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_1,
"us",
createProtoSchema("foo"),
100000,
100000,
Duration.ofSeconds(100),
FlowController.LimitExceededBehavior.Block,
TEST_TRACE_ID,
client.getSettings());
StatusRuntimeException ex =
assertThrows(
StatusRuntimeException.class,
() ->
sendTestMessage(
connectionWorker,
sw1,
createFooProtoRows(new String[] {String.valueOf(0)}),
0));
assertEquals(
"INVALID_ARGUMENT: StreamWriter with location eu is scheduled to use a connection with location us",
ex.getMessage());
}

@Test
public void testStreamNameMismatch() throws Exception {
ProtoSchema schema1 = createProtoSchema("foo");
StreamWriter sw1 =
StreamWriter.newBuilder(TEST_STREAM_1, client).setWriterSchema(schema1).build();
ConnectionWorker connectionWorker =
new ConnectionWorker(
TEST_STREAM_2,
null,
createProtoSchema("foo"),
100000,
100000,
Duration.ofSeconds(100),
FlowController.LimitExceededBehavior.Block,
TEST_TRACE_ID,
client.getSettings());
StatusRuntimeException ex =
assertThrows(
StatusRuntimeException.class,
() ->
sendTestMessage(
connectionWorker,
sw1,
createFooProtoRows(new String[] {String.valueOf(0)}),
0));
assertEquals(
"INVALID_ARGUMENT: StreamWriter with stream name projects/p1/datasets/d1/tables/t1/streams/s1 is scheduled to use a connection with stream name projects/p2/datasets/d2/tables/t2/streams/s2",
ex.getMessage());
}

@Test
public void testExponentialBackoff() throws Exception {
assertThat(ConnectionWorker.calculateSleepTimeMilli(0)).isEqualTo(1);
Expand Down Expand Up @@ -440,6 +524,7 @@ private ConnectionWorker createConnectionWorker(
throws IOException {
return new ConnectionWorker(
streamName,
"us",
createProtoSchema("foo"),
maxRequests,
maxBytes,
Expand Down