diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 13c16928..083bd5d0 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -34,6 +34,14 @@ log = logging.getLogger(__name__) +def _syncpoint_692(): + """A no-op function meant to be patched for testing. + + Should be replaced by `orchestrate.syncpoint` using `mock.patch` during testing to + orchestrate concurrent testing scenarios. + """ + + class ContextCache(dict): """A per-context in-memory entity cache. @@ -684,6 +692,7 @@ def _update_key(key, new_value): value = new_value(old_value) utils.logging_debug(log, "new value: {}", value) + _syncpoint_692() if old_value is not None: utils.logging_debug(log, "compare and swap") diff --git a/tests/conftest.py b/tests/conftest.py index 8c3775cd..1aabecd7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,22 +88,31 @@ def initialize_environment(request, environ): @pytest.fixture -def context(): - client = mock.Mock( - project="testing", - namespace=None, - spec=("project", "namespace"), - stub=mock.Mock(spec=()), - ) - context = context_module.Context( - client, - eventloop=TestingEventLoop(), - datastore_policy=True, - legacy_data=False, - ) +def context_factory(): + def context(**kwargs): + client = mock.Mock( + project="testing", + namespace=None, + spec=("project", "namespace"), + stub=mock.Mock(spec=()), + ) + context = context_module.Context( + client, + eventloop=TestingEventLoop(), + datastore_policy=True, + legacy_data=False, + **kwargs, + ) + return context + return context +@pytest.fixture +def context(context_factory): + return context_factory() + + @pytest.fixture def in_context(context): assert not context_module._state.context diff --git a/tests/unit/orchestrate.py b/tests/unit/orchestrate.py new file mode 100644 index 00000000..ed008775 --- /dev/null +++ b/tests/unit/orchestrate.py @@ -0,0 +1,356 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import math +import queue +import threading + + +def orchestrate(*tests): + """ + Orchestrate a deterministic concurrency test. + + Runs test functions in separate threads, with each thread taking turns running up + until predefined syncpoints in a deterministic order. All possible orderings are + tested. + + Most of the time, we try to use logic, best practices, and static analysis to insure + correct operation of concurrent code. Sometimes our powers of reasoning fail us and, + either through non-determistic stress testing or running code in production, a + concurrent bug is discovered. When this occurs, we'd like to have a regression test + to insure we've understood the problem and implemented a correct solution. + `orchestrate` provides a means of deterministically testing concurrent code so we + can write robust regression tests for complex concurrent scenarios. + + `orchestrate` runs each passed in test function in its own thread. Threads then + "take turns" running. Turns are defined by setting syncpoints in the code under + test. To do this, you'll write a no-op function and call it at the point where you'd + like your code to pause and give another thread a turn. In your test, then, use + `mock.patch` to replace your no-op function with :func:`syncpoint` in your test. + + For example, let's say you have the following code in production:: + + def hither_and_yon(destination): + hither(destination) + yon(destination) + + You've found there's a concurrency bug when two threads execute this code with the + same destination, and you think that by adding a syncpoint between the calls to + `hither` and `yon` you can reproduce the problem in a regression test. First you'd, + write a no-op function, include it in your production code, and call it in + `hither_and_yon`:: + + def _syncpoint_123(): + pass + + def hither_and_yon(destination): + hither(destination) + _syncpoint_123() + yon(destination) + + Now you can write a test to exercise `hither_and_yon` running in parallel:: + + from unittest import mock + from tests.unit import orchestrate + + from google.cloud.sales import travel + + @mock.patch("google.cloud.sales.travel._syncpoint_123", orchestrate.syncpoint) + def test_concurrent_hither_and_yon(): + + def test_hither_and_yon(): + assert something + travel.hither_and_yon("Raleigh") + assert something_else + + counts = orchestrate.orchestrate(test_hither_and_yon, test_hither_and_yon) + assert counts == (2, 2) + + What `orchestrate` will do now is take each of the two test functions passed in + (actually the same function, twice, in this case), run them serially, and count the + number of turns it takes to run each test to completion. In this example, it will + take two turns for each test: one turn to start the thread and execute up until the + syncpoint, and then another turn to execute from the syncpoint to the end of the + test. The number of turns will always be one greater than the number of syncpoints + encountered when executing the test. + + Once the counts have been taken, `orchestrate` will construct a test sequence that + represents the all the turns taken by the passed in tests, with each value in the + sequence representing the the index of the test whose turn it is in the sequence. In + this example, then, it would produce:: + + [0, 0, 1, 1] + + This represents the first test taking both of its turns, followed by the second test + taking both of its turns. At this point this scenario has already been tested, + because this is what was run to produce the counts and the initial test sequence. + Now `orchestrate` will run all of the remaining scenarios by finding all the + permutations of the test sequence and executing those, in turn:: + + [0, 1, 0, 1] + [0, 1, 1, 0] + [1, 0, 0, 1] + [1, 0, 1, 0] + [1, 1, 0, 0] + + You'll notice in our example that since both test functions are actually the same + function, that although it tested 6 scenarios there are effectively only really 3 + unique scenarios. For the time being, though, `orchestrate` doesn't attempt to + detect this condition or optimize for it. + + There are some performance considerations that should be taken into account when + writing tests. The number of unique test sequences grows quite quickly with the + number of turns taken by the functions under test. Our simple example with two + threads each taking two turns, only yielded 6 scenarios, but two threads each taking + 6 turns, for example, yields 924 scenarios. Add another six step thread and now you + have over 17 thousand scenarios. In general, use the least number of steps/threads + you can get away with and still expose the behavior you want to correct. + + For the same reason as above, if you have many concurrent tests, when writing a new + test, make sure you're not accidentally patching syncpoints intended for other + tests, as this will add steps to your tests. While it's not problematic from a + testing standpoint to have extra steps in your tests, it can use computing resources + unnecessarily. Using different no-op functions with different names for different + tests can help with this. + + As soon as any error or failure is detected, no more scenarios are run + and that error is propagated to the main thread. + + Args: + tests (Tuple[Callable]): Test functions to be run. These functions will not be + called with any arguments, so they must not have any required arguments. + + Returns: + Tuple[int]: A tuple of the count of the number turns for test passed in. Can be + used a sanity check in tests to make sure you understand what's actually + happening during a test. + """ + # Produce an initial test sequence. The fundamental question we're always trying to + # answer is "whose turn is it?" First we'll find out how many "turns" each test + # needs to complete when run serially and use that to construct a sequence of + # indexes. When a test's index appears in the sequence, it is that test's turn to + # run. We'll start by constructing a sequence that would run each test through to + # completion serially, one after the other. + test_sequence = [] + counts = [] + for index, test in enumerate(tests): + thread = _TestThread(test) + for count in itertools.count(1): # pragma: NO BRANCH + # Pragma is required because loop never finishes naturally. + thread.go() + if thread.finished: + break + + counts.append(count) + test_sequence += [index] * count + + # Now we can take that initial sequence and generate all of its permutations, + # running each one to try to uncover concurrency bugs + sequences = iter(_permutations(test_sequence)) + + # We already tested the first sequence getting our counts, so we can discard it + next(sequences) + + # Test each sequence + for test_sequence in sequences: + threads = [_TestThread(test) for test in tests] + try: + for index in test_sequence: + threads[index].go() + + # Its possible for number of turns to vary from one test run to the other, + # especially if there is some undiscovered concurrency bug. Go ahead and + # finish running each test to completion, if not already complete. + for thread in threads: + while not thread.finished: + thread.go() + + except Exception: + # If an exception occurs, we still need to let any threads that are still + # going finish up. Additional exceptions are silently ignored. + for thread in threads: + thread.finish() + raise + + return tuple(counts) + + +def syncpoint(): + """End a thread's "turn" at this point. + + This will generally be inserted by `mock.patch` to replace a no-op function in + production code. See documentation for :func:`orchestrate`. + """ + conductor = _local.conductor + conductor.notify() + conductor.standby() + + +_local = threading.local() + + +class _Conductor: + """Coordinate communication between main thread and a test thread. + + Two way communicaton is maintained between the main thread and a test thread using + two synchronized queues (`queue.Queue`) each with a size of one. + """ + + def __init__(self): + self._notify = queue.Queue(1) + self._go = queue.Queue(1) + + def notify(self): + """Called from test thread to let us know it's finished or is ready for its next + turn.""" + self._notify.put(None) + + def standby(self): + """Called from test thread in order to block until told to go.""" + self._go.get() + + def wait(self): + """Called from main thread to wait for test thread to either get to the + next syncpoint or finish.""" + self._notify.get() + + def go(self): + """Called from main thread to tell test thread to go.""" + self._go.put(None) + + +class _TestThread: + """A thread for a test function.""" + + thread = None + finished = False + error = None + + def __init__(self, test): + self.test = test + self.conductor = _Conductor() + + def _run(self): + _local.conductor = self.conductor + try: + self.test() + except Exception as error: + self.error = error + finally: + self.finished = True + self.conductor.notify() + + def go(self): + if self.finished: + return + + if self.thread is None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + else: + self.conductor.go() + + self.conductor.wait() + + if self.error: + raise self.error + + def finish(self): + while not self.finished: + try: + self.go() + except Exception: + pass + + +class _permutations: + """Generates a sequence of all permutations of `sequence`. + + Permutations are returned in lexicographic order using the "Generation in + lexicographic order" algorithm described in `the Wikipedia article on "Permutation" + `_. + + This implementation differs significantly from `itertools.permutations` in that the + value of individual elements is taken into account, thus eliminating redundant + orderings that would be produced by `itertools.permutations`. + + Args: + sequence (Sequence[Any]): Sequence must be finite and orderable. + + Returns: + Sequence[Sequence[Any]]: Set of all permutations of `sequence`. + """ + + def __init__(self, sequence): + self._start = tuple(sorted(sequence)) + + def __len__(self): + """Compute the number of permutations. + + Let the number of elements in a sequence N and the number of repetitions for + individual members of the sequence be n1, n2, ... nx. The number of unique + permutations is: N! / n1! / n2! / ... / nx!. + + For example, let `sequence` be [1, 2, 3, 1, 2, 3, 1, 2, 3]. The number of unique + permutations is: 9! / 3! / 3! / 3! = 1680. + + See: "Permutations of multisets" in `the Wikipedia article on "Permutation" + `_. + """ + repeats = [len(list(group)) for value, group in itertools.groupby(self._start)] + length = math.factorial(len(self._start)) + for repeat in repeats: + length /= math.factorial(repeat) + + return int(length) + + def __iter__(self): + """Iterate over permutations. + + See: "Generation in lexicographic order" algorithm described in `the Wikipedia + article on "Permutation" `_. + """ + current = list(self._start) + size = len(current) + + while True: + yield tuple(current) + + # 1. Find the largest index i such that a[i] < a[i + 1]. + for i in range(size - 2, -1, -1): + if current[i] < current[i + 1]: + break + + else: + # If no such index exists, the permutation is the last permutation. + return + + # 2. Find the largest index j greater than i such that a[i] < a[j]. + for j in range(size - 1, i, -1): + if current[i] < current[j]: + break + + else: # pragma: NO COVER + raise RuntimeError("Broken algorithm") + + # 3. Swap the value of a[i] with that of a[j]. + temp = current[i] + current[i] = current[j] + current[j] = temp + + # 4. Reverse the sequence from a[i + 1] up to and including the final + # element a[n]. + current = current[: i + 1] + list(reversed(current[i + 1 :])) diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index 742cbc09..098e6062 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging try: @@ -21,131 +20,30 @@ import mock from google.cloud.ndb import _cache -from google.cloud.ndb import _eventloop from google.cloud.ndb import global_cache as global_cache_module from google.cloud.ndb import tasklets -from google.cloud.ndb import utils +from . import orchestrate log = logging.getLogger(__name__) -class Delay(object): - """A tasklet wrapper which delays the return of a tasklet. - - Used to orchestrate timing of events in async code to test particular scenarios - involving concurrency. Use with `mock.patch` to replace particular tasklets with - wrapped versions. When those tasklets are called, they will execute and then the - wrapper will hang on to the result until :meth:`Delay.advance()` is called, at which - time the tasklet's caller will receive the result. - - Args: - wrapped (tasklets.Tasklet): The tasklet to be delayed. - """ - - def __init__(self, wrapped): - self.wrapped = wrapped - self.info = "Delay {}".format(self.wrapped.__name__) - self._futures = collections.deque() - - @tasklets.tasklet - def __call__(self, *args, **kwargs): - future = tasklets.Future(self.info) - self._futures.append(future) - - result = yield self.wrapped(*args, **kwargs) - yield future - raise tasklets.Return(result) - - def advance(self): - """Allow a call to the wrapper to proceed. - - Calls are advanced in the order in which they were orignally made. - """ - self._futures.popleft().set_result(None) - - -def run_until(): - """Do all queued work on the event loop. - - This will allow any currently running tasklets to execute up to the point that they - hit a call to a tasklet that is delayed by :class:`Delay`. When this call is - finished, either all in progress tasklets will have been completed, or a call to - :class:`Delay.advance` will be required to move execution forward again. - """ - while _eventloop.run1(): - pass - - -def test_global_cache_concurrent_writes_692(in_context): +@mock.patch("google.cloud.ndb._cache._syncpoint_692", orchestrate.syncpoint) +def test_global_cache_concurrent_write_692(context_factory): """Regression test for #692 https://github.com/googleapis/python-ndb/issues/692 """ key = b"somekey" - @tasklets.tasklet - def run_test(): - lock1 = yield _cache.global_lock_for_write(key) - lock2, _ = yield ( - _cache.global_lock_for_write(key), - _cache.global_unlock_for_write(key, lock1), - ) - yield _cache.global_unlock_for_write(key, lock2) + @tasklets.synctasklet + def lock_unlock_key(): + lock = yield _cache.global_lock_for_write(key) + yield _cache.global_unlock_for_write(key, lock) - delay_global_get = Delay(_cache.global_get) - with mock.patch("google.cloud.ndb._cache._global_get", delay_global_get): + def run_test(): global_cache = global_cache_module._InProcessGlobalCache() - with in_context.new(global_cache=global_cache).use(): - future = run_test() - - # Run until the global_cache_get call in the first global_lock_for_write - # call - run_until() - utils.logging_debug(log, "zero") - - # Let the first global_cache_get call return and advance to the - # global_cache_get calls in the first call to global_unlock_for_write and - # second call to global_lock_for_write. They will have both gotten the same - # "old" value from the cache - delay_global_get.advance() - run_until() - utils.logging_debug(log, "one") - - # Let the global_cache_get call return in the second global_lock_for_write - # call. It should write a new lock value containing both locks. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "two") - - # Let the global_cache_get call return in the first global_unlock_for_write - # call. Since its "old" cache value contained only the first lock, it might - # think it's done and delete the key, since as far as it's concerned, there - # are no more locks. This is the bug exposed by this test. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "three") - - # Since we've fixed the bug now, what we expect it to do instead is attempt - # to write a new cache value that is a write lock value but contains no - # locks. This attempt will fail since the cache value was changed out from - # under it by the second global_lock_write call occurring in parallel. When - # this attempt fails it will call global_get again to get the new value - # containing both locks and recompute a value that only includes the second - # lock and write it. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "four") - - # Now the last call to global_unlock_for_write will call global_get to get - # the current lock value with only one write lock, and then write an empty - # write lock. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "five") - - # Make sure we can get to the end without raising an exception - future.result() + with context_factory(global_cache=global_cache).use(): + lock_unlock_key() - # Make sure the empty write lock registers as "not locked". - assert not _cache.is_locked_value(_cache.global_get(key).result()) + orchestrate.orchestrate(run_test, run_test) diff --git a/tests/unit/test_orchestrate.py b/tests/unit/test_orchestrate.py new file mode 100644 index 00000000..5a8ad621 --- /dev/null +++ b/tests/unit/test_orchestrate.py @@ -0,0 +1,219 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import pytest + +from . import orchestrate + + +def test__permutations(): + sequence = [1, 2, 3, 1, 2, 3, 1, 2, 3] + permutations = orchestrate._permutations(sequence) + assert len(permutations) == 1680 + + result = list(permutations) + assert len(permutations) == len(result) # computed length matches reality + assert len(result) == len(set(result)) # no duplicates + assert result[0] == (1, 1, 1, 2, 2, 2, 3, 3, 3) + assert result[-1] == (3, 3, 3, 2, 2, 2, 1, 1, 1) + + assert list(orchestrate._permutations([1, 2, 3])) == [ + (1, 2, 3), + (1, 3, 2), + (2, 1, 3), + (2, 3, 1), + (3, 1, 2), + (3, 2, 1), + ] + + +class Test_orchestrate: + @staticmethod + def test_no_failures(): + test_calls = [] + + def make_test(name): + def test(): + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + permutations = orchestrate._permutations(["A", "B", "A", "B", "A", "B"]) + expected = list(itertools.chain(*permutations)) + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (3, 3) + assert test_calls == expected + + @staticmethod + def test_syncpoints_decrease_after_initial_run(): + test_calls = [] + + def make_test(name): + syncpoints = [name] * 4 + + def test(): + test_calls.append(name) + if syncpoints: + orchestrate.syncpoint() + test_calls.append(syncpoints.pop()) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + ] + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (2, 2) + assert test_calls == expected + + @staticmethod + def test_syncpoints_increase_after_initial_run(): + test_calls = [] + + def make_test(name): + syncpoints = [None] * 4 + + def test(): + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + + if syncpoints: + syncpoints.pop() + else: + orchestrate.syncpoint() + test_calls.append(name) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "B", + "A", + "A", + "A", + "B", + ] + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (2, 2) + assert test_calls == expected + + @staticmethod + def test_failure(): + test_calls = [] + + def make_test(name): + syncpoints = [None] * 4 + + def test(): + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + + if syncpoints: + syncpoints.pop() + else: + assert True is False + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + ] + + with pytest.raises(AssertionError): + orchestrate.orchestrate(test1, test2) + + assert test_calls == expected