From d69e2facec1adbd680742a2c5fe37cda73f6290c Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 17 Sep 2019 10:14:07 -0400 Subject: [PATCH] Implement ``Future.cancel()`` This brings our implementation of ``Future`` in parity with the ``Future`` interface defined in the Python 3 standard library, and makes it possible to cancel asynchronous ``grpc`` calls from NDB. --- google/cloud/ndb/_remote.py | 38 ++++++++++++++++-- google/cloud/ndb/exceptions.py | 9 +++++ google/cloud/ndb/tasklets.py | 40 +++++++++++++++---- tests/system/test_query.py | 25 ++++++++++++ tests/unit/test__remote.py | 40 +++++++++++++++++-- tests/unit/test_tasklets.py | 70 ++++++++++++++++++++++++++++++++-- 6 files changed, 205 insertions(+), 17 deletions(-) diff --git a/google/cloud/ndb/_remote.py b/google/cloud/ndb/_remote.py index fea024a5..0b7f9083 100644 --- a/google/cloud/ndb/_remote.py +++ b/google/cloud/ndb/_remote.py @@ -16,6 +16,9 @@ # In its own module to avoid circular import between _datastore_api and # tasklets modules. +import grpc + +from google.cloud.ndb import exceptions class RemoteCall: @@ -36,18 +39,47 @@ class RemoteCall: def __init__(self, future, info): self.future = future self.info = info + self._callbacks = [] + + future.add_done_callback(self._finish) def __repr__(self): return self.info def exception(self): """Calls :meth:`grpc.Future.exception` on attr:`future`.""" - return self.future.exception() + # GRPC will actually raise FutureCancelledError. + # We'll translate that to our own Cancelled exception and *return* it, + # which is far more polite for a method that *returns exceptions*. + try: + return self.future.exception() + except grpc.FutureCancelledError: + return exceptions.Cancelled() def result(self): """Calls :meth:`grpc.Future.result` on attr:`future`.""" return self.future.result() def add_done_callback(self, callback): - """Calls :meth:`grpc.Future.add_done_callback` on attr:`future`.""" - return self.future.add_done_callback(callback) + """Add a callback function to be run upon task completion. Will run + immediately if task has already finished. + + Args: + callback (Callable): The function to execute. + """ + if self.future.done(): + callback(self) + else: + self._callbacks.append(callback) + + def cancel(self): + """Calls :meth:`grpc.Future.cancel` on attr:`cancel`.""" + return self.future.cancel() + + def _finish(self, rpc): + """Called when remote future is finished. + + Used to call our own done callbacks. + """ + for callback in self._callbacks: + callback(self) diff --git a/google/cloud/ndb/exceptions.py b/google/cloud/ndb/exceptions.py index b0920779..a5073ddf 100644 --- a/google/cloud/ndb/exceptions.py +++ b/google/cloud/ndb/exceptions.py @@ -112,3 +112,12 @@ class NoLongerImplementedError(NotImplementedError): def __init__(self): super(NoLongerImplementedError, self).__init__("No longer implemented") + + +class Cancelled(Error): + """An operation has been cancelled by user request. + + Raised when trying to get a result from a future that has been cancelled by + a call to ``Future.cancel`` (possibly on a future that depends on this + future). + """ diff --git a/google/cloud/ndb/tasklets.py b/google/cloud/ndb/tasklets.py index 9493c2a7..5bea3d83 100644 --- a/google/cloud/ndb/tasklets.py +++ b/google/cloud/ndb/tasklets.py @@ -58,6 +58,7 @@ def main(): from google.cloud.ndb import context as context_module from google.cloud.ndb import _eventloop +from google.cloud.ndb import exceptions from google.cloud.ndb import _remote __all__ = [ @@ -232,20 +233,26 @@ def add_done_callback(self, callback): self._callbacks.append(callback) def cancel(self): - """Cancel the task for this future. + """Attempt to cancel the task for this future. - Raises: - NotImplementedError: Always, not supported. + If the task has already completed, this call will do nothing. + Otherwise, this will attempt to cancel whatever task this future is + waiting on. There is no specific guarantee the underlying task will be + cancelled. """ - raise NotImplementedError + if not self.done(): + self.set_exception(exceptions.Cancelled()) def cancelled(self): - """Get whether task for this future has been canceled. + """Get whether the task for this future has been cancelled. Returns: - :data:`False`: Always. + :data:`True`: If this future's task has been cancelled, otherwise + :data:`False`. """ - return False + return self._exception is not None and isinstance( + self._exception, exceptions.Cancelled + ) @staticmethod def wait_any(futures): @@ -278,6 +285,7 @@ def __init__(self, generator, context, info="Unknown"): super(_TaskletFuture, self).__init__(info=info) self.generator = generator self.context = context + self.waiting_on = None def _advance_tasklet(self, send_value=None, error=None): """Advance a tasklet one step by sending in a value or error.""" @@ -324,6 +332,8 @@ def done_callback(yielded): # in Legacy) directly. Doing so, it has been found, can lead to # exceeding the maximum recursion depth. Queing it up to run on the # event loop avoids this issue by keeping the call stack shallow. + self.waiting_on = None + error = yielded.exception() if error: _eventloop.call_soon(self._advance_tasklet, error=error) @@ -332,19 +342,30 @@ def done_callback(yielded): if isinstance(yielded, Future): yielded.add_done_callback(done_callback) + self.waiting_on = yielded elif isinstance(yielded, _remote.RemoteCall): _eventloop.queue_rpc(yielded, done_callback) + self.waiting_on = yielded elif isinstance(yielded, (list, tuple)): future = _MultiFuture(yielded) future.add_done_callback(done_callback) + self.waiting_on = future else: raise RuntimeError( "A tasklet yielded an illegal value: {!r}".format(yielded) ) + def cancel(self): + """Overrides :meth:`Future.cancel`.""" + if self.waiting_on: + self.waiting_on.cancel() + + else: + super(_TaskletFuture, self).cancel() + def _get_return_value(stop): """Inspect `StopIteration` instance for return value of tasklet. @@ -399,6 +420,11 @@ def _dependency_done(self, dependency): result = tuple((future.result() for future in self._dependencies)) self.set_result(result) + def cancel(self): + """Overrides :meth:`Future.cancel`.""" + for dependency in self._dependencies: + dependency.cancel() + def tasklet(wrapped): """ diff --git a/tests/system/test_query.py b/tests/system/test_query.py index 2242ea13..7438bb19 100644 --- a/tests/system/test_query.py +++ b/tests/system/test_query.py @@ -96,6 +96,31 @@ def make_entities(): assert [entity.foo for entity in results][:5] == [0, 1, 2, 3, 4] +@pytest.mark.usefixtures("client_context") +def test_fetch_and_immediately_cancel(dispose_of): + # Make a lot of entities so the query call won't complete before we get to + # call cancel. + n_entities = 500 + + class SomeKind(ndb.Model): + foo = ndb.IntegerProperty() + + @ndb.toplevel + def make_entities(): + entities = [SomeKind(foo=i) for i in range(n_entities)] + keys = yield [entity.put_async() for entity in entities] + raise ndb.Return(keys) + + for key in make_entities(): + dispose_of(key._key) + + query = SomeKind.query() + future = query.fetch_async() + future.cancel() + with pytest.raises(ndb.exceptions.Cancelled): + future.result() + + @pytest.mark.usefixtures("client_context") def test_ancestor_query(ds_entity): root_id = test_utils.system.unique_resource_id() diff --git a/tests/unit/test__remote.py b/tests/unit/test__remote.py index 9f5c5838..0c0bf19e 100644 --- a/tests/unit/test__remote.py +++ b/tests/unit/test__remote.py @@ -14,6 +14,10 @@ from unittest import mock +import grpc +import pytest + +from google.cloud.ndb import exceptions from google.cloud.ndb import _remote from google.cloud.ndb import tasklets @@ -21,13 +25,15 @@ class TestRemoteCall: @staticmethod def test_constructor(): - call = _remote.RemoteCall("future", "info") - assert call.future == "future" + future = tasklets.Future() + call = _remote.RemoteCall(future, "info") + assert call.future is future assert call.info == "info" @staticmethod def test_repr(): - call = _remote.RemoteCall(None, "a remote call") + future = tasklets.Future() + call = _remote.RemoteCall(future, "a remote call") assert repr(call) == "a remote call" @staticmethod @@ -38,6 +44,14 @@ def test_exception(): call = _remote.RemoteCall(future, "testing") assert call.exception() is error + @staticmethod + def test_exception_FutureCancelledError(): + error = grpc.FutureCancelledError() + future = tasklets.Future() + future.exception = mock.Mock(side_effect=error) + call = _remote.RemoteCall(future, "testing") + assert isinstance(call.exception(), exceptions.Cancelled) + @staticmethod def test_result(): future = tasklets.Future() @@ -52,4 +66,22 @@ def test_add_done_callback(): callback = mock.Mock(spec=()) call.add_done_callback(callback) future.set_result(None) - callback.assert_called_once_with(future) + callback.assert_called_once_with(call) + + @staticmethod + def test_add_done_callback_already_done(): + future = tasklets.Future() + future.set_result(None) + call = _remote.RemoteCall(future, "testing") + callback = mock.Mock(spec=()) + call.add_done_callback(callback) + callback.assert_called_once_with(call) + + @staticmethod + def test_cancel(): + future = tasklets.Future() + call = _remote.RemoteCall(future, "testing") + call.cancel() + assert future.cancelled() + with pytest.raises(exceptions.Cancelled): + call.result() diff --git a/tests/unit/test_tasklets.py b/tests/unit/test_tasklets.py index c2ff12c8..cda4b50f 100644 --- a/tests/unit/test_tasklets.py +++ b/tests/unit/test_tasklets.py @@ -20,6 +20,7 @@ from google.cloud.ndb import context as context_module from google.cloud.ndb import _eventloop +from google.cloud.ndb import exceptions from google.cloud.ndb import _remote from google.cloud.ndb import tasklets @@ -188,10 +189,38 @@ def side_effects(future): assert _eventloop.run1.call_count == 3 @staticmethod + @pytest.mark.usefixtures("in_context") def test_cancel(): - future = tasklets.Future() - with pytest.raises(NotImplementedError): - future.cancel() + # Integration test. Actually test that a cancel propagates properly. + rpc = tasklets.Future("Fake RPC") + wrapped_rpc = _remote.RemoteCall(rpc, "Wrapped Fake RPC") + + @tasklets.tasklet + def inner_tasklet(): + yield wrapped_rpc + + @tasklets.tasklet + def outer_tasklet(): + yield inner_tasklet() + + future = outer_tasklet() + assert not future.cancelled() + future.cancel() + assert rpc.cancelled() + + with pytest.raises(exceptions.Cancelled): + future.result() + + assert future.cancelled() + + @staticmethod + @pytest.mark.usefixtures("in_context") + def test_cancel_already_done(): + future = tasklets.Future("testing") + future.set_result(42) + future.cancel() # noop + assert not future.cancelled() + assert future.result() == 42 @staticmethod def test_cancelled(): @@ -358,6 +387,31 @@ def generator_function(dependencies): assert future.result() == 11 assert future.context is in_context + @staticmethod + def test_cancel_not_waiting(in_context): + dependency = tasklets.Future() + future = tasklets._TaskletFuture(None, in_context) + future.cancel() + + assert not dependency.cancelled() + with pytest.raises(exceptions.Cancelled): + future.result() + + @staticmethod + def test_cancel_waiting_on_dependency(in_context): + def generator_function(dependency): + yield dependency + + dependency = tasklets.Future() + generator = generator_function(dependency) + future = tasklets._TaskletFuture(generator, in_context) + future._advance_tasklet() + future.cancel() + + assert dependency.cancelled() + with pytest.raises(exceptions.Cancelled): + future.result() + class Test_MultiFuture: @staticmethod @@ -388,6 +442,16 @@ def test_error(): with pytest.raises(Exception): future.result() + @staticmethod + def test_cancel(): + dependencies = (tasklets.Future(), tasklets.Future()) + future = tasklets._MultiFuture(dependencies) + future.cancel() + assert dependencies[0].cancelled() + assert dependencies[1].cancelled() + with pytest.raises(exceptions.Cancelled): + future.result() + class Test__get_return_value: @staticmethod