diff --git a/google/cloud/ndb/context.py b/google/cloud/ndb/context.py index fdfe0ccb..f8dff2d0 100644 --- a/google/cloud/ndb/context.py +++ b/google/cloud/ndb/context.py @@ -28,30 +28,36 @@ from google.cloud.ndb import key as key_module -def _generate_context_ids(): - """Generate a sequence of context ids. +class _ContextIds: + """Iterator which generates a sequence of context ids. Useful for debugging complicated interactions among concurrent processes and threads. - The return value is a generator for strings that include the machine's "node", - acquired via `uuid.getnode()`, the current process id, and a sequence number which - increases monotonically starting from one in each process. The combination of all - three is sufficient to uniquely identify the context in which a particular piece of - code is being run. Each context, as it is created, is assigned the next id in this - sequence. The context id is used by `utils.logging_debug` to grant insight into - where a debug logging statement is coming from in a cloud evironment. - - Returns: - Generator[str]: Sequence of context ids. + Each value in the sequence is a string that include the machine's "node", acquired + via `uuid.getnode()`, the current process id, and a sequence number which increases + monotonically starting from one in each process. The combination of all three is + sufficient to uniquely identify the context in which a particular piece of code is + being run. Each context, as it is created, is assigned the next id in this sequence. + The context id is used by `utils.logging_debug` to grant insight into where a debug + logging statement is coming from in a cloud evironment. """ - prefix = "{}-{}-".format(uuid.getnode(), os.getpid()) - for sequence_number in itertools.count(1): # pragma NO BRANCH - # pragma is required because this loop never exits (infinite sequence) - yield prefix + str(sequence_number) + + def __init__(self): + self.prefix = "{}-{}-".format(uuid.getnode(), os.getpid()) + self.counter = itertools.count(1) + self.lock = threading.Lock() + + def __next__(self): + with self.lock: + sequence_number = next(self.counter) + + return self.prefix + str(sequence_number) + + next = __next__ # Python 2.7 -_context_ids = _generate_context_ids() +_context_ids = _ContextIds() try: # pragma: NO PY2 COVER diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index fda9e60e..c5441b1a 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +import threading try: from unittest import mock @@ -75,6 +76,38 @@ def test_constructor_defaults(self): assert context.batches == {} assert context.transaction is None + node1, pid1, sequence_no1 = context.id.split("-") + node2, pid2, sequence_no2 = context_module.Context("client").id.split("-") + assert node1 == node2 + assert pid1 == pid2 + assert int(sequence_no2) - int(sequence_no1) == 1 + + def test_constructuor_concurrent_instantiation(self): + """Regression test for #716 + + This test non-deterministically tests a potential concurrency issue. Before the + bug this is a test for was fixed, it failed most of the time. + + https://github.com/googleapis/python-ndb/issues/715 + """ + errors = [] + + def make_some(): + try: + for _ in range(10000): + context_module.Context("client") + except Exception as error: # pragma: NO COVER + errors.append(error) + + thread1 = threading.Thread(target=make_some) + thread2 = threading.Thread(target=make_some) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + assert not errors + def test_constructor_overrides(self): context = context_module.Context( client="client",