Skip to content

Commit

Permalink
fix: detect cache write failure for MemcacheCache (#665)
Browse files Browse the repository at this point in the history
Fixes #656
  • Loading branch information
Chris Rossi authored Jun 7, 2021
1 parent 73020ed commit 5d7f163
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 9 deletions.
45 changes: 43 additions & 2 deletions google/cloud/ndb/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class _GlobalCacheSetBatch(_GlobalCacheBatch):
def __init__(self, options):
self.expires = options.get("expires")
self.todo = {}
self.futures = []
self.futures = {}

def add(self, key, value):
"""Add a key, value pair to store in the cache.
Expand All @@ -316,11 +316,52 @@ def add(self, key, value):
Returns:
tasklets.Future: Eventual result will be ``None``.
"""
future = self.futures.get(key)
if future:
if self.todo[key] != value:
# I don't think this is likely to happen. I'd like to know about it if
# it does because that might indicate a bad software design.
future = tasklets.Future()
future.set_exception(
RuntimeError(
"Key has already been set in this batch: {}".format(key)
)
)

return future

future = tasklets.Future(info=self.future_info(key, value))
self.todo[key] = value
self.futures.append(future)
self.futures[key] = future
return future

def done_callback(self, cache_call):
"""Process results of call to global cache.
If there is an exception for the cache call, distribute that to waiting
futures, otherwise examine the result of the cache call. If the result is
:data:`None`, simply set the result to :data:`None` for all waiting futures.
Otherwise, if the result is a `dict`, use that to propagate results for
individual keys to waiting figures.
"""
exception = cache_call.exception()
if exception:
for future in self.futures.values():
future.set_exception(exception)
return

result = cache_call.result()
if result:
for key, future in self.futures.items():
key_result = result.get(key, None)
if isinstance(key_result, Exception):
future.set_exception(key_result)
else:
future.set_result(key_result)
else:
for future in self.futures.values():
future.set_result(None)

def make_call(self):
"""Call :method:`GlobalCache.set`."""
return _global_cache().set(self.todo, expires=self.expires)
Expand Down
38 changes: 36 additions & 2 deletions google/cloud/ndb/global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import threading
import time
import uuid
import warnings

import pymemcache
import redis as redis_module
Expand Down Expand Up @@ -106,6 +107,12 @@ def set(self, items, expires=None):
items (Dict[bytes, Union[bytes, None]]): Mapping of keys to
serialized entities.
expires (Optional[float]): Number of seconds until value expires.
Returns:
Optional[Dict[bytes, Any]]: May return :data:`None`, or a `dict` mapping
keys to arbitrary results. If the result for a key is an instance of
`Exception`, the result will be raised as an exception in that key's
future.
"""
raise NotImplementedError

Expand Down Expand Up @@ -446,9 +453,22 @@ class MemcacheCache(GlobalCache):
errors in the cache layer. Default: :data:`True`.
"""

class KeyNotSet(Exception):
def __init__(self, key):
self.key = key
super(MemcacheCache.KeyNotSet, self).__init__(
"SET operation failed in memcache for key: {}".format(key)
)

def __eq__(self, other):
if isinstance(other, type(self)):
return self.key == other.key
return NotImplemented

transient_errors = (
IOError,
ConnectionError,
KeyNotSet,
pymemcache.exceptions.MemcacheServerError,
pymemcache.exceptions.MemcacheUnexpectedCloseError,
)
Expand Down Expand Up @@ -561,9 +581,23 @@ def get(self, keys):

def set(self, items, expires=None):
"""Implements :meth:`GlobalCache.set`."""
items = {self._key(key): value for key, value in items.items()}
expires = expires if expires else 0
self.client.set_many(items, expire=expires)
orig_items = items
items = {}
orig_keys = {}
for orig_key, value in orig_items.items():
key = self._key(orig_key)
orig_keys[key] = orig_key
items[key] = value

unset_keys = self.client.set_many(items, expire=expires, noreply=False)
if unset_keys:
unset_keys = [orig_keys[key] for key in unset_keys]
warnings.warn(
"Keys failed to set in memcache: {}".format(unset_keys),
RuntimeWarning,
)
return {key: MemcacheCache.KeyNotSet(key) for key in unset_keys}

def delete(self, keys):
"""Implements :meth:`GlobalCache.delete`."""
Expand Down
63 changes: 58 additions & 5 deletions tests/unit/test__cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,17 @@ def test_with_expires(_batch, _global_cache):


class Test_GlobalCacheSetBatch:
@staticmethod
def test_add_duplicate_key_and_value():
batch = _cache._GlobalCacheSetBatch({})
future1 = batch.add(b"foo", b"one")
future2 = batch.add(b"foo", b"one")
assert future1 is future2

@staticmethod
def test_add_and_idle_and_done_callbacks(in_context):
cache = mock.Mock()
cache = mock.Mock(spec=("set",))
cache.set.return_value = []

batch = _cache._GlobalCacheSetBatch({})
future1 = batch.add(b"foo", b"one")
Expand All @@ -363,9 +371,29 @@ def test_add_and_idle_and_done_callbacks(in_context):
assert future1.result() is None
assert future2.result() is None

@staticmethod
def test_add_and_idle_and_done_callbacks_with_duplicate_keys(in_context):
cache = mock.Mock(spec=("set",))
cache.set.return_value = []

batch = _cache._GlobalCacheSetBatch({})
future1 = batch.add(b"foo", b"one")
future2 = batch.add(b"foo", b"two")

assert batch.expires is None

with in_context.new(global_cache=cache).use():
batch.idle_callback()

cache.set.assert_called_once_with({b"foo": b"one"}, expires=None)
assert future1.result() is None
with pytest.raises(RuntimeError):
future2.result()

@staticmethod
def test_add_and_idle_and_done_callbacks_with_expires(in_context):
cache = mock.Mock()
cache = mock.Mock(spec=("set",))
cache.set.return_value = []

batch = _cache._GlobalCacheSetBatch({"expires": 5})
future1 = batch.add(b"foo", b"one")
Expand All @@ -383,7 +411,8 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context):
@staticmethod
def test_add_and_idle_and_done_callbacks_w_error(in_context):
error = Exception("spurious error")
cache = mock.Mock()
cache = mock.Mock(spec=("set",))
cache.set.return_value = []
cache.set.return_value = tasklets.Future()
cache.set.return_value.set_exception(error)

Expand All @@ -400,6 +429,28 @@ def test_add_and_idle_and_done_callbacks_w_error(in_context):
assert future1.exception() is error
assert future2.exception() is error

@staticmethod
def test_done_callbacks_with_results(in_context):
class SpeciousError(Exception):
pass

cache_call = _future_result(
{
b"foo": "this is a result",
b"bar": SpeciousError("this is also a kind of result"),
}
)

batch = _cache._GlobalCacheSetBatch({})
future1 = batch.add(b"foo", b"one")
future2 = batch.add(b"bar", b"two")

batch.done_callback(cache_call)

assert future1.result() == "this is a result"
with pytest.raises(SpeciousError):
assert future2.result()


@pytest.mark.usefixtures("in_context")
@mock.patch("google.cloud.ndb._cache._global_cache")
Expand Down Expand Up @@ -552,7 +603,8 @@ def test_with_expires(_batch, _global_cache):
class Test_GlobalCacheCompareAndSwapBatch:
@staticmethod
def test_add_and_idle_and_done_callbacks(in_context):
cache = mock.Mock()
cache = mock.Mock(spec=("compare_and_swap",))
cache.compare_and_swap.return_value = None

batch = _cache._GlobalCacheCompareAndSwapBatch({})
future1 = batch.add(b"foo", b"one")
Expand All @@ -571,7 +623,8 @@ def test_add_and_idle_and_done_callbacks(in_context):

@staticmethod
def test_add_and_idle_and_done_callbacks_with_expires(in_context):
cache = mock.Mock()
cache = mock.Mock(spec=("compare_and_swap",))
cache.compare_and_swap.return_value = None

batch = _cache._GlobalCacheCompareAndSwapBatch({"expires": 5})
future1 = batch.add(b"foo", b"one")
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def test_get():
@staticmethod
def test_set():
client = mock.Mock(spec=("set_many",))
client.set_many.return_value = []
cache = global_cache.MemcacheCache(client)
key1 = cache._key(b"one")
key2 = cache._key(b"two")
Expand All @@ -464,11 +465,13 @@ def test_set():
key2: "shoe",
},
expire=0,
noreply=False,
)

@staticmethod
def test_set_w_expires():
client = mock.Mock(spec=("set_many",))
client.set_many.return_value = []
cache = global_cache.MemcacheCache(client)
key1 = cache._key(b"one")
key2 = cache._key(b"two")
Expand All @@ -485,8 +488,41 @@ def test_set_w_expires():
key2: "shoe",
},
expire=5,
noreply=False,
)

@staticmethod
def test_set_failed_key():
client = mock.Mock(spec=("set_many",))
cache = global_cache.MemcacheCache(client)
key1 = cache._key(b"one")
key2 = cache._key(b"two")
client.set_many.return_value = [key2]

unset = cache.set(
{
b"one": "bun",
b"two": "shoe",
}
)
assert unset == {b"two": global_cache.MemcacheCache.KeyNotSet(b"two")}

client.set_many.assert_called_once_with(
{
key1: "bun",
key2: "shoe",
},
expire=0,
noreply=False,
)

@staticmethod
def test_KeyNotSet():
unset = global_cache.MemcacheCache.KeyNotSet(b"foo")
assert unset == global_cache.MemcacheCache.KeyNotSet(b"foo")
assert not unset == global_cache.MemcacheCache.KeyNotSet(b"goo")
assert not unset == "hamburger"

@staticmethod
def test_delete():
client = mock.Mock(spec=("delete_many",))
Expand Down

0 comments on commit 5d7f163

Please sign in to comment.