Skip to content

Commit

Permalink
fix: prevent overwrite of cache lock value
Browse files Browse the repository at this point in the history
Fixes #651
  • Loading branch information
Chris Rossi committed Jun 7, 2021
1 parent 5d7f163 commit 0f3d73a
Show file tree
Hide file tree
Showing 6 changed files with 459 additions and 44 deletions.
107 changes: 83 additions & 24 deletions google/cloud/ndb/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,32 @@ def __init__(self, options):
self.todo = {}
self.futures = {}

def done_callback(self, cache_call):
"""Process results of call to global cache.
If there is an exception for the cache call, distribute that to waiting
futures, otherwise examine the result of the cache call. If the result is
:data:`None`, simply set the result to :data:`None` for all waiting futures.
Otherwise, if the result is a `dict`, use that to propagate results for
individual keys to waiting figures.
"""
exception = cache_call.exception()
if exception:
for future in self.futures.values():
future.set_exception(exception)
return

result = cache_call.result()
if result:
for key, future in self.futures.items():
key_result = result.get(key, None)
if isinstance(key_result, Exception):
future.set_exception(key_result)
else:
future.set_result(key_result)
else:
for future in self.futures.values():
future.set_result(None)

def add(self, key, value):
"""Add a key, value pair to store in the cache.
Expand Down Expand Up @@ -335,36 +361,66 @@ def add(self, key, value):
self.futures[key] = future
return future

def done_callback(self, cache_call):
"""Process results of call to global cache.
def make_call(self):
"""Call :method:`GlobalCache.set`."""
return _global_cache().set(self.todo, expires=self.expires)

If there is an exception for the cache call, distribute that to waiting
futures, otherwise examine the result of the cache call. If the result is
:data:`None`, simply set the result to :data:`None` for all waiting futures.
Otherwise, if the result is a `dict`, use that to propagate results for
individual keys to waiting figures.
def future_info(self, key, value):
"""Generate info string for Future."""
return "GlobalCache.set_if_not_exists({}, {})".format(key, value)


@_handle_transient_errors()
def global_set_if_not_exists(key, value, expires=None, read=False):
"""Store entity in the global cache if key is not already present.
Args:
key (bytes): The key to save.
value (bytes): The entity to save.
expires (Optional[float]): Number of seconds until value expires.
read (bool): Indicates if being set in a read (lookup) context.
Returns:
tasklets.Future: Eventual result will be a ``bool`` value which will be
:data:`True` if a new value was set for the key, or :data:`False` if a value
was already set for the key.
"""
options = {}
if expires:
options = {"expires": expires}

batch = _batch.get_batch(_GlobalCacheSetIfNotExistsBatch, options)
return batch.add(key, value)


class _GlobalCacheSetIfNotExistsBatch(_GlobalCacheSetBatch):
"""Batch for global cache set_if_not_exists requests. """

def add(self, key, value):
"""Add a key, value pair to store in the cache.
Arguments:
key (bytes): The key to store in the cache.
value (bytes): The value to store in the cache.
Returns:
tasklets.Future: Eventual result will be a ``bool`` value which will be
:data:`True` if a new value was set for the key, or :data:`False` if a
value was already set for the key.
"""
exception = cache_call.exception()
if exception:
for future in self.futures.values():
future.set_exception(exception)
return
if key in self.todo:
future = tasklets.Future()
future.set_result(False)
return future

result = cache_call.result()
if result:
for key, future in self.futures.items():
key_result = result.get(key, None)
if isinstance(key_result, Exception):
future.set_exception(key_result)
else:
future.set_result(key_result)
else:
for future in self.futures.values():
future.set_result(None)
future = tasklets.Future(info=self.future_info(key, value))
self.todo[key] = value
self.futures[key] = future
return future

def make_call(self):
"""Call :method:`GlobalCache.set`."""
return _global_cache().set(self.todo, expires=self.expires)
return _global_cache().set_if_not_exists(self.todo, expires=self.expires)

def future_info(self, key, value):
"""Generate info string for Future."""
Expand Down Expand Up @@ -523,6 +579,9 @@ def global_lock(key, read=False):
Returns:
tasklets.Future: Eventual result will be ``None``.
"""
if read:
return global_set_if_not_exists(key, _LOCKED, expires=_LOCK_TIME, read=read)

return global_set(key, _LOCKED, expires=_LOCK_TIME, read=read)


Expand Down
11 changes: 9 additions & 2 deletions google/cloud/ndb/_datastore_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,15 @@ def lookup(key, options):
entity_pb.MergeFromString(result)

elif use_datastore:
yield _cache.global_lock(cache_key, read=True)
yield _cache.global_watch(cache_key)
lock_acquired = yield _cache.global_lock(cache_key, read=True)
if lock_acquired:
yield _cache.global_watch(cache_key)

else:
# Another thread locked or wrote to this key after the call to
# _cache.global_get above. Behave as though the key was locked by
# another thread and don't attempt to write our value below
key_locked = True

if entity_pb is _NOT_FOUND and use_datastore:
batch = _batch.get_batch(_LookupBatch, options)
Expand Down
50 changes: 50 additions & 0 deletions google/cloud/ndb/global_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ def set(self, items, expires=None):
"""
raise NotImplementedError

@abc.abstractmethod
def set_if_not_exists(self, items, expires=None):
"""Stores entities in the cache if and only if keys are not already set.
Arguments:
items (Dict[bytes, Union[bytes, None]]): Mapping of keys to
serialized entities.
expires (Optional[float]): Number of seconds until value expires.
Returns:
Dict[bytes, bool]: A `dict` mapping to boolean value wich will be
:data:`True` if that key was set with a new value, and :data:`False`
otherwise.
"""
raise NotImplementedError

@abc.abstractmethod
def delete(self, keys):
"""Remove entities from the cache.
Expand Down Expand Up @@ -217,6 +234,18 @@ def set(self, items, expires=None):
for key, value in items.items():
self.cache[key] = (value, expires) # Supposedly threadsafe

def set_if_not_exists(self, items, expires=None):
"""Implements :meth:`GlobalCache.set_if_not_exists`."""
if expires:
expires = time.time() + expires

results = {}
for key, value in items.items():
set_value = (value, expires)
results[key] = self.cache.setdefault(key, set_value) is set_value

return results

def delete(self, keys):
"""Implements :meth:`GlobalCache.delete`."""
for key in keys:
Expand Down Expand Up @@ -355,6 +384,16 @@ def set(self, items, expires=None):
for key in items.keys():
self.redis.expire(key, expires)

def set_if_not_exists(self, items, expires=None):
"""Implements :meth:`GlobalCache.set_if_not_exists`."""
results = {}
for key, value in items.items():
results[key] = key_was_set = self.redis.setnx(key, value)
if key_was_set and expires:
self.redis.expire(key, expires)

return results

def delete(self, keys):
"""Implements :meth:`GlobalCache.delete`."""
self.redis.delete(*keys)
Expand Down Expand Up @@ -599,6 +638,17 @@ def set(self, items, expires=None):
)
return {key: MemcacheCache.KeyNotSet(key) for key in unset_keys}

def set_if_not_exists(self, items, expires=None):
"""Implements :meth:`GlobalCache.set_if_not_exists`."""
expires = expires if expires else 0
results = {}
for key, value in items.items():
results[key] = self.client.add(
self._key(key), value, expire=expires, noreply=False
)

return results

def delete(self, keys):
"""Implements :meth:`GlobalCache.delete`."""
keys = [self._key(key) for key in keys]
Expand Down
Loading

0 comments on commit 0f3d73a

Please sign in to comment.