Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in recv()'s usage of self.call. #5935

Merged
merged 3 commits into from
Sep 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions pubsub/google/cloud/pubsub_v1/subscriber/_protocol/bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,12 @@ def _on_call_done(self, future):
# Unlike the base class, we only execute the callbacks on a terminal
# error, not for errors that we can recover from. Note that grpc's
# "future" here is also a grpc.RpcError.
if not self._should_recover(future):
self._finalize(future)
else:
_LOGGER.debug('Re-opening stream from gRPC callback.')
self._reopen()
with self._operational_lock:
if not self._should_recover(future):
self._finalize(future)
else:
_LOGGER.debug('Re-opening stream from gRPC callback.')
self._reopen()

def _reopen(self):
with self._operational_lock:
Expand All @@ -361,6 +362,7 @@ def _reopen(self):
# If re-opening or re-calling the method fails for any reason,
# consider it a terminal error and finalize the stream.
except Exception as exc:
_LOGGER.debug('Failed to re-open stream due to %s', exc)
self._finalize(exc)
raise

Expand All @@ -385,23 +387,60 @@ def _recoverable(self, method, *args, **kwargs):
return method(*args, **kwargs)

except Exception as exc:
_LOGGER.debug('Call to retryable %r caused %s.', method, exc)
if not self._should_recover(exc):
self.close()
_LOGGER.debug('Not retrying %r due to %s.', method, exc)
self._finalize(exc)
raise exc
with self._operational_lock:
_LOGGER.debug(
'Call to retryable %r caused %s.', method, exc)

if not self._should_recover(exc):
self.close()
_LOGGER.debug(
'Not retrying %r due to %s.', method, exc)
self._finalize(exc)
raise exc

_LOGGER.debug(
'Re-opening stream from retryable %r.', method)
self._reopen()

def _send(self, request):
# Grab a reference to the RPC call. Because another thread (notably
# the gRPC error thread) can modify self.call (by invoking reopen),
# we should ensure our reference can not change underneath us.
# If self.call is modified (such as replaced with a new RPC call) then
# this will use the "old" RPC, which should result in the same
# exception passed into gRPC's error handler being raised here, which
# will be handled by the usual error handling in retryable.
with self._operational_lock:
call = self.call

if call is None:
raise ValueError(
'Can not send() on an RPC that has never been open()ed.')

_LOGGER.debug('Re-opening stream from retryable %r.', method)
self._reopen()
# Don't use self.is_active(), as ResumableBidiRpc will overload it
# to mean something semantically different.
if call.is_active():
self._request_queue.put(request)
pass
else:
# calling next should cause the call to raise.
next(call)

def send(self, request):
return self._recoverable(
super(ResumableBidiRpc, self).send, request)
return self._recoverable(self._send, request)

def _recv(self):
with self._operational_lock:
call = self.call

if call is None:
raise ValueError(
'Can not recv() on an RPC that has never been open()ed.')
tseaver marked this conversation as resolved.
Show resolved Hide resolved

return next(call)

def recv(self):
return self._recoverable(
super(ResumableBidiRpc, self).recv)
return self._recoverable(self._recv)

@property
def is_active(self):
Expand Down Expand Up @@ -506,8 +545,7 @@ def _thread_main(self):

else:
_LOGGER.error(
'The bidirectional RPC unexpectedly exited. This is a truly '
'exceptional case. Please file a bug with your logs.')
'The bidirectional RPC exited.')

_LOGGER.info('%s exiting', _BIDIRECTIONAL_CONSUMER_NAME)

Expand Down
44 changes: 18 additions & 26 deletions pubsub/tests/unit/pubsub_v1/subscriber/test_bidi.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,41 +373,21 @@ def test_recv_recover(self):
assert bidi_rpc.call == call_2
assert bidi_rpc.is_active is True

def test_recv_recover_race_condition(self):
# This test checks the race condition where two threads recv() and
# encounter an error and must re-open the stream. Only one thread
# should succeed in doing so.
error = ValueError()
call_1 = CallStub([error, error])
call_2 = CallStub([1, 2])
def test_recv_recover_already_recovered(self):
call_1 = CallStub([])
call_2 = CallStub([])
start_rpc = mock.create_autospec(
grpc.StreamStreamMultiCallable,
instance=True,
side_effect=[call_1, call_2])
recovered_event = threading.Event()

def second_thread_main():
assert bidi_rpc.recv() == 2

second_thread = threading.Thread(target=second_thread_main)

def should_recover(exception):
assert exception == error
if threading.current_thread() == second_thread:
recovered_event.wait()
return True

bidi_rpc = bidi.ResumableBidiRpc(start_rpc, should_recover)
bidi_rpc = bidi.ResumableBidiRpc(start_rpc, lambda _: True)

bidi_rpc.open()
second_thread.start()

assert bidi_rpc.recv() == 1
recovered_event.set()
bidi_rpc._reopen()

assert bidi_rpc.call == call_2
assert bidi_rpc.call is call_1
assert bidi_rpc.is_active is True
second_thread.join()

def test_recv_failure(self):
error = ValueError()
Expand Down Expand Up @@ -456,6 +436,18 @@ def test_reopen_failure_on_rpc_restart(self):
assert bidi_rpc.is_active is False
callback.assert_called_once_with(error2)

def test_send_not_open(self):
bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)

with pytest.raises(ValueError):
bidi_rpc.send(mock.sentinel.request)

def test_recv_not_open(self):
bidi_rpc = bidi.ResumableBidiRpc(None, lambda _: False)

with pytest.raises(ValueError):
bidi_rpc.recv()

def test_finalize_idempotent(self):
error1 = ValueError('1')
error2 = ValueError('2')
Expand Down