diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 6234c96435..de610e1387 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -86,13 +86,18 @@ def _restart_on_unavailable( ) request.transaction = transaction_selector + iterator = None - with trace_call( - trace_name, session, attributes, observability_options=observability_options - ): - iterator = method(request=request) while True: try: + if iterator is None: + with trace_call( + trace_name, + session, + attributes, + observability_options=observability_options, + ): + iterator = method(request=request) for item in iterator: item_buffer.append(item) # Setting the transaction id because the transaction begin was inlined for first rpc. diff --git a/tests/mockserver_tests/mock_server_test_base.py b/tests/mockserver_tests/mock_server_test_base.py index 12c98bc51b..b332c88d7c 100644 --- a/tests/mockserver_tests/mock_server_test_base.py +++ b/tests/mockserver_tests/mock_server_test_base.py @@ -57,6 +57,27 @@ def aborted_status() -> _Status: return status +# Creates an UNAVAILABLE status with the smallest possible retry delay. +def unavailable_status() -> _Status: + error = status_pb2.Status( + code=code_pb2.UNAVAILABLE, + message="Service unavailable.", + ) + retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1)) + status = _Status( + code=code_to_grpc_status_code(error.code), + details=error.message, + trailing_metadata=( + ("grpc-status-details-bin", error.SerializeToString()), + ( + "google.rpc.retryinfo-bin", + retry_info.SerializeToString(), + ), + ), + ) + return status + + def add_error(method: str, error: status_pb2.Status): MockServerTestBase.spanner_service.mock_spanner.add_error(method, error) diff --git a/tests/mockserver_tests/test_basics.py b/tests/mockserver_tests/test_basics.py index ed0906cb9b..d34065a6ff 100644 --- a/tests/mockserver_tests/test_basics.py +++ b/tests/mockserver_tests/test_basics.py @@ -21,11 +21,14 @@ BeginTransactionRequest, TransactionOptions, ) +from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, add_select1_result, add_update_count, + add_error, + unavailable_status, ) @@ -85,3 +88,22 @@ def test_dbapi_partitioned_dml(self): self.assertEqual( TransactionOptions(dict(partitioned_dml={})), begin_request.options ) + + def test_execute_streaming_sql_unavailable(self): + add_select1_result() + # Add an UNAVAILABLE error that is returned the first time the + # ExecuteStreamingSql RPC is called. + add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status()) + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + # The ExecuteStreamingSql call should be retried. + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))