From 33facf558650e221341f37a4b186c6b0323ef427 Mon Sep 17 00:00:00 2001 From: ron-gal <125445217+ron-gal@users.noreply.github.com> Date: Mon, 15 Apr 2024 14:20:05 -0400 Subject: [PATCH] fix: Add paging to hbase client (#4166) * Added paging to hbase client * minor fixes * minor fixes * Fix tests * Fix tests * Fix lint * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * test * Brought back test and fixed page size handling * fixed test * add tests * minor refactor * minor refactor * add error handling tests * fix format * Add protection against OOM exceptions * Add protection against OOM exceptions * Remove useless assertion * handle setCaching properly * handle setCaching properly * handle setCaching properly * handle setCaching properly * handle setCaching properly * handle setCaching properly * remove useless code * Get page size directly from the paginator * fix lint * cancel serverStream when reaching memory limit * add test for low memory * add test for low memory * fix lint * update java-bigtable dependency * Fixed several PR comments * Fixed several PR comments * Fixed several PR comments * Moved to async API * Fixed several PR comments * Fixed several PR comments * Fixed several PR comments * Fixed several PR comments * Fixed several PR comments * Fixed several PR comments * Fixed according to PR comments * Fix wrong advance * Fixed according to PR * Fixed according to PR * Fixed according to PR * remove test * adjust tests according to PR * adjust tests according to PR * fix bug found on beam --- .../google/cloud/bigtable/hbase/TestScan.java | 36 +++- .../bigtable/hbase/AbstractBigtableTable.java | 19 +- .../hbase/wrappers/DataClientWrapper.java | 6 + .../wrappers/veneer/DataClientVeneerApi.java | 160 +++++++++++++- .../veneer/SharedDataClientWrapper.java | 5 + .../veneer/TestDataClientVeneerApi.java | 196 +++++++++++++++++- 6 files changed, 413 insertions(+), 9 deletions(-) diff --git a/bigtable-client-core-parent/bigtable-hbase-integration-tests-common/src/test/java/com/google/cloud/bigtable/hbase/TestScan.java b/bigtable-client-core-parent/bigtable-hbase-integration-tests-common/src/test/java/com/google/cloud/bigtable/hbase/TestScan.java index 3ca91c866a..4160b8d95c 100644 --- a/bigtable-client-core-parent/bigtable-hbase-integration-tests-common/src/test/java/com/google/cloud/bigtable/hbase/TestScan.java +++ b/bigtable-client-core-parent/bigtable-hbase-integration-tests-common/src/test/java/com/google/cloud/bigtable/hbase/TestScan.java @@ -177,9 +177,33 @@ public void testGetScannerNoQualifiers() throws IOException { } @Test - public void test100ResultsInScanner() throws IOException { + public void testManyResultsInScanner_lessThanPageSize() throws IOException { + testManyResultsInScanner(95, true); + } + + @Test + public void testManyResultsInScanner_equalToPageSize() throws IOException { + testManyResultsInScanner(100, true); + } + + @Test + public void testManyResultsInScanner_greaterThanPageSize() throws IOException { + testManyResultsInScanner(105, true); + } + + @Test + public void testManyResultsInScanner_greaterThanTwoPageSizes() throws IOException { + testManyResultsInScanner(205, true); + } + + @Test + public void testManyResultsInScanner_onePageSizeNoPagination() throws IOException { + testManyResultsInScanner(100, false); + } + + private void testManyResultsInScanner(int rowsToWrite, boolean withPagination) + throws IOException { String prefix = "scan_row_"; - int rowsToWrite = 100; // Initialize variables Table table = getDefaultTable(); @@ -208,9 +232,13 @@ public void test100ResultsInScanner() throws IOException { Scan scan = new Scan(); scan.withStartRow(rowKeys[0]) - .withStopRow(rowFollowing(rowKeys[rowsToWrite - 1])) + .withStopRow(rowFollowingSameLength(rowKeys[rowsToWrite - 1])) .addFamily(COLUMN_FAMILY); + if (withPagination) { + scan = scan.setCaching(100); + } + try (ResultScanner resultScanner = table.getScanner(scan)) { for (int rowIndex = 0; rowIndex < rowsToWrite; rowIndex++) { Result result = resultScanner.next(); @@ -275,7 +303,7 @@ public void testScanDelete() throws IOException { Scan scan = new Scan(); scan.withStartRow(rowKeys[0]) - .withStopRow(rowFollowing(rowKeys[rowsToWrite - 1])) + .withStopRow(rowFollowingSameLength(rowKeys[rowsToWrite - 1])) .addFamily(COLUMN_FAMILY); int deleteCount = 0; try (ResultScanner resultScanner = table.getScanner(scan)) { diff --git a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/AbstractBigtableTable.java b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/AbstractBigtableTable.java index 2f16587f34..2d13be7f17 100644 --- a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/AbstractBigtableTable.java +++ b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/AbstractBigtableTable.java @@ -18,6 +18,7 @@ import com.google.api.core.InternalApi; import com.google.cloud.bigtable.data.v2.models.ConditionalRowMutation; import com.google.cloud.bigtable.data.v2.models.Filters; +import com.google.cloud.bigtable.data.v2.models.Query; import com.google.cloud.bigtable.data.v2.models.ReadModifyWriteRow; import com.google.cloud.bigtable.data.v2.models.RowMutation; import com.google.cloud.bigtable.hbase.adapters.Adapters; @@ -92,6 +93,14 @@ public abstract class AbstractBigtableTable implements Table { private static final Tracer TRACER = Tracing.getTracer(); + private static final int MIN_BYTE_BUFFER_SIZE = 100 * 1024 * 1024; + private static final double DEFAULT_BYTE_LIMIT_PERCENTAGE = .1; + private static final long DEFAULT_MAX_SEGMENT_SIZE = + (long) + Math.max( + MIN_BYTE_BUFFER_SIZE, + (Runtime.getRuntime().totalMemory() * DEFAULT_BYTE_LIMIT_PERCENTAGE)); + private static class TableMetrics { Timer putTimer = BigtableClientMetrics.timer(MetricLevel.Info, "table.put.latency"); Timer getTimer = BigtableClientMetrics.timer(MetricLevel.Info, "table.get.latency"); @@ -295,8 +304,14 @@ public ResultScanner getScanner(final Scan scan) throws IOException { LOG.trace("getScanner(Scan)"); Span span = TRACER.spanBuilder("BigtableTable.scan").startSpan(); try (Scope scope = TRACER.withSpan(span)) { - - final ResultScanner scanner = clientWrapper.readRows(hbaseAdapter.adapt(scan)); + ResultScanner scanner; + if (scan.getCaching() == -1) { + scanner = clientWrapper.readRows(hbaseAdapter.adapt(scan)); + } else { + Query.QueryPaginator paginator = + hbaseAdapter.adapt(scan).createPaginator(scan.getCaching()); + scanner = clientWrapper.readRows(paginator, DEFAULT_MAX_SEGMENT_SIZE); + } if (hasWhileMatchFilter(scan.getFilter())) { return Adapters.BIGTABLE_WHILE_MATCH_RESULT_RESULT_SCAN_ADAPTER.adapt(scanner, span); } diff --git a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/DataClientWrapper.java b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/DataClientWrapper.java index d845aea974..db246162e7 100644 --- a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/DataClientWrapper.java +++ b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/DataClientWrapper.java @@ -91,4 +91,10 @@ ApiFuture readRowAsync( @Override void close() throws IOException; + + /** + * Perform a scan over {@link Result}s, in key order, using a paginator. maxSegmentByteSize is + * used for testing purposes only. + */ + ResultScanner readRows(Query.QueryPaginator paginator, long maxSegmentByteSize); } diff --git a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/DataClientVeneerApi.java b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/DataClientVeneerApi.java index 506d6ccbad..ae1735a19e 100644 --- a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/DataClientVeneerApi.java +++ b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/DataClientVeneerApi.java @@ -21,6 +21,7 @@ import com.google.api.core.InternalApi; import com.google.api.gax.grpc.GrpcCallContext; import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.StateCheckingResponseObserver; import com.google.api.gax.rpc.StreamController; @@ -43,12 +44,18 @@ import com.google.cloud.bigtable.metrics.Timer; import com.google.cloud.bigtable.metrics.Timer.Context; import com.google.common.util.concurrent.MoreExecutors; +import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.ByteString; import io.grpc.CallOptions; import io.grpc.Deadline; import io.grpc.stub.StreamObserver; +import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Queue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import org.apache.hadoop.hbase.client.AbstractClientScanner; @@ -134,6 +141,12 @@ public Result apply(Row row) { MoreExecutors.directExecutor()); } + @Override + public ResultScanner readRows(Query.QueryPaginator paginator, long maxSegmentByteSize) { + return new PaginatedRowResultScanner( + paginator, delegate, maxSegmentByteSize, this.createScanCallContext()); + } + @Override public ResultScanner readRows(Query request) { return new RowResultScanner( @@ -228,6 +241,151 @@ protected void onCompleteImpl() { } } + /** + * wraps {@link ServerStream} onto HBase {@link ResultScanner}. {@link PaginatedRowResultScanner} + * gets a paginator and a {@link Query.QueryPaginator} used to get a {@link ServerStream}<{@link + * Result}> using said paginator to iterate over pages of rows. The {@link Query.QueryPaginator} + * pageSize property indicates the size of each page in every API call. A cache of a maximum size + * of 1.1*pageSize and a minimum of 0.1*pageSize is held at all times. In order to avoid OOM + * exceptions, there is a limit for the total byte size held in cache. + */ + static class PaginatedRowResultScanner extends AbstractClientScanner { + // Percentage of max number of rows allowed in the buffer + private static final double WATERMARK_PERCENTAGE = .1; + private static final RowResultAdapter RESULT_ADAPTER = new RowResultAdapter(); + + private final Meter scannerResultMeter = + BigtableClientMetrics.meter(BigtableClientMetrics.MetricLevel.Info, "scanner.results"); + private final Timer scannerResultTimer = + BigtableClientMetrics.timer( + BigtableClientMetrics.MetricLevel.Debug, "scanner.results.latency"); + + private ByteString lastSeenRowKey = ByteString.EMPTY; + private Boolean hasMore = true; + private final Queue buffer; + private final Query.QueryPaginator paginator; + private final int refillSegmentWaterMark; + + private final BigtableDataClient dataClient; + + private final long maxSegmentByteSize; + + private long currentByteSize = 0; + + private @Nullable Future> future; + private GrpcCallContext scanCallContext; + + PaginatedRowResultScanner( + Query.QueryPaginator paginator, + BigtableDataClient dataClient, + long maxSegmentByteSize, + GrpcCallContext scanCallContext) { + this.maxSegmentByteSize = maxSegmentByteSize; + + this.paginator = paginator; + this.dataClient = dataClient; + this.buffer = new ArrayDeque<>(); + this.refillSegmentWaterMark = + (int) Math.max(1, paginator.getPageSize() * WATERMARK_PERCENTAGE); + this.scanCallContext = scanCallContext; + this.future = fetchNextSegment(); + } + + @Override + public Result next() { + try (Context ignored = scannerResultTimer.time()) { + if (this.future != null && this.future.isDone()) { + this.consumeReadRowsFuture(); + } + if (this.buffer.size() < this.refillSegmentWaterMark && this.future == null && hasMore) { + future = fetchNextSegment(); + } + if (this.buffer.isEmpty() && this.future != null) { + this.consumeReadRowsFuture(); + } + Result result = this.buffer.poll(); + if (result != null) { + scannerResultMeter.mark(); + currentByteSize -= Result.getTotalSizeOfCells(result); + } + return result; + } + } + + @Override + public void close() { + if (this.future != null) { + this.future.cancel(true); + } + } + + public boolean renewLease() { + return true; + } + + private Future> fetchNextSegment() { + SettableFuture> resultsFuture = SettableFuture.create(); + + dataClient + .readRowsCallable(RESULT_ADAPTER) + .call( + paginator.getNextQuery(), + new ResponseObserver() { + private StreamController controller; + List results = new ArrayList(); + + @Override + public void onStart(StreamController controller) { + this.controller = controller; + } + + @Override + public void onResponse(Result result) { + // calculate size of the response + currentByteSize += Result.getTotalSizeOfCells(result); + results.add(result); + if (result != null && result.rawCells() != null) { + lastSeenRowKey = RESULT_ADAPTER.getKey(result); + } + + if (currentByteSize > maxSegmentByteSize) { + controller.cancel(); + return; + } + } + + @Override + public void onError(Throwable t) { + if (currentByteSize > maxSegmentByteSize) { + onComplete(); + } else { + resultsFuture.setException(t); + } + } + + @Override + public void onComplete() { + resultsFuture.set(results); + } + }, + this.scanCallContext); + return resultsFuture; + } + + private void consumeReadRowsFuture() { + try { + List results = this.future.get(); + this.buffer.addAll(results); + this.hasMore = this.paginator.advance(this.lastSeenRowKey); + this.future = null; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (ExecutionException e) { + // Do nothing. + } + } + } + /** wraps {@link ServerStream} onto HBase {@link ResultScanner}. */ private static class RowResultScanner extends AbstractClientScanner { @@ -264,7 +422,7 @@ public void close() { } public boolean renewLease() { - throw new UnsupportedOperationException("renewLease"); + return true; } } } diff --git a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/SharedDataClientWrapper.java b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/SharedDataClientWrapper.java index 0a48767b03..df8b168006 100644 --- a/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/SharedDataClientWrapper.java +++ b/bigtable-client-core-parent/bigtable-hbase/src/main/java/com/google/cloud/bigtable/hbase/wrappers/veneer/SharedDataClientWrapper.java @@ -108,4 +108,9 @@ public void close() throws IOException { delegate.close(); owner.release(key); } + + @Override + public ResultScanner readRows(Query.QueryPaginator paginator, long maxSegmentByteSize) { + return delegate.readRows(paginator, maxSegmentByteSize); + } } diff --git a/bigtable-client-core-parent/bigtable-hbase/src/test/java/com/google/cloud/bigtable/hbase/wrappers/veneer/TestDataClientVeneerApi.java b/bigtable-client-core-parent/bigtable-hbase/src/test/java/com/google/cloud/bigtable/hbase/wrappers/veneer/TestDataClientVeneerApi.java index fa47ed9325..02260e1d13 100644 --- a/bigtable-client-core-parent/bigtable-hbase/src/test/java/com/google/cloud/bigtable/hbase/wrappers/veneer/TestDataClientVeneerApi.java +++ b/bigtable-client-core-parent/bigtable-hbase/src/test/java/com/google/cloud/bigtable/hbase/wrappers/veneer/TestDataClientVeneerApi.java @@ -17,6 +17,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; @@ -27,15 +28,19 @@ import com.google.api.core.ApiFutures; import com.google.api.gax.batching.Batcher; import com.google.api.gax.grpc.GrpcCallContext; +import com.google.api.gax.grpc.GrpcStatusCode; +import com.google.api.gax.rpc.InternalException; import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.ServerStream; import com.google.api.gax.rpc.ServerStreamingCallable; +import com.google.api.gax.rpc.StreamController; import com.google.api.gax.rpc.UnaryCallable; import com.google.cloud.bigtable.data.v2.BigtableDataClient; import com.google.cloud.bigtable.data.v2.models.ConditionalRowMutation; import com.google.cloud.bigtable.data.v2.models.Filters.Filter; import com.google.cloud.bigtable.data.v2.models.KeyOffset; import com.google.cloud.bigtable.data.v2.models.Query; +import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange; import com.google.cloud.bigtable.data.v2.models.ReadModifyWriteRow; import com.google.cloud.bigtable.data.v2.models.Row; import com.google.cloud.bigtable.data.v2.models.RowCell; @@ -45,11 +50,14 @@ import com.google.cloud.bigtable.hbase.wrappers.BulkReadWrapper; import com.google.cloud.bigtable.hbase.wrappers.veneer.BigtableHBaseVeneerSettings.ClientOperationTimeouts; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import com.google.protobuf.ByteString; +import io.grpc.Status.Code; import io.grpc.stub.StreamObserver; import java.io.IOException; import java.util.Iterator; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.hadoop.hbase.Cell; import org.apache.hadoop.hbase.client.Result; import org.apache.hadoop.hbase.client.ResultScanner; @@ -74,6 +82,8 @@ public class TestDataClientVeneerApi { private static final String TABLE_ID = "fake-table"; private static final ByteString ROW_KEY = ByteString.copyFromUtf8("row-key"); + private AtomicBoolean cancelled = new AtomicBoolean(false); + private static final Row MODEL_ROW = Row.create( ROW_KEY, @@ -194,6 +204,41 @@ public void testReadRowAsync() throws Exception { .futureCall(Mockito.eq(expectedRequest), Mockito.any(GrpcCallContext.class)); } + @Test + public void testReadRows_Errors() throws IOException { + Query query = Query.create(TABLE_ID).rowKey(ROW_KEY); + when(mockDataClient.readRowsCallable(Mockito.any(RowResultAdapter.class))) + .thenReturn(mockStreamingCallable); + when(mockStreamingCallable.call(Mockito.any(Query.class), Mockito.any(GrpcCallContext.class))) + .thenReturn(serverStream); + when(serverStream.iterator()) + .thenReturn( + new Iterator() { + @Override + public boolean hasNext() { + return true; + } + + @Override + public Result next() { + throw new InternalException( + "fake error", null, GrpcStatusCode.of(Code.INTERNAL), false); + } + }) + .thenReturn(ImmutableList.of().iterator()); + + assertThrows(Exception.class, () -> dataClientWrapper.readRows(query).next()); + + ResultScanner noRowsResultScanner = dataClientWrapper.readRows(query); + assertNull(noRowsResultScanner.next()); + noRowsResultScanner.close(); + + verify(mockDataClient, times(2)).readRowsCallable(Mockito.any()); + verify(serverStream, times(2)).iterator(); + verify(mockStreamingCallable, times(2)) + .call(Mockito.any(Query.class), Mockito.any(GrpcCallContext.class)); + } + @Test public void testReadRows() throws IOException { Query query = Query.create(TABLE_ID).rowKey(ROW_KEY); @@ -225,9 +270,156 @@ public void testReadRows() throws IOException { .call(Mockito.eq(query), Mockito.any(GrpcCallContext.class)); } + private static Result createRow(String key) { + return Result.create( + ImmutableList.of( + new com.google.cloud.bigtable.hbase.adapters.read.RowCell( + Bytes.toBytes(key), + Bytes.toBytes("cf"), + Bytes.toBytes("q"), + 10L, + Bytes.toBytes("value"), + ImmutableList.of("label")))); + } + @Test - public void testReadRowsCancel() throws IOException { + public void testReadPaginatedRows() throws IOException { + Query query = Query.create(TABLE_ID).range("a", "z"); + when(mockDataClient.readRowsCallable(Mockito.any())) + .thenReturn(mockStreamingCallable); + + // First Page + doAnswer( + (args) -> { + ResponseObserver observer = args.getArgument(1); + observer.onResponse(createRow("a")); + observer.onResponse(createRow("b")); + observer.onComplete(); + return null; + }) + .when(mockStreamingCallable) + .call( + Mockito.eq(Query.create(TABLE_ID).range("a", "z").limit(2)), + Mockito.any(), + Mockito.any()); + // 2nd Page + doAnswer( + (args) -> { + ResponseObserver observer = args.getArgument(1); + observer.onResponse(createRow("c")); + observer.onResponse(createRow("d")); + observer.onComplete(); + return null; + }) + .when(mockStreamingCallable) + .call( + Mockito.eq( + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startOpen("b").endOpen("z")) + .limit(2)), + Mockito.any(), + Mockito.any()); + + // 3rd Page + doAnswer( + (args) -> { + ResponseObserver observer = args.getArgument(1); + observer.onResponse(createRow("e")); + observer.onComplete(); + return null; + }) + .when(mockStreamingCallable) + .call( + Mockito.eq( + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startOpen("d").endOpen("z")) + .limit(2)), + Mockito.any(), + Mockito.any()); + + // 3rd Page + doAnswer( + (args) -> { + ResponseObserver observer = args.getArgument(1); + observer.onComplete(); + return null; + }) + .when(mockStreamingCallable) + .call( + Mockito.eq( + Query.create(TABLE_ID) + .range(ByteStringRange.unbounded().startOpen("e").endOpen("z")) + .limit(2)), + Mockito.any(), + Mockito.any()); + + ResultScanner resultScanner = dataClientWrapper.readRows(query.createPaginator(2), 1000); + + assertResult(createRow("a"), resultScanner.next()); + assertResult(createRow("b"), resultScanner.next()); + assertResult(createRow("c"), resultScanner.next()); + assertResult(createRow("d"), resultScanner.next()); + assertResult(createRow("e"), resultScanner.next()); + assertNull(resultScanner.next()); + } + + @Test + public void testReadRowsLowMemory() throws IOException { + Query query = Query.create(TABLE_ID); + when(mockDataClient.readRowsCallable(Mockito.any(RowResultAdapter.class))) + .thenReturn(mockStreamingCallable); + + StreamController mockController = Mockito.mock(StreamController.class); + doAnswer( + invocation -> { + cancelled.set(true); + return null; + }) + .when(mockController) + .cancel(); + + // Generate + doAnswer( + (Answer) + invocation -> { + ResponseObserver observer = invocation.getArgument(1); + observer.onStart(mockController); + + for (int i = 0; i < 1000 && !cancelled.get(); i++) { + observer.onResponse(createRow(String.format("row%010d", i))); + Thread.sleep(10); + } + observer.onComplete(); + return null; + }) + .doAnswer( + (Answer) + invocation -> { + ResponseObserver observer = invocation.getArgument(1); + observer.onComplete(); + return null; + }) + .when(mockStreamingCallable) + .call( + Mockito.any(Query.class), + Mockito.any(ResponseObserver.class), + Mockito.any(GrpcCallContext.class)); + + ResultScanner resultScanner = dataClientWrapper.readRows(query.createPaginator(100), 3); + // Consume the stream + Lists.newArrayList(resultScanner); + + verify(mockStreamingCallable, times(2)) + .call( + Mockito.any(Query.class), + Mockito.any(ResponseObserver.class), + Mockito.any(GrpcCallContext.class)); + assertTrue(cancelled.get()); + } + + @Test + public void testReadRowsCancel() throws IOException { Query query = Query.create(TABLE_ID).rowKey(ROW_KEY); when(mockDataClient.readRowsCallable(Mockito.any())) .thenReturn(mockStreamingCallable) @@ -247,7 +439,7 @@ public void testReadRowsCancel() throws IOException { doNothing().when(serverStream).cancel(); resultScanner.close(); - // make sure that the scanner doesn't iteract with the iterator on close + // make sure that the scanner doesn't interact with the iterator on close verify(serverStream).cancel(); verify(mockIter, times(1)).hasNext(); verify(mockIter, times(1)).next();