From 424d1bd58db7ba0e40d18cc64b5d6c103d44f8ca Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 1 Jun 2021 12:02:52 -0400 Subject: [PATCH 01/13] fix: prevent overwrite of cache lock value Fixes #651 --- google/cloud/ndb/_cache.py | 107 +++++++++++---- google/cloud/ndb/_datastore_api.py | 11 +- google/cloud/ndb/global_cache.py | 50 +++++++ tests/unit/test__cache.py | 211 ++++++++++++++++++++++++++--- tests/unit/test__datastore_api.py | 22 +++ tests/unit/test_global_cache.py | 102 ++++++++++++++ 6 files changed, 459 insertions(+), 44 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index ebf51030..83c7b75f 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -306,6 +306,32 @@ def __init__(self, options): self.todo = {} self.futures = {} + def done_callback(self, cache_call): + """Process results of call to global cache. + If there is an exception for the cache call, distribute that to waiting + futures, otherwise examine the result of the cache call. If the result is + :data:`None`, simply set the result to :data:`None` for all waiting futures. + Otherwise, if the result is a `dict`, use that to propagate results for + individual keys to waiting figures. + """ + exception = cache_call.exception() + if exception: + for future in self.futures.values(): + future.set_exception(exception) + return + + result = cache_call.result() + if result: + for key, future in self.futures.items(): + key_result = result.get(key, None) + if isinstance(key_result, Exception): + future.set_exception(key_result) + else: + future.set_result(key_result) + else: + for future in self.futures.values(): + future.set_result(None) + def add(self, key, value): """Add a key, value pair to store in the cache. @@ -335,36 +361,66 @@ def add(self, key, value): self.futures[key] = future return future - def done_callback(self, cache_call): - """Process results of call to global cache. + def make_call(self): + """Call :method:`GlobalCache.set`.""" + return _global_cache().set(self.todo, expires=self.expires) - If there is an exception for the cache call, distribute that to waiting - futures, otherwise examine the result of the cache call. If the result is - :data:`None`, simply set the result to :data:`None` for all waiting futures. - Otherwise, if the result is a `dict`, use that to propagate results for - individual keys to waiting figures. + def future_info(self, key, value): + """Generate info string for Future.""" + return "GlobalCache.set_if_not_exists({}, {})".format(key, value) + + +@_handle_transient_errors() +def global_set_if_not_exists(key, value, expires=None, read=False): + """Store entity in the global cache if key is not already present. + + Args: + key (bytes): The key to save. + value (bytes): The entity to save. + expires (Optional[float]): Number of seconds until value expires. + read (bool): Indicates if being set in a read (lookup) context. + + Returns: + tasklets.Future: Eventual result will be a ``bool`` value which will be + :data:`True` if a new value was set for the key, or :data:`False` if a value + was already set for the key. + """ + options = {} + if expires: + options = {"expires": expires} + + batch = _batch.get_batch(_GlobalCacheSetIfNotExistsBatch, options) + return batch.add(key, value) + + +class _GlobalCacheSetIfNotExistsBatch(_GlobalCacheSetBatch): + """Batch for global cache set_if_not_exists requests. """ + + def add(self, key, value): + """Add a key, value pair to store in the cache. + + Arguments: + key (bytes): The key to store in the cache. + value (bytes): The value to store in the cache. + + Returns: + tasklets.Future: Eventual result will be a ``bool`` value which will be + :data:`True` if a new value was set for the key, or :data:`False` if a + value was already set for the key. """ - exception = cache_call.exception() - if exception: - for future in self.futures.values(): - future.set_exception(exception) - return + if key in self.todo: + future = tasklets.Future() + future.set_result(False) + return future - result = cache_call.result() - if result: - for key, future in self.futures.items(): - key_result = result.get(key, None) - if isinstance(key_result, Exception): - future.set_exception(key_result) - else: - future.set_result(key_result) - else: - for future in self.futures.values(): - future.set_result(None) + future = tasklets.Future(info=self.future_info(key, value)) + self.todo[key] = value + self.futures[key] = future + return future def make_call(self): """Call :method:`GlobalCache.set`.""" - return _global_cache().set(self.todo, expires=self.expires) + return _global_cache().set_if_not_exists(self.todo, expires=self.expires) def future_info(self, key, value): """Generate info string for Future.""" @@ -523,6 +579,9 @@ def global_lock(key, read=False): Returns: tasklets.Future: Eventual result will be ``None``. """ + if read: + return global_set_if_not_exists(key, _LOCKED, expires=_LOCK_TIME, read=read) + return global_set(key, _LOCKED, expires=_LOCK_TIME, read=read) diff --git a/google/cloud/ndb/_datastore_api.py b/google/cloud/ndb/_datastore_api.py index f7a247a9..aec514f1 100644 --- a/google/cloud/ndb/_datastore_api.py +++ b/google/cloud/ndb/_datastore_api.py @@ -146,8 +146,15 @@ def lookup(key, options): entity_pb.MergeFromString(result) elif use_datastore: - yield _cache.global_lock(cache_key, read=True) - yield _cache.global_watch(cache_key) + lock_acquired = yield _cache.global_lock(cache_key, read=True) + if lock_acquired: + yield _cache.global_watch(cache_key) + + else: + # Another thread locked or wrote to this key after the call to + # _cache.global_get above. Behave as though the key was locked by + # another thread and don't attempt to write our value below + key_locked = True if entity_pb is _NOT_FOUND and use_datastore: batch = _batch.get_batch(_LookupBatch, options) diff --git a/google/cloud/ndb/global_cache.py b/google/cloud/ndb/global_cache.py index 8d39a60a..ca70d3b4 100644 --- a/google/cloud/ndb/global_cache.py +++ b/google/cloud/ndb/global_cache.py @@ -116,6 +116,23 @@ def set(self, items, expires=None): """ raise NotImplementedError + @abc.abstractmethod + def set_if_not_exists(self, items, expires=None): + """Stores entities in the cache if and only if keys are not already set. + + Arguments: + items (Dict[bytes, Union[bytes, None]]): Mapping of keys to + serialized entities. + expires (Optional[float]): Number of seconds until value expires. + + + Returns: + Dict[bytes, bool]: A `dict` mapping to boolean value wich will be + :data:`True` if that key was set with a new value, and :data:`False` + otherwise. + """ + raise NotImplementedError + @abc.abstractmethod def delete(self, keys): """Remove entities from the cache. @@ -217,6 +234,18 @@ def set(self, items, expires=None): for key, value in items.items(): self.cache[key] = (value, expires) # Supposedly threadsafe + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + if expires: + expires = time.time() + expires + + results = {} + for key, value in items.items(): + set_value = (value, expires) + results[key] = self.cache.setdefault(key, set_value) is set_value + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" for key in keys: @@ -355,6 +384,16 @@ def set(self, items, expires=None): for key in items.keys(): self.redis.expire(key, expires) + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + results = {} + for key, value in items.items(): + results[key] = key_was_set = self.redis.setnx(key, value) + if key_was_set and expires: + self.redis.expire(key, expires) + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" self.redis.delete(*keys) @@ -599,6 +638,17 @@ def set(self, items, expires=None): ) return {key: MemcacheCache.KeyNotSet(key) for key in unset_keys} + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + expires = expires if expires else 0 + results = {} + for key, value in items.items(): + results[key] = self.client.add( + self._key(key), value, expire=expires, noreply=False + ) + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" keys = [self._key(key) for key in keys] diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index d835f8c3..2658f470 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -412,7 +412,6 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): def test_add_and_idle_and_done_callbacks_w_error(in_context): error = Exception("spurious error") cache = mock.Mock(spec=("set",)) - cache.set.return_value = [] cache.set.return_value = tasklets.Future() cache.set.return_value.set_exception(error) @@ -452,6 +451,160 @@ class SpeciousError(Exception): assert future2.result() +@pytest.mark.usefixtures("in_context") +class Test_global_set_if_not_exists: + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_without_expires(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert _cache.global_set_if_not_exists(b"key", b"value").result() == "hi mom!" + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {} + ) + batch.add.assert_called_once_with(b"key", b"value") + + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_with_expires(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert ( + _cache.global_set_if_not_exists(b"key", b"value", expires=123).result() + == "hi mom!" + ) + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {"expires": 123} + ) + batch.add.assert_called_once_with(b"key", b"value") + + +class Test_GlobalCacheSetIfNotExistsBatch: + @staticmethod + def test_add_duplicate_key_and_value(): + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"foo", b"one") + assert not future1.done() + assert future2.result() is False + + @staticmethod + def test_add_and_idle_and_done_callbacks(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = [] + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires is None + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=None + ) + assert future1.result() is None + assert future2.result() is None + + @staticmethod + def test_add_and_idle_and_done_callbacks_with_duplicate_keys(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = {b"foo": True} + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"foo", b"two") + + assert batch.expires is None + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with({b"foo": b"one"}, expires=None) + assert future1.result() is True + assert future2.result() is False + + @staticmethod + def test_add_and_idle_and_done_callbacks_with_expires(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = [] + + batch = _cache._GlobalCacheSetIfNotExistsBatch({"expires": 5}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires == 5 + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=5 + ) + assert future1.result() is None + assert future2.result() is None + + @staticmethod + def test_add_and_idle_and_done_callbacks_w_error(in_context): + error = Exception("spurious error") + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = tasklets.Future() + cache.set_if_not_exists.return_value.set_exception(error) + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=None + ) + assert future1.exception() is error + assert future2.exception() is error + + @staticmethod + def test_done_callbacks_with_results(in_context): + class SpeciousError(Exception): + pass + + cache_call = _future_result( + { + b"foo": "this is a result", + b"bar": SpeciousError("this is also a kind of result"), + } + ) + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + batch.done_callback(cache_call) + + assert future1.result() == "this is a result" + with pytest.raises(SpeciousError): + assert future2.result() + + @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache._global_cache") @mock.patch("google.cloud.ndb._cache._batch") @@ -642,24 +795,46 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): assert future2.result() is None -@pytest.mark.usefixtures("in_context") -@mock.patch("google.cloud.ndb._cache._global_cache") -@mock.patch("google.cloud.ndb._cache._batch") -def test_global_lock(_batch, _global_cache): - batch = _batch.get_batch.return_value - future = _future_result("hi mom!") - batch.add.return_value = future - _global_cache.return_value = mock.Mock( - transient_errors=(), - strict_write=False, - spec=("transient_errors", "strict_write"), - ) +class Test_global_lock: + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_global_lock_read(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_read=False, + spec=("transient_errors", "strict_read"), + ) - assert _cache.global_lock(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with( - _cache._GlobalCacheSetBatch, {"expires": _cache._LOCK_TIME} - ) - batch.add.assert_called_once_with(b"key", _cache._LOCKED) + assert _cache.global_lock(b"key", read=True).result() == "hi mom!" + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {"expires": _cache._LOCK_TIME} + ) + batch.add.assert_called_once_with(b"key", _cache._LOCKED) + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_global_lock_write(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert _cache.global_lock(b"key").result() == "hi mom!" + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetBatch, {"expires": _cache._LOCK_TIME} + ) + batch.add.assert_called_once_with(b"key", _cache._LOCKED) def test_is_locked_value(): diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 8e81fe15..6ee132b1 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -270,6 +270,28 @@ class SomeKind(model.Model): assert global_cache.get([cache_key]) == [cache_value] + @staticmethod + @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") + def test_cache_miss_lock_not_acquired(_LookupBatch, global_cache): + class SomeKind(model.Model): + pass + + key = key_module.Key("SomeKind", 1) + cache_key = _cache.global_cache_key(key._key) + + entity = SomeKind(key=key) + entity_pb = model._entity_to_protobuf(entity) + + batch = _LookupBatch.return_value + batch.add.return_value = future_result(entity_pb) + + global_cache.set_if_not_exists = mock.Mock(return_value={cache_key: False}) + + future = _api.lookup(key._key, _options.ReadOptions()) + assert future.result() == entity_pb + + assert global_cache.get([cache_key]) == [None] + @staticmethod @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") def test_cache_miss_no_datastore(_LookupBatch, global_cache): diff --git a/tests/unit/test_global_cache.py b/tests/unit/test_global_cache.py index 0a724a23..b5577828 100644 --- a/tests/unit/test_global_cache.py +++ b/tests/unit/test_global_cache.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections + try: from unittest import mock except ImportError: # pragma: NO PY3 COVER @@ -32,6 +34,9 @@ def get(self, keys): def set(self, items, expires=None): return super(MockImpl, self).set(items, expires=expires) + def set_if_not_exists(self, items, expires=None): + return super(MockImpl, self).set_if_not_exists(items, expires=expires) + def delete(self, keys): return super(MockImpl, self).delete(keys) @@ -59,6 +64,11 @@ def test_set(self): with pytest.raises(NotImplementedError): cache.set({b"foo": "bar"}) + def test_set_if_not_exists(self): + cache = self.make_one() + with pytest.raises(NotImplementedError): + cache.set_if_not_exists({b"foo": "bar"}) + def test_delete(self): cache = self.make_one() with pytest.raises(NotImplementedError): @@ -123,6 +133,37 @@ def test_set_get_delete_w_expires(time): result = cache.get([b"two", b"three", b"one"]) assert result == [None, None, None] + @staticmethod + def test_set_if_not_exists(): + cache = global_cache._InProcessGlobalCache() + result = cache.set_if_not_exists({b"one": b"foo", b"two": b"bar"}) + assert result == {b"one": True, b"two": True} + + result = cache.set_if_not_exists({b"two": b"bar", b"three": b"baz"}) + assert result == {b"two": False, b"three": True} + + result = cache.get([b"two", b"three", b"one"]) + assert result == [b"bar", b"baz", b"foo"] + + @staticmethod + @mock.patch("google.cloud.ndb.global_cache.time") + def test_set_if_not_exists_w_expires(time): + time.time.return_value = 0 + + cache = global_cache._InProcessGlobalCache() + result = cache.set_if_not_exists({b"one": b"foo", b"two": b"bar"}, expires=5) + assert result == {b"one": True, b"two": True} + + result = cache.set_if_not_exists({b"two": b"bar", b"three": b"baz"}, expires=5) + assert result == {b"two": False, b"three": True} + + result = cache.get([b"two", b"three", b"one"]) + assert result == [b"bar", b"baz", b"foo"] + + time.time.return_value = 10 + result = cache.get([b"two", b"three", b"one"]) + assert result == [None, None, None] + @staticmethod def test_watch_compare_and_swap(): cache = global_cache._InProcessGlobalCache() @@ -234,6 +275,37 @@ def mock_expire(key, expires): redis.mset.assert_called_once_with(cache_items) assert expired == {"a": 32, "b": 32} + @staticmethod + def test_set_if_not_exists(): + redis = mock.Mock(spec=("setnx",)) + redis.setnx.side_effect = (True, False) + cache_items = collections.OrderedDict([("a", "foo"), ("b", "bar")]) + cache = global_cache.RedisCache(redis) + results = cache.set_if_not_exists(cache_items) + assert results == {"a": True, "b": False} + redis.setnx.assert_has_calls( + [ + mock.call("a", "foo"), + mock.call("b", "bar"), + ] + ) + + @staticmethod + def test_set_if_not_exists_w_expires(): + redis = mock.Mock(spec=("setnx", "expire")) + redis.setnx.side_effect = (True, False) + cache_items = collections.OrderedDict([("a", "foo"), ("b", "bar")]) + cache = global_cache.RedisCache(redis) + results = cache.set_if_not_exists(cache_items, expires=123) + assert results == {"a": True, "b": False} + redis.setnx.assert_has_calls( + [ + mock.call("a", "foo"), + mock.call("b", "bar"), + ] + ) + redis.expire.assert_called_once_with("a", 123) + @staticmethod def test_delete(): redis = mock.Mock(spec=("delete",)) @@ -468,6 +540,36 @@ def test_set(): noreply=False, ) + @staticmethod + def test_set_if_not_exists(): + client = mock.Mock(spec=("add",)) + client.add.side_effect = (True, False) + cache_items = collections.OrderedDict([(b"a", b"foo"), (b"b", b"bar")]) + cache = global_cache.MemcacheCache(client) + results = cache.set_if_not_exists(cache_items) + assert results == {b"a": True, b"b": False} + client.add.assert_has_calls( + [ + mock.call(cache._key(b"a"), b"foo", expire=0, noreply=False), + mock.call(cache._key(b"b"), b"bar", expire=0, noreply=False), + ] + ) + + @staticmethod + def test_set_if_not_exists_w_expires(): + client = mock.Mock(spec=("add",)) + client.add.side_effect = (True, False) + cache_items = collections.OrderedDict([(b"a", b"foo"), (b"b", b"bar")]) + cache = global_cache.MemcacheCache(client) + results = cache.set_if_not_exists(cache_items, expires=123) + assert results == {b"a": True, b"b": False} + client.add.assert_has_calls( + [ + mock.call(cache._key(b"a"), b"foo", expire=123, noreply=False), + mock.call(cache._key(b"b"), b"bar", expire=123, noreply=False), + ] + ) + @staticmethod def test_set_w_expires(): client = mock.Mock(spec=("set_many",)) From 7ade7dcbaad023415fd1f8928c9114a1630bf360 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 8 Jun 2021 10:30:10 -0400 Subject: [PATCH 02/13] Lint --- google/cloud/ndb/_cache.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 83c7b75f..f8e37ee8 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -308,11 +308,12 @@ def __init__(self, options): def done_callback(self, cache_call): """Process results of call to global cache. + If there is an exception for the cache call, distribute that to waiting futures, otherwise examine the result of the cache call. If the result is :data:`None`, simply set the result to :data:`None` for all waiting futures. Otherwise, if the result is a `dict`, use that to propagate results for - individual keys to waiting figures. + individual keys to waiting futures. """ exception = cache_call.exception() if exception: From d6a83a99165cc7a187ca9f4c4d680725bd7340d0 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 8 Jun 2021 15:43:08 -0400 Subject: [PATCH 03/13] Consider transient errors during to be the same as if the key aleady exists. --- google/cloud/ndb/_cache.py | 21 +++++++++++++++------ tests/unit/test__cache.py | 22 ++++++++++++++++++++++ tests/unit/test__datastore_api.py | 4 +++- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index f8e37ee8..c3119428 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -368,10 +368,10 @@ def make_call(self): def future_info(self, key, value): """Generate info string for Future.""" - return "GlobalCache.set_if_not_exists({}, {})".format(key, value) + return "GlobalCache.set({}, {})".format(key, value) -@_handle_transient_errors() +@tasklets.tasklet def global_set_if_not_exists(key, value, expires=None, read=False): """Store entity in the global cache if key is not already present. @@ -384,14 +384,21 @@ def global_set_if_not_exists(key, value, expires=None, read=False): Returns: tasklets.Future: Eventual result will be a ``bool`` value which will be :data:`True` if a new value was set for the key, or :data:`False` if a value - was already set for the key. + was already set for the key or if a transient error occurred while + attempting to set the key. """ options = {} if expires: options = {"expires": expires} + cache = _global_cache() batch = _batch.get_batch(_GlobalCacheSetIfNotExistsBatch, options) - return batch.add(key, value) + try: + success = yield batch.add(key, value) + except cache.transient_errors: + success = False + + raise tasklets.Return(success) class _GlobalCacheSetIfNotExistsBatch(_GlobalCacheSetBatch): @@ -425,7 +432,7 @@ def make_call(self): def future_info(self, key, value): """Generate info string for Future.""" - return "GlobalCache.set({}, {})".format(key, value) + return "GlobalCache.set_if_not_exists({}, {})".format(key, value) @_handle_transient_errors() @@ -578,7 +585,9 @@ def global_lock(key, read=False): read (bool): Indicates if being called as part of a read (lookup) operation. Returns: - tasklets.Future: Eventual result will be ``None``. + tasklets.Future: Eventual result will either be :data:`None`, or a boolean value + indicating whether the lock was successfully acquired. The result will + always be a boolean when `read` is :data:`True`. """ if read: return global_set_if_not_exists(key, _LOCKED, expires=_LOCK_TIME, read=read) diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index 2658f470..8a905d61 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -472,6 +472,28 @@ def test_without_expires(_batch, _global_cache): ) batch.add.assert_called_once_with(b"key", b"value") + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_transientError(_batch, _global_cache): + class TransientError(Exception): + pass + + batch = _batch.get_batch.return_value + future = _future_exception(TransientError("oops, mom!")) + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(TransientError,), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert _cache.global_set_if_not_exists(b"key", b"value").result() is False + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {} + ) + batch.add.assert_called_once_with(b"key", b"value") + @staticmethod @mock.patch("google.cloud.ndb._cache._global_cache") @mock.patch("google.cloud.ndb._cache._batch") diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 6ee132b1..93cc319a 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -285,7 +285,9 @@ class SomeKind(model.Model): batch = _LookupBatch.return_value batch.add.return_value = future_result(entity_pb) - global_cache.set_if_not_exists = mock.Mock(return_value={cache_key: False}) + global_cache.set_if_not_exists = mock.Mock( + return_value=future_result({cache_key: False}) + ) future = _api.lookup(key._key, _options.ReadOptions()) assert future.result() == entity_pb From c6979d74305489bd61a150a48f07681d71dfd815 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 15 Jun 2021 11:40:14 -0400 Subject: [PATCH 04/13] Fairly big refactor. --- google/cloud/ndb/_cache.py | 152 ++++++++++++++++----- google/cloud/ndb/_datastore_api.py | 21 ++- google/cloud/ndb/global_cache.py | 105 +++++++------- tests/system/test_crud.py | 12 +- tests/unit/test__cache.py | 211 ++++++++++++++++++++++++----- tests/unit/test__datastore_api.py | 6 +- tests/unit/test_global_cache.py | 180 +++++++++++++----------- 7 files changed, 478 insertions(+), 209 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index c3119428..e77275f8 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -14,6 +14,7 @@ import functools import itertools +import uuid import warnings from google.api_core import retry as core_retry @@ -22,7 +23,8 @@ from google.cloud.ndb import context as context_module from google.cloud.ndb import tasklets -_LOCKED = b"0" +_LOCKED_FOR_READ = b"0" +_LOCKED_FOR_WRITE = b"00" _LOCK_TIME = 32 _PREFIX = b"NDB30" @@ -200,8 +202,7 @@ def wrapper(key, *args, **kwargs): return wrap -@_handle_transient_errors(read=True) -def global_get(key): +def _global_get(key): """Get entity from global cache. Args: @@ -215,6 +216,9 @@ def global_get(key): return batch.add(key) +global_get = _handle_transient_errors(read=True)(_global_get) + + class _GlobalCacheGetBatch(_GlobalCacheBatch): """Batch for global cache get requests. @@ -372,14 +376,13 @@ def future_info(self, key, value): @tasklets.tasklet -def global_set_if_not_exists(key, value, expires=None, read=False): +def global_set_if_not_exists(key, value, expires=None): """Store entity in the global cache if key is not already present. Args: key (bytes): The key to save. value (bytes): The entity to save. expires (Optional[float]): Number of seconds until value expires. - read (bool): Indicates if being set in a read (lookup) context. Returns: tasklets.Future: Eventual result will be a ``bool`` value which will be @@ -435,8 +438,7 @@ def future_info(self, key, value): return "GlobalCache.set_if_not_exists({}, {})".format(key, value) -@_handle_transient_errors() -def global_delete(key): +def _global_delete(key): """Delete an entity from the global cache. Args: @@ -449,6 +451,9 @@ def global_delete(key): return batch.add(key) +global_delete = _handle_transient_errors()(_global_delete) + + class _GlobalCacheDeleteBatch(_GlobalCacheBatch): """Batch for global cache delete requests.""" @@ -479,8 +484,7 @@ def future_info(self, key): return "GlobalCache.delete({})".format(key) -@_handle_transient_errors(read=True) -def global_watch(key): +def _global_watch(key, value): """Start optimistic transaction with global cache. A future call to :func:`global_compare_and_swap` will only set the value @@ -492,24 +496,23 @@ def global_watch(key): Returns: tasklets.Future: Eventual result will be ``None``. """ - batch = _batch.get_batch(_GlobalCacheWatchBatch) - return batch.add(key) + batch = _batch.get_batch(_GlobalCacheWatchBatch, {}) + return batch.add(key, value) -class _GlobalCacheWatchBatch(_GlobalCacheDeleteBatch): - """Batch for global cache watch requests. """ +global_watch = _handle_transient_errors(read=True)(_global_watch) - def __init__(self, ignore_options): - self.keys = [] - self.futures = [] + +class _GlobalCacheWatchBatch(_GlobalCacheSetBatch): + """Batch for global cache watch requests. """ def make_call(self): """Call :method:`GlobalCache.watch`.""" - return _global_cache().watch(self.keys) + return _global_cache().watch(self.todo) - def future_info(self, key): + def future_info(self, key, value): """Generate info string for Future.""" - return "GlobalCache.watch({})".format(key) + return "GlobalCache.watch({}, {})".format(key, value) @_handle_transient_errors() @@ -526,11 +529,11 @@ def global_unwatch(key): Returns: tasklets.Future: Eventual result will be ``None``. """ - batch = _batch.get_batch(_GlobalCacheUnwatchBatch) + batch = _batch.get_batch(_GlobalCacheUnwatchBatch, {}) return batch.add(key) -class _GlobalCacheUnwatchBatch(_GlobalCacheWatchBatch): +class _GlobalCacheUnwatchBatch(_GlobalCacheDeleteBatch): """Batch for global cache unwatch requests. """ def make_call(self): @@ -542,8 +545,7 @@ def future_info(self, key): return "GlobalCache.unwatch({})".format(key) -@_handle_transient_errors(read=True) -def global_compare_and_swap(key, value, expires=None): +def _global_compare_and_swap(key, value, expires=None): """Like :func:`global_set` but using an optimistic transaction. Value will only be set for the given key if the value in the cache hasn't @@ -565,6 +567,9 @@ def global_compare_and_swap(key, value, expires=None): return batch.add(key, value) +global_compare_and_swap = _handle_transient_errors(read=True)(_global_compare_and_swap) + + class _GlobalCacheCompareAndSwapBatch(_GlobalCacheSetBatch): """Batch for global cache compare and swap requests. """ @@ -577,22 +582,100 @@ def future_info(self, key, value): return "GlobalCache.compare_and_swap({}, {})".format(key, value) -def global_lock(key, read=False): - """Lock a key by setting a special value. +@tasklets.tasklet +def global_lock_for_read(key): + """Lock a key for a read (lookup) operation by setting a special value. + + Lock may be preempted by a parallel write (put) operation. Args: key (bytes): The key to lock. - read (bool): Indicates if being called as part of a read (lookup) operation. Returns: - tasklets.Future: Eventual result will either be :data:`None`, or a boolean value - indicating whether the lock was successfully acquired. The result will - always be a boolean when `read` is :data:`True`. + tasklets.Future: Eventual result will be lock value (``bytes``) written to + Datastore for the given key, or :data:`None` if the lock was not acquired. """ - if read: - return global_set_if_not_exists(key, _LOCKED, expires=_LOCK_TIME, read=read) + lock_acquired = yield global_set_if_not_exists( + key, _LOCKED_FOR_READ, expires=_LOCK_TIME + ) + if lock_acquired: + raise tasklets.Return(_LOCKED_FOR_READ) + + +@_handle_transient_errors() +@tasklets.tasklet +def global_lock_for_write(key): + """Lock a key for a write (put) operation, by setting or updating a special value. - return global_set(key, _LOCKED, expires=_LOCK_TIME, read=read) + There can be multiple write locks for a given key. Key will only be released when + all write locks have been released. + + Args: + key (bytes): The key to lock. + + Returns: + tasklets.Future: Eventual result will be a lock value to be used later with + :func:`global_unlock`. + """ + lock = "." + str(uuid.uuid4()) + lock = lock.encode("ascii") + + def new_value(old_value): + if old_value and old_value.startswith(_LOCKED_FOR_WRITE): + return old_value + lock + + return _LOCKED_FOR_WRITE + lock + + yield _update_key(key, new_value) + + raise tasklets.Return(lock) + + +@tasklets.tasklet +def global_unlock_for_write(key, lock): + """Remove a lock for key by updating or removing a lock value. + + The lock represented by the ``lock`` argument will be released. If no other locks + remain, the key will be deleted. + + Args: + key (bytes): The key to lock. + lock (bytes): The return value from the call :func:`global_lock` which acquired + the lock. + + Returns: + tasklets.Future: Eventual result will be :data:`None`. + """ + + def new_value(old_value): + return old_value.replace(lock, b"") + + cache = _global_cache() + try: + yield _update_key(key, new_value) + except cache.transient_errors: + # Worst case scenario, lock sticks around for longer than we'd like + pass + + +@tasklets.tasklet +def _update_key(key, new_value): + success = False + + while not success: + old_value = yield _global_get(key) + value = new_value(old_value) + if value == _LOCKED_FOR_WRITE: + # No more locks for this key, we can delete + yield _global_delete(key) + break + + if old_value: + yield _global_watch(key, old_value) + success = yield _global_compare_and_swap(key, value) + + else: + success = yield global_set_if_not_exists(key, value) def is_locked_value(value): @@ -601,7 +684,10 @@ def is_locked_value(value): Returns: bool: Whether the value is the special reserved value for key lock. """ - return value == _LOCKED + if value: + return value == _LOCKED_FOR_READ or value.startswith(_LOCKED_FOR_WRITE) + + return False def global_cache_key(key): diff --git a/google/cloud/ndb/_datastore_api.py b/google/cloud/ndb/_datastore_api.py index aec514f1..4abfa514 100644 --- a/google/cloud/ndb/_datastore_api.py +++ b/google/cloud/ndb/_datastore_api.py @@ -146,9 +146,9 @@ def lookup(key, options): entity_pb.MergeFromString(result) elif use_datastore: - lock_acquired = yield _cache.global_lock(cache_key, read=True) - if lock_acquired: - yield _cache.global_watch(cache_key) + lock = yield _cache.global_lock_for_read(cache_key) + if lock: + yield _cache.global_watch(cache_key, lock) else: # Another thread locked or wrote to this key after the call to @@ -366,11 +366,12 @@ def put(entity, options): if not use_datastore and entity.key.is_partial: raise TypeError("Can't store partial keys when use_datastore is False") + lock = None entity_pb = helpers.entity_to_protobuf(entity) cache_key = _cache.global_cache_key(entity.key) if use_global_cache and not entity.key.is_partial: if use_datastore: - yield _cache.global_lock(cache_key) + lock = yield _cache.global_lock_for_write(cache_key) else: expires = context._global_cache_timeout(entity.key, options) cache_value = entity_pb.SerializeToString() @@ -389,11 +390,12 @@ def put(entity, options): else: key = None - if use_global_cache: + if lock: if transaction: + # ??? context.global_cache_flush_keys.add(cache_key) else: - yield _cache.global_delete(cache_key) + yield _cache.global_unlock_for_write(cache_key, lock) raise tasklets.Return(key) @@ -423,7 +425,7 @@ def delete(key, options): if use_datastore: if use_global_cache: - yield _cache.global_lock(cache_key) + lock = yield _cache.global_lock_for_write(cache_key) if transaction: batch = _get_commit_batch(transaction, options) @@ -434,7 +436,12 @@ def delete(key, options): if use_global_cache: if transaction: + # ??? context.global_cache_flush_keys.add(cache_key) + + elif use_datastore: + yield _cache.global_unlock_for_write(cache_key, lock) + else: yield _cache.global_delete(cache_key) diff --git a/google/cloud/ndb/global_cache.py b/google/cloud/ndb/global_cache.py index ca70d3b4..b2c1e5d2 100644 --- a/google/cloud/ndb/global_cache.py +++ b/google/cloud/ndb/global_cache.py @@ -143,14 +143,15 @@ def delete(self, keys): raise NotImplementedError @abc.abstractmethod - def watch(self, keys): - """Begin an optimistic transaction for the given keys. + def watch(self, items): + """Begin an optimistic transaction for the given items. A future call to :meth:`compare_and_swap` will only set values for keys - whose values haven't changed since the call to this method. + whose values haven't changed since the call to this method. Values are used to + check that the watched value matches the expected value for a given key. Arguments: - keys (List[bytes]): The keys to watch. + items (Dict[bytes, bytes]): The items to watch. """ raise NotImplementedError @@ -178,6 +179,10 @@ def compare_and_swap(self, items, expires=None): items (Dict[bytes, Union[bytes, None]]): Mapping of keys to serialized entities. expires (Optional[float]): Number of seconds until value expires. + + Returns: + Dict[bytes, bool]: A mapping of key to result. A key will have a result of + :data:`True` if it was changed successfully. """ raise NotImplementedError @@ -251,10 +256,10 @@ def delete(self, keys): for key in keys: self.cache.pop(key, None) # Threadsafe? - def watch(self, keys): + def watch(self, items): """Implements :meth:`GlobalCache.watch`.""" - for key in keys: - self._watch_keys[key] = self.cache.get(key) + for key, value in items.items(): + self._watch_keys[key] = value def unwatch(self, keys): """Implements :meth:`GlobalCache.unwatch`.""" @@ -266,20 +271,22 @@ def compare_and_swap(self, items, expires=None): if expires: expires = time.time() + expires + results = {key: False for key in items.keys()} for key, new_value in items.items(): watch_value = self._watch_keys.get(key) current_value = self.cache.get(key) + current_value = current_value[0] if current_value else current_value if watch_value == current_value: self.cache[key] = (new_value, expires) + results[key] = True + + return results def clear(self): """Implements :meth:`GlobalCache.clear`.""" self.cache.clear() -_Pipeline = collections.namedtuple("_Pipeline", ("pipe", "id")) - - class RedisCache(GlobalCache): """Redis implementation of the :class:`GlobalCache`. @@ -398,48 +405,41 @@ def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" self.redis.delete(*keys) - def watch(self, keys): + def watch(self, items): """Implements :meth:`GlobalCache.watch`.""" - pipe = self.redis.pipeline() - pipe.watch(*keys) - holder = _Pipeline(pipe, str(uuid.uuid4())) - for key in keys: - self.pipes[key] = holder + for key, value in items.items(): + pipe = self.redis.pipeline() + pipe.watch(key) + if pipe.get(key) == value: + self.pipes[key] = pipe + else: + pipe.reset() def unwatch(self, keys): """Implements :meth:`GlobalCache.watch`.""" for key in keys: - holder = self.pipes.pop(key, None) - if holder: - holder.pipe.reset() + pipe = self.pipes.pop(key, None) + if pipe: + pipe.reset() def compare_and_swap(self, items, expires=None): """Implements :meth:`GlobalCache.compare_and_swap`.""" - pipes = {} - mappings = {} - remove_keys = [] + results = {key: False for key in items.keys()} - # get associated pipes + pipes = self.pipes for key, value in items.items(): - remove_keys.append(key) - if key not in self.pipes: + pipe = pipes.pop(key, None) + if pipe is None: continue - pipe = self.pipes[key] - pipes[pipe.id] = pipe - mapping = mappings.setdefault(pipe.id, {}) - mapping[key] = value - - # execute transaction for each pipes - for pipe_id, mapping in mappings.items(): - pipe = pipes[pipe_id].pipe try: pipe.multi() - pipe.mset(mapping) if expires: - for key in mapping.keys(): - pipe.expire(key, expires) + pipe.setex(key, value, expires) + else: + pipe.set(key, value) pipe.execute() + results[key] = True except redis_module.exceptions.WatchError: pass @@ -447,14 +447,7 @@ def compare_and_swap(self, items, expires=None): finally: pipe.reset() - # get keys associated to pipes but not updated - for key, pipe in self.pipes.items(): - if pipe.id in pipes: - remove_keys.append(key) - - # remove keys - for key in remove_keys: - self.pipes.pop(key, None) + return results def clear(self): """Implements :meth:`GlobalCache.clear`.""" @@ -654,12 +647,19 @@ def delete(self, keys): keys = [self._key(key) for key in keys] self.client.delete_many(keys) - def watch(self, keys): + def watch(self, items): """Implements :meth:`GlobalCache.watch`.""" - keys = [self._key(key) for key in keys] caskeys = self.caskeys + keys = [] + prev_values = {} + for key, prev_value in items.items(): + key = self._key(key) + keys.append(key) + prev_values[key] = prev_value + for key, (value, caskey) in self.client.gets_many(keys).items(): - caskeys[key] = caskey + if prev_values[key] == value: + caskeys[key] = caskey def unwatch(self, keys): """Implements :meth:`GlobalCache.unwatch`.""" @@ -671,14 +671,19 @@ def unwatch(self, keys): def compare_and_swap(self, items, expires=None): """Implements :meth:`GlobalCache.compare_and_swap`.""" caskeys = self.caskeys - for key, value in items.items(): - key = self._key(key) + results = {} + for orig_key, value in items.items(): + key = self._key(orig_key) caskey = caskeys.pop(key, None) if caskey is None: continue expires = expires if expires else 0 - self.client.cas(key, value, caskey, expire=expires) + results[orig_key] = bool( + self.client.cas(key, value, caskey, expire=expires, noreply=False) + ) + + return results def clear(self): """Implements :meth:`GlobalCache.clear`.""" diff --git a/tests/system/test_crud.py b/tests/system/test_crud.py index 945e55d4..95fcba8c 100644 --- a/tests/system/test_crud.py +++ b/tests/system/test_crud.py @@ -110,8 +110,10 @@ class SomeKind(ndb.Model): cache_key = _cache.global_cache_key(key._key) assert cache_key in cache_dict - patch = mock.patch("google.cloud.ndb._datastore_api._LookupBatch.add") - patch.side_effect = Exception("Shouldn't call this") + patch = mock.patch( + "google.cloud.ndb._datastore_api._LookupBatch.add", + mock.Mock(side_effect=Exception("Shouldn't call this")), + ) with patch: entity = key.get() assert isinstance(entity, SomeKind) @@ -587,8 +589,6 @@ class SomeKind(ndb.Model): entity.foo = 43 entity.put() - # This is py27 behavior. I can see a case being made for caching the - # entity on write rather than waiting for a subsequent lookup. assert cache_key not in cache_dict @@ -613,8 +613,6 @@ class SomeKind(ndb.Model): entity.foo = 43 entity.put() - # This is py27 behavior. I can see a case being made for caching the - # entity on write rather than waiting for a subsequent lookup. assert redis_context.global_cache.redis.get(cache_key) is None @@ -640,8 +638,6 @@ class SomeKind(ndb.Model): entity.foo = 43 entity.put() - # This is py27 behavior. I can see a case being made for caching the - # entity on write rather than waiting for a subsequent lookup. assert memcache_context.global_cache.client.get(cache_key) is None diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index 8a905d61..b46b8df9 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -675,24 +675,28 @@ def test_global_watch(_batch, _global_cache): spec=("transient_errors", "strict_read"), ) - assert _cache.global_watch(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with(_cache._GlobalCacheWatchBatch) - batch.add.assert_called_once_with(b"key") + assert _cache.global_watch(b"key", b"value").result() == "hi mom!" + _batch.get_batch.assert_called_once_with(_cache._GlobalCacheWatchBatch, {}) + batch.add.assert_called_once_with(b"key", b"value") +@pytest.mark.usefixtures("in_context") class Test_GlobalCacheWatchBatch: @staticmethod def test_add_and_idle_and_done_callbacks(in_context): - cache = mock.Mock() + cache = mock.Mock(spec=("watch",)) + cache.watch.return_value = None batch = _cache._GlobalCacheWatchBatch({}) - future1 = batch.add(b"foo") - future2 = batch.add(b"bar") + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires is None with in_context.new(global_cache=cache).use(): batch.idle_callback() - cache.watch.assert_called_once_with([b"foo", b"bar"]) + cache.watch.assert_called_once_with({b"foo": b"one", b"bar": b"two"}) assert future1.result() is None assert future2.result() is None @@ -711,7 +715,7 @@ def test_global_unwatch(_batch, _global_cache): ) assert _cache.global_unwatch(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with(_cache._GlobalCacheUnwatchBatch) + _batch.get_batch.assert_called_once_with(_cache._GlobalCacheUnwatchBatch, {}) batch.add.assert_called_once_with(b"key") @@ -817,51 +821,196 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): assert future2.result() is None -class Test_global_lock: +class Test_global_lock_for_read: + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") + def test_lock_acquired(global_set_if_not_exists): + global_set_if_not_exists.return_value = _future_result(True) + assert _cache.global_lock_for_read(b"key").result() == _cache._LOCKED_FOR_READ + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") + def test_lock_not_acquired(global_set_if_not_exists): + global_set_if_not_exists.return_value = _future_result(False) + assert _cache.global_lock_for_read(b"key").result() is None + + +class Test_global_lock_for_write: @staticmethod @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") + @mock.patch("google.cloud.ndb._cache._global_get") @mock.patch("google.cloud.ndb._cache._global_cache") - @mock.patch("google.cloud.ndb._cache._batch") - def test_global_lock_read(_batch, _global_cache): - batch = _batch.get_batch.return_value - future = _future_result("hi mom!") - batch.add.return_value = future + def test_first_time(_global_cache, _global_get, global_set_if_not_exists, uuid): + uuid.uuid4.return_value = "arandomuuid" + _global_cache.return_value = mock.Mock( transient_errors=(), - strict_read=False, - spec=("transient_errors", "strict_read"), + strict_write=False, + spec=("transient_errors", "strict_write"), ) - assert _cache.global_lock(b"key", read=True).result() == "hi mom!" - _batch.get_batch.assert_called_once_with( - _cache._GlobalCacheSetIfNotExistsBatch, {"expires": _cache._LOCK_TIME} + lock_value = _cache._LOCKED_FOR_WRITE + b".arandomuuid" + _global_get.return_value = _future_result(None) + global_set_if_not_exists.return_value = _future_result(True) + + assert _cache.global_lock_for_write(b"key").result() == b".arandomuuid" + _global_get.assert_called_once_with(b"key") + global_set_if_not_exists.assert_called_once_with(b"key", lock_value) + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_compare_and_swap") + @mock.patch("google.cloud.ndb._cache._global_watch") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_not_first_time_fail_once( + _global_cache, _global_get, _global_watch, _global_compare_and_swap, uuid + ): + uuid.uuid4.return_value = "arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), ) - batch.add.assert_called_once_with(b"key", _cache._LOCKED) + old_lock_value = _cache._LOCKED_FOR_WRITE + b".whatevs" + new_lock_value = old_lock_value + b".arandomuuid" + _global_get.return_value = _future_result(old_lock_value) + _global_watch.return_value = _future_result(None) + _global_compare_and_swap.side_effect = ( + _future_result(False), + _future_result(True), + ) + assert _cache.global_lock_for_write(b"key").result() == b".arandomuuid" + _global_get.assert_has_calls( + [ + mock.call(b"key"), + mock.call(b"key"), + ] + ) + _global_watch.assert_has_calls( + [ + mock.call(b"key", old_lock_value), + mock.call(b"key", old_lock_value), + ] + ) + _global_compare_and_swap.assert_has_calls( + [ + mock.call(b"key", new_lock_value), + mock.call(b"key", new_lock_value), + ] + ) + + +class Test_global_unlock_for_write: @staticmethod @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_delete") + @mock.patch("google.cloud.ndb._cache._global_get") @mock.patch("google.cloud.ndb._cache._global_cache") - @mock.patch("google.cloud.ndb._cache._batch") - def test_global_lock_write(_batch, _global_cache): - batch = _batch.get_batch.return_value - future = _future_result("hi mom!") - batch.add.return_value = future + def test_last_time(_global_cache, _global_get, _global_delete, uuid): + lock = b".arandomuuid" + _global_cache.return_value = mock.Mock( transient_errors=(), strict_write=False, spec=("transient_errors", "strict_write"), ) - assert _cache.global_lock(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with( - _cache._GlobalCacheSetBatch, {"expires": _cache._LOCK_TIME} + lock_value = _cache._LOCKED_FOR_WRITE + lock + _global_get.return_value = _future_result(lock_value) + _global_delete.return_value = _future_result(None) + + assert _cache.global_unlock_for_write(b"key", lock).result() is None + _global_get.assert_called_once_with(b"key") + _global_delete.assert_called_once_with(b"key") + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_delete") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_transient_error(_global_cache, _global_get, _global_delete, uuid): + class TransientError(Exception): + pass + + lock = b".arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(TransientError,), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + lock_value = _cache._LOCKED_FOR_WRITE + lock + _global_get.return_value = _future_result(lock_value) + _global_delete.return_value = _future_exception(TransientError()) + + assert _cache.global_unlock_for_write(b"key", lock).result() is None + _global_get.assert_called_once_with(b"key") + _global_delete.assert_called_once_with(b"key") + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_compare_and_swap") + @mock.patch("google.cloud.ndb._cache._global_watch") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_not_last_time_fail_once( + _global_cache, _global_get, _global_watch, _global_compare_and_swap, uuid + ): + lock = b".arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + new_lock_value = _cache._LOCKED_FOR_WRITE + b".whatevs" + old_lock_value = new_lock_value + lock + _global_get.return_value = _future_result(old_lock_value) + _global_watch.return_value = _future_result(None) + _global_compare_and_swap.side_effect = ( + _future_result(False), + _future_result(True), + ) + + assert _cache.global_unlock_for_write(b"key", lock).result() is None + _global_get.assert_has_calls( + [ + mock.call(b"key"), + mock.call(b"key"), + ] + ) + _global_watch.assert_has_calls( + [ + mock.call(b"key", old_lock_value), + mock.call(b"key", old_lock_value), + ] + ) + _global_compare_and_swap.assert_has_calls( + [ + mock.call(b"key", new_lock_value), + mock.call(b"key", new_lock_value), + ] ) - batch.add.assert_called_once_with(b"key", _cache._LOCKED) def test_is_locked_value(): - assert _cache.is_locked_value(_cache._LOCKED) - assert not _cache.is_locked_value("new db, who dis?") + assert _cache.is_locked_value(_cache._LOCKED_FOR_READ) + assert _cache.is_locked_value(_cache._LOCKED_FOR_WRITE + b"whatever") + assert not _cache.is_locked_value(b"new db, who dis?") + assert not _cache.is_locked_value(None) def test_global_cache_key(): diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 93cc319a..166eb665 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -344,7 +344,7 @@ class SomeKind(model.Model): entity = SomeKind(key=key) entity_pb = model._entity_to_protobuf(entity) - global_cache.set({cache_key: _cache._LOCKED}) + global_cache.set({cache_key: _cache._LOCKED_FOR_READ}) batch = _LookupBatch.return_value batch.add.return_value = future_result(entity_pb) @@ -352,7 +352,7 @@ class SomeKind(model.Model): future = _api.lookup(key._key, _options.ReadOptions()) assert future.result() == entity_pb - assert global_cache.get([cache_key]) == [_cache._LOCKED] + assert global_cache.get([cache_key]) == [_cache._LOCKED_FOR_READ] @staticmethod @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") @@ -369,7 +369,7 @@ class SomeKind(model.Model): future = _api.lookup(key._key, _options.ReadOptions()) assert future.result() is _api._NOT_FOUND - assert global_cache.get([cache_key]) == [_cache._LOCKED] + assert global_cache.get([cache_key]) == [_cache._LOCKED_FOR_READ] assert len(global_cache._watch_keys) == 0 diff --git a/tests/unit/test_global_cache.py b/tests/unit/test_global_cache.py index b5577828..38310b97 100644 --- a/tests/unit/test_global_cache.py +++ b/tests/unit/test_global_cache.py @@ -167,7 +167,10 @@ def test_set_if_not_exists_w_expires(time): @staticmethod def test_watch_compare_and_swap(): cache = global_cache._InProcessGlobalCache() - result = cache.watch([b"one", b"two", b"three"]) + cache.cache[b"one"] = (b"food", None) + cache.cache[b"two"] = (b"bard", None) + cache.cache[b"three"] = (b"bazz", None) + result = cache.watch({b"one": b"food", b"two": b"bard", b"three": b"bazd"}) assert result is None cache.cache[b"two"] = (b"hamburgers", None) @@ -175,10 +178,10 @@ def test_watch_compare_and_swap(): result = cache.compare_and_swap( {b"one": b"foo", b"two": b"bar", b"three": b"baz"} ) - assert result is None + assert result == {b"one": True, b"two": False, b"three": False} result = cache.get([b"one", b"two", b"three"]) - assert result == [b"foo", b"hamburgers", b"baz"] + assert result == [b"foo", b"hamburgers", b"bazz"] @staticmethod @mock.patch("google.cloud.ndb.global_cache.time") @@ -186,7 +189,10 @@ def test_watch_compare_and_swap_with_expires(time): time.time.return_value = 0 cache = global_cache._InProcessGlobalCache() - result = cache.watch([b"one", b"two", b"three"]) + cache.cache[b"one"] = (b"food", None) + cache.cache[b"two"] = (b"bard", None) + cache.cache[b"three"] = (b"bazz", None) + result = cache.watch({b"one": b"food", b"two": b"bard", b"three": b"bazd"}) assert result is None cache.cache[b"two"] = (b"hamburgers", None) @@ -194,20 +200,20 @@ def test_watch_compare_and_swap_with_expires(time): result = cache.compare_and_swap( {b"one": b"foo", b"two": b"bar", b"three": b"baz"}, expires=5 ) - assert result is None + assert result == {b"one": True, b"two": False, b"three": False} result = cache.get([b"one", b"two", b"three"]) - assert result == [b"foo", b"hamburgers", b"baz"] + assert result == [b"foo", b"hamburgers", b"bazz"] time.time.return_value = 10 result = cache.get([b"one", b"two", b"three"]) - assert result == [None, b"hamburgers", None] + assert result == [None, b"hamburgers", b"bazz"] @staticmethod def test_watch_unwatch(): cache = global_cache._InProcessGlobalCache() - result = cache.watch([b"one", b"two", b"three"]) + result = cache.watch({b"one": "foo", b"two": "bar", b"three": "baz"}) assert result is None result = cache.unwatch([b"one", b"two", b"three"]) @@ -315,107 +321,119 @@ def test_delete(): redis.delete.assert_called_once_with(*cache_keys) @staticmethod - @mock.patch("google.cloud.ndb.global_cache.uuid") - def test_watch(uuid): - uuid.uuid4.return_value = "abc123" - redis = mock.Mock(pipeline=mock.Mock(spec=("watch",)), spec=("pipeline",)) + def test_watch(): + def mock_redis_get(key): + if key == "foo": + return "moo" + + return "nope" + + redis = mock.Mock( + pipeline=mock.Mock(spec=("watch", "get", "reset")), spec=("pipeline",) + ) pipe = redis.pipeline.return_value - keys = ["foo", "bar"] + pipe.get.side_effect = mock_redis_get + items = {"foo": "moo", "bar": "car"} cache = global_cache.RedisCache(redis) - cache.watch(keys) + cache.watch(items) - pipe.watch.assert_called_once_with("foo", "bar") - assert cache.pipes == { - "foo": global_cache._Pipeline(pipe, "abc123"), - "bar": global_cache._Pipeline(pipe, "abc123"), - } + pipe.watch.assert_has_calls( + [ + mock.call("foo"), + mock.call("bar"), + ], + any_order=True, + ) + + pipe.get.assert_has_calls( + [ + mock.call("foo"), + mock.call("bar"), + ], + any_order=True, + ) + + assert cache.pipes == {"foo": pipe} @staticmethod def test_unwatch(): redis = mock.Mock(spec=()) cache = global_cache.RedisCache(redis) - pipe1 = mock.Mock(spec=("reset",)) - pipe2 = mock.Mock(spec=("reset",)) + pipe = mock.Mock(spec=("reset",)) cache._pipes.pipes = { - "ay": global_cache._Pipeline(pipe1, "abc123"), - "be": global_cache._Pipeline(pipe1, "abc123"), - "see": global_cache._Pipeline(pipe2, "def456"), - "dee": global_cache._Pipeline(pipe2, "def456"), - "whatevs": global_cache._Pipeline(None, "himom!"), + "ay": pipe, + "be": pipe, + "see": pipe, + "dee": pipe, + "whatevs": "himom!", } cache.unwatch(["ay", "be", "see", "dee", "nuffin"]) - assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")} + assert cache.pipes == {"whatevs": "himom!"} + pipe.reset.assert_has_calls([mock.call()] * 4) @staticmethod def test_compare_and_swap(): redis = mock.Mock(spec=()) cache = global_cache.RedisCache(redis) - pipe1 = mock.Mock(spec=("multi", "mset", "execute", "reset")) - pipe2 = mock.Mock(spec=("multi", "mset", "execute", "reset")) + pipe1 = mock.Mock(spec=("multi", "set", "execute", "reset")) + pipe2 = mock.Mock(spec=("multi", "set", "execute", "reset")) + pipe2.execute.side_effect = redis_module.exceptions.WatchError cache._pipes.pipes = { - "ay": global_cache._Pipeline(pipe1, "abc123"), - "be": global_cache._Pipeline(pipe1, "abc123"), - "see": global_cache._Pipeline(pipe2, "def456"), - "dee": global_cache._Pipeline(pipe2, "def456"), - "whatevs": global_cache._Pipeline(None, "himom!"), + "foo": pipe1, + "bar": pipe2, } - pipe2.execute.side_effect = redis_module.exceptions.WatchError - items = {"ay": "foo", "be": "bar", "see": "baz", "wut": "huh?"} - cache.compare_and_swap(items) + result = cache.compare_and_swap( + { + "foo": "moo", + "bar": "car", + "baz": "maz", + } + ) + assert result == {"foo": True, "bar": False, "baz": False} pipe1.multi.assert_called_once_with() - pipe2.multi.assert_called_once_with() - pipe1.mset.assert_called_once_with({"ay": "foo", "be": "bar"}) - pipe2.mset.assert_called_once_with({"see": "baz"}) + pipe1.set.assert_called_once_with("foo", "moo") pipe1.execute.assert_called_once_with() - pipe2.execute.assert_called_once_with() pipe1.reset.assert_called_once_with() - pipe2.reset.assert_called_once_with() - assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")} + pipe2.multi.assert_called_once_with() + pipe2.set.assert_called_once_with("bar", "car") + pipe2.execute.assert_called_once_with() + pipe2.reset.assert_called_once_with() @staticmethod def test_compare_and_swap_w_expires(): - expired = {} - - def mock_expire(key, expires): - expired[key] = expires - redis = mock.Mock(spec=()) cache = global_cache.RedisCache(redis) - pipe1 = mock.Mock( - expire=mock_expire, - spec=("multi", "mset", "execute", "expire", "reset"), - ) - pipe2 = mock.Mock( - expire=mock_expire, - spec=("multi", "mset", "execute", "expire", "reset"), - ) + pipe1 = mock.Mock(spec=("multi", "setex", "execute", "reset")) + pipe2 = mock.Mock(spec=("multi", "setex", "execute", "reset")) + pipe2.execute.side_effect = redis_module.exceptions.WatchError cache._pipes.pipes = { - "ay": global_cache._Pipeline(pipe1, "abc123"), - "be": global_cache._Pipeline(pipe1, "abc123"), - "see": global_cache._Pipeline(pipe2, "def456"), - "dee": global_cache._Pipeline(pipe2, "def456"), - "whatevs": global_cache._Pipeline(None, "himom!"), + "foo": pipe1, + "bar": pipe2, } - pipe2.execute.side_effect = redis_module.exceptions.WatchError - items = {"ay": "foo", "be": "bar", "see": "baz", "wut": "huh?"} - cache.compare_and_swap(items, expires=32) + result = cache.compare_and_swap( + { + "foo": "moo", + "bar": "car", + "baz": "maz", + }, + expires=5, + ) + assert result == {"foo": True, "bar": False, "baz": False} pipe1.multi.assert_called_once_with() - pipe2.multi.assert_called_once_with() - pipe1.mset.assert_called_once_with({"ay": "foo", "be": "bar"}) - pipe2.mset.assert_called_once_with({"see": "baz"}) + pipe1.setex.assert_called_once_with("foo", "moo", 5) pipe1.execute.assert_called_once_with() - pipe2.execute.assert_called_once_with() pipe1.reset.assert_called_once_with() - pipe2.reset.assert_called_once_with() - assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")} - assert expired == {"ay": 32, "be": 32, "see": 32} + pipe2.multi.assert_called_once_with() + pipe2.setex.assert_called_once_with("bar", "car", 5) + pipe2.execute.assert_called_once_with() + pipe2.reset.assert_called_once_with() @staticmethod def test_clear(): @@ -644,11 +662,17 @@ def test_watch(): key1: ("bun", b"0"), key2: ("shoe", b"1"), } - cache.watch((b"one", b"two")) + cache.watch( + collections.OrderedDict( + ( + (b"one", "bun"), + (b"two", "shot"), + ) + ) + ) client.gets_many.assert_called_once_with([key1, key2]) assert cache.caskeys == { key1: b"0", - key2: b"1", } @staticmethod @@ -669,14 +693,15 @@ def test_compare_and_swap(): key2 = cache._key(b"two") cache.caskeys[key2] = b"5" cache.caskeys["whatevs"] = b"6" - cache.compare_and_swap( + result = cache.compare_and_swap( { b"one": "bun", b"two": "shoe", } ) - client.cas.assert_called_once_with(key2, "shoe", b"5", expire=0) + assert result == {b"two": True} + client.cas.assert_called_once_with(key2, "shoe", b"5", expire=0, noreply=False) assert cache.caskeys == {"whatevs": b"6"} @staticmethod @@ -686,7 +711,7 @@ def test_compare_and_swap_and_expires(): key2 = cache._key(b"two") cache.caskeys[key2] = b"5" cache.caskeys["whatevs"] = b"6" - cache.compare_and_swap( + result = cache.compare_and_swap( { b"one": "bun", b"two": "shoe", @@ -694,7 +719,8 @@ def test_compare_and_swap_and_expires(): expires=5, ) - client.cas.assert_called_once_with(key2, "shoe", b"5", expire=5) + assert result == {b"two": True} + client.cas.assert_called_once_with(key2, "shoe", b"5", expire=5, noreply=False) assert cache.caskeys == {"whatevs": b"6"} @staticmethod From 40c5bb5513d84683a81ffc7294178047fa76ad74 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 15 Jun 2021 12:34:24 -0400 Subject: [PATCH 05/13] lint --- google/cloud/ndb/global_cache.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/google/cloud/ndb/global_cache.py b/google/cloud/ndb/global_cache.py index b2c1e5d2..6fe4e6a8 100644 --- a/google/cloud/ndb/global_cache.py +++ b/google/cloud/ndb/global_cache.py @@ -16,14 +16,12 @@ import abc import base64 -import collections import hashlib import os import pymemcache.exceptions import redis.exceptions import threading import time -import uuid import warnings import pymemcache From 365a50e8d33bae1dde9b5171aed17f070016aadf Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 15 Jun 2021 14:33:20 -0400 Subject: [PATCH 06/13] Clean up cache after transactions. --- google/cloud/ndb/_datastore_api.py | 15 +++++--- google/cloud/ndb/_transaction.py | 55 +++++++++++++++++------------- google/cloud/ndb/context.py | 26 ++++++++++++++ tests/unit/test__datastore_api.py | 24 +++++++++---- tests/unit/test__transaction.py | 42 +++++++++++++++++++++++ tests/unit/test_context.py | 15 ++++++++ 6 files changed, 143 insertions(+), 34 deletions(-) diff --git a/google/cloud/ndb/_datastore_api.py b/google/cloud/ndb/_datastore_api.py index 4abfa514..b08ebb9d 100644 --- a/google/cloud/ndb/_datastore_api.py +++ b/google/cloud/ndb/_datastore_api.py @@ -392,8 +392,12 @@ def put(entity, options): if lock: if transaction: - # ??? - context.global_cache_flush_keys.add(cache_key) + + def callback(): + _cache.global_unlock_for_write(cache_key, lock).result() + + context.call_on_transaction_complete(callback) + else: yield _cache.global_unlock_for_write(cache_key, lock) @@ -436,8 +440,11 @@ def delete(key, options): if use_global_cache: if transaction: - # ??? - context.global_cache_flush_keys.add(cache_key) + + def callback(): + _cache.global_unlock_for_write(cache_key, lock).result() + + context.call_on_transaction_complete(callback) elif use_datastore: yield _cache.global_unlock_for_write(cache_key, lock) diff --git a/google/cloud/ndb/_transaction.py b/google/cloud/ndb/_transaction.py index 1932fefa..d731de6d 100644 --- a/google/cloud/ndb/_transaction.py +++ b/google/cloud/ndb/_transaction.py @@ -259,9 +259,11 @@ def _transaction_async(context, callback, read_only=False): utils.logging_debug(log, "Transaction Id: {}", transaction_id) on_commit_callbacks = [] + transaction_complete_callbacks = [] tx_context = context.new( transaction=transaction_id, on_commit_callbacks=on_commit_callbacks, + transaction_complete_callbacks=transaction_complete_callbacks, batches=None, commit_batches=None, cache=None, @@ -285,30 +287,35 @@ def run_inner_loop(inner_context): tx_context.global_cache_flush_keys = flush_keys = set() with tx_context.use(): try: - # Run the callback - result = callback() - if isinstance(result, tasklets.Future): - result = yield result - - # Make sure we've run everything we can run before calling commit - _datastore_api.prepare_to_commit(transaction_id) - tx_context.eventloop.run() - - # Commit the transaction - yield _datastore_api.commit(transaction_id, retries=0) - - # Rollback if there is an error - except Exception as e: # noqa: E722 - tx_context.cache.clear() - yield _datastore_api.rollback(transaction_id) - raise e - - # Flush keys of entities written during the transaction from the global cache - if flush_keys: - yield [_cache.global_delete(key) for key in flush_keys] - - for callback in on_commit_callbacks: - callback() + try: + # Run the callback + result = callback() + if isinstance(result, tasklets.Future): + result = yield result + + # Make sure we've run everything we can run before calling commit + _datastore_api.prepare_to_commit(transaction_id) + tx_context.eventloop.run() + + # Commit the transaction + yield _datastore_api.commit(transaction_id, retries=0) + + # Rollback if there is an error + except Exception as e: # noqa: E722 + tx_context.cache.clear() + yield _datastore_api.rollback(transaction_id) + raise e + + # Flush keys of entities written during the transaction from the global cache + if flush_keys: + yield [_cache.global_delete(key) for key in flush_keys] + + for callback in on_commit_callbacks: + callback() + + finally: + for callback in transaction_complete_callbacks: + callback() raise tasklets.Return(result) diff --git a/google/cloud/ndb/context.py b/google/cloud/ndb/context.py index c4e67567..de2f5849 100644 --- a/google/cloud/ndb/context.py +++ b/google/cloud/ndb/context.py @@ -208,6 +208,7 @@ def policy(key): "cache", "global_cache", "on_commit_callbacks", + "transaction_complete_callbacks", "legacy_data", ], ) @@ -246,6 +247,7 @@ def __new__( global_cache_timeout_policy=None, datastore_policy=None, on_commit_callbacks=None, + transaction_complete_callbacks=None, legacy_data=True, retry=None, rpc_time=None, @@ -280,6 +282,7 @@ def __new__( cache=new_cache, global_cache=global_cache, on_commit_callbacks=on_commit_callbacks, + transaction_complete_callbacks=transaction_complete_callbacks, legacy_data=legacy_data, ) @@ -565,6 +568,29 @@ def call_on_commit(self, callback): else: callback() + def call_on_transaction_complete(self, callback): + """Call a callback upon completion of a transaction. + + If not in a transaction, the callback is called immediately. + + In a transaction, multiple callbacks may be registered and will be called once + the transaction completes, in the order in which they were registered. Callbacks + are called regardless of whether transaction is committed or rolled back. + + If the callback raises an exception, it bubbles up normally. This means: If the + callback is called immediately, any exception it raises will bubble up + immediately. If the call is postponed until commit, remaining callbacks will be + skipped and the exception will bubble up through the transaction() call. + (However, the transaction is already committed or rolled back at that point.) + + Args: + callback (Callable): The callback function. + """ + if self.in_transaction(): + self.transaction_complete_callbacks.append(callback) + else: + callback() + def in_transaction(self): """Get whether a transaction is currently active. diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 166eb665..5bc51062 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -740,8 +740,10 @@ class SomeKind(model.Model): pass context = context_module.get_context() - with context.new(transaction=b"abc123").use() as in_context: - in_context.global_cache_flush_keys = set() + callbacks = [] + with context.new( + transaction=b"abc123", transaction_complete_callbacks=callbacks + ).use(): key = key_module.Key("SomeKind", 1) cache_key = _cache.global_cache_key(key._key) @@ -752,7 +754,11 @@ class SomeKind(model.Model): future = _api.put(model._entity_to_ds_entity(entity), _options.Options()) assert future.result() is None - assert in_context.global_cache_flush_keys == {cache_key} + assert cache_key in global_cache.cache # lock + for callback in callbacks: + callback() + + assert cache_key not in global_cache.cache # unlocked by callback @staticmethod @mock.patch("google.cloud.ndb._datastore_api._NonTransactionalCommitBatch") @@ -867,8 +873,10 @@ def test_cache_enabled(Batch, global_cache): @mock.patch("google.cloud.ndb._datastore_api._NonTransactionalCommitBatch") def test_w_transaction(Batch, global_cache): context = context_module.get_context() - with context.new(transaction=b"abc123").use() as in_context: - in_context.global_cache_flush_keys = set() + callbacks = [] + with context.new( + transaction=b"abc123", transaction_complete_callbacks=callbacks + ).use(): key = key_module.Key("SomeKind", 1) cache_key = _cache.global_cache_key(key._key) @@ -878,7 +886,11 @@ def test_w_transaction(Batch, global_cache): future = _api.delete(key._key, _options.Options()) assert future.result() is None - assert in_context.global_cache_flush_keys == {cache_key} + assert cache_key in global_cache.cache # lock + for callback in callbacks: + callback() + + assert cache_key not in global_cache.cache # lock removed by callback @staticmethod @mock.patch("google.cloud.ndb._datastore_api._NonTransactionalCommitBatch") diff --git a/tests/unit/test__transaction.py b/tests/unit/test__transaction.py index 8f48a206..85daa37a 100644 --- a/tests/unit/test__transaction.py +++ b/tests/unit/test__transaction.py @@ -91,11 +91,13 @@ class Test_transaction_async: def test_success(_datastore_api): context_module.get_context().cache["foo"] = "bar" on_commit_callback = mock.Mock() + transaction_complete_callback = mock.Mock() def callback(): context = context_module.get_context() assert not context.cache context.call_on_commit(on_commit_callback) + context.call_on_transaction_complete(transaction_complete_callback) return "I tried, momma." begin_future = tasklets.Future("begin transaction") @@ -114,6 +116,46 @@ def callback(): assert future.result() == "I tried, momma." on_commit_callback.assert_called_once_with() + transaction_complete_callback.assert_called_once_with() + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._datastore_api") + def test_failure(_datastore_api): + class SpuriousError(Exception): + pass + + context_module.get_context().cache["foo"] = "bar" + on_commit_callback = mock.Mock() + transaction_complete_callback = mock.Mock() + + def callback(): + context = context_module.get_context() + assert not context.cache + context.call_on_commit(on_commit_callback) + context.call_on_transaction_complete(transaction_complete_callback) + raise SpuriousError() + + begin_future = tasklets.Future("begin transaction") + _datastore_api.begin_transaction.return_value = begin_future + + rollback_future = tasklets.Future("rollback transaction") + _datastore_api.rollback.return_value = rollback_future + + future = _transaction.transaction_async(callback) + + _datastore_api.begin_transaction.assert_called_once_with(False, retries=0) + begin_future.set_result(b"tx123") + + _datastore_api.commit.assert_not_called() + _datastore_api.rollback.assert_called_once_with(b"tx123") + rollback_future.set_result(None) + + with pytest.raises(SpuriousError): + future.result() + + on_commit_callback.assert_not_called() + transaction_complete_callback.assert_called_once_with() @staticmethod def test_success_join(in_context): diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 0222b7cb..3a35ddf5 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -336,6 +336,21 @@ def test_call_on_commit_with_transaction(self): context.call_on_commit(callback) assert context.on_commit_callbacks == ["himom!"] + def test_call_on_transaction_complete(self): + context = self._make_one() + callback = mock.Mock() + context.call_on_transaction_complete(callback) + callback.assert_called_once_with() + + def test_call_on_transaction_complete_with_transaction(self): + callbacks = [] + callback = "himom!" + context = self._make_one( + transaction=b"tx123", transaction_complete_callbacks=callbacks + ) + context.call_on_transaction_complete(callback) + assert context.transaction_complete_callbacks == ["himom!"] + def test_in_transaction(self): context = self._make_one() assert context.in_transaction() is False From b8e146adf52d99dd41e50d1f1dbb6568c5312094 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 15 Jun 2021 14:35:35 -0400 Subject: [PATCH 07/13] Remove unused --- google/cloud/ndb/_transaction.py | 5 ----- google/cloud/ndb/context.py | 3 --- tests/unit/test__transaction.py | 29 ----------------------------- tests/unit/test_context.py | 5 ----- 4 files changed, 42 deletions(-) diff --git a/google/cloud/ndb/_transaction.py b/google/cloud/ndb/_transaction.py index d731de6d..7b932daa 100644 --- a/google/cloud/ndb/_transaction.py +++ b/google/cloud/ndb/_transaction.py @@ -284,7 +284,6 @@ def run_inner_loop(inner_context): context.eventloop.add_idle(run_inner_loop, tx_context) - tx_context.global_cache_flush_keys = flush_keys = set() with tx_context.use(): try: try: @@ -306,10 +305,6 @@ def run_inner_loop(inner_context): yield _datastore_api.rollback(transaction_id) raise e - # Flush keys of entities written during the transaction from the global cache - if flush_keys: - yield [_cache.global_delete(key) for key in flush_keys] - for callback in on_commit_callbacks: callback() diff --git a/google/cloud/ndb/context.py b/google/cloud/ndb/context.py index de2f5849..91520609 100644 --- a/google/cloud/ndb/context.py +++ b/google/cloud/ndb/context.py @@ -242,7 +242,6 @@ def __new__( cache=None, cache_policy=None, global_cache=None, - global_cache_flush_keys=None, global_cache_policy=None, global_cache_timeout_policy=None, datastore_policy=None, @@ -292,8 +291,6 @@ def __new__( context.set_datastore_policy(datastore_policy) context.set_retry_state(retry) - context.global_cache_flush_keys = global_cache_flush_keys - return context def new(self, **kwargs): diff --git a/tests/unit/test__transaction.py b/tests/unit/test__transaction.py index 85daa37a..2ebdedf4 100644 --- a/tests/unit/test__transaction.py +++ b/tests/unit/test__transaction.py @@ -449,35 +449,6 @@ def callback(): assert future.result() == "I tried, momma." - @staticmethod - @pytest.mark.usefixtures("in_context") - @mock.patch("google.cloud.ndb._cache") - @mock.patch("google.cloud.ndb._datastore_api") - def test_success_flush_keys(_datastore_api, _cache): - def callback(): - context = context_module.get_context() - context.global_cache_flush_keys.add(b"abc123") - return "I tried, momma." - - _cache.global_delete.return_value = utils.future_result(None) - - begin_future = tasklets.Future("begin transaction") - _datastore_api.begin_transaction.return_value = begin_future - - commit_future = tasklets.Future("commit transaction") - _datastore_api.commit.return_value = commit_future - - future = _transaction.transaction_async(callback, retries=0) - - _datastore_api.begin_transaction.assert_called_once_with(False, retries=0) - begin_future.set_result(b"tx123") - - _datastore_api.commit.assert_called_once_with(b"tx123", retries=0) - commit_future.set_result(None) - - assert future.result() == "I tried, momma." - _cache.global_delete.assert_called_once_with(b"abc123") - @staticmethod @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._datastore_api") diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 3a35ddf5..fda9e60e 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -93,11 +93,6 @@ def test_new_transaction(self): assert new_context.transaction == "tx123" assert context.transaction is None - def test_new_global_cache_flush_keys(self): - context = self._make_one(global_cache_flush_keys={"hi", "mom!"}) - new_context = context.new() - assert new_context.global_cache_flush_keys == {"hi", "mom!"} - def test_new_with_cache(self): context = self._make_one() context.cache["foo"] = "bar" From 92c021736dfeaa92c6039e292f8faa4c804ddc47 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 15 Jun 2021 15:10:54 -0400 Subject: [PATCH 08/13] lint --- google/cloud/ndb/_transaction.py | 1 - tests/unit/test__transaction.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/google/cloud/ndb/_transaction.py b/google/cloud/ndb/_transaction.py index 7b932daa..aaf1a6f9 100644 --- a/google/cloud/ndb/_transaction.py +++ b/google/cloud/ndb/_transaction.py @@ -250,7 +250,6 @@ def transaction_async_( @tasklets.tasklet def _transaction_async(context, callback, read_only=False): # Avoid circular import in Python 2.7 - from google.cloud.ndb import _cache from google.cloud.ndb import _datastore_api # Start the transaction diff --git a/tests/unit/test__transaction.py b/tests/unit/test__transaction.py index 2ebdedf4..3d00efbc 100644 --- a/tests/unit/test__transaction.py +++ b/tests/unit/test__transaction.py @@ -28,8 +28,6 @@ from google.cloud.ndb import tasklets from google.cloud.ndb import _transaction -from . import utils - class Test_in_transaction: @staticmethod From cba8ff369b5e9c74f20bcbdf8db72612d9f1e755 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Thu, 17 Jun 2021 11:37:11 -0400 Subject: [PATCH 09/13] Coverage --- google/cloud/ndb/_transaction.py | 2 +- tests/unit/test__transaction.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/google/cloud/ndb/_transaction.py b/google/cloud/ndb/_transaction.py index aaf1a6f9..f07d752c 100644 --- a/google/cloud/ndb/_transaction.py +++ b/google/cloud/ndb/_transaction.py @@ -311,7 +311,7 @@ def run_inner_loop(inner_context): for callback in transaction_complete_callbacks: callback() - raise tasklets.Return(result) + raise tasklets.Return(result) def transactional( diff --git a/tests/unit/test__transaction.py b/tests/unit/test__transaction.py index 3d00efbc..08926ce7 100644 --- a/tests/unit/test__transaction.py +++ b/tests/unit/test__transaction.py @@ -88,6 +88,33 @@ class Test_transaction_async: @mock.patch("google.cloud.ndb._datastore_api") def test_success(_datastore_api): context_module.get_context().cache["foo"] = "bar" + + def callback(): + context = context_module.get_context() + assert not context.cache + return "I tried, momma." + + begin_future = tasklets.Future("begin transaction") + _datastore_api.begin_transaction.return_value = begin_future + + commit_future = tasklets.Future("commit transaction") + _datastore_api.commit.return_value = commit_future + + future = _transaction.transaction_async(callback) + + _datastore_api.begin_transaction.assert_called_once_with(False, retries=0) + begin_future.set_result(b"tx123") + + _datastore_api.commit.assert_called_once_with(b"tx123", retries=0) + commit_future.set_result(None) + + assert future.result() == "I tried, momma." + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._datastore_api") + def test_success_w_callbacks(_datastore_api): + context_module.get_context().cache["foo"] = "bar" on_commit_callback = mock.Mock() transaction_complete_callback = mock.Mock() @@ -119,7 +146,7 @@ def callback(): @staticmethod @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._datastore_api") - def test_failure(_datastore_api): + def test_failure_w_callbacks(_datastore_api): class SpuriousError(Exception): pass From 86393ebddeda3a8ebf6954e9735ab4e50e8ac5ea Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Wed, 14 Jul 2021 13:43:04 -0400 Subject: [PATCH 10/13] Set expiration for write locks. --- google/cloud/ndb/_cache.py | 4 ++-- google/cloud/ndb/global_cache.py | 2 +- tests/unit/test__cache.py | 10 +++++----- tests/unit/test_global_cache.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index e77275f8..226b62b1 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -672,10 +672,10 @@ def _update_key(key, new_value): if old_value: yield _global_watch(key, old_value) - success = yield _global_compare_and_swap(key, value) + success = yield _global_compare_and_swap(key, value, expires=_LOCK_TIME) else: - success = yield global_set_if_not_exists(key, value) + success = yield global_set_if_not_exists(key, value, expires=_LOCK_TIME) def is_locked_value(value): diff --git a/google/cloud/ndb/global_cache.py b/google/cloud/ndb/global_cache.py index 6fe4e6a8..906a1294 100644 --- a/google/cloud/ndb/global_cache.py +++ b/google/cloud/ndb/global_cache.py @@ -433,7 +433,7 @@ def compare_and_swap(self, items, expires=None): try: pipe.multi() if expires: - pipe.setex(key, value, expires) + pipe.setex(key, expires, value) else: pipe.set(key, value) pipe.execute() diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index b46b8df9..26155cc3 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -859,7 +859,7 @@ def test_first_time(_global_cache, _global_get, global_set_if_not_exists, uuid): assert _cache.global_lock_for_write(b"key").result() == b".arandomuuid" _global_get.assert_called_once_with(b"key") - global_set_if_not_exists.assert_called_once_with(b"key", lock_value) + global_set_if_not_exists.assert_called_once_with(b"key", lock_value, expires=32) @staticmethod @pytest.mark.usefixtures("in_context") @@ -902,8 +902,8 @@ def test_not_first_time_fail_once( ) _global_compare_and_swap.assert_has_calls( [ - mock.call(b"key", new_lock_value), - mock.call(b"key", new_lock_value), + mock.call(b"key", new_lock_value, expires=32), + mock.call(b"key", new_lock_value, expires=32), ] ) @@ -1000,8 +1000,8 @@ def test_not_last_time_fail_once( ) _global_compare_and_swap.assert_has_calls( [ - mock.call(b"key", new_lock_value), - mock.call(b"key", new_lock_value), + mock.call(b"key", new_lock_value, expires=32), + mock.call(b"key", new_lock_value, expires=32), ] ) diff --git a/tests/unit/test_global_cache.py b/tests/unit/test_global_cache.py index 38310b97..d2a7b560 100644 --- a/tests/unit/test_global_cache.py +++ b/tests/unit/test_global_cache.py @@ -426,12 +426,12 @@ def test_compare_and_swap_w_expires(): assert result == {"foo": True, "bar": False, "baz": False} pipe1.multi.assert_called_once_with() - pipe1.setex.assert_called_once_with("foo", "moo", 5) + pipe1.setex.assert_called_once_with("foo", 5, "moo") pipe1.execute.assert_called_once_with() pipe1.reset.assert_called_once_with() pipe2.multi.assert_called_once_with() - pipe2.setex.assert_called_once_with("bar", "car", 5) + pipe2.setex.assert_called_once_with("bar", 5, "car") pipe2.execute.assert_called_once_with() pipe2.reset.assert_called_once_with() From 628cd089644b141ae099e70160783495a789acc9 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Fri, 16 Jul 2021 13:21:31 -0400 Subject: [PATCH 11/13] Address nits. --- tests/unit/test__cache.py | 12 ++++-------- tests/unit/test__transaction.py | 6 ++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index 26155cc3..0c48fe24 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -529,7 +529,7 @@ def test_add_duplicate_key_and_value(): @staticmethod def test_add_and_idle_and_done_callbacks(in_context): cache = mock.Mock(spec=("set_if_not_exists",)) - cache.set_if_not_exists.return_value = [] + cache.set_if_not_exists.return_value = {} batch = _cache._GlobalCacheSetIfNotExistsBatch({}) future1 = batch.add(b"foo", b"one") @@ -821,25 +821,24 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): assert future2.result() is None +@pytest.mark.usefixtures("in_context") class Test_global_lock_for_read: @staticmethod - @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") def test_lock_acquired(global_set_if_not_exists): global_set_if_not_exists.return_value = _future_result(True) assert _cache.global_lock_for_read(b"key").result() == _cache._LOCKED_FOR_READ @staticmethod - @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") def test_lock_not_acquired(global_set_if_not_exists): global_set_if_not_exists.return_value = _future_result(False) assert _cache.global_lock_for_read(b"key").result() is None +@pytest.mark.usefixtures("in_context") class Test_global_lock_for_write: @staticmethod - @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache.uuid") @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") @mock.patch("google.cloud.ndb._cache._global_get") @@ -862,7 +861,6 @@ def test_first_time(_global_cache, _global_get, global_set_if_not_exists, uuid): global_set_if_not_exists.assert_called_once_with(b"key", lock_value, expires=32) @staticmethod - @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache.uuid") @mock.patch("google.cloud.ndb._cache._global_compare_and_swap") @mock.patch("google.cloud.ndb._cache._global_watch") @@ -908,9 +906,9 @@ def test_not_first_time_fail_once( ) +@pytest.mark.usefixtures("in_context") class Test_global_unlock_for_write: @staticmethod - @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache.uuid") @mock.patch("google.cloud.ndb._cache._global_delete") @mock.patch("google.cloud.ndb._cache._global_get") @@ -933,7 +931,6 @@ def test_last_time(_global_cache, _global_get, _global_delete, uuid): _global_delete.assert_called_once_with(b"key") @staticmethod - @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache.uuid") @mock.patch("google.cloud.ndb._cache._global_delete") @mock.patch("google.cloud.ndb._cache._global_get") @@ -959,7 +956,6 @@ class TransientError(Exception): _global_delete.assert_called_once_with(b"key") @staticmethod - @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache.uuid") @mock.patch("google.cloud.ndb._cache._global_compare_and_swap") @mock.patch("google.cloud.ndb._cache._global_watch") diff --git a/tests/unit/test__transaction.py b/tests/unit/test__transaction.py index 08926ce7..435b9840 100644 --- a/tests/unit/test__transaction.py +++ b/tests/unit/test__transaction.py @@ -90,8 +90,11 @@ def test_success(_datastore_api): context_module.get_context().cache["foo"] = "bar" def callback(): + # The transaction uses its own in-memory cache, which should be empty in + # the transaction context and not include the key set above. context = context_module.get_context() assert not context.cache + return "I tried, momma." begin_future = tasklets.Future("begin transaction") @@ -119,8 +122,11 @@ def test_success_w_callbacks(_datastore_api): transaction_complete_callback = mock.Mock() def callback(): + # The transaction uses its own in-memory cache, which should be empty in + # the transaction context and not include the key set above. context = context_module.get_context() assert not context.cache + context.call_on_commit(on_commit_callback) context.call_on_transaction_complete(transaction_complete_callback) return "I tried, momma." From fa67d2c158bcf1880c0b3b4e6718a325cc7ff2d4 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Mon, 19 Jul 2021 15:16:11 -0400 Subject: [PATCH 12/13] Better test name? --- tests/unit/test__datastore_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 5bc51062..435e15dd 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -272,7 +272,9 @@ class SomeKind(model.Model): @staticmethod @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") - def test_cache_miss_lock_not_acquired(_LookupBatch, global_cache): + def test_cache_miss_followed_by_lock_acquisition_failure( + _LookupBatch, global_cache + ): class SomeKind(model.Model): pass From b77c2b919a31bf5a0eb0a7974d78d434b6a4baa0 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 20 Jul 2021 13:15:04 -0400 Subject: [PATCH 13/13] Make read locks unique. --- google/cloud/ndb/_cache.py | 11 +++++------ tests/system/test_crud.py | 18 +++++++++++------- tests/unit/test__cache.py | 6 +++++- tests/unit/test__datastore_api.py | 2 +- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 226b62b1..b611f8e9 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -23,7 +23,7 @@ from google.cloud.ndb import context as context_module from google.cloud.ndb import tasklets -_LOCKED_FOR_READ = b"0" +_LOCKED_FOR_READ = b"0-" _LOCKED_FOR_WRITE = b"00" _LOCK_TIME = 32 _PREFIX = b"NDB30" @@ -595,11 +595,10 @@ def global_lock_for_read(key): tasklets.Future: Eventual result will be lock value (``bytes``) written to Datastore for the given key, or :data:`None` if the lock was not acquired. """ - lock_acquired = yield global_set_if_not_exists( - key, _LOCKED_FOR_READ, expires=_LOCK_TIME - ) + lock = _LOCKED_FOR_READ + str(uuid.uuid4()).encode("ascii") + lock_acquired = yield global_set_if_not_exists(key, lock, expires=_LOCK_TIME) if lock_acquired: - raise tasklets.Return(_LOCKED_FOR_READ) + raise tasklets.Return(lock) @_handle_transient_errors() @@ -685,7 +684,7 @@ def is_locked_value(value): bool: Whether the value is the special reserved value for key lock. """ if value: - return value == _LOCKED_FOR_READ or value.startswith(_LOCKED_FOR_WRITE) + return value.startswith(_LOCKED_FOR_READ) or value.startswith(_LOCKED_FOR_WRITE) return False diff --git a/tests/system/test_crud.py b/tests/system/test_crud.py index 95fcba8c..d99a92e7 100644 --- a/tests/system/test_crud.py +++ b/tests/system/test_crud.py @@ -142,8 +142,10 @@ class SomeKind(ndb.Model): cache_key = _cache.global_cache_key(key._key) assert redis_context.global_cache.redis.get(cache_key) is not None - patch = mock.patch("google.cloud.ndb._datastore_api._LookupBatch.add") - patch.side_effect = Exception("Shouldn't call this") + patch = mock.patch( + "google.cloud.ndb._datastore_api._LookupBatch.add", + mock.Mock(side_effect=Exception("Shouldn't call this")), + ) with patch: entity = key.get() assert isinstance(entity, SomeKind) @@ -173,8 +175,10 @@ class SomeKind(ndb.Model): cache_key = global_cache_module.MemcacheCache._key(cache_key) assert memcache_context.global_cache.client.get(cache_key) is not None - patch = mock.patch("google.cloud.ndb._datastore_api._LookupBatch.add") - patch.side_effect = Exception("Shouldn't call this") + patch = mock.patch( + "google.cloud.ndb._datastore_api._LookupBatch.add", + mock.Mock(side_effect=Exception("Shouldn't call this")), + ) with patch: entity = key.get() assert isinstance(entity, SomeKind) @@ -779,7 +783,7 @@ class SomeKind(ndb.Model): # This is py27 behavior. Not entirely sold on leaving _LOCKED value for # Datastore misses. assert key.get() is None - assert cache_dict[cache_key][0] == b"0" + assert cache_dict[cache_key][0].startswith(b"0-") @pytest.mark.skipif(not USE_REDIS_CACHE, reason="Redis is not configured") @@ -802,7 +806,7 @@ class SomeKind(ndb.Model): # This is py27 behavior. Not entirely sold on leaving _LOCKED value for # Datastore misses. assert key.get() is None - assert redis_context.global_cache.redis.get(cache_key) == b"0" + assert redis_context.global_cache.redis.get(cache_key).startswith(b"0-") @pytest.mark.skipif(not USE_MEMCACHE, reason="Memcache is not configured") @@ -826,7 +830,7 @@ class SomeKind(ndb.Model): # This is py27 behavior. Not entirely sold on leaving _LOCKED value for # Datastore misses. assert key.get() is None - assert memcache_context.global_cache.client.get(cache_key) == b"0" + assert memcache_context.global_cache.client.get(cache_key).startswith(b"0-") @pytest.mark.usefixtures("client_context") diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index 0c48fe24..7eb3f519 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -827,7 +827,11 @@ class Test_global_lock_for_read: @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") def test_lock_acquired(global_set_if_not_exists): global_set_if_not_exists.return_value = _future_result(True) - assert _cache.global_lock_for_read(b"key").result() == _cache._LOCKED_FOR_READ + assert ( + _cache.global_lock_for_read(b"key") + .result() + .startswith(_cache._LOCKED_FOR_READ) + ) @staticmethod @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 435e15dd..d2e3a0ee 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -371,7 +371,7 @@ class SomeKind(model.Model): future = _api.lookup(key._key, _options.ReadOptions()) assert future.result() is _api._NOT_FOUND - assert global_cache.get([cache_key]) == [_cache._LOCKED_FOR_READ] + assert global_cache.get([cache_key])[0].startswith(_cache._LOCKED_FOR_READ) assert len(global_cache._watch_keys) == 0