From 0f3d73ac86c734715f6ab0e4490207b249b506dd Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 1 Jun 2021 12:02:52 -0400 Subject: [PATCH] fix: prevent overwrite of cache lock value Fixes #651 --- google/cloud/ndb/_cache.py | 107 +++++++++++---- google/cloud/ndb/_datastore_api.py | 11 +- google/cloud/ndb/global_cache.py | 50 +++++++ tests/unit/test__cache.py | 211 ++++++++++++++++++++++++++--- tests/unit/test__datastore_api.py | 22 +++ tests/unit/test_global_cache.py | 102 ++++++++++++++ 6 files changed, 459 insertions(+), 44 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index ebf51030..83c7b75f 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -306,6 +306,32 @@ def __init__(self, options): self.todo = {} self.futures = {} + def done_callback(self, cache_call): + """Process results of call to global cache. + If there is an exception for the cache call, distribute that to waiting + futures, otherwise examine the result of the cache call. If the result is + :data:`None`, simply set the result to :data:`None` for all waiting futures. + Otherwise, if the result is a `dict`, use that to propagate results for + individual keys to waiting figures. + """ + exception = cache_call.exception() + if exception: + for future in self.futures.values(): + future.set_exception(exception) + return + + result = cache_call.result() + if result: + for key, future in self.futures.items(): + key_result = result.get(key, None) + if isinstance(key_result, Exception): + future.set_exception(key_result) + else: + future.set_result(key_result) + else: + for future in self.futures.values(): + future.set_result(None) + def add(self, key, value): """Add a key, value pair to store in the cache. @@ -335,36 +361,66 @@ def add(self, key, value): self.futures[key] = future return future - def done_callback(self, cache_call): - """Process results of call to global cache. + def make_call(self): + """Call :method:`GlobalCache.set`.""" + return _global_cache().set(self.todo, expires=self.expires) - If there is an exception for the cache call, distribute that to waiting - futures, otherwise examine the result of the cache call. If the result is - :data:`None`, simply set the result to :data:`None` for all waiting futures. - Otherwise, if the result is a `dict`, use that to propagate results for - individual keys to waiting figures. + def future_info(self, key, value): + """Generate info string for Future.""" + return "GlobalCache.set_if_not_exists({}, {})".format(key, value) + + +@_handle_transient_errors() +def global_set_if_not_exists(key, value, expires=None, read=False): + """Store entity in the global cache if key is not already present. + + Args: + key (bytes): The key to save. + value (bytes): The entity to save. + expires (Optional[float]): Number of seconds until value expires. + read (bool): Indicates if being set in a read (lookup) context. + + Returns: + tasklets.Future: Eventual result will be a ``bool`` value which will be + :data:`True` if a new value was set for the key, or :data:`False` if a value + was already set for the key. + """ + options = {} + if expires: + options = {"expires": expires} + + batch = _batch.get_batch(_GlobalCacheSetIfNotExistsBatch, options) + return batch.add(key, value) + + +class _GlobalCacheSetIfNotExistsBatch(_GlobalCacheSetBatch): + """Batch for global cache set_if_not_exists requests. """ + + def add(self, key, value): + """Add a key, value pair to store in the cache. + + Arguments: + key (bytes): The key to store in the cache. + value (bytes): The value to store in the cache. + + Returns: + tasklets.Future: Eventual result will be a ``bool`` value which will be + :data:`True` if a new value was set for the key, or :data:`False` if a + value was already set for the key. """ - exception = cache_call.exception() - if exception: - for future in self.futures.values(): - future.set_exception(exception) - return + if key in self.todo: + future = tasklets.Future() + future.set_result(False) + return future - result = cache_call.result() - if result: - for key, future in self.futures.items(): - key_result = result.get(key, None) - if isinstance(key_result, Exception): - future.set_exception(key_result) - else: - future.set_result(key_result) - else: - for future in self.futures.values(): - future.set_result(None) + future = tasklets.Future(info=self.future_info(key, value)) + self.todo[key] = value + self.futures[key] = future + return future def make_call(self): """Call :method:`GlobalCache.set`.""" - return _global_cache().set(self.todo, expires=self.expires) + return _global_cache().set_if_not_exists(self.todo, expires=self.expires) def future_info(self, key, value): """Generate info string for Future.""" @@ -523,6 +579,9 @@ def global_lock(key, read=False): Returns: tasklets.Future: Eventual result will be ``None``. """ + if read: + return global_set_if_not_exists(key, _LOCKED, expires=_LOCK_TIME, read=read) + return global_set(key, _LOCKED, expires=_LOCK_TIME, read=read) diff --git a/google/cloud/ndb/_datastore_api.py b/google/cloud/ndb/_datastore_api.py index f7a247a9..aec514f1 100644 --- a/google/cloud/ndb/_datastore_api.py +++ b/google/cloud/ndb/_datastore_api.py @@ -146,8 +146,15 @@ def lookup(key, options): entity_pb.MergeFromString(result) elif use_datastore: - yield _cache.global_lock(cache_key, read=True) - yield _cache.global_watch(cache_key) + lock_acquired = yield _cache.global_lock(cache_key, read=True) + if lock_acquired: + yield _cache.global_watch(cache_key) + + else: + # Another thread locked or wrote to this key after the call to + # _cache.global_get above. Behave as though the key was locked by + # another thread and don't attempt to write our value below + key_locked = True if entity_pb is _NOT_FOUND and use_datastore: batch = _batch.get_batch(_LookupBatch, options) diff --git a/google/cloud/ndb/global_cache.py b/google/cloud/ndb/global_cache.py index 8d39a60a..ca70d3b4 100644 --- a/google/cloud/ndb/global_cache.py +++ b/google/cloud/ndb/global_cache.py @@ -116,6 +116,23 @@ def set(self, items, expires=None): """ raise NotImplementedError + @abc.abstractmethod + def set_if_not_exists(self, items, expires=None): + """Stores entities in the cache if and only if keys are not already set. + + Arguments: + items (Dict[bytes, Union[bytes, None]]): Mapping of keys to + serialized entities. + expires (Optional[float]): Number of seconds until value expires. + + + Returns: + Dict[bytes, bool]: A `dict` mapping to boolean value wich will be + :data:`True` if that key was set with a new value, and :data:`False` + otherwise. + """ + raise NotImplementedError + @abc.abstractmethod def delete(self, keys): """Remove entities from the cache. @@ -217,6 +234,18 @@ def set(self, items, expires=None): for key, value in items.items(): self.cache[key] = (value, expires) # Supposedly threadsafe + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + if expires: + expires = time.time() + expires + + results = {} + for key, value in items.items(): + set_value = (value, expires) + results[key] = self.cache.setdefault(key, set_value) is set_value + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" for key in keys: @@ -355,6 +384,16 @@ def set(self, items, expires=None): for key in items.keys(): self.redis.expire(key, expires) + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + results = {} + for key, value in items.items(): + results[key] = key_was_set = self.redis.setnx(key, value) + if key_was_set and expires: + self.redis.expire(key, expires) + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" self.redis.delete(*keys) @@ -599,6 +638,17 @@ def set(self, items, expires=None): ) return {key: MemcacheCache.KeyNotSet(key) for key in unset_keys} + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + expires = expires if expires else 0 + results = {} + for key, value in items.items(): + results[key] = self.client.add( + self._key(key), value, expire=expires, noreply=False + ) + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" keys = [self._key(key) for key in keys] diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index d835f8c3..2658f470 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -412,7 +412,6 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): def test_add_and_idle_and_done_callbacks_w_error(in_context): error = Exception("spurious error") cache = mock.Mock(spec=("set",)) - cache.set.return_value = [] cache.set.return_value = tasklets.Future() cache.set.return_value.set_exception(error) @@ -452,6 +451,160 @@ class SpeciousError(Exception): assert future2.result() +@pytest.mark.usefixtures("in_context") +class Test_global_set_if_not_exists: + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_without_expires(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert _cache.global_set_if_not_exists(b"key", b"value").result() == "hi mom!" + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {} + ) + batch.add.assert_called_once_with(b"key", b"value") + + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_with_expires(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert ( + _cache.global_set_if_not_exists(b"key", b"value", expires=123).result() + == "hi mom!" + ) + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {"expires": 123} + ) + batch.add.assert_called_once_with(b"key", b"value") + + +class Test_GlobalCacheSetIfNotExistsBatch: + @staticmethod + def test_add_duplicate_key_and_value(): + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"foo", b"one") + assert not future1.done() + assert future2.result() is False + + @staticmethod + def test_add_and_idle_and_done_callbacks(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = [] + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires is None + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=None + ) + assert future1.result() is None + assert future2.result() is None + + @staticmethod + def test_add_and_idle_and_done_callbacks_with_duplicate_keys(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = {b"foo": True} + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"foo", b"two") + + assert batch.expires is None + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with({b"foo": b"one"}, expires=None) + assert future1.result() is True + assert future2.result() is False + + @staticmethod + def test_add_and_idle_and_done_callbacks_with_expires(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = [] + + batch = _cache._GlobalCacheSetIfNotExistsBatch({"expires": 5}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires == 5 + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=5 + ) + assert future1.result() is None + assert future2.result() is None + + @staticmethod + def test_add_and_idle_and_done_callbacks_w_error(in_context): + error = Exception("spurious error") + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = tasklets.Future() + cache.set_if_not_exists.return_value.set_exception(error) + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=None + ) + assert future1.exception() is error + assert future2.exception() is error + + @staticmethod + def test_done_callbacks_with_results(in_context): + class SpeciousError(Exception): + pass + + cache_call = _future_result( + { + b"foo": "this is a result", + b"bar": SpeciousError("this is also a kind of result"), + } + ) + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + batch.done_callback(cache_call) + + assert future1.result() == "this is a result" + with pytest.raises(SpeciousError): + assert future2.result() + + @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache._global_cache") @mock.patch("google.cloud.ndb._cache._batch") @@ -642,24 +795,46 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): assert future2.result() is None -@pytest.mark.usefixtures("in_context") -@mock.patch("google.cloud.ndb._cache._global_cache") -@mock.patch("google.cloud.ndb._cache._batch") -def test_global_lock(_batch, _global_cache): - batch = _batch.get_batch.return_value - future = _future_result("hi mom!") - batch.add.return_value = future - _global_cache.return_value = mock.Mock( - transient_errors=(), - strict_write=False, - spec=("transient_errors", "strict_write"), - ) +class Test_global_lock: + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_global_lock_read(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_read=False, + spec=("transient_errors", "strict_read"), + ) - assert _cache.global_lock(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with( - _cache._GlobalCacheSetBatch, {"expires": _cache._LOCK_TIME} - ) - batch.add.assert_called_once_with(b"key", _cache._LOCKED) + assert _cache.global_lock(b"key", read=True).result() == "hi mom!" + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {"expires": _cache._LOCK_TIME} + ) + batch.add.assert_called_once_with(b"key", _cache._LOCKED) + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_global_lock_write(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert _cache.global_lock(b"key").result() == "hi mom!" + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetBatch, {"expires": _cache._LOCK_TIME} + ) + batch.add.assert_called_once_with(b"key", _cache._LOCKED) def test_is_locked_value(): diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 8e81fe15..6ee132b1 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -270,6 +270,28 @@ class SomeKind(model.Model): assert global_cache.get([cache_key]) == [cache_value] + @staticmethod + @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") + def test_cache_miss_lock_not_acquired(_LookupBatch, global_cache): + class SomeKind(model.Model): + pass + + key = key_module.Key("SomeKind", 1) + cache_key = _cache.global_cache_key(key._key) + + entity = SomeKind(key=key) + entity_pb = model._entity_to_protobuf(entity) + + batch = _LookupBatch.return_value + batch.add.return_value = future_result(entity_pb) + + global_cache.set_if_not_exists = mock.Mock(return_value={cache_key: False}) + + future = _api.lookup(key._key, _options.ReadOptions()) + assert future.result() == entity_pb + + assert global_cache.get([cache_key]) == [None] + @staticmethod @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") def test_cache_miss_no_datastore(_LookupBatch, global_cache): diff --git a/tests/unit/test_global_cache.py b/tests/unit/test_global_cache.py index 0a724a23..b5577828 100644 --- a/tests/unit/test_global_cache.py +++ b/tests/unit/test_global_cache.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections + try: from unittest import mock except ImportError: # pragma: NO PY3 COVER @@ -32,6 +34,9 @@ def get(self, keys): def set(self, items, expires=None): return super(MockImpl, self).set(items, expires=expires) + def set_if_not_exists(self, items, expires=None): + return super(MockImpl, self).set_if_not_exists(items, expires=expires) + def delete(self, keys): return super(MockImpl, self).delete(keys) @@ -59,6 +64,11 @@ def test_set(self): with pytest.raises(NotImplementedError): cache.set({b"foo": "bar"}) + def test_set_if_not_exists(self): + cache = self.make_one() + with pytest.raises(NotImplementedError): + cache.set_if_not_exists({b"foo": "bar"}) + def test_delete(self): cache = self.make_one() with pytest.raises(NotImplementedError): @@ -123,6 +133,37 @@ def test_set_get_delete_w_expires(time): result = cache.get([b"two", b"three", b"one"]) assert result == [None, None, None] + @staticmethod + def test_set_if_not_exists(): + cache = global_cache._InProcessGlobalCache() + result = cache.set_if_not_exists({b"one": b"foo", b"two": b"bar"}) + assert result == {b"one": True, b"two": True} + + result = cache.set_if_not_exists({b"two": b"bar", b"three": b"baz"}) + assert result == {b"two": False, b"three": True} + + result = cache.get([b"two", b"three", b"one"]) + assert result == [b"bar", b"baz", b"foo"] + + @staticmethod + @mock.patch("google.cloud.ndb.global_cache.time") + def test_set_if_not_exists_w_expires(time): + time.time.return_value = 0 + + cache = global_cache._InProcessGlobalCache() + result = cache.set_if_not_exists({b"one": b"foo", b"two": b"bar"}, expires=5) + assert result == {b"one": True, b"two": True} + + result = cache.set_if_not_exists({b"two": b"bar", b"three": b"baz"}, expires=5) + assert result == {b"two": False, b"three": True} + + result = cache.get([b"two", b"three", b"one"]) + assert result == [b"bar", b"baz", b"foo"] + + time.time.return_value = 10 + result = cache.get([b"two", b"three", b"one"]) + assert result == [None, None, None] + @staticmethod def test_watch_compare_and_swap(): cache = global_cache._InProcessGlobalCache() @@ -234,6 +275,37 @@ def mock_expire(key, expires): redis.mset.assert_called_once_with(cache_items) assert expired == {"a": 32, "b": 32} + @staticmethod + def test_set_if_not_exists(): + redis = mock.Mock(spec=("setnx",)) + redis.setnx.side_effect = (True, False) + cache_items = collections.OrderedDict([("a", "foo"), ("b", "bar")]) + cache = global_cache.RedisCache(redis) + results = cache.set_if_not_exists(cache_items) + assert results == {"a": True, "b": False} + redis.setnx.assert_has_calls( + [ + mock.call("a", "foo"), + mock.call("b", "bar"), + ] + ) + + @staticmethod + def test_set_if_not_exists_w_expires(): + redis = mock.Mock(spec=("setnx", "expire")) + redis.setnx.side_effect = (True, False) + cache_items = collections.OrderedDict([("a", "foo"), ("b", "bar")]) + cache = global_cache.RedisCache(redis) + results = cache.set_if_not_exists(cache_items, expires=123) + assert results == {"a": True, "b": False} + redis.setnx.assert_has_calls( + [ + mock.call("a", "foo"), + mock.call("b", "bar"), + ] + ) + redis.expire.assert_called_once_with("a", 123) + @staticmethod def test_delete(): redis = mock.Mock(spec=("delete",)) @@ -468,6 +540,36 @@ def test_set(): noreply=False, ) + @staticmethod + def test_set_if_not_exists(): + client = mock.Mock(spec=("add",)) + client.add.side_effect = (True, False) + cache_items = collections.OrderedDict([(b"a", b"foo"), (b"b", b"bar")]) + cache = global_cache.MemcacheCache(client) + results = cache.set_if_not_exists(cache_items) + assert results == {b"a": True, b"b": False} + client.add.assert_has_calls( + [ + mock.call(cache._key(b"a"), b"foo", expire=0, noreply=False), + mock.call(cache._key(b"b"), b"bar", expire=0, noreply=False), + ] + ) + + @staticmethod + def test_set_if_not_exists_w_expires(): + client = mock.Mock(spec=("add",)) + client.add.side_effect = (True, False) + cache_items = collections.OrderedDict([(b"a", b"foo"), (b"b", b"bar")]) + cache = global_cache.MemcacheCache(client) + results = cache.set_if_not_exists(cache_items, expires=123) + assert results == {b"a": True, b"b": False} + client.add.assert_has_calls( + [ + mock.call(cache._key(b"a"), b"foo", expire=123, noreply=False), + mock.call(cache._key(b"b"), b"bar", expire=123, noreply=False), + ] + ) + @staticmethod def test_set_w_expires(): client = mock.Mock(spec=("set_many",))