From 9df91965c306ccd7fa7871e569afe9a0d82595b3 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Mon, 16 Aug 2021 14:55:33 -0400 Subject: [PATCH 1/7] test: refactor concurrency test using orchestrate Towards #691 --- google/cloud/ndb/_cache.py | 9 + tests/conftest.py | 35 ++-- tests/unit/orchestrate.py | 356 +++++++++++++++++++++++++++++++++ tests/unit/test_concurrency.py | 124 +----------- tests/unit/test_orchestrate.py | 219 ++++++++++++++++++++ 5 files changed, 617 insertions(+), 126 deletions(-) create mode 100644 tests/unit/orchestrate.py create mode 100644 tests/unit/test_orchestrate.py diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 13c16928..083bd5d0 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -34,6 +34,14 @@ log = logging.getLogger(__name__) +def _syncpoint_692(): + """A no-op function meant to be patched for testing. + + Should be replaced by `orchestrate.syncpoint` using `mock.patch` during testing to + orchestrate concurrent testing scenarios. + """ + + class ContextCache(dict): """A per-context in-memory entity cache. @@ -684,6 +692,7 @@ def _update_key(key, new_value): value = new_value(old_value) utils.logging_debug(log, "new value: {}", value) + _syncpoint_692() if old_value is not None: utils.logging_debug(log, "compare and swap") diff --git a/tests/conftest.py b/tests/conftest.py index 8c3775cd..1aabecd7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,22 +88,31 @@ def initialize_environment(request, environ): @pytest.fixture -def context(): - client = mock.Mock( - project="testing", - namespace=None, - spec=("project", "namespace"), - stub=mock.Mock(spec=()), - ) - context = context_module.Context( - client, - eventloop=TestingEventLoop(), - datastore_policy=True, - legacy_data=False, - ) +def context_factory(): + def context(**kwargs): + client = mock.Mock( + project="testing", + namespace=None, + spec=("project", "namespace"), + stub=mock.Mock(spec=()), + ) + context = context_module.Context( + client, + eventloop=TestingEventLoop(), + datastore_policy=True, + legacy_data=False, + **kwargs, + ) + return context + return context +@pytest.fixture +def context(context_factory): + return context_factory() + + @pytest.fixture def in_context(context): assert not context_module._state.context diff --git a/tests/unit/orchestrate.py b/tests/unit/orchestrate.py new file mode 100644 index 00000000..ed008775 --- /dev/null +++ b/tests/unit/orchestrate.py @@ -0,0 +1,356 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import math +import queue +import threading + + +def orchestrate(*tests): + """ + Orchestrate a deterministic concurrency test. + + Runs test functions in separate threads, with each thread taking turns running up + until predefined syncpoints in a deterministic order. All possible orderings are + tested. + + Most of the time, we try to use logic, best practices, and static analysis to insure + correct operation of concurrent code. Sometimes our powers of reasoning fail us and, + either through non-determistic stress testing or running code in production, a + concurrent bug is discovered. When this occurs, we'd like to have a regression test + to insure we've understood the problem and implemented a correct solution. + `orchestrate` provides a means of deterministically testing concurrent code so we + can write robust regression tests for complex concurrent scenarios. + + `orchestrate` runs each passed in test function in its own thread. Threads then + "take turns" running. Turns are defined by setting syncpoints in the code under + test. To do this, you'll write a no-op function and call it at the point where you'd + like your code to pause and give another thread a turn. In your test, then, use + `mock.patch` to replace your no-op function with :func:`syncpoint` in your test. + + For example, let's say you have the following code in production:: + + def hither_and_yon(destination): + hither(destination) + yon(destination) + + You've found there's a concurrency bug when two threads execute this code with the + same destination, and you think that by adding a syncpoint between the calls to + `hither` and `yon` you can reproduce the problem in a regression test. First you'd, + write a no-op function, include it in your production code, and call it in + `hither_and_yon`:: + + def _syncpoint_123(): + pass + + def hither_and_yon(destination): + hither(destination) + _syncpoint_123() + yon(destination) + + Now you can write a test to exercise `hither_and_yon` running in parallel:: + + from unittest import mock + from tests.unit import orchestrate + + from google.cloud.sales import travel + + @mock.patch("google.cloud.sales.travel._syncpoint_123", orchestrate.syncpoint) + def test_concurrent_hither_and_yon(): + + def test_hither_and_yon(): + assert something + travel.hither_and_yon("Raleigh") + assert something_else + + counts = orchestrate.orchestrate(test_hither_and_yon, test_hither_and_yon) + assert counts == (2, 2) + + What `orchestrate` will do now is take each of the two test functions passed in + (actually the same function, twice, in this case), run them serially, and count the + number of turns it takes to run each test to completion. In this example, it will + take two turns for each test: one turn to start the thread and execute up until the + syncpoint, and then another turn to execute from the syncpoint to the end of the + test. The number of turns will always be one greater than the number of syncpoints + encountered when executing the test. + + Once the counts have been taken, `orchestrate` will construct a test sequence that + represents the all the turns taken by the passed in tests, with each value in the + sequence representing the the index of the test whose turn it is in the sequence. In + this example, then, it would produce:: + + [0, 0, 1, 1] + + This represents the first test taking both of its turns, followed by the second test + taking both of its turns. At this point this scenario has already been tested, + because this is what was run to produce the counts and the initial test sequence. + Now `orchestrate` will run all of the remaining scenarios by finding all the + permutations of the test sequence and executing those, in turn:: + + [0, 1, 0, 1] + [0, 1, 1, 0] + [1, 0, 0, 1] + [1, 0, 1, 0] + [1, 1, 0, 0] + + You'll notice in our example that since both test functions are actually the same + function, that although it tested 6 scenarios there are effectively only really 3 + unique scenarios. For the time being, though, `orchestrate` doesn't attempt to + detect this condition or optimize for it. + + There are some performance considerations that should be taken into account when + writing tests. The number of unique test sequences grows quite quickly with the + number of turns taken by the functions under test. Our simple example with two + threads each taking two turns, only yielded 6 scenarios, but two threads each taking + 6 turns, for example, yields 924 scenarios. Add another six step thread and now you + have over 17 thousand scenarios. In general, use the least number of steps/threads + you can get away with and still expose the behavior you want to correct. + + For the same reason as above, if you have many concurrent tests, when writing a new + test, make sure you're not accidentally patching syncpoints intended for other + tests, as this will add steps to your tests. While it's not problematic from a + testing standpoint to have extra steps in your tests, it can use computing resources + unnecessarily. Using different no-op functions with different names for different + tests can help with this. + + As soon as any error or failure is detected, no more scenarios are run + and that error is propagated to the main thread. + + Args: + tests (Tuple[Callable]): Test functions to be run. These functions will not be + called with any arguments, so they must not have any required arguments. + + Returns: + Tuple[int]: A tuple of the count of the number turns for test passed in. Can be + used a sanity check in tests to make sure you understand what's actually + happening during a test. + """ + # Produce an initial test sequence. The fundamental question we're always trying to + # answer is "whose turn is it?" First we'll find out how many "turns" each test + # needs to complete when run serially and use that to construct a sequence of + # indexes. When a test's index appears in the sequence, it is that test's turn to + # run. We'll start by constructing a sequence that would run each test through to + # completion serially, one after the other. + test_sequence = [] + counts = [] + for index, test in enumerate(tests): + thread = _TestThread(test) + for count in itertools.count(1): # pragma: NO BRANCH + # Pragma is required because loop never finishes naturally. + thread.go() + if thread.finished: + break + + counts.append(count) + test_sequence += [index] * count + + # Now we can take that initial sequence and generate all of its permutations, + # running each one to try to uncover concurrency bugs + sequences = iter(_permutations(test_sequence)) + + # We already tested the first sequence getting our counts, so we can discard it + next(sequences) + + # Test each sequence + for test_sequence in sequences: + threads = [_TestThread(test) for test in tests] + try: + for index in test_sequence: + threads[index].go() + + # Its possible for number of turns to vary from one test run to the other, + # especially if there is some undiscovered concurrency bug. Go ahead and + # finish running each test to completion, if not already complete. + for thread in threads: + while not thread.finished: + thread.go() + + except Exception: + # If an exception occurs, we still need to let any threads that are still + # going finish up. Additional exceptions are silently ignored. + for thread in threads: + thread.finish() + raise + + return tuple(counts) + + +def syncpoint(): + """End a thread's "turn" at this point. + + This will generally be inserted by `mock.patch` to replace a no-op function in + production code. See documentation for :func:`orchestrate`. + """ + conductor = _local.conductor + conductor.notify() + conductor.standby() + + +_local = threading.local() + + +class _Conductor: + """Coordinate communication between main thread and a test thread. + + Two way communicaton is maintained between the main thread and a test thread using + two synchronized queues (`queue.Queue`) each with a size of one. + """ + + def __init__(self): + self._notify = queue.Queue(1) + self._go = queue.Queue(1) + + def notify(self): + """Called from test thread to let us know it's finished or is ready for its next + turn.""" + self._notify.put(None) + + def standby(self): + """Called from test thread in order to block until told to go.""" + self._go.get() + + def wait(self): + """Called from main thread to wait for test thread to either get to the + next syncpoint or finish.""" + self._notify.get() + + def go(self): + """Called from main thread to tell test thread to go.""" + self._go.put(None) + + +class _TestThread: + """A thread for a test function.""" + + thread = None + finished = False + error = None + + def __init__(self, test): + self.test = test + self.conductor = _Conductor() + + def _run(self): + _local.conductor = self.conductor + try: + self.test() + except Exception as error: + self.error = error + finally: + self.finished = True + self.conductor.notify() + + def go(self): + if self.finished: + return + + if self.thread is None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + else: + self.conductor.go() + + self.conductor.wait() + + if self.error: + raise self.error + + def finish(self): + while not self.finished: + try: + self.go() + except Exception: + pass + + +class _permutations: + """Generates a sequence of all permutations of `sequence`. + + Permutations are returned in lexicographic order using the "Generation in + lexicographic order" algorithm described in `the Wikipedia article on "Permutation" + `_. + + This implementation differs significantly from `itertools.permutations` in that the + value of individual elements is taken into account, thus eliminating redundant + orderings that would be produced by `itertools.permutations`. + + Args: + sequence (Sequence[Any]): Sequence must be finite and orderable. + + Returns: + Sequence[Sequence[Any]]: Set of all permutations of `sequence`. + """ + + def __init__(self, sequence): + self._start = tuple(sorted(sequence)) + + def __len__(self): + """Compute the number of permutations. + + Let the number of elements in a sequence N and the number of repetitions for + individual members of the sequence be n1, n2, ... nx. The number of unique + permutations is: N! / n1! / n2! / ... / nx!. + + For example, let `sequence` be [1, 2, 3, 1, 2, 3, 1, 2, 3]. The number of unique + permutations is: 9! / 3! / 3! / 3! = 1680. + + See: "Permutations of multisets" in `the Wikipedia article on "Permutation" + `_. + """ + repeats = [len(list(group)) for value, group in itertools.groupby(self._start)] + length = math.factorial(len(self._start)) + for repeat in repeats: + length /= math.factorial(repeat) + + return int(length) + + def __iter__(self): + """Iterate over permutations. + + See: "Generation in lexicographic order" algorithm described in `the Wikipedia + article on "Permutation" `_. + """ + current = list(self._start) + size = len(current) + + while True: + yield tuple(current) + + # 1. Find the largest index i such that a[i] < a[i + 1]. + for i in range(size - 2, -1, -1): + if current[i] < current[i + 1]: + break + + else: + # If no such index exists, the permutation is the last permutation. + return + + # 2. Find the largest index j greater than i such that a[i] < a[j]. + for j in range(size - 1, i, -1): + if current[i] < current[j]: + break + + else: # pragma: NO COVER + raise RuntimeError("Broken algorithm") + + # 3. Swap the value of a[i] with that of a[j]. + temp = current[i] + current[i] = current[j] + current[j] = temp + + # 4. Reverse the sequence from a[i + 1] up to and including the final + # element a[n]. + current = current[: i + 1] + list(reversed(current[i + 1 :])) diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index 742cbc09..098e6062 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging try: @@ -21,131 +20,30 @@ import mock from google.cloud.ndb import _cache -from google.cloud.ndb import _eventloop from google.cloud.ndb import global_cache as global_cache_module from google.cloud.ndb import tasklets -from google.cloud.ndb import utils +from . import orchestrate log = logging.getLogger(__name__) -class Delay(object): - """A tasklet wrapper which delays the return of a tasklet. - - Used to orchestrate timing of events in async code to test particular scenarios - involving concurrency. Use with `mock.patch` to replace particular tasklets with - wrapped versions. When those tasklets are called, they will execute and then the - wrapper will hang on to the result until :meth:`Delay.advance()` is called, at which - time the tasklet's caller will receive the result. - - Args: - wrapped (tasklets.Tasklet): The tasklet to be delayed. - """ - - def __init__(self, wrapped): - self.wrapped = wrapped - self.info = "Delay {}".format(self.wrapped.__name__) - self._futures = collections.deque() - - @tasklets.tasklet - def __call__(self, *args, **kwargs): - future = tasklets.Future(self.info) - self._futures.append(future) - - result = yield self.wrapped(*args, **kwargs) - yield future - raise tasklets.Return(result) - - def advance(self): - """Allow a call to the wrapper to proceed. - - Calls are advanced in the order in which they were orignally made. - """ - self._futures.popleft().set_result(None) - - -def run_until(): - """Do all queued work on the event loop. - - This will allow any currently running tasklets to execute up to the point that they - hit a call to a tasklet that is delayed by :class:`Delay`. When this call is - finished, either all in progress tasklets will have been completed, or a call to - :class:`Delay.advance` will be required to move execution forward again. - """ - while _eventloop.run1(): - pass - - -def test_global_cache_concurrent_writes_692(in_context): +@mock.patch("google.cloud.ndb._cache._syncpoint_692", orchestrate.syncpoint) +def test_global_cache_concurrent_write_692(context_factory): """Regression test for #692 https://github.com/googleapis/python-ndb/issues/692 """ key = b"somekey" - @tasklets.tasklet - def run_test(): - lock1 = yield _cache.global_lock_for_write(key) - lock2, _ = yield ( - _cache.global_lock_for_write(key), - _cache.global_unlock_for_write(key, lock1), - ) - yield _cache.global_unlock_for_write(key, lock2) + @tasklets.synctasklet + def lock_unlock_key(): + lock = yield _cache.global_lock_for_write(key) + yield _cache.global_unlock_for_write(key, lock) - delay_global_get = Delay(_cache.global_get) - with mock.patch("google.cloud.ndb._cache._global_get", delay_global_get): + def run_test(): global_cache = global_cache_module._InProcessGlobalCache() - with in_context.new(global_cache=global_cache).use(): - future = run_test() - - # Run until the global_cache_get call in the first global_lock_for_write - # call - run_until() - utils.logging_debug(log, "zero") - - # Let the first global_cache_get call return and advance to the - # global_cache_get calls in the first call to global_unlock_for_write and - # second call to global_lock_for_write. They will have both gotten the same - # "old" value from the cache - delay_global_get.advance() - run_until() - utils.logging_debug(log, "one") - - # Let the global_cache_get call return in the second global_lock_for_write - # call. It should write a new lock value containing both locks. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "two") - - # Let the global_cache_get call return in the first global_unlock_for_write - # call. Since its "old" cache value contained only the first lock, it might - # think it's done and delete the key, since as far as it's concerned, there - # are no more locks. This is the bug exposed by this test. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "three") - - # Since we've fixed the bug now, what we expect it to do instead is attempt - # to write a new cache value that is a write lock value but contains no - # locks. This attempt will fail since the cache value was changed out from - # under it by the second global_lock_write call occurring in parallel. When - # this attempt fails it will call global_get again to get the new value - # containing both locks and recompute a value that only includes the second - # lock and write it. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "four") - - # Now the last call to global_unlock_for_write will call global_get to get - # the current lock value with only one write lock, and then write an empty - # write lock. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "five") - - # Make sure we can get to the end without raising an exception - future.result() + with context_factory(global_cache=global_cache).use(): + lock_unlock_key() - # Make sure the empty write lock registers as "not locked". - assert not _cache.is_locked_value(_cache.global_get(key).result()) + orchestrate.orchestrate(run_test, run_test) diff --git a/tests/unit/test_orchestrate.py b/tests/unit/test_orchestrate.py new file mode 100644 index 00000000..5a8ad621 --- /dev/null +++ b/tests/unit/test_orchestrate.py @@ -0,0 +1,219 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import pytest + +from . import orchestrate + + +def test__permutations(): + sequence = [1, 2, 3, 1, 2, 3, 1, 2, 3] + permutations = orchestrate._permutations(sequence) + assert len(permutations) == 1680 + + result = list(permutations) + assert len(permutations) == len(result) # computed length matches reality + assert len(result) == len(set(result)) # no duplicates + assert result[0] == (1, 1, 1, 2, 2, 2, 3, 3, 3) + assert result[-1] == (3, 3, 3, 2, 2, 2, 1, 1, 1) + + assert list(orchestrate._permutations([1, 2, 3])) == [ + (1, 2, 3), + (1, 3, 2), + (2, 1, 3), + (2, 3, 1), + (3, 1, 2), + (3, 2, 1), + ] + + +class Test_orchestrate: + @staticmethod + def test_no_failures(): + test_calls = [] + + def make_test(name): + def test(): + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + permutations = orchestrate._permutations(["A", "B", "A", "B", "A", "B"]) + expected = list(itertools.chain(*permutations)) + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (3, 3) + assert test_calls == expected + + @staticmethod + def test_syncpoints_decrease_after_initial_run(): + test_calls = [] + + def make_test(name): + syncpoints = [name] * 4 + + def test(): + test_calls.append(name) + if syncpoints: + orchestrate.syncpoint() + test_calls.append(syncpoints.pop()) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + ] + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (2, 2) + assert test_calls == expected + + @staticmethod + def test_syncpoints_increase_after_initial_run(): + test_calls = [] + + def make_test(name): + syncpoints = [None] * 4 + + def test(): + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + + if syncpoints: + syncpoints.pop() + else: + orchestrate.syncpoint() + test_calls.append(name) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "B", + "A", + "A", + "A", + "B", + ] + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (2, 2) + assert test_calls == expected + + @staticmethod + def test_failure(): + test_calls = [] + + def make_test(name): + syncpoints = [None] * 4 + + def test(): + test_calls.append(name) + orchestrate.syncpoint() + test_calls.append(name) + + if syncpoints: + syncpoints.pop() + else: + assert True is False + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + ] + + with pytest.raises(AssertionError): + orchestrate.orchestrate(test1, test2) + + assert test_calls == expected From e9e33fe2c083dd17e20e404c13f58939e0221731 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 17 Aug 2021 10:08:49 -0400 Subject: [PATCH 2/7] Better name/locality for no-op function. --- google/cloud/ndb/_cache.py | 20 +++++++++++--------- tests/unit/test_concurrency.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 083bd5d0..6b13119a 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -34,14 +34,6 @@ log = logging.getLogger(__name__) -def _syncpoint_692(): - """A no-op function meant to be patched for testing. - - Should be replaced by `orchestrate.syncpoint` using `mock.patch` during testing to - orchestrate concurrent testing scenarios. - """ - - class ContextCache(dict): """A per-context in-memory entity cache. @@ -682,6 +674,16 @@ def new_value(old_value): pass +def _syncpoint_update_key(): + """A no-op function meant to be patched for testing. + + Should be replaced by `orchestrate.syncpoint` using `mock.patch` during testing to + orchestrate concurrent testing scenarios. + + See: `tests.unit.test_concurrency` + """ + + @tasklets.tasklet def _update_key(key, new_value): success = False @@ -692,7 +694,7 @@ def _update_key(key, new_value): value = new_value(old_value) utils.logging_debug(log, "new value: {}", value) - _syncpoint_692() + _syncpoint_update_key() if old_value is not None: utils.logging_debug(log, "compare and swap") diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index 098e6062..6a67696a 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -28,7 +28,7 @@ log = logging.getLogger(__name__) -@mock.patch("google.cloud.ndb._cache._syncpoint_692", orchestrate.syncpoint) +@mock.patch("google.cloud.ndb._cache._syncpoint_update_key", orchestrate.syncpoint) def test_global_cache_concurrent_write_692(context_factory): """Regression test for #692 From 5ec4bda3084f9609b7e4d183e69467fad11b3c58 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 17 Aug 2021 10:14:37 -0400 Subject: [PATCH 3/7] Only call syncpoint if __debug__ is True --- google/cloud/ndb/_cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 6b13119a..3fdccbf2 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -694,7 +694,9 @@ def _update_key(key, new_value): value = new_value(old_value) utils.logging_debug(log, "new value: {}", value) - _syncpoint_update_key() + + if __debug__: + _syncpoint_update_key() if old_value is not None: utils.logging_debug(log, "compare and swap") From 5cd8e7065c19b60b2af7dac1d5bdb36c0d8a8be7 Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 17 Aug 2021 10:24:56 -0400 Subject: [PATCH 4/7] Python 2.7 --- tests/conftest.py | 2 +- tests/unit/orchestrate.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1aabecd7..7c8f0a16 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -101,7 +101,7 @@ def context(**kwargs): eventloop=TestingEventLoop(), datastore_policy=True, legacy_data=False, - **kwargs, + **kwargs ) return context diff --git a/tests/unit/orchestrate.py b/tests/unit/orchestrate.py index ed008775..b4a4a647 100644 --- a/tests/unit/orchestrate.py +++ b/tests/unit/orchestrate.py @@ -14,9 +14,13 @@ import itertools import math -import queue import threading +try: + import queue +except ImportError: # pragma: NO PY3 COVER + import Queue as queue + def orchestrate(*tests): """ From 37c33f870f0098e4bd45eaa36f9611d948cdad3b Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 17 Aug 2021 11:18:58 -0400 Subject: [PATCH 5/7] Assert something. --- tests/unit/test_concurrency.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index 6a67696a..d7800316 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -39,7 +39,12 @@ def test_global_cache_concurrent_write_692(context_factory): @tasklets.synctasklet def lock_unlock_key(): lock = yield _cache.global_lock_for_write(key) + cache_value = yield _cache.global_get(key) + assert lock in cache_value + yield _cache.global_unlock_for_write(key, lock) + cache_value = yield _cache.global_get(key) + assert lock not in cache_value def run_test(): global_cache = global_cache_module._InProcessGlobalCache() From 3525e04b72cf8f13ccccd6e666d4c1db9cc62bfb Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 17 Aug 2021 11:33:33 -0400 Subject: [PATCH 6/7] Test the 'real' caches, too, if present. --- tests/unit/test_concurrency.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index d7800316..639ecdf9 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -13,6 +13,9 @@ # limitations under the License. import logging +import os + +import pytest try: from unittest import mock @@ -28,8 +31,25 @@ log = logging.getLogger(__name__) +def cache_factories(): + yield global_cache_module._InProcessGlobalCache + + def redis_cache(): + return global_cache_module.RedisCache.from_environment() + + if os.environ.get("REDIS_CACHE_URL"): + yield redis_cache + + def memcache_cache(): + return global_cache_module.MemcacheCache.from_environment() + + if os.environ.get("MEMCACHED_HOSTS"): + yield global_cache_module.MemcacheCache.from_environment + + +@pytest.mark.parametrize("cache_factory", cache_factories()) @mock.patch("google.cloud.ndb._cache._syncpoint_update_key", orchestrate.syncpoint) -def test_global_cache_concurrent_write_692(context_factory): +def test_global_cache_concurrent_write_692(cache_factory, context_factory): """Regression test for #692 https://github.com/googleapis/python-ndb/issues/692 @@ -47,7 +67,7 @@ def lock_unlock_key(): assert lock not in cache_value def run_test(): - global_cache = global_cache_module._InProcessGlobalCache() + global_cache = cache_factory() with context_factory(global_cache=global_cache).use(): lock_unlock_key() From 19e207d3b1429747d3f95bf800d83f3511e2ea0c Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Mon, 23 Aug 2021 18:39:04 -0400 Subject: [PATCH 7/7] Use settrace --- google/cloud/ndb/_cache.py | 15 +-- tests/unit/orchestrate.py | 160 ++++++++++++++++++++++------ tests/unit/test_concurrency.py | 14 +-- tests/unit/test_orchestrate.py | 189 ++++++++++++++++++++++++++++++--- 4 files changed, 304 insertions(+), 74 deletions(-) diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 3fdccbf2..09fe9840 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -674,16 +674,6 @@ def new_value(old_value): pass -def _syncpoint_update_key(): - """A no-op function meant to be patched for testing. - - Should be replaced by `orchestrate.syncpoint` using `mock.patch` during testing to - orchestrate concurrent testing scenarios. - - See: `tests.unit.test_concurrency` - """ - - @tasklets.tasklet def _update_key(key, new_value): success = False @@ -693,10 +683,7 @@ def _update_key(key, new_value): utils.logging_debug(log, "old value: {}", old_value) value = new_value(old_value) - utils.logging_debug(log, "new value: {}", value) - - if __debug__: - _syncpoint_update_key() + utils.logging_debug(log, "new value: {}", value) # pragma: SYNCPOINT update key if old_value is not None: utils.logging_debug(log, "compare and swap") diff --git a/tests/unit/orchestrate.py b/tests/unit/orchestrate.py index b4a4a647..5ac0c01a 100644 --- a/tests/unit/orchestrate.py +++ b/tests/unit/orchestrate.py @@ -14,7 +14,9 @@ import itertools import math +import sys import threading +import tokenize try: import queue @@ -22,7 +24,7 @@ import Queue as queue -def orchestrate(*tests): +def orchestrate(*tests, **kwargs): """ Orchestrate a deterministic concurrency test. @@ -40,9 +42,8 @@ def orchestrate(*tests): `orchestrate` runs each passed in test function in its own thread. Threads then "take turns" running. Turns are defined by setting syncpoints in the code under - test. To do this, you'll write a no-op function and call it at the point where you'd - like your code to pause and give another thread a turn. In your test, then, use - `mock.patch` to replace your no-op function with :func:`syncpoint` in your test. + test, using comment containing "pragma: SYNCPOINT". `orchestrate` will scan the code + under test and add syncpoints where it finds these comments. For example, let's say you have the following code in production:: @@ -52,19 +53,16 @@ def hither_and_yon(destination): You've found there's a concurrency bug when two threads execute this code with the same destination, and you think that by adding a syncpoint between the calls to - `hither` and `yon` you can reproduce the problem in a regression test. First you'd, - write a no-op function, include it in your production code, and call it in - `hither_and_yon`:: - - def _syncpoint_123(): - pass + `hither` and `yon` you can reproduce the problem in a regression test. First add a + comment with "pragma: SYNCPOINT" to the code under test:: def hither_and_yon(destination): - hither(destination) - _syncpoint_123() + hither(destination) # pragma: SYNCPOINT yon(destination) - Now you can write a test to exercise `hither_and_yon` running in parallel:: + When testing with orchestrate, there will now be a syncpoint, or a pause, after the + call to `hither` and before the call to `yon`. Now you can write a test to exercise + `hither_and_yon` running in parallel:: from unittest import mock from tests.unit import orchestrate @@ -91,8 +89,8 @@ def test_hither_and_yon(): encountered when executing the test. Once the counts have been taken, `orchestrate` will construct a test sequence that - represents the all the turns taken by the passed in tests, with each value in the - sequence representing the the index of the test whose turn it is in the sequence. In + represents all of the turns taken by the passed in tests, with each value in the + sequence representing the index of the test whose turn it is in the sequence. In this example, then, it would produce:: [0, 0, 1, 1] @@ -122,25 +120,49 @@ def test_hither_and_yon(): have over 17 thousand scenarios. In general, use the least number of steps/threads you can get away with and still expose the behavior you want to correct. - For the same reason as above, if you have many concurrent tests, when writing a new - test, make sure you're not accidentally patching syncpoints intended for other - tests, as this will add steps to your tests. While it's not problematic from a - testing standpoint to have extra steps in your tests, it can use computing resources - unnecessarily. Using different no-op functions with different names for different - tests can help with this. + For the same reason as above, its recommended that if you have many concurrent + tests, that you name your syncpoints so that you're not accidentally using + syncpoints intended for other tests, as this will add steps to your tests. While + it's not problematic from a testing standpoint to have extra steps in your tests, it + can use computing resources unnecessarily. A name can be added to any syncpoint + after the `SYNCPOINT` keyword in the pragma definition:: + + def hither_and_yon(destination): + hither(destination) # pragma: SYNCPOINT hither and yon + yon(destination) + + In your test, then, pass that name to `orchestrate` to cause it to use only + syncpoints with that name:: + + orchestrate.orchestrate( + test_hither_and_yon, test_hither_and_yon, name="hither and yon" + ) As soon as any error or failure is detected, no more scenarios are run and that error is propagated to the main thread. + One limitation of `orchestrate` is that it cannot really be used with `coverage`, + since both tools use `sys.set_trace`. Any code that needs verifiable test coverage + should have additional tests that do not use `orchestrate`, since code that is run + under orchestrate will not show up in a coverage report generated by `coverage`. + Args: tests (Tuple[Callable]): Test functions to be run. These functions will not be called with any arguments, so they must not have any required arguments. + name (Optional[str]): Only use syncpoints with the given name. If omitted, only + unnamed syncpoints will be used. Returns: Tuple[int]: A tuple of the count of the number turns for test passed in. Can be used a sanity check in tests to make sure you understand what's actually happening during a test. """ + name = kwargs.pop("name", None) + if kwargs: + raise TypeError( + "Unexpected keyword arguments: {}".format(", ".join(kwargs.keys())) + ) + # Produce an initial test sequence. The fundamental question we're always trying to # answer is "whose turn is it?" First we'll find out how many "turns" each test # needs to complete when run serially and use that to construct a sequence of @@ -150,7 +172,7 @@ def test_hither_and_yon(): test_sequence = [] counts = [] for index, test in enumerate(tests): - thread = _TestThread(test) + thread = _TestThread(test, name) for count in itertools.count(1): # pragma: NO BRANCH # Pragma is required because loop never finishes naturally. thread.go() @@ -169,7 +191,7 @@ def test_hither_and_yon(): # Test each sequence for test_sequence in sequences: - threads = [_TestThread(test) for test in tests] + threads = [_TestThread(test, name) for test in tests] try: for index in test_sequence: threads[index].go() @@ -191,17 +213,6 @@ def test_hither_and_yon(): return tuple(counts) -def syncpoint(): - """End a thread's "turn" at this point. - - This will generally be inserted by `mock.patch` to replace a no-op function in - production code. See documentation for :func:`orchestrate`. - """ - conductor = _local.conductor - conductor.notify() - conductor.standby() - - _local = threading.local() @@ -235,18 +246,53 @@ def go(self): self._go.put(None) +_SYNCPOINTS = {} +"""Dict[str, Dict[str, Set[int]]]: Dict mapping source fileneme to a dict mapping +syncpoint name to set of line numbers where syncpoints with that name occur in the +source file. +""" + + +def _get_syncpoints(filename): + """Find syncpoints in a source file. + + Does a simple tokenization of the source file, looking for comments with "pragma: + SYNCPOINT", and populates _SYNCPOINTS using the syncpoint name and line number in + the source file. + """ + _SYNCPOINTS[filename] = syncpoints = {} + + # Use tokenize to find pragma comments + with open(filename, "r") as pyfile: + tokens = tokenize.generate_tokens(pyfile.readline) + for type, value, start, end, line in tokens: + if type == tokenize.COMMENT and "pragma: SYNCPOINT" in value: + name = value.split("SYNCPOINT", 1)[1].strip() + if not name: + name = None + + if name not in syncpoints: + syncpoints[name] = set() + + lineno, column = start + syncpoints[name].add(lineno) + + class _TestThread: """A thread for a test function.""" thread = None finished = False error = None + at_syncpoint = False - def __init__(self, test): + def __init__(self, test, name): self.test = test + self.name = name self.conductor = _Conductor() def _run(self): + sys.settrace(self._trace) _local.conductor = self.conductor try: self.test() @@ -256,6 +302,50 @@ def _run(self): self.finished = True self.conductor.notify() + def _sync(self): + # Tell main thread we're finished, for now + self.conductor.notify() + + # Wait for the main thread to tell us to go again + self.conductor.standby() + + def _trace(self, frame, event, arg): + """Argument to `sys.settrace`. + + Handles frames during test run, syncing at syncpoints, when found. + + Returns: + `None` if no more tracing is required for the function call, `self._trace` + if tracing should continue. + """ + if self.at_syncpoint: + # We hit a syncpoint on the previous call, so now we sync. + self._sync() + self.at_syncpoint = False + + filename = frame.f_globals.get("__file__") + if not filename: + # Can't trace code without a source file + return + + if filename.endswith(".pyc"): + filename = filename[:-1] + + if filename not in _SYNCPOINTS: + _get_syncpoints(filename) + + syncpoints = _SYNCPOINTS[filename].get(self.name) + if not syncpoints: + # This file doesn't contain syncpoints, don't continue to trace + return + + # We've hit a syncpoint. Execute whatever line the syncpoint is on and then + # sync next time this gets called. + if frame.f_lineno in syncpoints: + self.at_syncpoint = True + + return self._trace + def go(self): if self.finished: return diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index 639ecdf9..6e56b6a4 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -17,11 +17,6 @@ import pytest -try: - from unittest import mock -except ImportError: # pragma: NO PY3 COVER - import mock - from google.cloud.ndb import _cache from google.cloud.ndb import global_cache as global_cache_module from google.cloud.ndb import tasklets @@ -31,7 +26,7 @@ log = logging.getLogger(__name__) -def cache_factories(): +def cache_factories(): # pragma: NO COVER yield global_cache_module._InProcessGlobalCache def redis_cache(): @@ -48,7 +43,6 @@ def memcache_cache(): @pytest.mark.parametrize("cache_factory", cache_factories()) -@mock.patch("google.cloud.ndb._cache._syncpoint_update_key", orchestrate.syncpoint) def test_global_cache_concurrent_write_692(cache_factory, context_factory): """Regression test for #692 @@ -57,7 +51,7 @@ def test_global_cache_concurrent_write_692(cache_factory, context_factory): key = b"somekey" @tasklets.synctasklet - def lock_unlock_key(): + def lock_unlock_key(): # pragma: NO COVER lock = yield _cache.global_lock_for_write(key) cache_value = yield _cache.global_get(key) assert lock in cache_value @@ -66,9 +60,9 @@ def lock_unlock_key(): cache_value = yield _cache.global_get(key) assert lock not in cache_value - def run_test(): + def run_test(): # pragma: NO COVER global_cache = cache_factory() with context_factory(global_cache=global_cache).use(): lock_unlock_key() - orchestrate.orchestrate(run_test, run_test) + orchestrate.orchestrate(run_test, run_test, name="update key") diff --git a/tests/unit/test_orchestrate.py b/tests/unit/test_orchestrate.py index 5a8ad621..60fe57e6 100644 --- a/tests/unit/test_orchestrate.py +++ b/tests/unit/test_orchestrate.py @@ -13,6 +13,12 @@ # limitations under the License. import itertools +import threading + +try: + from unittest import mock +except ImportError: # pragma: NO PY3 COVER + import mock import pytest @@ -41,16 +47,19 @@ def test__permutations(): class Test_orchestrate: + @staticmethod + def test_bad_keyword_argument(): + with pytest.raises(TypeError): + orchestrate.orchestrate(None, None, what="for?") + @staticmethod def test_no_failures(): test_calls = [] def make_test(name): - def test(): - test_calls.append(name) - orchestrate.syncpoint() - test_calls.append(name) - orchestrate.syncpoint() + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT + test_calls.append(name) # pragma: SYNCPOINT test_calls.append(name) return test @@ -65,6 +74,28 @@ def test(): assert counts == (3, 3) assert test_calls == expected + @staticmethod + def test_named_syncpoints(): + test_calls = [] + + def make_test(name): + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT test_named_syncpoints + test_calls.append(name) # pragma: SYNCPOINT test_named_syncpoints + test_calls.append(name) # pragma: SYNCPOINT + + return test + + test1 = make_test("A") + test2 = make_test("B") + + permutations = orchestrate._permutations(["A", "B", "A", "B", "A", "B"]) + expected = list(itertools.chain(*permutations)) + + counts = orchestrate.orchestrate(test1, test2, name="test_named_syncpoints") + assert counts == (3, 3) + assert test_calls == expected + @staticmethod def test_syncpoints_decrease_after_initial_run(): test_calls = [] @@ -72,11 +103,11 @@ def test_syncpoints_decrease_after_initial_run(): def make_test(name): syncpoints = [name] * 4 - def test(): + def test(): # pragma: NO COVER test_calls.append(name) if syncpoints: - orchestrate.syncpoint() - test_calls.append(syncpoints.pop()) + syncpoints.pop() # pragma: SYNCPOINT + test_calls.append(name) return test @@ -114,18 +145,20 @@ def test(): def test_syncpoints_increase_after_initial_run(): test_calls = [] + def do_nothing(): # pragma: NO COVER + pass + def make_test(name): syncpoints = [None] * 4 - def test(): - test_calls.append(name) - orchestrate.syncpoint() + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT test_calls.append(name) if syncpoints: syncpoints.pop() else: - orchestrate.syncpoint() + do_nothing() # pragma: SYNCPOINT test_calls.append(name) return test @@ -175,9 +208,8 @@ def test_failure(): def make_test(name): syncpoints = [None] * 4 - def test(): - test_calls.append(name) - orchestrate.syncpoint() + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT test_calls.append(name) if syncpoints: @@ -217,3 +249,130 @@ def test(): orchestrate.orchestrate(test1, test2) assert test_calls == expected + + +def test__conductor(): + conductor = orchestrate._Conductor() + items = [] + + def run_in_test_thread(): + conductor.notify() + items.append("test1") + conductor.standby() + items.append("test2") + conductor.notify() + conductor.standby() + items.append("test3") + conductor.notify() + + assert not items + test_thread = threading.Thread(target=run_in_test_thread) + + test_thread.start() + conductor.wait() + assert items == ["test1"] + + conductor.go() + conductor.wait() + assert items == ["test1", "test2"] + + conductor.go() + conductor.wait() + assert items == ["test1", "test2", "test3"] + + +def test__get_syncpoints(): # pragma: SYNCPOINT test_get_syncpoints + lines = enumerate(open(__file__, "r"), start=1) + for expected_lineno, line in lines: # pragma: NO BRANCH COVER + if "# pragma: SYNCPOINT test_get_syncpoints" in line: + break + + orchestrate._get_syncpoints(__file__) + syncpoints = orchestrate._SYNCPOINTS[__file__]["test_get_syncpoints"] + assert syncpoints == {expected_lineno} + + +class Test_TestThread: + @staticmethod + def test__sync(): + test_thread = orchestrate._TestThread(None, None) + test_thread.conductor = mock.Mock() + test_thread._sync() + + test_thread.conductor.notify.assert_called_once_with() + test_thread.conductor.standby.assert_called_once_with() + + @staticmethod + def test__trace_no_source_file(): + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock(f_globals={}, spec=("f_globals",)) + test_thread = orchestrate._TestThread(None, None) + assert test_thread._trace(frame, None, None) is None + assert not orchestrate._SYNCPOINTS + + @staticmethod + def test__trace_this_source_file(): + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock( + f_globals={"__file__": __file__}, + f_lineno=1, + spec=( + "f_globals", + "f_lineno", + ), + ) + test_thread = orchestrate._TestThread(None, None) + assert test_thread._trace(frame, None, None) == test_thread._trace + assert __file__ in orchestrate._SYNCPOINTS + + @staticmethod + def test__trace_reach_syncpoint(): + lines = enumerate(open(__file__, "r"), start=1) + for syncpoint_lineno, line in lines: # pragma: NO BRANCH COVER + if "# pragma: SYNCPOINT test_get_syncpoints" in line: + break + + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock( + f_globals={"__file__": __file__}, + f_lineno=syncpoint_lineno, + spec=( + "f_globals", + "f_lineno", + ), + ) + test_thread = orchestrate._TestThread(None, "test_get_syncpoints") + test_thread._sync = mock.Mock() + assert test_thread._trace(frame, None, None) == test_thread._trace + test_thread._sync.assert_not_called() + + frame = mock.Mock( + f_globals={"__file__": __file__}, + f_lineno=syncpoint_lineno + 1, + spec=( + "f_globals", + "f_lineno", + ), + ) + assert test_thread._trace(frame, None, None) == test_thread._trace + test_thread._sync.assert_called_once_with() + + @staticmethod + def test__trace_other_source_file_with_no_syncpoints(): + filename = orchestrate.__file__ + if filename.endswith(".pyc"): # pragma: NO COVER + filename = filename[:-1] + + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock( + f_globals={"__file__": filename + "c"}, + f_lineno=1, + spec=( + "f_globals", + "f_lineno", + ), + ) + test_thread = orchestrate._TestThread(None, None) + assert test_thread._trace(frame, None, None) is None + syncpoints = orchestrate._SYNCPOINTS[filename] + assert not syncpoints