diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index ebf51030..b611f8e9 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -14,6 +14,7 @@ import functools import itertools +import uuid import warnings from google.api_core import retry as core_retry @@ -22,7 +23,8 @@ from google.cloud.ndb import context as context_module from google.cloud.ndb import tasklets -_LOCKED = b"0" +_LOCKED_FOR_READ = b"0-" +_LOCKED_FOR_WRITE = b"00" _LOCK_TIME = 32 _PREFIX = b"NDB30" @@ -200,8 +202,7 @@ def wrapper(key, *args, **kwargs): return wrap -@_handle_transient_errors(read=True) -def global_get(key): +def _global_get(key): """Get entity from global cache. Args: @@ -215,6 +216,9 @@ def global_get(key): return batch.add(key) +global_get = _handle_transient_errors(read=True)(_global_get) + + class _GlobalCacheGetBatch(_GlobalCacheBatch): """Batch for global cache get requests. @@ -306,6 +310,33 @@ def __init__(self, options): self.todo = {} self.futures = {} + def done_callback(self, cache_call): + """Process results of call to global cache. + + If there is an exception for the cache call, distribute that to waiting + futures, otherwise examine the result of the cache call. If the result is + :data:`None`, simply set the result to :data:`None` for all waiting futures. + Otherwise, if the result is a `dict`, use that to propagate results for + individual keys to waiting futures. + """ + exception = cache_call.exception() + if exception: + for future in self.futures.values(): + future.set_exception(exception) + return + + result = cache_call.result() + if result: + for key, future in self.futures.items(): + key_result = result.get(key, None) + if isinstance(key_result, Exception): + future.set_exception(key_result) + else: + future.set_result(key_result) + else: + for future in self.futures.values(): + future.set_result(None) + def add(self, key, value): """Add a key, value pair to store in the cache. @@ -335,44 +366,79 @@ def add(self, key, value): self.futures[key] = future return future - def done_callback(self, cache_call): - """Process results of call to global cache. + def make_call(self): + """Call :method:`GlobalCache.set`.""" + return _global_cache().set(self.todo, expires=self.expires) - If there is an exception for the cache call, distribute that to waiting - futures, otherwise examine the result of the cache call. If the result is - :data:`None`, simply set the result to :data:`None` for all waiting futures. - Otherwise, if the result is a `dict`, use that to propagate results for - individual keys to waiting figures. + def future_info(self, key, value): + """Generate info string for Future.""" + return "GlobalCache.set({}, {})".format(key, value) + + +@tasklets.tasklet +def global_set_if_not_exists(key, value, expires=None): + """Store entity in the global cache if key is not already present. + + Args: + key (bytes): The key to save. + value (bytes): The entity to save. + expires (Optional[float]): Number of seconds until value expires. + + Returns: + tasklets.Future: Eventual result will be a ``bool`` value which will be + :data:`True` if a new value was set for the key, or :data:`False` if a value + was already set for the key or if a transient error occurred while + attempting to set the key. + """ + options = {} + if expires: + options = {"expires": expires} + + cache = _global_cache() + batch = _batch.get_batch(_GlobalCacheSetIfNotExistsBatch, options) + try: + success = yield batch.add(key, value) + except cache.transient_errors: + success = False + + raise tasklets.Return(success) + + +class _GlobalCacheSetIfNotExistsBatch(_GlobalCacheSetBatch): + """Batch for global cache set_if_not_exists requests. """ + + def add(self, key, value): + """Add a key, value pair to store in the cache. + + Arguments: + key (bytes): The key to store in the cache. + value (bytes): The value to store in the cache. + + Returns: + tasklets.Future: Eventual result will be a ``bool`` value which will be + :data:`True` if a new value was set for the key, or :data:`False` if a + value was already set for the key. """ - exception = cache_call.exception() - if exception: - for future in self.futures.values(): - future.set_exception(exception) - return + if key in self.todo: + future = tasklets.Future() + future.set_result(False) + return future - result = cache_call.result() - if result: - for key, future in self.futures.items(): - key_result = result.get(key, None) - if isinstance(key_result, Exception): - future.set_exception(key_result) - else: - future.set_result(key_result) - else: - for future in self.futures.values(): - future.set_result(None) + future = tasklets.Future(info=self.future_info(key, value)) + self.todo[key] = value + self.futures[key] = future + return future def make_call(self): """Call :method:`GlobalCache.set`.""" - return _global_cache().set(self.todo, expires=self.expires) + return _global_cache().set_if_not_exists(self.todo, expires=self.expires) def future_info(self, key, value): """Generate info string for Future.""" - return "GlobalCache.set({}, {})".format(key, value) + return "GlobalCache.set_if_not_exists({}, {})".format(key, value) -@_handle_transient_errors() -def global_delete(key): +def _global_delete(key): """Delete an entity from the global cache. Args: @@ -385,6 +451,9 @@ def global_delete(key): return batch.add(key) +global_delete = _handle_transient_errors()(_global_delete) + + class _GlobalCacheDeleteBatch(_GlobalCacheBatch): """Batch for global cache delete requests.""" @@ -415,8 +484,7 @@ def future_info(self, key): return "GlobalCache.delete({})".format(key) -@_handle_transient_errors(read=True) -def global_watch(key): +def _global_watch(key, value): """Start optimistic transaction with global cache. A future call to :func:`global_compare_and_swap` will only set the value @@ -428,24 +496,23 @@ def global_watch(key): Returns: tasklets.Future: Eventual result will be ``None``. """ - batch = _batch.get_batch(_GlobalCacheWatchBatch) - return batch.add(key) + batch = _batch.get_batch(_GlobalCacheWatchBatch, {}) + return batch.add(key, value) -class _GlobalCacheWatchBatch(_GlobalCacheDeleteBatch): - """Batch for global cache watch requests. """ +global_watch = _handle_transient_errors(read=True)(_global_watch) - def __init__(self, ignore_options): - self.keys = [] - self.futures = [] + +class _GlobalCacheWatchBatch(_GlobalCacheSetBatch): + """Batch for global cache watch requests. """ def make_call(self): """Call :method:`GlobalCache.watch`.""" - return _global_cache().watch(self.keys) + return _global_cache().watch(self.todo) - def future_info(self, key): + def future_info(self, key, value): """Generate info string for Future.""" - return "GlobalCache.watch({})".format(key) + return "GlobalCache.watch({}, {})".format(key, value) @_handle_transient_errors() @@ -462,11 +529,11 @@ def global_unwatch(key): Returns: tasklets.Future: Eventual result will be ``None``. """ - batch = _batch.get_batch(_GlobalCacheUnwatchBatch) + batch = _batch.get_batch(_GlobalCacheUnwatchBatch, {}) return batch.add(key) -class _GlobalCacheUnwatchBatch(_GlobalCacheWatchBatch): +class _GlobalCacheUnwatchBatch(_GlobalCacheDeleteBatch): """Batch for global cache unwatch requests. """ def make_call(self): @@ -478,8 +545,7 @@ def future_info(self, key): return "GlobalCache.unwatch({})".format(key) -@_handle_transient_errors(read=True) -def global_compare_and_swap(key, value, expires=None): +def _global_compare_and_swap(key, value, expires=None): """Like :func:`global_set` but using an optimistic transaction. Value will only be set for the given key if the value in the cache hasn't @@ -501,6 +567,9 @@ def global_compare_and_swap(key, value, expires=None): return batch.add(key, value) +global_compare_and_swap = _handle_transient_errors(read=True)(_global_compare_and_swap) + + class _GlobalCacheCompareAndSwapBatch(_GlobalCacheSetBatch): """Batch for global cache compare and swap requests. """ @@ -513,17 +582,99 @@ def future_info(self, key, value): return "GlobalCache.compare_and_swap({}, {})".format(key, value) -def global_lock(key, read=False): - """Lock a key by setting a special value. +@tasklets.tasklet +def global_lock_for_read(key): + """Lock a key for a read (lookup) operation by setting a special value. + + Lock may be preempted by a parallel write (put) operation. Args: key (bytes): The key to lock. - read (bool): Indicates if being called as part of a read (lookup) operation. Returns: - tasklets.Future: Eventual result will be ``None``. + tasklets.Future: Eventual result will be lock value (``bytes``) written to + Datastore for the given key, or :data:`None` if the lock was not acquired. + """ + lock = _LOCKED_FOR_READ + str(uuid.uuid4()).encode("ascii") + lock_acquired = yield global_set_if_not_exists(key, lock, expires=_LOCK_TIME) + if lock_acquired: + raise tasklets.Return(lock) + + +@_handle_transient_errors() +@tasklets.tasklet +def global_lock_for_write(key): + """Lock a key for a write (put) operation, by setting or updating a special value. + + There can be multiple write locks for a given key. Key will only be released when + all write locks have been released. + + Args: + key (bytes): The key to lock. + + Returns: + tasklets.Future: Eventual result will be a lock value to be used later with + :func:`global_unlock`. + """ + lock = "." + str(uuid.uuid4()) + lock = lock.encode("ascii") + + def new_value(old_value): + if old_value and old_value.startswith(_LOCKED_FOR_WRITE): + return old_value + lock + + return _LOCKED_FOR_WRITE + lock + + yield _update_key(key, new_value) + + raise tasklets.Return(lock) + + +@tasklets.tasklet +def global_unlock_for_write(key, lock): + """Remove a lock for key by updating or removing a lock value. + + The lock represented by the ``lock`` argument will be released. If no other locks + remain, the key will be deleted. + + Args: + key (bytes): The key to lock. + lock (bytes): The return value from the call :func:`global_lock` which acquired + the lock. + + Returns: + tasklets.Future: Eventual result will be :data:`None`. """ - return global_set(key, _LOCKED, expires=_LOCK_TIME, read=read) + + def new_value(old_value): + return old_value.replace(lock, b"") + + cache = _global_cache() + try: + yield _update_key(key, new_value) + except cache.transient_errors: + # Worst case scenario, lock sticks around for longer than we'd like + pass + + +@tasklets.tasklet +def _update_key(key, new_value): + success = False + + while not success: + old_value = yield _global_get(key) + value = new_value(old_value) + if value == _LOCKED_FOR_WRITE: + # No more locks for this key, we can delete + yield _global_delete(key) + break + + if old_value: + yield _global_watch(key, old_value) + success = yield _global_compare_and_swap(key, value, expires=_LOCK_TIME) + + else: + success = yield global_set_if_not_exists(key, value, expires=_LOCK_TIME) def is_locked_value(value): @@ -532,7 +683,10 @@ def is_locked_value(value): Returns: bool: Whether the value is the special reserved value for key lock. """ - return value == _LOCKED + if value: + return value.startswith(_LOCKED_FOR_READ) or value.startswith(_LOCKED_FOR_WRITE) + + return False def global_cache_key(key): diff --git a/google/cloud/ndb/_datastore_api.py b/google/cloud/ndb/_datastore_api.py index f7a247a9..b08ebb9d 100644 --- a/google/cloud/ndb/_datastore_api.py +++ b/google/cloud/ndb/_datastore_api.py @@ -146,8 +146,15 @@ def lookup(key, options): entity_pb.MergeFromString(result) elif use_datastore: - yield _cache.global_lock(cache_key, read=True) - yield _cache.global_watch(cache_key) + lock = yield _cache.global_lock_for_read(cache_key) + if lock: + yield _cache.global_watch(cache_key, lock) + + else: + # Another thread locked or wrote to this key after the call to + # _cache.global_get above. Behave as though the key was locked by + # another thread and don't attempt to write our value below + key_locked = True if entity_pb is _NOT_FOUND and use_datastore: batch = _batch.get_batch(_LookupBatch, options) @@ -359,11 +366,12 @@ def put(entity, options): if not use_datastore and entity.key.is_partial: raise TypeError("Can't store partial keys when use_datastore is False") + lock = None entity_pb = helpers.entity_to_protobuf(entity) cache_key = _cache.global_cache_key(entity.key) if use_global_cache and not entity.key.is_partial: if use_datastore: - yield _cache.global_lock(cache_key) + lock = yield _cache.global_lock_for_write(cache_key) else: expires = context._global_cache_timeout(entity.key, options) cache_value = entity_pb.SerializeToString() @@ -382,11 +390,16 @@ def put(entity, options): else: key = None - if use_global_cache: + if lock: if transaction: - context.global_cache_flush_keys.add(cache_key) + + def callback(): + _cache.global_unlock_for_write(cache_key, lock).result() + + context.call_on_transaction_complete(callback) + else: - yield _cache.global_delete(cache_key) + yield _cache.global_unlock_for_write(cache_key, lock) raise tasklets.Return(key) @@ -416,7 +429,7 @@ def delete(key, options): if use_datastore: if use_global_cache: - yield _cache.global_lock(cache_key) + lock = yield _cache.global_lock_for_write(cache_key) if transaction: batch = _get_commit_batch(transaction, options) @@ -427,7 +440,15 @@ def delete(key, options): if use_global_cache: if transaction: - context.global_cache_flush_keys.add(cache_key) + + def callback(): + _cache.global_unlock_for_write(cache_key, lock).result() + + context.call_on_transaction_complete(callback) + + elif use_datastore: + yield _cache.global_unlock_for_write(cache_key, lock) + else: yield _cache.global_delete(cache_key) diff --git a/google/cloud/ndb/_transaction.py b/google/cloud/ndb/_transaction.py index 1932fefa..f07d752c 100644 --- a/google/cloud/ndb/_transaction.py +++ b/google/cloud/ndb/_transaction.py @@ -250,7 +250,6 @@ def transaction_async_( @tasklets.tasklet def _transaction_async(context, callback, read_only=False): # Avoid circular import in Python 2.7 - from google.cloud.ndb import _cache from google.cloud.ndb import _datastore_api # Start the transaction @@ -259,9 +258,11 @@ def _transaction_async(context, callback, read_only=False): utils.logging_debug(log, "Transaction Id: {}", transaction_id) on_commit_callbacks = [] + transaction_complete_callbacks = [] tx_context = context.new( transaction=transaction_id, on_commit_callbacks=on_commit_callbacks, + transaction_complete_callbacks=transaction_complete_callbacks, batches=None, commit_batches=None, cache=None, @@ -282,35 +283,35 @@ def run_inner_loop(inner_context): context.eventloop.add_idle(run_inner_loop, tx_context) - tx_context.global_cache_flush_keys = flush_keys = set() with tx_context.use(): try: - # Run the callback - result = callback() - if isinstance(result, tasklets.Future): - result = yield result + try: + # Run the callback + result = callback() + if isinstance(result, tasklets.Future): + result = yield result - # Make sure we've run everything we can run before calling commit - _datastore_api.prepare_to_commit(transaction_id) - tx_context.eventloop.run() + # Make sure we've run everything we can run before calling commit + _datastore_api.prepare_to_commit(transaction_id) + tx_context.eventloop.run() - # Commit the transaction - yield _datastore_api.commit(transaction_id, retries=0) + # Commit the transaction + yield _datastore_api.commit(transaction_id, retries=0) - # Rollback if there is an error - except Exception as e: # noqa: E722 - tx_context.cache.clear() - yield _datastore_api.rollback(transaction_id) - raise e + # Rollback if there is an error + except Exception as e: # noqa: E722 + tx_context.cache.clear() + yield _datastore_api.rollback(transaction_id) + raise e - # Flush keys of entities written during the transaction from the global cache - if flush_keys: - yield [_cache.global_delete(key) for key in flush_keys] + for callback in on_commit_callbacks: + callback() - for callback in on_commit_callbacks: - callback() + finally: + for callback in transaction_complete_callbacks: + callback() - raise tasklets.Return(result) + raise tasklets.Return(result) def transactional( diff --git a/google/cloud/ndb/context.py b/google/cloud/ndb/context.py index c4e67567..91520609 100644 --- a/google/cloud/ndb/context.py +++ b/google/cloud/ndb/context.py @@ -208,6 +208,7 @@ def policy(key): "cache", "global_cache", "on_commit_callbacks", + "transaction_complete_callbacks", "legacy_data", ], ) @@ -241,11 +242,11 @@ def __new__( cache=None, cache_policy=None, global_cache=None, - global_cache_flush_keys=None, global_cache_policy=None, global_cache_timeout_policy=None, datastore_policy=None, on_commit_callbacks=None, + transaction_complete_callbacks=None, legacy_data=True, retry=None, rpc_time=None, @@ -280,6 +281,7 @@ def __new__( cache=new_cache, global_cache=global_cache, on_commit_callbacks=on_commit_callbacks, + transaction_complete_callbacks=transaction_complete_callbacks, legacy_data=legacy_data, ) @@ -289,8 +291,6 @@ def __new__( context.set_datastore_policy(datastore_policy) context.set_retry_state(retry) - context.global_cache_flush_keys = global_cache_flush_keys - return context def new(self, **kwargs): @@ -565,6 +565,29 @@ def call_on_commit(self, callback): else: callback() + def call_on_transaction_complete(self, callback): + """Call a callback upon completion of a transaction. + + If not in a transaction, the callback is called immediately. + + In a transaction, multiple callbacks may be registered and will be called once + the transaction completes, in the order in which they were registered. Callbacks + are called regardless of whether transaction is committed or rolled back. + + If the callback raises an exception, it bubbles up normally. This means: If the + callback is called immediately, any exception it raises will bubble up + immediately. If the call is postponed until commit, remaining callbacks will be + skipped and the exception will bubble up through the transaction() call. + (However, the transaction is already committed or rolled back at that point.) + + Args: + callback (Callable): The callback function. + """ + if self.in_transaction(): + self.transaction_complete_callbacks.append(callback) + else: + callback() + def in_transaction(self): """Get whether a transaction is currently active. diff --git a/google/cloud/ndb/global_cache.py b/google/cloud/ndb/global_cache.py index 8d39a60a..906a1294 100644 --- a/google/cloud/ndb/global_cache.py +++ b/google/cloud/ndb/global_cache.py @@ -16,14 +16,12 @@ import abc import base64 -import collections import hashlib import os import pymemcache.exceptions import redis.exceptions import threading import time -import uuid import warnings import pymemcache @@ -116,6 +114,23 @@ def set(self, items, expires=None): """ raise NotImplementedError + @abc.abstractmethod + def set_if_not_exists(self, items, expires=None): + """Stores entities in the cache if and only if keys are not already set. + + Arguments: + items (Dict[bytes, Union[bytes, None]]): Mapping of keys to + serialized entities. + expires (Optional[float]): Number of seconds until value expires. + + + Returns: + Dict[bytes, bool]: A `dict` mapping to boolean value wich will be + :data:`True` if that key was set with a new value, and :data:`False` + otherwise. + """ + raise NotImplementedError + @abc.abstractmethod def delete(self, keys): """Remove entities from the cache. @@ -126,14 +141,15 @@ def delete(self, keys): raise NotImplementedError @abc.abstractmethod - def watch(self, keys): - """Begin an optimistic transaction for the given keys. + def watch(self, items): + """Begin an optimistic transaction for the given items. A future call to :meth:`compare_and_swap` will only set values for keys - whose values haven't changed since the call to this method. + whose values haven't changed since the call to this method. Values are used to + check that the watched value matches the expected value for a given key. Arguments: - keys (List[bytes]): The keys to watch. + items (Dict[bytes, bytes]): The items to watch. """ raise NotImplementedError @@ -161,6 +177,10 @@ def compare_and_swap(self, items, expires=None): items (Dict[bytes, Union[bytes, None]]): Mapping of keys to serialized entities. expires (Optional[float]): Number of seconds until value expires. + + Returns: + Dict[bytes, bool]: A mapping of key to result. A key will have a result of + :data:`True` if it was changed successfully. """ raise NotImplementedError @@ -217,15 +237,27 @@ def set(self, items, expires=None): for key, value in items.items(): self.cache[key] = (value, expires) # Supposedly threadsafe + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + if expires: + expires = time.time() + expires + + results = {} + for key, value in items.items(): + set_value = (value, expires) + results[key] = self.cache.setdefault(key, set_value) is set_value + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" for key in keys: self.cache.pop(key, None) # Threadsafe? - def watch(self, keys): + def watch(self, items): """Implements :meth:`GlobalCache.watch`.""" - for key in keys: - self._watch_keys[key] = self.cache.get(key) + for key, value in items.items(): + self._watch_keys[key] = value def unwatch(self, keys): """Implements :meth:`GlobalCache.unwatch`.""" @@ -237,20 +269,22 @@ def compare_and_swap(self, items, expires=None): if expires: expires = time.time() + expires + results = {key: False for key in items.keys()} for key, new_value in items.items(): watch_value = self._watch_keys.get(key) current_value = self.cache.get(key) + current_value = current_value[0] if current_value else current_value if watch_value == current_value: self.cache[key] = (new_value, expires) + results[key] = True + + return results def clear(self): """Implements :meth:`GlobalCache.clear`.""" self.cache.clear() -_Pipeline = collections.namedtuple("_Pipeline", ("pipe", "id")) - - class RedisCache(GlobalCache): """Redis implementation of the :class:`GlobalCache`. @@ -355,52 +389,55 @@ def set(self, items, expires=None): for key in items.keys(): self.redis.expire(key, expires) + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + results = {} + for key, value in items.items(): + results[key] = key_was_set = self.redis.setnx(key, value) + if key_was_set and expires: + self.redis.expire(key, expires) + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" self.redis.delete(*keys) - def watch(self, keys): + def watch(self, items): """Implements :meth:`GlobalCache.watch`.""" - pipe = self.redis.pipeline() - pipe.watch(*keys) - holder = _Pipeline(pipe, str(uuid.uuid4())) - for key in keys: - self.pipes[key] = holder + for key, value in items.items(): + pipe = self.redis.pipeline() + pipe.watch(key) + if pipe.get(key) == value: + self.pipes[key] = pipe + else: + pipe.reset() def unwatch(self, keys): """Implements :meth:`GlobalCache.watch`.""" for key in keys: - holder = self.pipes.pop(key, None) - if holder: - holder.pipe.reset() + pipe = self.pipes.pop(key, None) + if pipe: + pipe.reset() def compare_and_swap(self, items, expires=None): """Implements :meth:`GlobalCache.compare_and_swap`.""" - pipes = {} - mappings = {} - remove_keys = [] + results = {key: False for key in items.keys()} - # get associated pipes + pipes = self.pipes for key, value in items.items(): - remove_keys.append(key) - if key not in self.pipes: + pipe = pipes.pop(key, None) + if pipe is None: continue - pipe = self.pipes[key] - pipes[pipe.id] = pipe - mapping = mappings.setdefault(pipe.id, {}) - mapping[key] = value - - # execute transaction for each pipes - for pipe_id, mapping in mappings.items(): - pipe = pipes[pipe_id].pipe try: pipe.multi() - pipe.mset(mapping) if expires: - for key in mapping.keys(): - pipe.expire(key, expires) + pipe.setex(key, expires, value) + else: + pipe.set(key, value) pipe.execute() + results[key] = True except redis_module.exceptions.WatchError: pass @@ -408,14 +445,7 @@ def compare_and_swap(self, items, expires=None): finally: pipe.reset() - # get keys associated to pipes but not updated - for key, pipe in self.pipes.items(): - if pipe.id in pipes: - remove_keys.append(key) - - # remove keys - for key in remove_keys: - self.pipes.pop(key, None) + return results def clear(self): """Implements :meth:`GlobalCache.clear`.""" @@ -599,17 +629,35 @@ def set(self, items, expires=None): ) return {key: MemcacheCache.KeyNotSet(key) for key in unset_keys} + def set_if_not_exists(self, items, expires=None): + """Implements :meth:`GlobalCache.set_if_not_exists`.""" + expires = expires if expires else 0 + results = {} + for key, value in items.items(): + results[key] = self.client.add( + self._key(key), value, expire=expires, noreply=False + ) + + return results + def delete(self, keys): """Implements :meth:`GlobalCache.delete`.""" keys = [self._key(key) for key in keys] self.client.delete_many(keys) - def watch(self, keys): + def watch(self, items): """Implements :meth:`GlobalCache.watch`.""" - keys = [self._key(key) for key in keys] caskeys = self.caskeys + keys = [] + prev_values = {} + for key, prev_value in items.items(): + key = self._key(key) + keys.append(key) + prev_values[key] = prev_value + for key, (value, caskey) in self.client.gets_many(keys).items(): - caskeys[key] = caskey + if prev_values[key] == value: + caskeys[key] = caskey def unwatch(self, keys): """Implements :meth:`GlobalCache.unwatch`.""" @@ -621,14 +669,19 @@ def unwatch(self, keys): def compare_and_swap(self, items, expires=None): """Implements :meth:`GlobalCache.compare_and_swap`.""" caskeys = self.caskeys - for key, value in items.items(): - key = self._key(key) + results = {} + for orig_key, value in items.items(): + key = self._key(orig_key) caskey = caskeys.pop(key, None) if caskey is None: continue expires = expires if expires else 0 - self.client.cas(key, value, caskey, expire=expires) + results[orig_key] = bool( + self.client.cas(key, value, caskey, expire=expires, noreply=False) + ) + + return results def clear(self): """Implements :meth:`GlobalCache.clear`.""" diff --git a/tests/system/test_crud.py b/tests/system/test_crud.py index 40123b4b..17ff6e20 100644 --- a/tests/system/test_crud.py +++ b/tests/system/test_crud.py @@ -110,8 +110,10 @@ class SomeKind(ndb.Model): cache_key = _cache.global_cache_key(key._key) assert cache_key in cache_dict - patch = mock.patch("google.cloud.ndb._datastore_api._LookupBatch.add") - patch.side_effect = Exception("Shouldn't call this") + patch = mock.patch( + "google.cloud.ndb._datastore_api._LookupBatch.add", + mock.Mock(side_effect=Exception("Shouldn't call this")), + ) with patch: entity = key.get() assert isinstance(entity, SomeKind) @@ -140,8 +142,10 @@ class SomeKind(ndb.Model): cache_key = _cache.global_cache_key(key._key) assert redis_context.global_cache.redis.get(cache_key) is not None - patch = mock.patch("google.cloud.ndb._datastore_api._LookupBatch.add") - patch.side_effect = Exception("Shouldn't call this") + patch = mock.patch( + "google.cloud.ndb._datastore_api._LookupBatch.add", + mock.Mock(side_effect=Exception("Shouldn't call this")), + ) with patch: entity = key.get() assert isinstance(entity, SomeKind) @@ -171,8 +175,10 @@ class SomeKind(ndb.Model): cache_key = global_cache_module.MemcacheCache._key(cache_key) assert memcache_context.global_cache.client.get(cache_key) is not None - patch = mock.patch("google.cloud.ndb._datastore_api._LookupBatch.add") - patch.side_effect = Exception("Shouldn't call this") + patch = mock.patch( + "google.cloud.ndb._datastore_api._LookupBatch.add", + mock.Mock(side_effect=Exception("Shouldn't call this")), + ) with patch: entity = key.get() assert isinstance(entity, SomeKind) @@ -587,8 +593,6 @@ class SomeKind(ndb.Model): entity.foo = 43 entity.put() - # This is py27 behavior. I can see a case being made for caching the - # entity on write rather than waiting for a subsequent lookup. assert cache_key not in cache_dict @@ -613,8 +617,6 @@ class SomeKind(ndb.Model): entity.foo = 43 entity.put() - # This is py27 behavior. I can see a case being made for caching the - # entity on write rather than waiting for a subsequent lookup. assert redis_context.global_cache.redis.get(cache_key) is None @@ -640,8 +642,6 @@ class SomeKind(ndb.Model): entity.foo = 43 entity.put() - # This is py27 behavior. I can see a case being made for caching the - # entity on write rather than waiting for a subsequent lookup. assert memcache_context.global_cache.client.get(cache_key) is None @@ -783,7 +783,7 @@ class SomeKind(ndb.Model): # This is py27 behavior. Not entirely sold on leaving _LOCKED value for # Datastore misses. assert key.get() is None - assert cache_dict[cache_key][0] == b"0" + assert cache_dict[cache_key][0].startswith(b"0-") @pytest.mark.skipif(not USE_REDIS_CACHE, reason="Redis is not configured") @@ -806,7 +806,7 @@ class SomeKind(ndb.Model): # This is py27 behavior. Not entirely sold on leaving _LOCKED value for # Datastore misses. assert key.get() is None - assert redis_context.global_cache.redis.get(cache_key) == b"0" + assert redis_context.global_cache.redis.get(cache_key).startswith(b"0-") @pytest.mark.skipif(not USE_MEMCACHE, reason="Memcache is not configured") @@ -830,7 +830,7 @@ class SomeKind(ndb.Model): # This is py27 behavior. Not entirely sold on leaving _LOCKED value for # Datastore misses. assert key.get() is None - assert memcache_context.global_cache.client.get(cache_key) == b"0" + assert memcache_context.global_cache.client.get(cache_key).startswith(b"0-") @pytest.mark.usefixtures("client_context") diff --git a/tests/unit/test__cache.py b/tests/unit/test__cache.py index d835f8c3..7eb3f519 100644 --- a/tests/unit/test__cache.py +++ b/tests/unit/test__cache.py @@ -412,7 +412,6 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): def test_add_and_idle_and_done_callbacks_w_error(in_context): error = Exception("spurious error") cache = mock.Mock(spec=("set",)) - cache.set.return_value = [] cache.set.return_value = tasklets.Future() cache.set.return_value.set_exception(error) @@ -452,6 +451,182 @@ class SpeciousError(Exception): assert future2.result() +@pytest.mark.usefixtures("in_context") +class Test_global_set_if_not_exists: + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_without_expires(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert _cache.global_set_if_not_exists(b"key", b"value").result() == "hi mom!" + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {} + ) + batch.add.assert_called_once_with(b"key", b"value") + + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_transientError(_batch, _global_cache): + class TransientError(Exception): + pass + + batch = _batch.get_batch.return_value + future = _future_exception(TransientError("oops, mom!")) + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(TransientError,), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert _cache.global_set_if_not_exists(b"key", b"value").result() is False + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {} + ) + batch.add.assert_called_once_with(b"key", b"value") + + @staticmethod + @mock.patch("google.cloud.ndb._cache._global_cache") + @mock.patch("google.cloud.ndb._cache._batch") + def test_with_expires(_batch, _global_cache): + batch = _batch.get_batch.return_value + future = _future_result("hi mom!") + batch.add.return_value = future + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + assert ( + _cache.global_set_if_not_exists(b"key", b"value", expires=123).result() + == "hi mom!" + ) + _batch.get_batch.assert_called_once_with( + _cache._GlobalCacheSetIfNotExistsBatch, {"expires": 123} + ) + batch.add.assert_called_once_with(b"key", b"value") + + +class Test_GlobalCacheSetIfNotExistsBatch: + @staticmethod + def test_add_duplicate_key_and_value(): + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"foo", b"one") + assert not future1.done() + assert future2.result() is False + + @staticmethod + def test_add_and_idle_and_done_callbacks(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = {} + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires is None + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=None + ) + assert future1.result() is None + assert future2.result() is None + + @staticmethod + def test_add_and_idle_and_done_callbacks_with_duplicate_keys(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = {b"foo": True} + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"foo", b"two") + + assert batch.expires is None + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with({b"foo": b"one"}, expires=None) + assert future1.result() is True + assert future2.result() is False + + @staticmethod + def test_add_and_idle_and_done_callbacks_with_expires(in_context): + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = [] + + batch = _cache._GlobalCacheSetIfNotExistsBatch({"expires": 5}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires == 5 + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=5 + ) + assert future1.result() is None + assert future2.result() is None + + @staticmethod + def test_add_and_idle_and_done_callbacks_w_error(in_context): + error = Exception("spurious error") + cache = mock.Mock(spec=("set_if_not_exists",)) + cache.set_if_not_exists.return_value = tasklets.Future() + cache.set_if_not_exists.return_value.set_exception(error) + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + with in_context.new(global_cache=cache).use(): + batch.idle_callback() + + cache.set_if_not_exists.assert_called_once_with( + {b"foo": b"one", b"bar": b"two"}, expires=None + ) + assert future1.exception() is error + assert future2.exception() is error + + @staticmethod + def test_done_callbacks_with_results(in_context): + class SpeciousError(Exception): + pass + + cache_call = _future_result( + { + b"foo": "this is a result", + b"bar": SpeciousError("this is also a kind of result"), + } + ) + + batch = _cache._GlobalCacheSetIfNotExistsBatch({}) + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + batch.done_callback(cache_call) + + assert future1.result() == "this is a result" + with pytest.raises(SpeciousError): + assert future2.result() + + @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._cache._global_cache") @mock.patch("google.cloud.ndb._cache._batch") @@ -500,24 +675,28 @@ def test_global_watch(_batch, _global_cache): spec=("transient_errors", "strict_read"), ) - assert _cache.global_watch(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with(_cache._GlobalCacheWatchBatch) - batch.add.assert_called_once_with(b"key") + assert _cache.global_watch(b"key", b"value").result() == "hi mom!" + _batch.get_batch.assert_called_once_with(_cache._GlobalCacheWatchBatch, {}) + batch.add.assert_called_once_with(b"key", b"value") +@pytest.mark.usefixtures("in_context") class Test_GlobalCacheWatchBatch: @staticmethod def test_add_and_idle_and_done_callbacks(in_context): - cache = mock.Mock() + cache = mock.Mock(spec=("watch",)) + cache.watch.return_value = None batch = _cache._GlobalCacheWatchBatch({}) - future1 = batch.add(b"foo") - future2 = batch.add(b"bar") + future1 = batch.add(b"foo", b"one") + future2 = batch.add(b"bar", b"two") + + assert batch.expires is None with in_context.new(global_cache=cache).use(): batch.idle_callback() - cache.watch.assert_called_once_with([b"foo", b"bar"]) + cache.watch.assert_called_once_with({b"foo": b"one", b"bar": b"two"}) assert future1.result() is None assert future2.result() is None @@ -536,7 +715,7 @@ def test_global_unwatch(_batch, _global_cache): ) assert _cache.global_unwatch(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with(_cache._GlobalCacheUnwatchBatch) + _batch.get_batch.assert_called_once_with(_cache._GlobalCacheUnwatchBatch, {}) batch.add.assert_called_once_with(b"key") @@ -643,28 +822,195 @@ def test_add_and_idle_and_done_callbacks_with_expires(in_context): @pytest.mark.usefixtures("in_context") -@mock.patch("google.cloud.ndb._cache._global_cache") -@mock.patch("google.cloud.ndb._cache._batch") -def test_global_lock(_batch, _global_cache): - batch = _batch.get_batch.return_value - future = _future_result("hi mom!") - batch.add.return_value = future - _global_cache.return_value = mock.Mock( - transient_errors=(), - strict_write=False, - spec=("transient_errors", "strict_write"), - ) +class Test_global_lock_for_read: + @staticmethod + @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") + def test_lock_acquired(global_set_if_not_exists): + global_set_if_not_exists.return_value = _future_result(True) + assert ( + _cache.global_lock_for_read(b"key") + .result() + .startswith(_cache._LOCKED_FOR_READ) + ) - assert _cache.global_lock(b"key").result() == "hi mom!" - _batch.get_batch.assert_called_once_with( - _cache._GlobalCacheSetBatch, {"expires": _cache._LOCK_TIME} - ) - batch.add.assert_called_once_with(b"key", _cache._LOCKED) + @staticmethod + @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") + def test_lock_not_acquired(global_set_if_not_exists): + global_set_if_not_exists.return_value = _future_result(False) + assert _cache.global_lock_for_read(b"key").result() is None + + +@pytest.mark.usefixtures("in_context") +class Test_global_lock_for_write: + @staticmethod + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache.global_set_if_not_exists") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_first_time(_global_cache, _global_get, global_set_if_not_exists, uuid): + uuid.uuid4.return_value = "arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + lock_value = _cache._LOCKED_FOR_WRITE + b".arandomuuid" + _global_get.return_value = _future_result(None) + global_set_if_not_exists.return_value = _future_result(True) + + assert _cache.global_lock_for_write(b"key").result() == b".arandomuuid" + _global_get.assert_called_once_with(b"key") + global_set_if_not_exists.assert_called_once_with(b"key", lock_value, expires=32) + + @staticmethod + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_compare_and_swap") + @mock.patch("google.cloud.ndb._cache._global_watch") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_not_first_time_fail_once( + _global_cache, _global_get, _global_watch, _global_compare_and_swap, uuid + ): + uuid.uuid4.return_value = "arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + old_lock_value = _cache._LOCKED_FOR_WRITE + b".whatevs" + new_lock_value = old_lock_value + b".arandomuuid" + _global_get.return_value = _future_result(old_lock_value) + _global_watch.return_value = _future_result(None) + _global_compare_and_swap.side_effect = ( + _future_result(False), + _future_result(True), + ) + assert _cache.global_lock_for_write(b"key").result() == b".arandomuuid" + _global_get.assert_has_calls( + [ + mock.call(b"key"), + mock.call(b"key"), + ] + ) + _global_watch.assert_has_calls( + [ + mock.call(b"key", old_lock_value), + mock.call(b"key", old_lock_value), + ] + ) + _global_compare_and_swap.assert_has_calls( + [ + mock.call(b"key", new_lock_value, expires=32), + mock.call(b"key", new_lock_value, expires=32), + ] + ) + + +@pytest.mark.usefixtures("in_context") +class Test_global_unlock_for_write: + @staticmethod + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_delete") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_last_time(_global_cache, _global_get, _global_delete, uuid): + lock = b".arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + lock_value = _cache._LOCKED_FOR_WRITE + lock + _global_get.return_value = _future_result(lock_value) + _global_delete.return_value = _future_result(None) + + assert _cache.global_unlock_for_write(b"key", lock).result() is None + _global_get.assert_called_once_with(b"key") + _global_delete.assert_called_once_with(b"key") + + @staticmethod + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_delete") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_transient_error(_global_cache, _global_get, _global_delete, uuid): + class TransientError(Exception): + pass + + lock = b".arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(TransientError,), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + lock_value = _cache._LOCKED_FOR_WRITE + lock + _global_get.return_value = _future_result(lock_value) + _global_delete.return_value = _future_exception(TransientError()) + + assert _cache.global_unlock_for_write(b"key", lock).result() is None + _global_get.assert_called_once_with(b"key") + _global_delete.assert_called_once_with(b"key") + + @staticmethod + @mock.patch("google.cloud.ndb._cache.uuid") + @mock.patch("google.cloud.ndb._cache._global_compare_and_swap") + @mock.patch("google.cloud.ndb._cache._global_watch") + @mock.patch("google.cloud.ndb._cache._global_get") + @mock.patch("google.cloud.ndb._cache._global_cache") + def test_not_last_time_fail_once( + _global_cache, _global_get, _global_watch, _global_compare_and_swap, uuid + ): + lock = b".arandomuuid" + + _global_cache.return_value = mock.Mock( + transient_errors=(), + strict_write=False, + spec=("transient_errors", "strict_write"), + ) + + new_lock_value = _cache._LOCKED_FOR_WRITE + b".whatevs" + old_lock_value = new_lock_value + lock + _global_get.return_value = _future_result(old_lock_value) + _global_watch.return_value = _future_result(None) + _global_compare_and_swap.side_effect = ( + _future_result(False), + _future_result(True), + ) + + assert _cache.global_unlock_for_write(b"key", lock).result() is None + _global_get.assert_has_calls( + [ + mock.call(b"key"), + mock.call(b"key"), + ] + ) + _global_watch.assert_has_calls( + [ + mock.call(b"key", old_lock_value), + mock.call(b"key", old_lock_value), + ] + ) + _global_compare_and_swap.assert_has_calls( + [ + mock.call(b"key", new_lock_value, expires=32), + mock.call(b"key", new_lock_value, expires=32), + ] + ) def test_is_locked_value(): - assert _cache.is_locked_value(_cache._LOCKED) - assert not _cache.is_locked_value("new db, who dis?") + assert _cache.is_locked_value(_cache._LOCKED_FOR_READ) + assert _cache.is_locked_value(_cache._LOCKED_FOR_WRITE + b"whatever") + assert not _cache.is_locked_value(b"new db, who dis?") + assert not _cache.is_locked_value(None) def test_global_cache_key(): diff --git a/tests/unit/test__datastore_api.py b/tests/unit/test__datastore_api.py index 8e81fe15..d2e3a0ee 100644 --- a/tests/unit/test__datastore_api.py +++ b/tests/unit/test__datastore_api.py @@ -270,6 +270,32 @@ class SomeKind(model.Model): assert global_cache.get([cache_key]) == [cache_value] + @staticmethod + @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") + def test_cache_miss_followed_by_lock_acquisition_failure( + _LookupBatch, global_cache + ): + class SomeKind(model.Model): + pass + + key = key_module.Key("SomeKind", 1) + cache_key = _cache.global_cache_key(key._key) + + entity = SomeKind(key=key) + entity_pb = model._entity_to_protobuf(entity) + + batch = _LookupBatch.return_value + batch.add.return_value = future_result(entity_pb) + + global_cache.set_if_not_exists = mock.Mock( + return_value=future_result({cache_key: False}) + ) + + future = _api.lookup(key._key, _options.ReadOptions()) + assert future.result() == entity_pb + + assert global_cache.get([cache_key]) == [None] + @staticmethod @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") def test_cache_miss_no_datastore(_LookupBatch, global_cache): @@ -320,7 +346,7 @@ class SomeKind(model.Model): entity = SomeKind(key=key) entity_pb = model._entity_to_protobuf(entity) - global_cache.set({cache_key: _cache._LOCKED}) + global_cache.set({cache_key: _cache._LOCKED_FOR_READ}) batch = _LookupBatch.return_value batch.add.return_value = future_result(entity_pb) @@ -328,7 +354,7 @@ class SomeKind(model.Model): future = _api.lookup(key._key, _options.ReadOptions()) assert future.result() == entity_pb - assert global_cache.get([cache_key]) == [_cache._LOCKED] + assert global_cache.get([cache_key]) == [_cache._LOCKED_FOR_READ] @staticmethod @mock.patch("google.cloud.ndb._datastore_api._LookupBatch") @@ -345,7 +371,7 @@ class SomeKind(model.Model): future = _api.lookup(key._key, _options.ReadOptions()) assert future.result() is _api._NOT_FOUND - assert global_cache.get([cache_key]) == [_cache._LOCKED] + assert global_cache.get([cache_key])[0].startswith(_cache._LOCKED_FOR_READ) assert len(global_cache._watch_keys) == 0 @@ -716,8 +742,10 @@ class SomeKind(model.Model): pass context = context_module.get_context() - with context.new(transaction=b"abc123").use() as in_context: - in_context.global_cache_flush_keys = set() + callbacks = [] + with context.new( + transaction=b"abc123", transaction_complete_callbacks=callbacks + ).use(): key = key_module.Key("SomeKind", 1) cache_key = _cache.global_cache_key(key._key) @@ -728,7 +756,11 @@ class SomeKind(model.Model): future = _api.put(model._entity_to_ds_entity(entity), _options.Options()) assert future.result() is None - assert in_context.global_cache_flush_keys == {cache_key} + assert cache_key in global_cache.cache # lock + for callback in callbacks: + callback() + + assert cache_key not in global_cache.cache # unlocked by callback @staticmethod @mock.patch("google.cloud.ndb._datastore_api._NonTransactionalCommitBatch") @@ -843,8 +875,10 @@ def test_cache_enabled(Batch, global_cache): @mock.patch("google.cloud.ndb._datastore_api._NonTransactionalCommitBatch") def test_w_transaction(Batch, global_cache): context = context_module.get_context() - with context.new(transaction=b"abc123").use() as in_context: - in_context.global_cache_flush_keys = set() + callbacks = [] + with context.new( + transaction=b"abc123", transaction_complete_callbacks=callbacks + ).use(): key = key_module.Key("SomeKind", 1) cache_key = _cache.global_cache_key(key._key) @@ -854,7 +888,11 @@ def test_w_transaction(Batch, global_cache): future = _api.delete(key._key, _options.Options()) assert future.result() is None - assert in_context.global_cache_flush_keys == {cache_key} + assert cache_key in global_cache.cache # lock + for callback in callbacks: + callback() + + assert cache_key not in global_cache.cache # lock removed by callback @staticmethod @mock.patch("google.cloud.ndb._datastore_api._NonTransactionalCommitBatch") diff --git a/tests/unit/test__transaction.py b/tests/unit/test__transaction.py index 8f48a206..435b9840 100644 --- a/tests/unit/test__transaction.py +++ b/tests/unit/test__transaction.py @@ -28,8 +28,6 @@ from google.cloud.ndb import tasklets from google.cloud.ndb import _transaction -from . import utils - class Test_in_transaction: @staticmethod @@ -90,12 +88,47 @@ class Test_transaction_async: @mock.patch("google.cloud.ndb._datastore_api") def test_success(_datastore_api): context_module.get_context().cache["foo"] = "bar" + + def callback(): + # The transaction uses its own in-memory cache, which should be empty in + # the transaction context and not include the key set above. + context = context_module.get_context() + assert not context.cache + + return "I tried, momma." + + begin_future = tasklets.Future("begin transaction") + _datastore_api.begin_transaction.return_value = begin_future + + commit_future = tasklets.Future("commit transaction") + _datastore_api.commit.return_value = commit_future + + future = _transaction.transaction_async(callback) + + _datastore_api.begin_transaction.assert_called_once_with(False, retries=0) + begin_future.set_result(b"tx123") + + _datastore_api.commit.assert_called_once_with(b"tx123", retries=0) + commit_future.set_result(None) + + assert future.result() == "I tried, momma." + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._datastore_api") + def test_success_w_callbacks(_datastore_api): + context_module.get_context().cache["foo"] = "bar" on_commit_callback = mock.Mock() + transaction_complete_callback = mock.Mock() def callback(): + # The transaction uses its own in-memory cache, which should be empty in + # the transaction context and not include the key set above. context = context_module.get_context() assert not context.cache + context.call_on_commit(on_commit_callback) + context.call_on_transaction_complete(transaction_complete_callback) return "I tried, momma." begin_future = tasklets.Future("begin transaction") @@ -114,6 +147,46 @@ def callback(): assert future.result() == "I tried, momma." on_commit_callback.assert_called_once_with() + transaction_complete_callback.assert_called_once_with() + + @staticmethod + @pytest.mark.usefixtures("in_context") + @mock.patch("google.cloud.ndb._datastore_api") + def test_failure_w_callbacks(_datastore_api): + class SpuriousError(Exception): + pass + + context_module.get_context().cache["foo"] = "bar" + on_commit_callback = mock.Mock() + transaction_complete_callback = mock.Mock() + + def callback(): + context = context_module.get_context() + assert not context.cache + context.call_on_commit(on_commit_callback) + context.call_on_transaction_complete(transaction_complete_callback) + raise SpuriousError() + + begin_future = tasklets.Future("begin transaction") + _datastore_api.begin_transaction.return_value = begin_future + + rollback_future = tasklets.Future("rollback transaction") + _datastore_api.rollback.return_value = rollback_future + + future = _transaction.transaction_async(callback) + + _datastore_api.begin_transaction.assert_called_once_with(False, retries=0) + begin_future.set_result(b"tx123") + + _datastore_api.commit.assert_not_called() + _datastore_api.rollback.assert_called_once_with(b"tx123") + rollback_future.set_result(None) + + with pytest.raises(SpuriousError): + future.result() + + on_commit_callback.assert_not_called() + transaction_complete_callback.assert_called_once_with() @staticmethod def test_success_join(in_context): @@ -407,35 +480,6 @@ def callback(): assert future.result() == "I tried, momma." - @staticmethod - @pytest.mark.usefixtures("in_context") - @mock.patch("google.cloud.ndb._cache") - @mock.patch("google.cloud.ndb._datastore_api") - def test_success_flush_keys(_datastore_api, _cache): - def callback(): - context = context_module.get_context() - context.global_cache_flush_keys.add(b"abc123") - return "I tried, momma." - - _cache.global_delete.return_value = utils.future_result(None) - - begin_future = tasklets.Future("begin transaction") - _datastore_api.begin_transaction.return_value = begin_future - - commit_future = tasklets.Future("commit transaction") - _datastore_api.commit.return_value = commit_future - - future = _transaction.transaction_async(callback, retries=0) - - _datastore_api.begin_transaction.assert_called_once_with(False, retries=0) - begin_future.set_result(b"tx123") - - _datastore_api.commit.assert_called_once_with(b"tx123", retries=0) - commit_future.set_result(None) - - assert future.result() == "I tried, momma." - _cache.global_delete.assert_called_once_with(b"abc123") - @staticmethod @pytest.mark.usefixtures("in_context") @mock.patch("google.cloud.ndb._datastore_api") diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 0222b7cb..fda9e60e 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -93,11 +93,6 @@ def test_new_transaction(self): assert new_context.transaction == "tx123" assert context.transaction is None - def test_new_global_cache_flush_keys(self): - context = self._make_one(global_cache_flush_keys={"hi", "mom!"}) - new_context = context.new() - assert new_context.global_cache_flush_keys == {"hi", "mom!"} - def test_new_with_cache(self): context = self._make_one() context.cache["foo"] = "bar" @@ -336,6 +331,21 @@ def test_call_on_commit_with_transaction(self): context.call_on_commit(callback) assert context.on_commit_callbacks == ["himom!"] + def test_call_on_transaction_complete(self): + context = self._make_one() + callback = mock.Mock() + context.call_on_transaction_complete(callback) + callback.assert_called_once_with() + + def test_call_on_transaction_complete_with_transaction(self): + callbacks = [] + callback = "himom!" + context = self._make_one( + transaction=b"tx123", transaction_complete_callbacks=callbacks + ) + context.call_on_transaction_complete(callback) + assert context.transaction_complete_callbacks == ["himom!"] + def test_in_transaction(self): context = self._make_one() assert context.in_transaction() is False diff --git a/tests/unit/test_global_cache.py b/tests/unit/test_global_cache.py index 0a724a23..d2a7b560 100644 --- a/tests/unit/test_global_cache.py +++ b/tests/unit/test_global_cache.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections + try: from unittest import mock except ImportError: # pragma: NO PY3 COVER @@ -32,6 +34,9 @@ def get(self, keys): def set(self, items, expires=None): return super(MockImpl, self).set(items, expires=expires) + def set_if_not_exists(self, items, expires=None): + return super(MockImpl, self).set_if_not_exists(items, expires=expires) + def delete(self, keys): return super(MockImpl, self).delete(keys) @@ -59,6 +64,11 @@ def test_set(self): with pytest.raises(NotImplementedError): cache.set({b"foo": "bar"}) + def test_set_if_not_exists(self): + cache = self.make_one() + with pytest.raises(NotImplementedError): + cache.set_if_not_exists({b"foo": "bar"}) + def test_delete(self): cache = self.make_one() with pytest.raises(NotImplementedError): @@ -123,10 +133,44 @@ def test_set_get_delete_w_expires(time): result = cache.get([b"two", b"three", b"one"]) assert result == [None, None, None] + @staticmethod + def test_set_if_not_exists(): + cache = global_cache._InProcessGlobalCache() + result = cache.set_if_not_exists({b"one": b"foo", b"two": b"bar"}) + assert result == {b"one": True, b"two": True} + + result = cache.set_if_not_exists({b"two": b"bar", b"three": b"baz"}) + assert result == {b"two": False, b"three": True} + + result = cache.get([b"two", b"three", b"one"]) + assert result == [b"bar", b"baz", b"foo"] + + @staticmethod + @mock.patch("google.cloud.ndb.global_cache.time") + def test_set_if_not_exists_w_expires(time): + time.time.return_value = 0 + + cache = global_cache._InProcessGlobalCache() + result = cache.set_if_not_exists({b"one": b"foo", b"two": b"bar"}, expires=5) + assert result == {b"one": True, b"two": True} + + result = cache.set_if_not_exists({b"two": b"bar", b"three": b"baz"}, expires=5) + assert result == {b"two": False, b"three": True} + + result = cache.get([b"two", b"three", b"one"]) + assert result == [b"bar", b"baz", b"foo"] + + time.time.return_value = 10 + result = cache.get([b"two", b"three", b"one"]) + assert result == [None, None, None] + @staticmethod def test_watch_compare_and_swap(): cache = global_cache._InProcessGlobalCache() - result = cache.watch([b"one", b"two", b"three"]) + cache.cache[b"one"] = (b"food", None) + cache.cache[b"two"] = (b"bard", None) + cache.cache[b"three"] = (b"bazz", None) + result = cache.watch({b"one": b"food", b"two": b"bard", b"three": b"bazd"}) assert result is None cache.cache[b"two"] = (b"hamburgers", None) @@ -134,10 +178,10 @@ def test_watch_compare_and_swap(): result = cache.compare_and_swap( {b"one": b"foo", b"two": b"bar", b"three": b"baz"} ) - assert result is None + assert result == {b"one": True, b"two": False, b"three": False} result = cache.get([b"one", b"two", b"three"]) - assert result == [b"foo", b"hamburgers", b"baz"] + assert result == [b"foo", b"hamburgers", b"bazz"] @staticmethod @mock.patch("google.cloud.ndb.global_cache.time") @@ -145,7 +189,10 @@ def test_watch_compare_and_swap_with_expires(time): time.time.return_value = 0 cache = global_cache._InProcessGlobalCache() - result = cache.watch([b"one", b"two", b"three"]) + cache.cache[b"one"] = (b"food", None) + cache.cache[b"two"] = (b"bard", None) + cache.cache[b"three"] = (b"bazz", None) + result = cache.watch({b"one": b"food", b"two": b"bard", b"three": b"bazd"}) assert result is None cache.cache[b"two"] = (b"hamburgers", None) @@ -153,20 +200,20 @@ def test_watch_compare_and_swap_with_expires(time): result = cache.compare_and_swap( {b"one": b"foo", b"two": b"bar", b"three": b"baz"}, expires=5 ) - assert result is None + assert result == {b"one": True, b"two": False, b"three": False} result = cache.get([b"one", b"two", b"three"]) - assert result == [b"foo", b"hamburgers", b"baz"] + assert result == [b"foo", b"hamburgers", b"bazz"] time.time.return_value = 10 result = cache.get([b"one", b"two", b"three"]) - assert result == [None, b"hamburgers", None] + assert result == [None, b"hamburgers", b"bazz"] @staticmethod def test_watch_unwatch(): cache = global_cache._InProcessGlobalCache() - result = cache.watch([b"one", b"two", b"three"]) + result = cache.watch({b"one": "foo", b"two": "bar", b"three": "baz"}) assert result is None result = cache.unwatch([b"one", b"two", b"three"]) @@ -234,6 +281,37 @@ def mock_expire(key, expires): redis.mset.assert_called_once_with(cache_items) assert expired == {"a": 32, "b": 32} + @staticmethod + def test_set_if_not_exists(): + redis = mock.Mock(spec=("setnx",)) + redis.setnx.side_effect = (True, False) + cache_items = collections.OrderedDict([("a", "foo"), ("b", "bar")]) + cache = global_cache.RedisCache(redis) + results = cache.set_if_not_exists(cache_items) + assert results == {"a": True, "b": False} + redis.setnx.assert_has_calls( + [ + mock.call("a", "foo"), + mock.call("b", "bar"), + ] + ) + + @staticmethod + def test_set_if_not_exists_w_expires(): + redis = mock.Mock(spec=("setnx", "expire")) + redis.setnx.side_effect = (True, False) + cache_items = collections.OrderedDict([("a", "foo"), ("b", "bar")]) + cache = global_cache.RedisCache(redis) + results = cache.set_if_not_exists(cache_items, expires=123) + assert results == {"a": True, "b": False} + redis.setnx.assert_has_calls( + [ + mock.call("a", "foo"), + mock.call("b", "bar"), + ] + ) + redis.expire.assert_called_once_with("a", 123) + @staticmethod def test_delete(): redis = mock.Mock(spec=("delete",)) @@ -243,107 +321,119 @@ def test_delete(): redis.delete.assert_called_once_with(*cache_keys) @staticmethod - @mock.patch("google.cloud.ndb.global_cache.uuid") - def test_watch(uuid): - uuid.uuid4.return_value = "abc123" - redis = mock.Mock(pipeline=mock.Mock(spec=("watch",)), spec=("pipeline",)) + def test_watch(): + def mock_redis_get(key): + if key == "foo": + return "moo" + + return "nope" + + redis = mock.Mock( + pipeline=mock.Mock(spec=("watch", "get", "reset")), spec=("pipeline",) + ) pipe = redis.pipeline.return_value - keys = ["foo", "bar"] + pipe.get.side_effect = mock_redis_get + items = {"foo": "moo", "bar": "car"} cache = global_cache.RedisCache(redis) - cache.watch(keys) + cache.watch(items) + + pipe.watch.assert_has_calls( + [ + mock.call("foo"), + mock.call("bar"), + ], + any_order=True, + ) - pipe.watch.assert_called_once_with("foo", "bar") - assert cache.pipes == { - "foo": global_cache._Pipeline(pipe, "abc123"), - "bar": global_cache._Pipeline(pipe, "abc123"), - } + pipe.get.assert_has_calls( + [ + mock.call("foo"), + mock.call("bar"), + ], + any_order=True, + ) + + assert cache.pipes == {"foo": pipe} @staticmethod def test_unwatch(): redis = mock.Mock(spec=()) cache = global_cache.RedisCache(redis) - pipe1 = mock.Mock(spec=("reset",)) - pipe2 = mock.Mock(spec=("reset",)) + pipe = mock.Mock(spec=("reset",)) cache._pipes.pipes = { - "ay": global_cache._Pipeline(pipe1, "abc123"), - "be": global_cache._Pipeline(pipe1, "abc123"), - "see": global_cache._Pipeline(pipe2, "def456"), - "dee": global_cache._Pipeline(pipe2, "def456"), - "whatevs": global_cache._Pipeline(None, "himom!"), + "ay": pipe, + "be": pipe, + "see": pipe, + "dee": pipe, + "whatevs": "himom!", } cache.unwatch(["ay", "be", "see", "dee", "nuffin"]) - assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")} + assert cache.pipes == {"whatevs": "himom!"} + pipe.reset.assert_has_calls([mock.call()] * 4) @staticmethod def test_compare_and_swap(): redis = mock.Mock(spec=()) cache = global_cache.RedisCache(redis) - pipe1 = mock.Mock(spec=("multi", "mset", "execute", "reset")) - pipe2 = mock.Mock(spec=("multi", "mset", "execute", "reset")) + pipe1 = mock.Mock(spec=("multi", "set", "execute", "reset")) + pipe2 = mock.Mock(spec=("multi", "set", "execute", "reset")) + pipe2.execute.side_effect = redis_module.exceptions.WatchError cache._pipes.pipes = { - "ay": global_cache._Pipeline(pipe1, "abc123"), - "be": global_cache._Pipeline(pipe1, "abc123"), - "see": global_cache._Pipeline(pipe2, "def456"), - "dee": global_cache._Pipeline(pipe2, "def456"), - "whatevs": global_cache._Pipeline(None, "himom!"), + "foo": pipe1, + "bar": pipe2, } - pipe2.execute.side_effect = redis_module.exceptions.WatchError - items = {"ay": "foo", "be": "bar", "see": "baz", "wut": "huh?"} - cache.compare_and_swap(items) + result = cache.compare_and_swap( + { + "foo": "moo", + "bar": "car", + "baz": "maz", + } + ) + assert result == {"foo": True, "bar": False, "baz": False} pipe1.multi.assert_called_once_with() - pipe2.multi.assert_called_once_with() - pipe1.mset.assert_called_once_with({"ay": "foo", "be": "bar"}) - pipe2.mset.assert_called_once_with({"see": "baz"}) + pipe1.set.assert_called_once_with("foo", "moo") pipe1.execute.assert_called_once_with() - pipe2.execute.assert_called_once_with() pipe1.reset.assert_called_once_with() - pipe2.reset.assert_called_once_with() - assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")} + pipe2.multi.assert_called_once_with() + pipe2.set.assert_called_once_with("bar", "car") + pipe2.execute.assert_called_once_with() + pipe2.reset.assert_called_once_with() @staticmethod def test_compare_and_swap_w_expires(): - expired = {} - - def mock_expire(key, expires): - expired[key] = expires - redis = mock.Mock(spec=()) cache = global_cache.RedisCache(redis) - pipe1 = mock.Mock( - expire=mock_expire, - spec=("multi", "mset", "execute", "expire", "reset"), - ) - pipe2 = mock.Mock( - expire=mock_expire, - spec=("multi", "mset", "execute", "expire", "reset"), - ) + pipe1 = mock.Mock(spec=("multi", "setex", "execute", "reset")) + pipe2 = mock.Mock(spec=("multi", "setex", "execute", "reset")) + pipe2.execute.side_effect = redis_module.exceptions.WatchError cache._pipes.pipes = { - "ay": global_cache._Pipeline(pipe1, "abc123"), - "be": global_cache._Pipeline(pipe1, "abc123"), - "see": global_cache._Pipeline(pipe2, "def456"), - "dee": global_cache._Pipeline(pipe2, "def456"), - "whatevs": global_cache._Pipeline(None, "himom!"), + "foo": pipe1, + "bar": pipe2, } - pipe2.execute.side_effect = redis_module.exceptions.WatchError - items = {"ay": "foo", "be": "bar", "see": "baz", "wut": "huh?"} - cache.compare_and_swap(items, expires=32) + result = cache.compare_and_swap( + { + "foo": "moo", + "bar": "car", + "baz": "maz", + }, + expires=5, + ) + assert result == {"foo": True, "bar": False, "baz": False} pipe1.multi.assert_called_once_with() - pipe2.multi.assert_called_once_with() - pipe1.mset.assert_called_once_with({"ay": "foo", "be": "bar"}) - pipe2.mset.assert_called_once_with({"see": "baz"}) + pipe1.setex.assert_called_once_with("foo", 5, "moo") pipe1.execute.assert_called_once_with() - pipe2.execute.assert_called_once_with() pipe1.reset.assert_called_once_with() - pipe2.reset.assert_called_once_with() - assert cache.pipes == {"whatevs": global_cache._Pipeline(None, "himom!")} - assert expired == {"ay": 32, "be": 32, "see": 32} + pipe2.multi.assert_called_once_with() + pipe2.setex.assert_called_once_with("bar", 5, "car") + pipe2.execute.assert_called_once_with() + pipe2.reset.assert_called_once_with() @staticmethod def test_clear(): @@ -468,6 +558,36 @@ def test_set(): noreply=False, ) + @staticmethod + def test_set_if_not_exists(): + client = mock.Mock(spec=("add",)) + client.add.side_effect = (True, False) + cache_items = collections.OrderedDict([(b"a", b"foo"), (b"b", b"bar")]) + cache = global_cache.MemcacheCache(client) + results = cache.set_if_not_exists(cache_items) + assert results == {b"a": True, b"b": False} + client.add.assert_has_calls( + [ + mock.call(cache._key(b"a"), b"foo", expire=0, noreply=False), + mock.call(cache._key(b"b"), b"bar", expire=0, noreply=False), + ] + ) + + @staticmethod + def test_set_if_not_exists_w_expires(): + client = mock.Mock(spec=("add",)) + client.add.side_effect = (True, False) + cache_items = collections.OrderedDict([(b"a", b"foo"), (b"b", b"bar")]) + cache = global_cache.MemcacheCache(client) + results = cache.set_if_not_exists(cache_items, expires=123) + assert results == {b"a": True, b"b": False} + client.add.assert_has_calls( + [ + mock.call(cache._key(b"a"), b"foo", expire=123, noreply=False), + mock.call(cache._key(b"b"), b"bar", expire=123, noreply=False), + ] + ) + @staticmethod def test_set_w_expires(): client = mock.Mock(spec=("set_many",)) @@ -542,11 +662,17 @@ def test_watch(): key1: ("bun", b"0"), key2: ("shoe", b"1"), } - cache.watch((b"one", b"two")) + cache.watch( + collections.OrderedDict( + ( + (b"one", "bun"), + (b"two", "shot"), + ) + ) + ) client.gets_many.assert_called_once_with([key1, key2]) assert cache.caskeys == { key1: b"0", - key2: b"1", } @staticmethod @@ -567,14 +693,15 @@ def test_compare_and_swap(): key2 = cache._key(b"two") cache.caskeys[key2] = b"5" cache.caskeys["whatevs"] = b"6" - cache.compare_and_swap( + result = cache.compare_and_swap( { b"one": "bun", b"two": "shoe", } ) - client.cas.assert_called_once_with(key2, "shoe", b"5", expire=0) + assert result == {b"two": True} + client.cas.assert_called_once_with(key2, "shoe", b"5", expire=0, noreply=False) assert cache.caskeys == {"whatevs": b"6"} @staticmethod @@ -584,7 +711,7 @@ def test_compare_and_swap_and_expires(): key2 = cache._key(b"two") cache.caskeys[key2] = b"5" cache.caskeys["whatevs"] = b"6" - cache.compare_and_swap( + result = cache.compare_and_swap( { b"one": "bun", b"two": "shoe", @@ -592,7 +719,8 @@ def test_compare_and_swap_and_expires(): expires=5, ) - client.cas.assert_called_once_with(key2, "shoe", b"5", expire=5) + assert result == {b"two": True} + client.cas.assert_called_once_with(key2, "shoe", b"5", expire=5, noreply=False) assert cache.caskeys == {"whatevs": b"6"} @staticmethod