diff --git a/google/cloud/bigtable/batcher.py b/google/cloud/bigtable/batcher.py index 6b06ec060..a6eb806e9 100644 --- a/google/cloud/bigtable/batcher.py +++ b/google/cloud/bigtable/batcher.py @@ -192,6 +192,11 @@ class MutationsBatcher(object): :type flush_interval: float :param flush_interval: (Optional) The interval (in seconds) between asynchronous flush. Default is 1 second. + + :type batch_completed_callback: Callable[list:[`~google.rpc.status_pb2.Status`]] = None + :param batch_completed_callback: (Optional) A callable for handling responses + after the current batch is sent. The callable function expect a list of grpc + Status. """ def __init__( @@ -200,6 +205,7 @@ def __init__( flush_count=FLUSH_COUNT, max_row_bytes=MAX_MUTATION_SIZE, flush_interval=1, + batch_completed_callback=None, ): self._rows = _MutationsBatchQueue( max_mutation_bytes=max_row_bytes, flush_count=flush_count @@ -215,6 +221,7 @@ def __init__( ) self.futures_mapping = {} self.exceptions = queue.Queue() + self._user_batch_completed_callback = batch_completed_callback @property def flush_count(self): @@ -337,7 +344,8 @@ def _flush_async(self): batch_info = _BatchInfo() def _batch_completed_callback(self, future): - """Callback for when the mutation has finished. + """Callback for when the mutation has finished to clean up the current batch + and release items from the flow controller. Raise exceptions if there's any. Release the resources locked by the flow control and allow enqueued tasks to be run. @@ -357,6 +365,9 @@ def _flush_rows(self, rows_to_flush): if len(rows_to_flush) > 0: response = self.table.mutate_rows(rows_to_flush) + if self._user_batch_completed_callback: + self._user_batch_completed_callback(response) + for result in response: if result.code != 0: exc = from_grpc_status(result.code, result.message) diff --git a/tests/unit/test_batcher.py b/tests/unit/test_batcher.py index a238b2852..998748141 100644 --- a/tests/unit/test_batcher.py +++ b/tests/unit/test_batcher.py @@ -35,6 +35,27 @@ def test_mutation_batcher_constructor(): assert table is mutation_batcher.table +def test_mutation_batcher_w_user_callback(): + table = _Table(TABLE_NAME) + + def callback_fn(response): + callback_fn.count = len(response) + + with MutationsBatcher( + table, flush_count=1, batch_completed_callback=callback_fn + ) as mutation_batcher: + rows = [ + DirectRow(row_key=b"row_key"), + DirectRow(row_key=b"row_key_2"), + DirectRow(row_key=b"row_key_3"), + DirectRow(row_key=b"row_key_4"), + ] + + mutation_batcher.mutate_rows(rows) + + assert callback_fn.count == 4 + + def test_mutation_batcher_mutate_row(): table = _Table(TABLE_NAME) with MutationsBatcher(table=table) as mutation_batcher: