diff --git a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2.java b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2.java
index 8debea15f9..1d001d5bbb 100644
--- a/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2.java
+++ b/google-cloud-bigquerystorage/src/main/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2.java
@@ -40,8 +40,6 @@
*
*
TODO: Attach schema.
*
- *
TODO: Add inflight control.
- *
*
TODO: Attach traceId.
*
*
TODO: Support batching.
@@ -53,12 +51,35 @@ public class StreamWriterV2 implements AutoCloseable {
private Lock lock;
private Condition hasMessageInWaitingQueue;
+ private Condition inflightReduced;
/*
* The identifier of stream to write to.
*/
private final String streamName;
+ /*
+ * Max allowed inflight requests in the stream. Method append is blocked at this.
+ */
+ private final long maxInflightRequests;
+
+ /*
+ * Max allowed inflight bytes in the stream. Method append is blocked at this.
+ */
+ private final long maxInflightBytes;
+
+ /*
+ * Tracks current inflight requests in the stream.
+ */
+ @GuardedBy("lock")
+ private long inflightRequests = 0;
+
+ /*
+ * Tracks current inflight bytes in the stream.
+ */
+ @GuardedBy("lock")
+ private long inflightBytes = 0;
+
/*
* Indicates whether user has called Close() or not.
*/
@@ -101,7 +122,10 @@ public static long getApiMaxRequestBytes() {
private StreamWriterV2(Builder builder) {
this.lock = new ReentrantLock();
this.hasMessageInWaitingQueue = lock.newCondition();
+ this.inflightReduced = lock.newCondition();
this.streamName = builder.streamName;
+ this.maxInflightRequests = builder.maxInflightRequest;
+ this.maxInflightBytes = builder.maxInflightBytes;
this.waitingRequestQueue = new LinkedList();
this.inflightRequestQueue = new LinkedList();
this.streamConnection =
@@ -186,14 +210,38 @@ public ApiFuture append(AppendRowsRequest message) {
"Stream is closed due to " + connectionFinalStatus.toString())));
return requestWrapper.appendResult;
}
+
+ ++this.inflightRequests;
+ this.inflightBytes += requestWrapper.messageSize;
waitingRequestQueue.addLast(requestWrapper);
hasMessageInWaitingQueue.signal();
+ maybeWaitForInflightQuota();
return requestWrapper.appendResult;
} finally {
this.lock.unlock();
}
}
+ @GuardedBy("lock")
+ private void maybeWaitForInflightQuota() {
+ while (this.inflightRequests >= this.maxInflightRequests
+ || this.inflightBytes >= this.maxInflightBytes) {
+ try {
+ inflightReduced.await(100, TimeUnit.MILLISECONDS);
+ } catch (InterruptedException e) {
+ log.warning(
+ "Interrupted while waiting for inflight quota. Stream: "
+ + streamName
+ + " Error: "
+ + e.toString());
+ throw new StatusRuntimeException(
+ Status.fromCode(Code.CANCELLED)
+ .withCause(e)
+ .withDescription("Interrupted while waiting for quota."));
+ }
+ }
+ }
+
/** Close the stream writer. Shut down all resources. */
@Override
public void close() {
@@ -303,7 +351,7 @@ private void cleanupInflightRequests() {
try {
finalStatus = this.connectionFinalStatus;
while (!this.inflightRequestQueue.isEmpty()) {
- localQueue.addLast(this.inflightRequestQueue.pollFirst());
+ localQueue.addLast(pollInflightRequestQueue());
}
} finally {
this.lock.unlock();
@@ -322,7 +370,7 @@ private void requestCallback(AppendRowsResponse response) {
AppendRequestAndResponse requestWrapper;
this.lock.lock();
try {
- requestWrapper = this.inflightRequestQueue.pollFirst();
+ requestWrapper = pollInflightRequestQueue();
} finally {
this.lock.unlock();
}
@@ -343,6 +391,15 @@ private void doneCallback(Throwable finalStatus) {
}
}
+ @GuardedBy("lock")
+ private AppendRequestAndResponse pollInflightRequestQueue() {
+ AppendRequestAndResponse requestWrapper = this.inflightRequestQueue.pollFirst();
+ --this.inflightRequests;
+ this.inflightBytes -= requestWrapper.messageSize;
+ this.inflightReduced.signal();
+ return requestWrapper;
+ }
+
/** Constructs a new {@link StreamWriterV2.Builder} using the given stream and client. */
public static StreamWriterV2.Builder newBuilder(String streamName, BigQueryWriteClient client) {
return new StreamWriterV2.Builder(streamName, client);
@@ -351,15 +408,33 @@ public static StreamWriterV2.Builder newBuilder(String streamName, BigQueryWrite
/** A builder of {@link StreamWriterV2}s. */
public static final class Builder {
+ private static final long DEFAULT_MAX_INFLIGHT_REQUESTS = 1000L;
+
+ private static final long DEFAULT_MAX_INFLIGHT_BYTES = 100 * 1024 * 1024; // 100Mb.
+
private String streamName;
private BigQueryWriteClient client;
+ private long maxInflightRequest = DEFAULT_MAX_INFLIGHT_REQUESTS;
+
+ private long maxInflightBytes = DEFAULT_MAX_INFLIGHT_BYTES;
+
private Builder(String streamName, BigQueryWriteClient client) {
this.streamName = Preconditions.checkNotNull(streamName);
this.client = Preconditions.checkNotNull(client);
}
+ public Builder setMaxInflightRequests(long value) {
+ this.maxInflightRequest = value;
+ return this;
+ }
+
+ public Builder setMaxInflightBytes(long value) {
+ this.maxInflightBytes = value;
+ return this;
+ }
+
/** Builds the {@code StreamWriterV2}. */
public StreamWriterV2 build() {
return new StreamWriterV2(this);
diff --git a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2Test.java b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2Test.java
index 4d6fba9dcd..bb82e79435 100644
--- a/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2Test.java
+++ b/google-cloud-bigquerystorage/src/test/java/com/google/cloud/bigquery/storage/v1beta2/StreamWriterV2Test.java
@@ -142,6 +142,22 @@ public void run() throws Throwable {
});
}
+ private void verifyAppendIsBlocked(final StreamWriterV2 writer) throws Exception {
+ Thread appendThread =
+ new Thread(
+ new Runnable() {
+ @Override
+ public void run() {
+ sendTestMessage(writer, new String[] {"A"});
+ }
+ });
+ // Start a separate thread to append and verify that it is still alive after 2 seoncds.
+ appendThread.start();
+ TimeUnit.SECONDS.sleep(2);
+ assertTrue(appendThread.isAlive());
+ appendThread.interrupt();
+ }
+
@Test
public void testAppendSuccess() throws Exception {
StreamWriterV2 writer = getTestStreamWriterV2();
@@ -291,6 +307,69 @@ public void serverCloseWhileRequestsInflight() throws Exception {
writer.close();
}
+ @Test
+ public void testZeroMaxInflightRequests() throws Exception {
+ StreamWriterV2 writer =
+ StreamWriterV2.newBuilder(TEST_STREAM, client).setMaxInflightRequests(0).build();
+ testBigQueryWrite.addResponse(createAppendResponse(0));
+ verifyAppendIsBlocked(writer);
+ writer.close();
+ }
+
+ @Test
+ public void testZeroMaxInflightBytes() throws Exception {
+ StreamWriterV2 writer =
+ StreamWriterV2.newBuilder(TEST_STREAM, client).setMaxInflightBytes(0).build();
+ testBigQueryWrite.addResponse(createAppendResponse(0));
+ verifyAppendIsBlocked(writer);
+ writer.close();
+ }
+
+ @Test
+ public void testOneMaxInflightRequests() throws Exception {
+ StreamWriterV2 writer =
+ StreamWriterV2.newBuilder(TEST_STREAM, client).setMaxInflightRequests(1).build();
+ // Server will sleep 1 second before every response.
+ testBigQueryWrite.setResponseSleep(Duration.ofSeconds(1));
+ testBigQueryWrite.addResponse(createAppendResponse(0));
+
+ long appendStartTimeMs = System.currentTimeMillis();
+ ApiFuture appendFuture1 = sendTestMessage(writer, new String[] {"A"});
+ long appendElapsedMs = System.currentTimeMillis() - appendStartTimeMs;
+ assertTrue(appendElapsedMs >= 1000);
+ assertEquals(0, appendFuture1.get().getAppendResult().getOffset().getValue());
+ writer.close();
+ }
+
+ @Test
+ public void testAppendsWithTinyMaxInflightBytes() throws Exception {
+ StreamWriterV2 writer =
+ StreamWriterV2.newBuilder(TEST_STREAM, client).setMaxInflightBytes(1).build();
+ // Server will sleep 100ms before every response.
+ testBigQueryWrite.setResponseSleep(Duration.ofMillis(100));
+ long appendCount = 10;
+ for (int i = 0; i < appendCount; i++) {
+ testBigQueryWrite.addResponse(createAppendResponse(i));
+ }
+
+ List> futures = new ArrayList<>();
+ long appendStartTimeMs = System.currentTimeMillis();
+ for (int i = 0; i < appendCount; i++) {
+ futures.add(writer.append(createAppendRequest(new String[] {String.valueOf(i)}, i)));
+ }
+ long appendElapsedMs = System.currentTimeMillis() - appendStartTimeMs;
+ assertTrue(appendElapsedMs >= 1000);
+
+ for (int i = 0; i < appendCount; i++) {
+ assertEquals(i, futures.get(i).get().getAppendResult().getOffset().getValue());
+ }
+ assertEquals(appendCount, testBigQueryWrite.getAppendRequests().size());
+ for (int i = 0; i < appendCount; i++) {
+ assertEquals(i, testBigQueryWrite.getAppendRequests().get(i).getOffset().getValue());
+ }
+ writer.close();
+ }
+
@Test
public void testMessageTooLarge() {
StreamWriterV2 writer = getTestStreamWriterV2();