Skip to content

Commit

Permalink
fix: fix bug with concurrent writes to global cache (#705)
Browse files Browse the repository at this point in the history
fix: fix bug with concurrent writes to global cache

Fixes #692
  • Loading branch information
Chris Rossi authored Aug 11, 2021
1 parent 60c293d commit bb7cadc
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 55 deletions.
40 changes: 30 additions & 10 deletions google/cloud/ndb/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import itertools
import logging
import uuid
import warnings

Expand All @@ -22,13 +23,15 @@
from google.cloud.ndb import _batch
from google.cloud.ndb import context as context_module
from google.cloud.ndb import tasklets
from google.cloud.ndb import utils

_LOCKED_FOR_READ = b"0-"
_LOCKED_FOR_WRITE = b"00"
_LOCK_TIME = 32
_PREFIX = b"NDB30"

warnings.filterwarnings("always", module=__name__)
log = logging.getLogger(__name__)


class ContextCache(dict):
Expand Down Expand Up @@ -583,20 +586,28 @@ def future_info(self, key, value):


@tasklets.tasklet
def global_lock_for_read(key):
def global_lock_for_read(key, prev_value):
"""Lock a key for a read (lookup) operation by setting a special value.
Lock may be preempted by a parallel write (put) operation.
Args:
key (bytes): The key to lock.
prev_value (bytes): The cache value previously read from the global cache.
Should be either :data:`None` or an empty bytes object if a key was written
recently.
Returns:
tasklets.Future: Eventual result will be lock value (``bytes``) written to
Datastore for the given key, or :data:`None` if the lock was not acquired.
"""
lock = _LOCKED_FOR_READ + str(uuid.uuid4()).encode("ascii")
lock_acquired = yield global_set_if_not_exists(key, lock, expires=_LOCK_TIME)
if prev_value is not None:
yield global_watch(key, prev_value)
lock_acquired = yield global_compare_and_swap(key, lock, expires=_LOCK_TIME)
else:
lock_acquired = yield global_set_if_not_exists(key, lock, expires=_LOCK_TIME)

if lock_acquired:
raise tasklets.Return(lock)

Expand All @@ -618,6 +629,7 @@ def global_lock_for_write(key):
"""
lock = "." + str(uuid.uuid4())
lock = lock.encode("ascii")
utils.logging_debug(log, "lock for write: {}", lock)

def new_value(old_value):
if old_value and old_value.startswith(_LOCKED_FOR_WRITE):
Expand All @@ -634,8 +646,7 @@ def new_value(old_value):
def global_unlock_for_write(key, lock):
"""Remove a lock for key by updating or removing a lock value.
The lock represented by the ``lock`` argument will be released. If no other locks
remain, the key will be deleted.
The lock represented by the ``lock`` argument will be released.
Args:
key (bytes): The key to lock.
Expand All @@ -645,9 +656,15 @@ def global_unlock_for_write(key, lock):
Returns:
tasklets.Future: Eventual result will be :data:`None`.
"""
utils.logging_debug(log, "unlock for write: {}", lock)

def new_value(old_value):
return old_value.replace(lock, b"")
assert lock in old_value, "attempt to remove lock that isn't present"
value = old_value.replace(lock, b"")
if value == _LOCKED_FOR_WRITE:
value = b""

return value

cache = _global_cache()
try:
Expand All @@ -663,19 +680,22 @@ def _update_key(key, new_value):

while not success:
old_value = yield _global_get(key)
utils.logging_debug(log, "old value: {}", old_value)

value = new_value(old_value)
if value == _LOCKED_FOR_WRITE:
# No more locks for this key, we can delete
yield _global_delete(key)
break
utils.logging_debug(log, "new value: {}", value)

if old_value:
if old_value is not None:
utils.logging_debug(log, "compare and swap")
yield _global_watch(key, old_value)
success = yield _global_compare_and_swap(key, value, expires=_LOCK_TIME)

else:
utils.logging_debug(log, "set if not exists")
success = yield global_set_if_not_exists(key, value, expires=_LOCK_TIME)

utils.logging_debug(log, "success: {}", success)


def is_locked_value(value):
"""Check if the given value is the special reserved value for key lock.
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/ndb/_datastore_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ def lookup(key, options):
result = yield _cache.global_get(cache_key)
key_locked = _cache.is_locked_value(result)
if not key_locked:
if result is not None:
if result:
entity_pb = entity_pb2.Entity()
entity_pb.MergeFromString(result)

elif use_datastore:
lock = yield _cache.global_lock_for_read(cache_key)
lock = yield _cache.global_lock_for_read(cache_key, result)
if lock:
yield _cache.global_watch(cache_key, lock)

Expand Down
35 changes: 35 additions & 0 deletions google/cloud/ndb/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,43 @@

import collections
import contextlib
import itertools
import os
import six
import threading
import uuid

from google.cloud.ndb import _eventloop
from google.cloud.ndb import exceptions
from google.cloud.ndb import key as key_module


def _generate_context_ids():
"""Generate a sequence of context ids.
Useful for debugging complicated interactions among concurrent processes and
threads.
The return value is a generator for strings that include the machine's "node",
acquired via `uuid.getnode()`, the current process id, and a sequence number which
increases monotonically starting from one in each process. The combination of all
three is sufficient to uniquely identify the context in which a particular piece of
code is being run. Each context, as it is created, is assigned the next id in this
sequence. The context id is used by `utils.logging_debug` to grant insight into
where a debug logging statement is coming from in a cloud evironment.
Returns:
Generator[str]: Sequence of context ids.
"""
prefix = "{}-{}-".format(uuid.getnode(), os.getpid())
for sequence_number in itertools.count(1): # pragma NO BRANCH
# pragma is required because this loop never exits (infinite sequence)
yield prefix + str(sequence_number)


_context_ids = _generate_context_ids()


try: # pragma: NO PY2 COVER
import contextvars

Expand Down Expand Up @@ -199,6 +228,7 @@ def policy(key):
_ContextTuple = collections.namedtuple(
"_ContextTuple",
[
"id",
"client",
"namespace",
"eventloop",
Expand Down Expand Up @@ -234,6 +264,7 @@ class _Context(_ContextTuple):
def __new__(
cls,
client,
id=None,
namespace=key_module.UNDEFINED,
eventloop=None,
batches=None,
Expand All @@ -255,6 +286,9 @@ def __new__(
# Prevent circular import in Python 2.7
from google.cloud.ndb import _cache

if id is None:
id = next(_context_ids)

if eventloop is None:
eventloop = _eventloop.EventLoop()

Expand All @@ -272,6 +306,7 @@ def __new__(

context = super(_Context, cls).__new__(
cls,
id=id,
client=client,
namespace=namespace,
eventloop=eventloop,
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/ndb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ def logging_debug(log, message, *args, **kwargs):
message = str(message)
if args or kwargs:
message = message.format(*args, **kwargs)

from google.cloud.ndb import context as context_module

context = context_module.get_context(False)
if context:
message = "{}: {}".format(context.id, message)

log.debug(message)


Expand Down
Loading

0 comments on commit bb7cadc

Please sign in to comment.