diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 5bf759151..5971a9894 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Sequence, TYPE_CHECKING import asyncio from dataclasses import dataclass import functools @@ -66,6 +66,7 @@ def __init__( mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, + retryable_exceptions: Sequence[type[Exception]] = (), ): """ Args: @@ -96,8 +97,7 @@ def __init__( # create predicate for determining which errors are retryable self.is_retryable = retries.if_exception_type( # RPC level errors - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, + *retryable_exceptions, # Entry level errors bt_exceptions._MutateRowsIncomplete, ) diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 90cc7e87c..ad1f7b84d 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -15,7 +15,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterable, Awaitable +from typing import ( + TYPE_CHECKING, + AsyncGenerator, + AsyncIterable, + Awaitable, + Sequence, +) from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -74,6 +80,7 @@ def __init__( table: "TableAsync", operation_timeout: float, attempt_timeout: float, + retryable_exceptions: Sequence[type[Exception]] = (), ): self.attempt_timeout_gen = _attempt_timeout_generator( attempt_timeout, operation_timeout @@ -88,11 +95,7 @@ def __init__( else: self.request = query._to_pb(table) self.table = table - self._predicate = retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - core_exceptions.Aborted, - ) + self._predicate = retries.if_exception_type(*retryable_exceptions) self._metadata = _make_metadata( table.table_name, table.app_profile_id, diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index ab8cc48f8..a79ead7f8 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -21,6 +21,7 @@ AsyncIterable, Optional, Set, + Sequence, TYPE_CHECKING, ) @@ -45,7 +46,9 @@ from google.api_core.exceptions import GoogleAPICallError from google.cloud.environment_vars import BIGTABLE_EMULATOR # type: ignore from google.api_core import retry_async as retries -from google.api_core import exceptions as core_exceptions +from google.api_core.exceptions import DeadlineExceeded +from google.api_core.exceptions import ServiceUnavailable +from google.api_core.exceptions import Aborted from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync import google.auth.credentials @@ -64,6 +67,7 @@ from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _convert_retry_deadline from google.cloud.bigtable.data._helpers import _validate_timeouts +from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync @@ -366,21 +370,10 @@ async def _remove_instance_registration( except KeyError: return False - def get_table( - self, - instance_id: str, - table_id: str, - app_profile_id: str | None = None, - *, - default_read_rows_operation_timeout: float = 600, - default_read_rows_attempt_timeout: float | None = None, - default_mutate_rows_operation_timeout: float = 600, - default_mutate_rows_attempt_timeout: float | None = None, - default_operation_timeout: float = 60, - default_attempt_timeout: float | None = None, - ) -> TableAsync: + def get_table(self, instance_id: str, table_id: str, *args, **kwargs) -> TableAsync: """ - Returns a table instance for making data API requests + Returns a table instance for making data API requests. All arguments are passed + directly to the TableAsync constructor. Args: instance_id: The Bigtable instance ID to associate with this client. @@ -402,15 +395,17 @@ def get_table( seconds. If not set, defaults to 60 seconds default_attempt_timeout: The default timeout for all other individual rpc requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) """ - return TableAsync( - self, - instance_id, - table_id, - app_profile_id, - default_operation_timeout=default_operation_timeout, - default_attempt_timeout=default_attempt_timeout, - ) + return TableAsync(self, instance_id, table_id, *args, **kwargs) async def __aenter__(self): self._start_background_channel_refresh() @@ -442,6 +437,19 @@ def __init__( default_mutate_rows_attempt_timeout: float | None = 60, default_operation_timeout: float = 60, default_attempt_timeout: float | None = 20, + default_read_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + Aborted, + ), + default_mutate_rows_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), + default_retryable_errors: Sequence[type[Exception]] = ( + DeadlineExceeded, + ServiceUnavailable, + ), ): """ Initialize a Table instance @@ -468,9 +476,20 @@ def __init__( seconds. If not set, defaults to 60 seconds default_attempt_timeout: The default timeout for all other individual rpc requests, in seconds. If not set, defaults to 20 seconds + default_read_rows_retryable_errors: a list of errors that will be retried + if encountered during read_rows and related operations. + Defaults to 4 (DeadlineExceeded), 14 (ServiceUnavailable), and 10 (Aborted) + default_mutate_rows_retryable_errors: a list of errors that will be retried + if encountered during mutate_rows and related operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) + default_retryable_errors: a list of errors that will be retried if + encountered during all other operations. + Defaults to 4 (DeadlineExceeded) and 14 (ServiceUnavailable) Raises: - RuntimeError if called outside of an async context (no running event loop) """ + # NOTE: any changes to the signature of this method should also be reflected + # in client.get_table() # validate timeouts _validate_timeouts( default_operation_timeout, default_attempt_timeout, allow_none=True @@ -506,6 +525,14 @@ def __init__( ) self.default_mutate_rows_attempt_timeout = default_mutate_rows_attempt_timeout + self.default_read_rows_retryable_errors = ( + default_read_rows_retryable_errors or () + ) + self.default_mutate_rows_retryable_errors = ( + default_mutate_rows_retryable_errors or () + ) + self.default_retryable_errors = default_retryable_errors or () + # raises RuntimeError if called outside of an async context (no running event loop) try: self._register_instance_task = asyncio.create_task( @@ -522,12 +549,15 @@ async def read_rows_stream( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> AsyncIterable[Row]: """ Read a set of rows from the table, based on the specified query. Returns an iterator to asynchronously stream back row data. - Failed requests within operation_timeout will be retried. + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. Args: - query: contains details about which rows to return @@ -539,6 +569,8 @@ async def read_rows_stream( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors Returns: - an asynchronous iterator that yields rows returned by the query Raises: @@ -551,12 +583,14 @@ async def read_rows_stream( operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) + retryable_excs = _get_retryable_errors(retryable_errors, self) row_merger = _ReadRowsOperationAsync( query, self, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + retryable_exceptions=retryable_excs, ) return row_merger.start_operation() @@ -566,13 +600,16 @@ async def read_rows( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """ Read a set of rows from the table, based on the specified query. Retruns results as a list of Row objects when the request is complete. For streamed results, use read_rows_stream. - Failed requests within operation_timeout will be retried. + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. Args: - query: contains details about which rows to return @@ -584,6 +621,10 @@ async def read_rows( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. + If None, defaults to the Table's default_read_rows_attempt_timeout, + or the operation_timeout if that is also None. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. Returns: - a list of Rows returned by the query Raises: @@ -596,6 +637,7 @@ async def read_rows( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) return [row async for row in row_generator] @@ -606,11 +648,14 @@ async def read_row( row_filter: RowFilter | None = None, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> Row | None: """ Read a single row from the table, based on the specified key. - Failed requests within operation_timeout will be retried. + Failed requests within operation_timeout will be retried based on the + retryable_errors list until operation_timeout is reached. Args: - query: contains details about which rows to return @@ -622,6 +667,8 @@ async def read_row( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. Returns: - a Row object if the row exists, otherwise None Raises: @@ -637,6 +684,7 @@ async def read_row( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) if len(results) == 0: return None @@ -648,6 +696,8 @@ async def read_rows_sharded( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> list[Row]: """ Runs a sharded query in parallel, then return the results in a single list. @@ -672,6 +722,8 @@ async def read_rows_sharded( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. Raises: - ShardedReadRowsExceptionGroup: if any of the queries failed - ValueError: if the query_list is empty @@ -701,6 +753,7 @@ async def read_rows_sharded( query, operation_timeout=batch_operation_timeout, attempt_timeout=min(attempt_timeout, batch_operation_timeout), + retryable_errors=retryable_errors, ) for query in batch ] @@ -729,10 +782,13 @@ async def row_exists( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, ) -> bool: """ Return a boolean indicating whether the specified row exists in the table. uses the filters: chain(limit cells per row = 1, strip value) + Args: - row_key: the key of the row to check - operation_timeout: the time budget for the entire operation, in seconds. @@ -743,6 +799,8 @@ async def row_exists( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_read_rows_attempt_timeout. If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_read_rows_retryable_errors. Returns: - a bool indicating whether the row exists Raises: @@ -762,6 +820,7 @@ async def row_exists( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + retryable_errors=retryable_errors, ) return len(results) > 0 @@ -770,6 +829,8 @@ async def sample_row_keys( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ) -> RowKeySamples: """ Return a set of RowKeySamples that delimit contiguous sections of the table of @@ -791,6 +852,8 @@ async def sample_row_keys( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_attempt_timeout. If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_retryable_errors. Returns: - a set of RowKeySamples the delimit contiguous sections of the table Raises: @@ -807,10 +870,8 @@ async def sample_row_keys( attempt_timeout, operation_timeout ) # prepare retryable - predicate = retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, - ) + retryable_excs = _get_retryable_errors(retryable_errors, self) + predicate = retries.if_exception_type(*retryable_excs) transient_errors = [] def on_error_fn(exc): @@ -856,6 +917,8 @@ def mutations_batcher( flow_control_max_bytes: int = 100 * _MB_SIZE, batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ) -> MutationsBatcherAsync: """ Returns a new mutations batcher instance. @@ -876,6 +939,8 @@ def mutations_batcher( - batch_attempt_timeout: timeout for each individual request, in seconds. Defaults to the Table's default_mutate_rows_attempt_timeout. If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. Returns: - a MutationsBatcherAsync context manager that can batch requests """ @@ -888,6 +953,7 @@ def mutations_batcher( flow_control_max_bytes=flow_control_max_bytes, batch_operation_timeout=batch_operation_timeout, batch_attempt_timeout=batch_attempt_timeout, + batch_retryable_errors=batch_retryable_errors, ) async def mutate_row( @@ -897,6 +963,8 @@ async def mutate_row( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.DEFAULT, ): """ Mutates a row atomically. @@ -918,6 +986,9 @@ async def mutate_row( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_attempt_timeout. If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Only idempotent mutations will be retried. Defaults to the Table's + default_retryable_errors. Raises: - DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing all @@ -937,8 +1008,7 @@ async def mutate_row( if all(mutation.is_idempotent() for mutation in mutations_list): # mutations are all idempotent and safe to retry predicate = retries.if_exception_type( - core_exceptions.DeadlineExceeded, - core_exceptions.ServiceUnavailable, + *_get_retryable_errors(retryable_errors, self) ) else: # mutations should not be retried @@ -982,6 +1052,8 @@ async def bulk_mutate_rows( *, operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): """ Applies mutations for multiple rows in a single batched request. @@ -1007,6 +1079,8 @@ async def bulk_mutate_rows( a DeadlineExceeded exception, and a retry will be attempted. Defaults to the Table's default_mutate_rows_attempt_timeout. If None, defaults to operation_timeout. + - retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors Raises: - MutationsExceptionGroup if one or more mutations fails Contains details about any failed entries in .exceptions @@ -1015,6 +1089,7 @@ async def bulk_mutate_rows( operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) + retryable_excs = _get_retryable_errors(retryable_errors, self) operation = _MutateRowsOperationAsync( self.client._gapic_client, @@ -1022,6 +1097,7 @@ async def bulk_mutate_rows( mutation_entries, operation_timeout, attempt_timeout, + retryable_exceptions=retryable_excs, ) await operation.start() diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 91d2b11e1..b2da30040 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -14,7 +14,7 @@ # from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING import asyncio import atexit import warnings @@ -23,6 +23,7 @@ from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import TABLE_DEFAULT @@ -192,6 +193,8 @@ def __init__( flow_control_max_bytes: int = 100 * _MB_SIZE, batch_operation_timeout: float | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, batch_attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, + batch_retryable_errors: Sequence[type[Exception]] + | TABLE_DEFAULT = TABLE_DEFAULT.MUTATE_ROWS, ): """ Args: @@ -208,10 +211,16 @@ def __init__( - batch_attempt_timeout: timeout for each individual request, in seconds. If TABLE_DEFAULT, defaults to the Table's default_mutate_rows_attempt_timeout. If None, defaults to batch_operation_timeout. + - batch_retryable_errors: a list of errors that will be retried if encountered. + Defaults to the Table's default_mutate_rows_retryable_errors. """ self._operation_timeout, self._attempt_timeout = _get_timeouts( batch_operation_timeout, batch_attempt_timeout, table ) + self._retryable_errors: list[type[Exception]] = _get_retryable_errors( + batch_retryable_errors, table + ) + self.closed: bool = False self._table = table self._staged_entries: list[RowMutationEntry] = [] @@ -349,6 +358,7 @@ async def _execute_mutate_rows( batch, operation_timeout=self._operation_timeout, attempt_timeout=self._attempt_timeout, + retryable_exceptions=self._retryable_errors, ) await operation.start() except MutationsExceptionGroup as e: diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 1d56926ff..96ea1d1ce 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -11,9 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # +""" +Helper functions used in various places in the library. +""" from __future__ import annotations -from typing import Callable, List, Tuple, Any +from typing import Callable, Sequence, List, Tuple, Any, TYPE_CHECKING import time import enum from collections import namedtuple @@ -22,6 +25,10 @@ from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import RetryExceptionGroup +if TYPE_CHECKING: + import grpc + from google.cloud.bigtable.data import TableAsync + """ Helper functions used in various places in the library. """ @@ -142,7 +149,9 @@ def wrapper(*args, **kwargs): def _get_timeouts( - operation: float | TABLE_DEFAULT, attempt: float | None | TABLE_DEFAULT, table + operation: float | TABLE_DEFAULT, + attempt: float | None | TABLE_DEFAULT, + table: "TableAsync", ) -> tuple[float, float]: """ Convert passed in timeout values to floats, using table defaults if necessary. @@ -209,3 +218,21 @@ def _validate_timeouts( elif attempt_timeout is not None: if attempt_timeout <= 0: raise ValueError("attempt_timeout must be greater than 0") + + +def _get_retryable_errors( + call_codes: Sequence["grpc.StatusCode" | int | type[Exception]] | TABLE_DEFAULT, + table: "TableAsync", +) -> list[type[Exception]]: + # load table defaults if necessary + if call_codes == TABLE_DEFAULT.DEFAULT: + call_codes = table.default_retryable_errors + elif call_codes == TABLE_DEFAULT.READ_ROWS: + call_codes = table.default_read_rows_retryable_errors + elif call_codes == TABLE_DEFAULT.MUTATE_ROWS: + call_codes = table.default_mutate_rows_retryable_errors + + return [ + e if isinstance(e, type) else type(core_exceptions.from_grpc_status(e, "")) + for e in call_codes + ] diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index 89a153af2..d41929518 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -46,9 +46,10 @@ def _make_one(self, *args, **kwargs): if not args: kwargs["gapic_client"] = kwargs.pop("gapic_client", mock.Mock()) kwargs["table"] = kwargs.pop("table", AsyncMock()) - kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) + kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) return self._target_class()(*args, **kwargs) async def _mock_stream(self, mutation_list, error_dict): @@ -78,15 +79,21 @@ def test_ctor(self): from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete from google.api_core.exceptions import DeadlineExceeded - from google.api_core.exceptions import ServiceUnavailable + from google.api_core.exceptions import Aborted client = mock.Mock() table = mock.Mock() entries = [_make_mutation(), _make_mutation()] operation_timeout = 0.05 attempt_timeout = 0.01 + retryable_exceptions = () instance = self._make_one( - client, table, entries, operation_timeout, attempt_timeout + client, + table, + entries, + operation_timeout, + attempt_timeout, + retryable_exceptions, ) # running gapic_fn should trigger a client call assert client.mutate_rows.call_count == 0 @@ -110,8 +117,8 @@ def test_ctor(self): assert next(instance.timeout_generator) == attempt_timeout # ensure predicate is set assert instance.is_retryable is not None - assert instance.is_retryable(DeadlineExceeded("")) is True - assert instance.is_retryable(ServiceUnavailable("")) is True + assert instance.is_retryable(DeadlineExceeded("")) is False + assert instance.is_retryable(Aborted("")) is False assert instance.is_retryable(_MutateRowsIncomplete("")) is True assert instance.is_retryable(RuntimeError("")) is False assert instance.remaining_indices == list(range(len(entries))) @@ -232,7 +239,7 @@ async def test_mutate_rows_exception(self, exc_type): @pytest.mark.parametrize( "exc_type", - [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + [core_exceptions.DeadlineExceeded, RuntimeError], ) @pytest.mark.asyncio async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): @@ -256,7 +263,12 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): ) as attempt_mock: attempt_mock.side_effect = [expected_cause] * num_retries + [None] instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + retryable_exceptions=(exc_type,), ) await instance.start() assert attempt_mock.call_count == num_retries + 1 diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 7718246fc..54bbb6158 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -26,6 +26,8 @@ from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery from google.api_core import exceptions as core_exceptions from google.cloud.bigtable.data.exceptions import InvalidChunk +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule from google.cloud.bigtable.data.read_modify_write_rules import AppendValueRule @@ -841,6 +843,39 @@ async def test_get_table(self): assert client._instance_owners[instance_key] == {id(table)} await client.close() + @pytest.mark.asyncio + async def test_get_table_arg_passthrough(self): + """ + All arguments passed in get_table should be sent to constructor + """ + async with self._make_one(project="project-id") as client: + with mock.patch( + "google.cloud.bigtable.data._async.client.TableAsync.__init__", + ) as mock_constructor: + mock_constructor.return_value = None + assert not client._active_instances + expected_table_id = "table-id" + expected_instance_id = "instance-id" + expected_app_profile_id = "app-profile-id" + expected_args = (1, "test", {"test": 2}) + expected_kwargs = {"hello": "world", "test": 2} + + client.get_table( + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + mock_constructor.assert_called_once_with( + client, + expected_instance_id, + expected_table_id, + expected_app_profile_id, + *expected_args, + **expected_kwargs, + ) + @pytest.mark.asyncio async def test_get_table_context_manager(self): from google.cloud.bigtable.data._async.client import TableAsync @@ -1099,6 +1134,173 @@ def test_table_ctor_sync(self): TableAsync(client, "instance-id", "table-id") assert e.match("TableAsync must be created within an async event loop context.") + @pytest.mark.asyncio + # iterate over all retryable rpcs + @pytest.mark.parametrize( + "fn_name,fn_args,retry_fn_path,extra_retryables", + [ + ( + "read_rows_stream", + (ReadRowsQuery(),), + "google.cloud.bigtable.data._async._read_rows.retry_target_stream", + (), + ), + ( + "read_rows", + (ReadRowsQuery(),), + "google.cloud.bigtable.data._async._read_rows.retry_target_stream", + (), + ), + ( + "read_row", + (b"row_key",), + "google.cloud.bigtable.data._async._read_rows.retry_target_stream", + (), + ), + ( + "read_rows_sharded", + ([ReadRowsQuery()],), + "google.cloud.bigtable.data._async._read_rows.retry_target_stream", + (), + ), + ( + "row_exists", + (b"row_key",), + "google.cloud.bigtable.data._async._read_rows.retry_target_stream", + (), + ), + ("sample_row_keys", (), "google.api_core.retry_async.retry_target", ()), + ( + "mutate_row", + (b"row_key", []), + "google.api_core.retry_async.retry_target", + (), + ), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), + "google.api_core.retry_async.retry_target", + (_MutateRowsIncomplete,), + ), + ], + ) + # test different inputs for retryable exceptions + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + async def test_customizable_retryable_errors( + self, + input_retryables, + expected_retryables, + fn_name, + fn_args, + retry_fn_path, + extra_retryables, + ): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + from google.cloud.bigtable.data import BigtableDataClientAsync + + with mock.patch( + "google.api_core.retry_async.if_exception_type" + ) as predicate_builder_mock: + with mock.patch(retry_fn_path) as retry_fn_mock: + async with BigtableDataClientAsync() as client: + table = client.get_table("instance-id", "table-id") + expected_predicate = lambda a: a in expected_retryables # noqa + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + with pytest.raises(Exception): + # we expect an exception from attempting to call the mock + test_fn = table.__getattribute__(fn_name) + await test_fn(*fn_args, retryable_errors=input_retryables) + # passed in errors should be used to build the predicate + predicate_builder_mock.assert_called_once_with( + *expected_retryables, *extra_retryables + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + # output of if_exception_type should be sent in to retry constructor + assert retry_call_args[1] is expected_predicate + + @pytest.mark.parametrize( + "fn_name,fn_args,gapic_fn", + [ + ("read_rows_stream", (ReadRowsQuery(),), "read_rows"), + ("read_rows", (ReadRowsQuery(),), "read_rows"), + ("read_row", (b"row_key",), "read_rows"), + ("read_rows_sharded", ([ReadRowsQuery()],), "read_rows"), + ("row_exists", (b"row_key",), "read_rows"), + ("sample_row_keys", (), "sample_row_keys"), + ("mutate_row", (b"row_key", []), "mutate_row"), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mock.Mock()])],), + "mutate_rows", + ), + ("check_and_mutate_row", (b"row_key", None), "check_and_mutate_row"), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + ), + ], + ) + @pytest.mark.parametrize("include_app_profile", [True, False]) + @pytest.mark.asyncio + async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn): + """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import BigtableDataClientAsync + + profile = "profile" if include_app_profile else None + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock() + ) as gapic_mock: + gapic_mock.side_effect = RuntimeError("stop early") + async with BigtableDataClientAsync() as client: + table = TableAsync(client, "instance-id", "table-id", profile) + try: + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args) + [i async for i in maybe_stream] + except Exception: + # we expect an exception from attempting to call the mock + pass + kwargs = gapic_mock.call_args_list[0].kwargs + metadata = kwargs["metadata"] + goog_metadata = None + for key, value in metadata: + if key == "x-goog-request-params": + goog_metadata = value + assert goog_metadata is not None, "x-goog-request-params not found" + assert "table_name=" + table.table_name in goog_metadata + if include_app_profile: + assert "app_profile_id=profile" in goog_metadata + else: + assert "app_profile_id=" not in goog_metadata + class TestReadRows: """ @@ -1608,28 +1810,6 @@ async def test_row_exists(self, return_value, expected_result): assert query.limit == 1 assert query.filter._to_dict() == expected_filter - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_read_rows_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with self._make_table(app_profile_id=profile) as table: - read_rows = table.client._gapic_client.read_rows - read_rows.return_value = self._make_gapic_stream([]) - await table.read_rows(ReadRowsQuery()) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - class TestReadRowsSharded: def _make_client(self, *args, **kwargs): @@ -1735,30 +1915,6 @@ async def mock_call(*args, **kwargs): # if run in sequence, we would expect this to take 1 second assert call_time < 0.2 - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_read_rows_sharded_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with self._make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "read_rows", AsyncMock() - ) as read_rows: - await table.read_rows_sharded([ReadRowsQuery()]) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - @pytest.mark.asyncio async def test_read_rows_sharded_batching(self): """ @@ -1875,7 +2031,10 @@ async def test_sample_row_keys_default_timeout(self): expected_timeout = 99 async with self._make_client() as client: async with client.get_table( - "i", "t", default_operation_timeout=expected_timeout + "i", + "t", + default_operation_timeout=expected_timeout, + default_attempt_timeout=expected_timeout, ) as table: with mock.patch.object( table.client._gapic_client, "sample_row_keys", AsyncMock() @@ -1914,30 +2073,6 @@ async def test_sample_row_keys_gapic_params(self): assert kwargs["metadata"] is not None assert kwargs["retry"] is None - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_sample_row_keys_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with self._make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "sample_row_keys", AsyncMock() - ) as read_rows: - await table.sample_row_keys() - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - @pytest.mark.parametrize( "retryable_exception", [ @@ -2525,39 +2660,6 @@ async def test_bulk_mutate_error_index(self): assert isinstance(cause.exceptions[1], DeadlineExceeded) assert isinstance(cause.exceptions[2], FailedPrecondition) - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_bulk_mutate_row_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with self._make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "mutate_rows", AsyncMock() - ) as mutate_rows: - mutate_rows.side_effect = core_exceptions.Aborted("mock") - mutation = mock.Mock() - mutation.size.return_value = 1 - entry = mock.Mock() - entry.mutations = [mutation] - try: - await table.bulk_mutate_rows([entry]) - except Exception: - # exception used to end early - pass - kwargs = mutate_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - class TestCheckAndMutateRow: def _make_client(self, *args, **kwargs): @@ -2727,30 +2829,6 @@ async def test_check_and_mutate_mutations_parsing(self): mutation._to_pb.call_count == 1 for mutation in mutations[:5] ) - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_check_and_mutate_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with self._make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "check_and_mutate_row", AsyncMock() - ) as mock_gapic: - await table.check_and_mutate_row(b"key", mock.Mock()) - kwargs = mock_gapic.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata - class TestReadModifyWriteRow: def _make_client(self, *args, **kwargs): @@ -2882,27 +2960,3 @@ async def test_read_modify_write_row_building(self): await table.read_modify_write_row("key", mock.Mock()) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) - - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_read_modify_write_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with self._make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row", AsyncMock() - ) as mock_gapic: - await table.read_modify_write_row("key", mock.Mock()) - kwargs = mock_gapic.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index f95b53271..17bd8d420 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -14,6 +14,9 @@ import pytest import asyncio +import google.api_core.exceptions as core_exceptions +from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete +from google.cloud.bigtable.data import TABLE_DEFAULT # try/except added for compatibility with python < 3.8 try: @@ -286,10 +289,17 @@ def _get_target_class(self): return MutationsBatcherAsync def _make_one(self, table=None, **kwargs): + from google.api_core.exceptions import DeadlineExceeded + from google.api_core.exceptions import ServiceUnavailable + if table is None: table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 10 + table.default_mutate_rows_retryable_errors = ( + DeadlineExceeded, + ServiceUnavailable, + ) return self._get_target_class()(table, **kwargs) @@ -302,6 +312,7 @@ async def test_ctor_defaults(self, flush_timer_mock): table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = [Exception] async with self._make_one(table) as instance: assert instance._table == table assert instance.closed is False @@ -323,6 +334,9 @@ async def test_ctor_defaults(self, flush_timer_mock): assert ( instance._attempt_timeout == table.default_mutate_rows_attempt_timeout ) + assert ( + instance._retryable_errors == table.default_mutate_rows_retryable_errors + ) await asyncio.sleep(0) assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] == 5 @@ -343,6 +357,7 @@ async def test_ctor_explicit(self, flush_timer_mock): flow_control_max_bytes = 12 operation_timeout = 11 attempt_timeout = 2 + retryable_errors = [Exception] async with self._make_one( table, flush_interval=flush_interval, @@ -352,6 +367,7 @@ async def test_ctor_explicit(self, flush_timer_mock): flow_control_max_bytes=flow_control_max_bytes, batch_operation_timeout=operation_timeout, batch_attempt_timeout=attempt_timeout, + batch_retryable_errors=retryable_errors, ) as instance: assert instance._table == table assert instance.closed is False @@ -371,6 +387,7 @@ async def test_ctor_explicit(self, flush_timer_mock): assert instance._entries_processed_since_last_raise == 0 assert instance._operation_timeout == operation_timeout assert instance._attempt_timeout == attempt_timeout + assert instance._retryable_errors == retryable_errors await asyncio.sleep(0) assert flush_timer_mock.call_count == 1 assert flush_timer_mock.call_args[0][0] == flush_interval @@ -386,6 +403,7 @@ async def test_ctor_no_flush_limits(self, flush_timer_mock): table = mock.Mock() table.default_mutate_rows_operation_timeout = 10 table.default_mutate_rows_attempt_timeout = 8 + table.default_mutate_rows_retryable_errors = () flush_interval = None flush_limit_count = None flush_limit_bytes = None @@ -442,7 +460,7 @@ def test_default_argument_consistency(self): batcher_init_signature.pop("table") # both should have same number of arguments assert len(get_batcher_signature.keys()) == len(batcher_init_signature.keys()) - assert len(get_batcher_signature) == 7 # update if expected params change + assert len(get_batcher_signature) == 8 # update if expected params change # both should have same argument names assert set(get_batcher_signature.keys()) == set(batcher_init_signature.keys()) # both should have same default values @@ -882,6 +900,7 @@ async def test__execute_mutate_rows(self, mutate_rows): table.app_profile_id = "test-app-profile" table.default_mutate_rows_operation_timeout = 17 table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: batch = [_make_mutation()] result = await instance._execute_mutate_rows(batch) @@ -911,6 +930,7 @@ async def test__execute_mutate_rows_returns_errors(self, mutate_rows): table = mock.Mock() table.default_mutate_rows_operation_timeout = 17 table.default_mutate_rows_attempt_timeout = 13 + table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: batch = [_make_mutation()] result = await instance._execute_mutate_rows(batch) @@ -1102,3 +1122,63 @@ def test__add_exceptions(self, limit, in_e, start_e, end_e): # then, the newest slots should be filled with the last items of the input list for i in range(1, newest_list_diff + 1): assert mock_batcher._newest_exceptions[-i] == input_list[-i] + + @pytest.mark.asyncio + # test different inputs for retryable exceptions + @pytest.mark.parametrize( + "input_retryables,expected_retryables", + [ + ( + TABLE_DEFAULT.READ_ROWS, + [ + core_exceptions.DeadlineExceeded, + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + ], + ), + ( + TABLE_DEFAULT.DEFAULT, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + [core_exceptions.DeadlineExceeded, core_exceptions.ServiceUnavailable], + ), + ([], []), + ([4], [core_exceptions.DeadlineExceeded]), + ], + ) + async def test_customizable_retryable_errors( + self, input_retryables, expected_retryables + ): + """ + Test that retryable functions support user-configurable arguments, and that the configured retryables are passed + down to the gapic layer. + """ + from google.cloud.bigtable.data._async.client import TableAsync + + with mock.patch( + "google.api_core.retry_async.if_exception_type" + ) as predicate_builder_mock: + with mock.patch( + "google.api_core.retry_async.retry_target" + ) as retry_fn_mock: + table = None + with mock.patch("asyncio.create_task"): + table = TableAsync(mock.Mock(), "instance", "table") + async with self._make_one( + table, batch_retryable_errors=input_retryables + ) as instance: + assert instance._retryable_errors == expected_retryables + expected_predicate = lambda a: a in expected_retryables # noqa + predicate_builder_mock.return_value = expected_predicate + retry_fn_mock.side_effect = RuntimeError("stop early") + mutation = _make_mutation(count=1, size=1) + await instance._execute_mutate_rows([mutation]) + # passed in errors should be used to build the predicate + predicate_builder_mock.assert_called_once_with( + *expected_retryables, _MutateRowsIncomplete + ) + retry_call_args = retry_fn_mock.call_args_list[0].args + # output of if_exception_type should be sent in to retry constructor + assert retry_call_args[1] is expected_predicate diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 6c11fa86a..b9c1dc2bb 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -13,6 +13,8 @@ # import pytest +import grpc +from google.api_core import exceptions as core_exceptions import google.cloud.bigtable.data._helpers as _helpers from google.cloud.bigtable.data._helpers import TABLE_DEFAULT import google.cloud.bigtable.data.exceptions as bigtable_exceptions @@ -264,3 +266,49 @@ def test_get_timeouts_invalid(self, input_times, input_table): setattr(fake_table, f"default_{key}_timeout", input_table[key]) with pytest.raises(ValueError): _helpers._get_timeouts(input_times[0], input_times[1], fake_table) + + +class TestGetRetryableErrors: + @pytest.mark.parametrize( + "input_codes,input_table,expected", + [ + ((), {}, []), + ((Exception,), {}, [Exception]), + (TABLE_DEFAULT.DEFAULT, {"default": [Exception]}, [Exception]), + ( + TABLE_DEFAULT.READ_ROWS, + {"default_read_rows": (RuntimeError, ValueError)}, + [RuntimeError, ValueError], + ), + ( + TABLE_DEFAULT.MUTATE_ROWS, + {"default_mutate_rows": (ValueError,)}, + [ValueError], + ), + ((4,), {}, [core_exceptions.DeadlineExceeded]), + ( + [grpc.StatusCode.DEADLINE_EXCEEDED], + {}, + [core_exceptions.DeadlineExceeded], + ), + ( + (14, grpc.StatusCode.ABORTED, RuntimeError), + {}, + [ + core_exceptions.ServiceUnavailable, + core_exceptions.Aborted, + RuntimeError, + ], + ), + ], + ) + def test_get_retryable_errors(self, input_codes, input_table, expected): + """ + test input/output mappings for a variety of valid inputs + """ + fake_table = mock.Mock() + for key in input_table.keys(): + # set the default fields in our fake table mock + setattr(fake_table, f"{key}_retryable_errors", input_table[key]) + result = _helpers._get_retryable_errors(input_codes, fake_table) + assert result == expected