Skip to content

Commit

Permalink
Implement Future.cancel()
Browse files Browse the repository at this point in the history
This brings our implementation of ``Future`` in parity with the
``Future`` interface defined in the Python 3 standard library, and makes
it possible to cancel asynchronous ``grpc`` calls from NDB.
  • Loading branch information
Chris Rossi committed Sep 17, 2019
1 parent 58dcd4a commit 26be739
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 17 deletions.
38 changes: 35 additions & 3 deletions google/cloud/ndb/_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

# In its own module to avoid circular import between _datastore_api and
# tasklets modules.
import grpc

from google.cloud.ndb import exceptions


class RemoteCall:
Expand All @@ -36,18 +39,47 @@ class RemoteCall:
def __init__(self, future, info):
self.future = future
self.info = info
self._callbacks = []

future.add_done_callback(self._finish)

def __repr__(self):
return self.info

def exception(self):
"""Calls :meth:`grpc.Future.exception` on attr:`future`."""
return self.future.exception()
# GRPC will actually raise FutureCancelledError.
# We'll translate that to our own Cancelled exception and *return* it,
# which is far more polite for a method that *returns exceptions*.
try:
return self.future.exception()
except grpc.FutureCancelledError:
return exceptions.Cancelled()

def result(self):
"""Calls :meth:`grpc.Future.result` on attr:`future`."""
return self.future.result()

def add_done_callback(self, callback):
"""Calls :meth:`grpc.Future.add_done_callback` on attr:`future`."""
return self.future.add_done_callback(callback)
"""Add a callback function to be run upon task completion. Will run
immediately if task has already finished.
Args:
callback (Callable): The function to execute.
"""
if self.future.done():
callback(self)
else:
self._callbacks.append(callback)

def cancel(self):
"""Calls :meth:`grpc.Future.cancel` on attr:`cancel`."""
return self.future.cancel()

def _finish(self, rpc):
"""Called when remote future is finished.
Used to call our own done callbacks.
"""
for callback in self._callbacks:
callback(self)
9 changes: 9 additions & 0 deletions google/cloud/ndb/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,12 @@ class NoLongerImplementedError(NotImplementedError):

def __init__(self):
super(NoLongerImplementedError, self).__init__("No longer implemented")


class Cancelled(Error):
"""An operation has been cancelled by user request.
Raised when trying to get a result from a future that has been cancelled by
a call to ``Future.cancel`` (possibly on a future that depends on this
future).
"""
40 changes: 33 additions & 7 deletions google/cloud/ndb/tasklets.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def main():

from google.cloud.ndb import context as context_module
from google.cloud.ndb import _eventloop
from google.cloud.ndb import exceptions
from google.cloud.ndb import _remote

__all__ = [
Expand Down Expand Up @@ -232,20 +233,26 @@ def add_done_callback(self, callback):
self._callbacks.append(callback)

def cancel(self):
"""Cancel the task for this future.
"""Attempt to cancel the task for this future.
Raises:
NotImplementedError: Always, not supported.
If the task has already completed, this call will do nothing.
Otherwise, this will attempt to cancel whatever task this future is
waiting on. There is no specific guarantee the underlying task will be
cancelled.
"""
raise NotImplementedError
if not self.done():
self.set_exception(exceptions.Cancelled())

def cancelled(self):
"""Get whether task for this future has been canceled.
"""Get whether the task for this future has been cancelled.
Returns:
:data:`False`: Always.
:data:`True`: If this future's task has been cancelled, otherwise
:data:`False`.
"""
return False
return self._exception is not None and isinstance(
self._exception, exceptions.Cancelled
)

@staticmethod
def wait_any(futures):
Expand Down Expand Up @@ -278,6 +285,7 @@ def __init__(self, generator, context, info="Unknown"):
super(_TaskletFuture, self).__init__(info=info)
self.generator = generator
self.context = context
self.waiting_on = None

def _advance_tasklet(self, send_value=None, error=None):
"""Advance a tasklet one step by sending in a value or error."""
Expand Down Expand Up @@ -324,6 +332,8 @@ def done_callback(yielded):
# in Legacy) directly. Doing so, it has been found, can lead to
# exceeding the maximum recursion depth. Queing it up to run on the
# event loop avoids this issue by keeping the call stack shallow.
self.waiting_on = None

error = yielded.exception()
if error:
_eventloop.call_soon(self._advance_tasklet, error=error)
Expand All @@ -332,19 +342,30 @@ def done_callback(yielded):

if isinstance(yielded, Future):
yielded.add_done_callback(done_callback)
self.waiting_on = yielded

elif isinstance(yielded, _remote.RemoteCall):
_eventloop.queue_rpc(yielded, done_callback)
self.waiting_on = yielded

elif isinstance(yielded, (list, tuple)):
future = _MultiFuture(yielded)
future.add_done_callback(done_callback)
self.waiting_on = future

else:
raise RuntimeError(
"A tasklet yielded an illegal value: {!r}".format(yielded)
)

def cancel(self):
"""Overrides :meth:`Future.cancel`."""
if self.waiting_on:
self.waiting_on.cancel()

else:
super(_TaskletFuture, self).cancel()


def _get_return_value(stop):
"""Inspect `StopIteration` instance for return value of tasklet.
Expand Down Expand Up @@ -399,6 +420,11 @@ def _dependency_done(self, dependency):
result = tuple((future.result() for future in self._dependencies))
self.set_result(result)

def cancel(self):
"""Overrides :meth:`Future.cancel`."""
for dependency in self._dependencies:
dependency.cancel()


def tasklet(wrapped):
"""
Expand Down
25 changes: 25 additions & 0 deletions tests/system/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,31 @@ def make_entities():
assert [entity.foo for entity in results][:5] == [0, 1, 2, 3, 4]


@pytest.mark.usefixtures("client_context")
def test_fetch_and_immediately_cancel(dispose_of):
# Make a lot of entities so the query call won't complete before we get to
# call cancel.
n_entities = 500

class SomeKind(ndb.Model):
foo = ndb.IntegerProperty()

@ndb.toplevel
def make_entities():
entities = [SomeKind(foo=i) for i in range(n_entities)]
keys = yield [entity.put_async() for entity in entities]
raise ndb.Return(keys)

for key in make_entities():
dispose_of(key._key)

query = SomeKind.query()
future = query.fetch_async()
future.cancel()
with pytest.raises(ndb.exceptions.Cancelled):
future.result()


@pytest.mark.usefixtures("client_context")
def test_ancestor_query(ds_entity):
root_id = test_utils.system.unique_resource_id()
Expand Down
40 changes: 36 additions & 4 deletions tests/unit/test__remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,26 @@

from unittest import mock

import grpc
import pytest

from google.cloud.ndb import exceptions
from google.cloud.ndb import _remote
from google.cloud.ndb import tasklets


class TestRemoteCall:
@staticmethod
def test_constructor():
call = _remote.RemoteCall("future", "info")
assert call.future == "future"
future = tasklets.Future()
call = _remote.RemoteCall(future, "info")
assert call.future is future
assert call.info == "info"

@staticmethod
def test_repr():
call = _remote.RemoteCall(None, "a remote call")
future = tasklets.Future()
call = _remote.RemoteCall(future, "a remote call")
assert repr(call) == "a remote call"

@staticmethod
Expand All @@ -38,6 +44,14 @@ def test_exception():
call = _remote.RemoteCall(future, "testing")
assert call.exception() is error

@staticmethod
def test_exception_FutureCancelledError():
error = grpc.FutureCancelledError()
future = tasklets.Future()
future.exception = mock.Mock(side_effect=error)
call = _remote.RemoteCall(future, "testing")
assert isinstance(call.exception(), exceptions.Cancelled)

@staticmethod
def test_result():
future = tasklets.Future()
Expand All @@ -52,4 +66,22 @@ def test_add_done_callback():
callback = mock.Mock(spec=())
call.add_done_callback(callback)
future.set_result(None)
callback.assert_called_once_with(future)
callback.assert_called_once_with(call)

@staticmethod
def test_add_done_callback_already_done():
future = tasklets.Future()
future.set_result(None)
call = _remote.RemoteCall(future, "testing")
callback = mock.Mock(spec=())
call.add_done_callback(callback)
callback.assert_called_once_with(call)

@staticmethod
def test_cancel():
future = tasklets.Future()
call = _remote.RemoteCall(future, "testing")
call.cancel()
assert future.cancelled()
with pytest.raises(exceptions.Cancelled):
call.result()
70 changes: 67 additions & 3 deletions tests/unit/test_tasklets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from google.cloud.ndb import context as context_module
from google.cloud.ndb import _eventloop
from google.cloud.ndb import exceptions
from google.cloud.ndb import _remote
from google.cloud.ndb import tasklets

Expand Down Expand Up @@ -188,10 +189,38 @@ def side_effects(future):
assert _eventloop.run1.call_count == 3

@staticmethod
@pytest.mark.usefixtures("in_context")
def test_cancel():
future = tasklets.Future()
with pytest.raises(NotImplementedError):
future.cancel()
# Integration test. Actually test that a cancel propagates properly.
rpc = tasklets.Future("Fake RPC")
wrapped_rpc = _remote.RemoteCall(rpc, "Wrapped Fake RPC")

@tasklets.tasklet
def inner_tasklet():
yield wrapped_rpc

@tasklets.tasklet
def outer_tasklet():
yield inner_tasklet()

future = outer_tasklet()
assert not future.cancelled()
future.cancel()
assert rpc.cancelled()

with pytest.raises(exceptions.Cancelled):
future.result()

assert future.cancelled()

@staticmethod
@pytest.mark.usefixtures("in_context")
def test_cancel_already_done():
future = tasklets.Future("testing")
future.set_result(42)
future.cancel() # noop
assert not future.cancelled()
assert future.result() == 42

@staticmethod
def test_cancelled():
Expand Down Expand Up @@ -358,6 +387,31 @@ def generator_function(dependencies):
assert future.result() == 11
assert future.context is in_context

@staticmethod
def test_cancel_not_waiting(in_context):
dependency = tasklets.Future()
future = tasklets._TaskletFuture(None, in_context)
future.cancel()

assert not dependency.cancelled()
with pytest.raises(exceptions.Cancelled):
future.result()

@staticmethod
def test_cancel_waiting_on_dependency(in_context):
def generator_function(dependency):
yield dependency

dependency = tasklets.Future()
generator = generator_function(dependency)
future = tasklets._TaskletFuture(generator, in_context)
future._advance_tasklet()
future.cancel()

assert dependency.cancelled()
with pytest.raises(exceptions.Cancelled):
future.result()


class Test_MultiFuture:
@staticmethod
Expand Down Expand Up @@ -388,6 +442,16 @@ def test_error():
with pytest.raises(Exception):
future.result()

@staticmethod
def test_cancel():
dependencies = (tasklets.Future(), tasklets.Future())
future = tasklets._MultiFuture(dependencies)
future.cancel()
assert dependencies[0].cancelled()
assert dependencies[1].cancelled()
with pytest.raises(exceptions.Cancelled):
future.result()


class Test__get_return_value:
@staticmethod
Expand Down

0 comments on commit 26be739

Please sign in to comment.