Skip to content

Commit

Permalink
fix: correctly set resume token when restarting streams (#314)
Browse files Browse the repository at this point in the history
* fix: correctly set resume token for restarting streams

* style: fix lint

* docs: update docstring

* test: fix assertion

Co-authored-by: larkee <[email protected]>
  • Loading branch information
larkee and larkee authored Apr 26, 2021
1 parent 772aa3c commit 0fcfc23
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 48 deletions.
6 changes: 3 additions & 3 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,11 @@ def execute_pdml():
param_types=param_types,
query_options=query_options,
)
restart = functools.partial(
api.execute_streaming_sql, request=request, metadata=metadata,
method = functools.partial(
api.execute_streaming_sql, metadata=metadata,
)

iterator = _restart_on_unavailable(restart)
iterator = _restart_on_unavailable(method, request)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
Expand Down
26 changes: 19 additions & 7 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,21 @@
)


def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=None):
def _restart_on_unavailable(
method, request, trace_name=None, session=None, attributes=None
):
"""Restart iteration after :exc:`.ServiceUnavailable`.
:type restart: callable
:param restart: curried function returning iterator
:type method: callable
:param method: function returning iterator
:type request: proto
:param request: request proto to call the method with
"""
resume_token = b""
item_buffer = []
with trace_call(trace_name, session, attributes):
iterator = restart()
iterator = method(request=request)
while True:
try:
for item in iterator:
Expand All @@ -61,7 +66,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
except ServiceUnavailable:
del item_buffer[:]
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
request.resume_token = resume_token
iterator = method(request=request)
continue
except InternalServerError as exc:
resumable_error = any(
Expand All @@ -72,7 +78,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
raise
del item_buffer[:]
with trace_call(trace_name, session, attributes):
iterator = restart(resume_token=resume_token)
request.resume_token = resume_token
iterator = method(request=request)
continue

if len(item_buffer) == 0:
Expand Down Expand Up @@ -189,7 +196,11 @@ def read(

trace_attributes = {"table_id": table, "columns": columns}
iterator = _restart_on_unavailable(
restart, "CloudSpanner.ReadOnlyTransaction", self._session, trace_attributes
restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
trace_attributes,
)

self._read_request_count += 1
Expand Down Expand Up @@ -302,6 +313,7 @@ def execute_sql(
trace_attributes = {"db.statement": sql}
iterator = _restart_on_unavailable(
restart,
request,
"CloudSpanner.ReadWriteTransaction",
self._session,
trace_attributes,
Expand Down
92 changes: 54 additions & 38 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@


class Test_restart_on_unavailable(OpenTelemetryBase):
def _call_fut(self, restart, span_name=None, session=None, attributes=None):
def _call_fut(
self, restart, request, span_name=None, session=None, attributes=None
):
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable

return _restart_on_unavailable(restart, span_name, session, attributes)
return _restart_on_unavailable(restart, request, span_name, session, attributes)

def _make_item(self, value, resume_token=b""):
return mock.Mock(
Expand All @@ -59,18 +61,21 @@ def _make_item(self, value, resume_token=b""):

def test_iteration_w_empty_raw(self):
raw = _MockIterator()
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), [])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_non_empty_raw(self):
ITEMS = (self._make_item(0), self._make_item(1))
raw = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
restart.assert_called_once_with()
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_w_resume_tken(self):
Expand All @@ -81,10 +86,11 @@ def test_iteration_w_raw_w_resume_tken(self):
self._make_item(3),
)
raw = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
restart.assert_called_once_with()
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable_no_token(self):
Expand All @@ -97,10 +103,12 @@ def test_iteration_w_raw_raising_unavailable_no_token(self):
)
before = _MockIterator(fail_after=True, error=ServiceUnavailable("testing"))
after = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, b"")
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
Expand All @@ -118,10 +126,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
),
)
after = _MockIterator(*ITEMS)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(ITEMS))
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, b"")
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
Expand All @@ -134,11 +144,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
)
before = _MockIterator(fail_after=True, error=InternalServerError("testing"))
after = _MockIterator(*ITEMS)
request = mock.Mock(spec=["resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable(self):
Expand All @@ -151,12 +162,12 @@ def test_iteration_w_raw_raising_unavailable(self):
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error(self):
Expand All @@ -173,12 +184,12 @@ def test_iteration_w_raw_raising_retryable_internal_error(self):
)
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error(self):
Expand All @@ -191,11 +202,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self):
*(FIRST + SECOND), fail_after=True, error=InternalServerError("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_raw_raising_unavailable_after_token(self):
Expand All @@ -207,12 +219,12 @@ def test_iteration_w_raw_raising_unavailable_after_token(self):
*FIRST, fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + SECOND))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
Expand All @@ -228,12 +240,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
)
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
self.assertEqual(list(resumable), list(FIRST + SECOND))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)
self.assertNoSpans()

def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
Expand All @@ -245,19 +257,23 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
*FIRST, fail_after=True, error=InternalServerError("testing")
)
after = _MockIterator(*SECOND)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
resumable = self._call_fut(restart)
resumable = self._call_fut(restart, request)
with self.assertRaises(InternalServerError):
list(resumable)
self.assertEqual(restart.mock_calls, [mock.call()])
restart.assert_called_once_with(request=request)
self.assertNoSpans()

def test_iteration_w_span_creation(self):
name = "TestSpan"
extra_atts = {"test_att": 1}
raw = _MockIterator()
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], return_value=raw)
resumable = self._call_fut(restart, name, _Session(_Database()), extra_atts)
resumable = self._call_fut(
restart, request, name, _Session(_Database()), extra_atts
)
self.assertEqual(list(resumable), [])
self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1))

Expand All @@ -272,13 +288,13 @@ def test_iteration_w_multiple_span_creation(self):
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
)
after = _MockIterator(*LAST)
request = mock.Mock(test="test", spec=["test", "resume_token"])
restart = mock.Mock(spec=[], side_effect=[before, after])
name = "TestSpan"
resumable = self._call_fut(restart, name, _Session(_Database()))
resumable = self._call_fut(restart, request, name, _Session(_Database()))
self.assertEqual(list(resumable), list(FIRST + LAST))
self.assertEqual(
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
)
self.assertEqual(len(restart.mock_calls), 2)
self.assertEqual(request.resume_token, RESUME_TOKEN)

span_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(span_list), 2)
Expand Down

0 comments on commit 0fcfc23

Please sign in to comment.