From b27725e788e0fe18819f1f543657610d598aa9fb Mon Sep 17 00:00:00 2001 From: Chris Rossi Date: Tue, 24 Aug 2021 13:29:12 -0400 Subject: [PATCH] test: refactor concurrency test using orchestrate (#709) Towards #691 --- google/cloud/ndb/_cache.py | 2 +- tests/conftest.py | 35 ++- tests/unit/orchestrate.py | 450 +++++++++++++++++++++++++++++++++ tests/unit/test_concurrency.py | 145 +++-------- tests/unit/test_orchestrate.py | 378 +++++++++++++++++++++++++++ 5 files changed, 882 insertions(+), 128 deletions(-) create mode 100644 tests/unit/orchestrate.py create mode 100644 tests/unit/test_orchestrate.py diff --git a/google/cloud/ndb/_cache.py b/google/cloud/ndb/_cache.py index 13c16928..09fe9840 100644 --- a/google/cloud/ndb/_cache.py +++ b/google/cloud/ndb/_cache.py @@ -683,7 +683,7 @@ def _update_key(key, new_value): utils.logging_debug(log, "old value: {}", old_value) value = new_value(old_value) - utils.logging_debug(log, "new value: {}", value) + utils.logging_debug(log, "new value: {}", value) # pragma: SYNCPOINT update key if old_value is not None: utils.logging_debug(log, "compare and swap") diff --git a/tests/conftest.py b/tests/conftest.py index 8c3775cd..7c8f0a16 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -88,22 +88,31 @@ def initialize_environment(request, environ): @pytest.fixture -def context(): - client = mock.Mock( - project="testing", - namespace=None, - spec=("project", "namespace"), - stub=mock.Mock(spec=()), - ) - context = context_module.Context( - client, - eventloop=TestingEventLoop(), - datastore_policy=True, - legacy_data=False, - ) +def context_factory(): + def context(**kwargs): + client = mock.Mock( + project="testing", + namespace=None, + spec=("project", "namespace"), + stub=mock.Mock(spec=()), + ) + context = context_module.Context( + client, + eventloop=TestingEventLoop(), + datastore_policy=True, + legacy_data=False, + **kwargs + ) + return context + return context +@pytest.fixture +def context(context_factory): + return context_factory() + + @pytest.fixture def in_context(context): assert not context_module._state.context diff --git a/tests/unit/orchestrate.py b/tests/unit/orchestrate.py new file mode 100644 index 00000000..5ac0c01a --- /dev/null +++ b/tests/unit/orchestrate.py @@ -0,0 +1,450 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import math +import sys +import threading +import tokenize + +try: + import queue +except ImportError: # pragma: NO PY3 COVER + import Queue as queue + + +def orchestrate(*tests, **kwargs): + """ + Orchestrate a deterministic concurrency test. + + Runs test functions in separate threads, with each thread taking turns running up + until predefined syncpoints in a deterministic order. All possible orderings are + tested. + + Most of the time, we try to use logic, best practices, and static analysis to insure + correct operation of concurrent code. Sometimes our powers of reasoning fail us and, + either through non-determistic stress testing or running code in production, a + concurrent bug is discovered. When this occurs, we'd like to have a regression test + to insure we've understood the problem and implemented a correct solution. + `orchestrate` provides a means of deterministically testing concurrent code so we + can write robust regression tests for complex concurrent scenarios. + + `orchestrate` runs each passed in test function in its own thread. Threads then + "take turns" running. Turns are defined by setting syncpoints in the code under + test, using comment containing "pragma: SYNCPOINT". `orchestrate` will scan the code + under test and add syncpoints where it finds these comments. + + For example, let's say you have the following code in production:: + + def hither_and_yon(destination): + hither(destination) + yon(destination) + + You've found there's a concurrency bug when two threads execute this code with the + same destination, and you think that by adding a syncpoint between the calls to + `hither` and `yon` you can reproduce the problem in a regression test. First add a + comment with "pragma: SYNCPOINT" to the code under test:: + + def hither_and_yon(destination): + hither(destination) # pragma: SYNCPOINT + yon(destination) + + When testing with orchestrate, there will now be a syncpoint, or a pause, after the + call to `hither` and before the call to `yon`. Now you can write a test to exercise + `hither_and_yon` running in parallel:: + + from unittest import mock + from tests.unit import orchestrate + + from google.cloud.sales import travel + + @mock.patch("google.cloud.sales.travel._syncpoint_123", orchestrate.syncpoint) + def test_concurrent_hither_and_yon(): + + def test_hither_and_yon(): + assert something + travel.hither_and_yon("Raleigh") + assert something_else + + counts = orchestrate.orchestrate(test_hither_and_yon, test_hither_and_yon) + assert counts == (2, 2) + + What `orchestrate` will do now is take each of the two test functions passed in + (actually the same function, twice, in this case), run them serially, and count the + number of turns it takes to run each test to completion. In this example, it will + take two turns for each test: one turn to start the thread and execute up until the + syncpoint, and then another turn to execute from the syncpoint to the end of the + test. The number of turns will always be one greater than the number of syncpoints + encountered when executing the test. + + Once the counts have been taken, `orchestrate` will construct a test sequence that + represents all of the turns taken by the passed in tests, with each value in the + sequence representing the index of the test whose turn it is in the sequence. In + this example, then, it would produce:: + + [0, 0, 1, 1] + + This represents the first test taking both of its turns, followed by the second test + taking both of its turns. At this point this scenario has already been tested, + because this is what was run to produce the counts and the initial test sequence. + Now `orchestrate` will run all of the remaining scenarios by finding all the + permutations of the test sequence and executing those, in turn:: + + [0, 1, 0, 1] + [0, 1, 1, 0] + [1, 0, 0, 1] + [1, 0, 1, 0] + [1, 1, 0, 0] + + You'll notice in our example that since both test functions are actually the same + function, that although it tested 6 scenarios there are effectively only really 3 + unique scenarios. For the time being, though, `orchestrate` doesn't attempt to + detect this condition or optimize for it. + + There are some performance considerations that should be taken into account when + writing tests. The number of unique test sequences grows quite quickly with the + number of turns taken by the functions under test. Our simple example with two + threads each taking two turns, only yielded 6 scenarios, but two threads each taking + 6 turns, for example, yields 924 scenarios. Add another six step thread and now you + have over 17 thousand scenarios. In general, use the least number of steps/threads + you can get away with and still expose the behavior you want to correct. + + For the same reason as above, its recommended that if you have many concurrent + tests, that you name your syncpoints so that you're not accidentally using + syncpoints intended for other tests, as this will add steps to your tests. While + it's not problematic from a testing standpoint to have extra steps in your tests, it + can use computing resources unnecessarily. A name can be added to any syncpoint + after the `SYNCPOINT` keyword in the pragma definition:: + + def hither_and_yon(destination): + hither(destination) # pragma: SYNCPOINT hither and yon + yon(destination) + + In your test, then, pass that name to `orchestrate` to cause it to use only + syncpoints with that name:: + + orchestrate.orchestrate( + test_hither_and_yon, test_hither_and_yon, name="hither and yon" + ) + + As soon as any error or failure is detected, no more scenarios are run + and that error is propagated to the main thread. + + One limitation of `orchestrate` is that it cannot really be used with `coverage`, + since both tools use `sys.set_trace`. Any code that needs verifiable test coverage + should have additional tests that do not use `orchestrate`, since code that is run + under orchestrate will not show up in a coverage report generated by `coverage`. + + Args: + tests (Tuple[Callable]): Test functions to be run. These functions will not be + called with any arguments, so they must not have any required arguments. + name (Optional[str]): Only use syncpoints with the given name. If omitted, only + unnamed syncpoints will be used. + + Returns: + Tuple[int]: A tuple of the count of the number turns for test passed in. Can be + used a sanity check in tests to make sure you understand what's actually + happening during a test. + """ + name = kwargs.pop("name", None) + if kwargs: + raise TypeError( + "Unexpected keyword arguments: {}".format(", ".join(kwargs.keys())) + ) + + # Produce an initial test sequence. The fundamental question we're always trying to + # answer is "whose turn is it?" First we'll find out how many "turns" each test + # needs to complete when run serially and use that to construct a sequence of + # indexes. When a test's index appears in the sequence, it is that test's turn to + # run. We'll start by constructing a sequence that would run each test through to + # completion serially, one after the other. + test_sequence = [] + counts = [] + for index, test in enumerate(tests): + thread = _TestThread(test, name) + for count in itertools.count(1): # pragma: NO BRANCH + # Pragma is required because loop never finishes naturally. + thread.go() + if thread.finished: + break + + counts.append(count) + test_sequence += [index] * count + + # Now we can take that initial sequence and generate all of its permutations, + # running each one to try to uncover concurrency bugs + sequences = iter(_permutations(test_sequence)) + + # We already tested the first sequence getting our counts, so we can discard it + next(sequences) + + # Test each sequence + for test_sequence in sequences: + threads = [_TestThread(test, name) for test in tests] + try: + for index in test_sequence: + threads[index].go() + + # Its possible for number of turns to vary from one test run to the other, + # especially if there is some undiscovered concurrency bug. Go ahead and + # finish running each test to completion, if not already complete. + for thread in threads: + while not thread.finished: + thread.go() + + except Exception: + # If an exception occurs, we still need to let any threads that are still + # going finish up. Additional exceptions are silently ignored. + for thread in threads: + thread.finish() + raise + + return tuple(counts) + + +_local = threading.local() + + +class _Conductor: + """Coordinate communication between main thread and a test thread. + + Two way communicaton is maintained between the main thread and a test thread using + two synchronized queues (`queue.Queue`) each with a size of one. + """ + + def __init__(self): + self._notify = queue.Queue(1) + self._go = queue.Queue(1) + + def notify(self): + """Called from test thread to let us know it's finished or is ready for its next + turn.""" + self._notify.put(None) + + def standby(self): + """Called from test thread in order to block until told to go.""" + self._go.get() + + def wait(self): + """Called from main thread to wait for test thread to either get to the + next syncpoint or finish.""" + self._notify.get() + + def go(self): + """Called from main thread to tell test thread to go.""" + self._go.put(None) + + +_SYNCPOINTS = {} +"""Dict[str, Dict[str, Set[int]]]: Dict mapping source fileneme to a dict mapping +syncpoint name to set of line numbers where syncpoints with that name occur in the +source file. +""" + + +def _get_syncpoints(filename): + """Find syncpoints in a source file. + + Does a simple tokenization of the source file, looking for comments with "pragma: + SYNCPOINT", and populates _SYNCPOINTS using the syncpoint name and line number in + the source file. + """ + _SYNCPOINTS[filename] = syncpoints = {} + + # Use tokenize to find pragma comments + with open(filename, "r") as pyfile: + tokens = tokenize.generate_tokens(pyfile.readline) + for type, value, start, end, line in tokens: + if type == tokenize.COMMENT and "pragma: SYNCPOINT" in value: + name = value.split("SYNCPOINT", 1)[1].strip() + if not name: + name = None + + if name not in syncpoints: + syncpoints[name] = set() + + lineno, column = start + syncpoints[name].add(lineno) + + +class _TestThread: + """A thread for a test function.""" + + thread = None + finished = False + error = None + at_syncpoint = False + + def __init__(self, test, name): + self.test = test + self.name = name + self.conductor = _Conductor() + + def _run(self): + sys.settrace(self._trace) + _local.conductor = self.conductor + try: + self.test() + except Exception as error: + self.error = error + finally: + self.finished = True + self.conductor.notify() + + def _sync(self): + # Tell main thread we're finished, for now + self.conductor.notify() + + # Wait for the main thread to tell us to go again + self.conductor.standby() + + def _trace(self, frame, event, arg): + """Argument to `sys.settrace`. + + Handles frames during test run, syncing at syncpoints, when found. + + Returns: + `None` if no more tracing is required for the function call, `self._trace` + if tracing should continue. + """ + if self.at_syncpoint: + # We hit a syncpoint on the previous call, so now we sync. + self._sync() + self.at_syncpoint = False + + filename = frame.f_globals.get("__file__") + if not filename: + # Can't trace code without a source file + return + + if filename.endswith(".pyc"): + filename = filename[:-1] + + if filename not in _SYNCPOINTS: + _get_syncpoints(filename) + + syncpoints = _SYNCPOINTS[filename].get(self.name) + if not syncpoints: + # This file doesn't contain syncpoints, don't continue to trace + return + + # We've hit a syncpoint. Execute whatever line the syncpoint is on and then + # sync next time this gets called. + if frame.f_lineno in syncpoints: + self.at_syncpoint = True + + return self._trace + + def go(self): + if self.finished: + return + + if self.thread is None: + self.thread = threading.Thread(target=self._run) + self.thread.start() + + else: + self.conductor.go() + + self.conductor.wait() + + if self.error: + raise self.error + + def finish(self): + while not self.finished: + try: + self.go() + except Exception: + pass + + +class _permutations: + """Generates a sequence of all permutations of `sequence`. + + Permutations are returned in lexicographic order using the "Generation in + lexicographic order" algorithm described in `the Wikipedia article on "Permutation" + `_. + + This implementation differs significantly from `itertools.permutations` in that the + value of individual elements is taken into account, thus eliminating redundant + orderings that would be produced by `itertools.permutations`. + + Args: + sequence (Sequence[Any]): Sequence must be finite and orderable. + + Returns: + Sequence[Sequence[Any]]: Set of all permutations of `sequence`. + """ + + def __init__(self, sequence): + self._start = tuple(sorted(sequence)) + + def __len__(self): + """Compute the number of permutations. + + Let the number of elements in a sequence N and the number of repetitions for + individual members of the sequence be n1, n2, ... nx. The number of unique + permutations is: N! / n1! / n2! / ... / nx!. + + For example, let `sequence` be [1, 2, 3, 1, 2, 3, 1, 2, 3]. The number of unique + permutations is: 9! / 3! / 3! / 3! = 1680. + + See: "Permutations of multisets" in `the Wikipedia article on "Permutation" + `_. + """ + repeats = [len(list(group)) for value, group in itertools.groupby(self._start)] + length = math.factorial(len(self._start)) + for repeat in repeats: + length /= math.factorial(repeat) + + return int(length) + + def __iter__(self): + """Iterate over permutations. + + See: "Generation in lexicographic order" algorithm described in `the Wikipedia + article on "Permutation" `_. + """ + current = list(self._start) + size = len(current) + + while True: + yield tuple(current) + + # 1. Find the largest index i such that a[i] < a[i + 1]. + for i in range(size - 2, -1, -1): + if current[i] < current[i + 1]: + break + + else: + # If no such index exists, the permutation is the last permutation. + return + + # 2. Find the largest index j greater than i such that a[i] < a[j]. + for j in range(size - 1, i, -1): + if current[i] < current[j]: + break + + else: # pragma: NO COVER + raise RuntimeError("Broken algorithm") + + # 3. Swap the value of a[i] with that of a[j]. + temp = current[i] + current[i] = current[j] + current[j] = temp + + # 4. Reverse the sequence from a[i + 1] up to and including the final + # element a[n]. + current = current[: i + 1] + list(reversed(current[i + 1 :])) diff --git a/tests/unit/test_concurrency.py b/tests/unit/test_concurrency.py index 742cbc09..6e56b6a4 100644 --- a/tests/unit/test_concurrency.py +++ b/tests/unit/test_concurrency.py @@ -12,140 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections import logging +import os -try: - from unittest import mock -except ImportError: # pragma: NO PY3 COVER - import mock +import pytest from google.cloud.ndb import _cache -from google.cloud.ndb import _eventloop from google.cloud.ndb import global_cache as global_cache_module from google.cloud.ndb import tasklets -from google.cloud.ndb import utils +from . import orchestrate log = logging.getLogger(__name__) -class Delay(object): - """A tasklet wrapper which delays the return of a tasklet. +def cache_factories(): # pragma: NO COVER + yield global_cache_module._InProcessGlobalCache - Used to orchestrate timing of events in async code to test particular scenarios - involving concurrency. Use with `mock.patch` to replace particular tasklets with - wrapped versions. When those tasklets are called, they will execute and then the - wrapper will hang on to the result until :meth:`Delay.advance()` is called, at which - time the tasklet's caller will receive the result. + def redis_cache(): + return global_cache_module.RedisCache.from_environment() - Args: - wrapped (tasklets.Tasklet): The tasklet to be delayed. - """ - - def __init__(self, wrapped): - self.wrapped = wrapped - self.info = "Delay {}".format(self.wrapped.__name__) - self._futures = collections.deque() - - @tasklets.tasklet - def __call__(self, *args, **kwargs): - future = tasklets.Future(self.info) - self._futures.append(future) + if os.environ.get("REDIS_CACHE_URL"): + yield redis_cache - result = yield self.wrapped(*args, **kwargs) - yield future - raise tasklets.Return(result) + def memcache_cache(): + return global_cache_module.MemcacheCache.from_environment() - def advance(self): - """Allow a call to the wrapper to proceed. + if os.environ.get("MEMCACHED_HOSTS"): + yield global_cache_module.MemcacheCache.from_environment - Calls are advanced in the order in which they were orignally made. - """ - self._futures.popleft().set_result(None) - -def run_until(): - """Do all queued work on the event loop. - - This will allow any currently running tasklets to execute up to the point that they - hit a call to a tasklet that is delayed by :class:`Delay`. When this call is - finished, either all in progress tasklets will have been completed, or a call to - :class:`Delay.advance` will be required to move execution forward again. - """ - while _eventloop.run1(): - pass - - -def test_global_cache_concurrent_writes_692(in_context): +@pytest.mark.parametrize("cache_factory", cache_factories()) +def test_global_cache_concurrent_write_692(cache_factory, context_factory): """Regression test for #692 https://github.com/googleapis/python-ndb/issues/692 """ key = b"somekey" - @tasklets.tasklet - def run_test(): - lock1 = yield _cache.global_lock_for_write(key) - lock2, _ = yield ( - _cache.global_lock_for_write(key), - _cache.global_unlock_for_write(key, lock1), - ) - yield _cache.global_unlock_for_write(key, lock2) - - delay_global_get = Delay(_cache.global_get) - with mock.patch("google.cloud.ndb._cache._global_get", delay_global_get): - global_cache = global_cache_module._InProcessGlobalCache() - with in_context.new(global_cache=global_cache).use(): - future = run_test() - - # Run until the global_cache_get call in the first global_lock_for_write - # call - run_until() - utils.logging_debug(log, "zero") - - # Let the first global_cache_get call return and advance to the - # global_cache_get calls in the first call to global_unlock_for_write and - # second call to global_lock_for_write. They will have both gotten the same - # "old" value from the cache - delay_global_get.advance() - run_until() - utils.logging_debug(log, "one") - - # Let the global_cache_get call return in the second global_lock_for_write - # call. It should write a new lock value containing both locks. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "two") - - # Let the global_cache_get call return in the first global_unlock_for_write - # call. Since its "old" cache value contained only the first lock, it might - # think it's done and delete the key, since as far as it's concerned, there - # are no more locks. This is the bug exposed by this test. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "three") - - # Since we've fixed the bug now, what we expect it to do instead is attempt - # to write a new cache value that is a write lock value but contains no - # locks. This attempt will fail since the cache value was changed out from - # under it by the second global_lock_write call occurring in parallel. When - # this attempt fails it will call global_get again to get the new value - # containing both locks and recompute a value that only includes the second - # lock and write it. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "four") - - # Now the last call to global_unlock_for_write will call global_get to get - # the current lock value with only one write lock, and then write an empty - # write lock. - delay_global_get.advance() - run_until() - utils.logging_debug(log, "five") - - # Make sure we can get to the end without raising an exception - future.result() - - # Make sure the empty write lock registers as "not locked". - assert not _cache.is_locked_value(_cache.global_get(key).result()) + @tasklets.synctasklet + def lock_unlock_key(): # pragma: NO COVER + lock = yield _cache.global_lock_for_write(key) + cache_value = yield _cache.global_get(key) + assert lock in cache_value + + yield _cache.global_unlock_for_write(key, lock) + cache_value = yield _cache.global_get(key) + assert lock not in cache_value + + def run_test(): # pragma: NO COVER + global_cache = cache_factory() + with context_factory(global_cache=global_cache).use(): + lock_unlock_key() + + orchestrate.orchestrate(run_test, run_test, name="update key") diff --git a/tests/unit/test_orchestrate.py b/tests/unit/test_orchestrate.py new file mode 100644 index 00000000..60fe57e6 --- /dev/null +++ b/tests/unit/test_orchestrate.py @@ -0,0 +1,378 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import threading + +try: + from unittest import mock +except ImportError: # pragma: NO PY3 COVER + import mock + +import pytest + +from . import orchestrate + + +def test__permutations(): + sequence = [1, 2, 3, 1, 2, 3, 1, 2, 3] + permutations = orchestrate._permutations(sequence) + assert len(permutations) == 1680 + + result = list(permutations) + assert len(permutations) == len(result) # computed length matches reality + assert len(result) == len(set(result)) # no duplicates + assert result[0] == (1, 1, 1, 2, 2, 2, 3, 3, 3) + assert result[-1] == (3, 3, 3, 2, 2, 2, 1, 1, 1) + + assert list(orchestrate._permutations([1, 2, 3])) == [ + (1, 2, 3), + (1, 3, 2), + (2, 1, 3), + (2, 3, 1), + (3, 1, 2), + (3, 2, 1), + ] + + +class Test_orchestrate: + @staticmethod + def test_bad_keyword_argument(): + with pytest.raises(TypeError): + orchestrate.orchestrate(None, None, what="for?") + + @staticmethod + def test_no_failures(): + test_calls = [] + + def make_test(name): + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT + test_calls.append(name) # pragma: SYNCPOINT + test_calls.append(name) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + permutations = orchestrate._permutations(["A", "B", "A", "B", "A", "B"]) + expected = list(itertools.chain(*permutations)) + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (3, 3) + assert test_calls == expected + + @staticmethod + def test_named_syncpoints(): + test_calls = [] + + def make_test(name): + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT test_named_syncpoints + test_calls.append(name) # pragma: SYNCPOINT test_named_syncpoints + test_calls.append(name) # pragma: SYNCPOINT + + return test + + test1 = make_test("A") + test2 = make_test("B") + + permutations = orchestrate._permutations(["A", "B", "A", "B", "A", "B"]) + expected = list(itertools.chain(*permutations)) + + counts = orchestrate.orchestrate(test1, test2, name="test_named_syncpoints") + assert counts == (3, 3) + assert test_calls == expected + + @staticmethod + def test_syncpoints_decrease_after_initial_run(): + test_calls = [] + + def make_test(name): + syncpoints = [name] * 4 + + def test(): # pragma: NO COVER + test_calls.append(name) + if syncpoints: + syncpoints.pop() # pragma: SYNCPOINT + test_calls.append(name) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + ] + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (2, 2) + assert test_calls == expected + + @staticmethod + def test_syncpoints_increase_after_initial_run(): + test_calls = [] + + def do_nothing(): # pragma: NO COVER + pass + + def make_test(name): + syncpoints = [None] * 4 + + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT + test_calls.append(name) + + if syncpoints: + syncpoints.pop() + else: + do_nothing() # pragma: SYNCPOINT + test_calls.append(name) + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "B", + "A", + "A", + "A", + "B", + ] + + counts = orchestrate.orchestrate(test1, test2) + assert counts == (2, 2) + assert test_calls == expected + + @staticmethod + def test_failure(): + test_calls = [] + + def make_test(name): + syncpoints = [None] * 4 + + def test(): # pragma: NO COVER + test_calls.append(name) # pragma: SYNCPOINT + test_calls.append(name) + + if syncpoints: + syncpoints.pop() + else: + assert True is False + + return test + + test1 = make_test("A") + test2 = make_test("B") + + expected = [ + "A", + "A", + "B", + "B", + "A", + "B", + "A", + "B", + "A", + "B", + "B", + "A", + "B", + "A", + "A", + "B", + "B", + "A", + "B", + "A", + ] + + with pytest.raises(AssertionError): + orchestrate.orchestrate(test1, test2) + + assert test_calls == expected + + +def test__conductor(): + conductor = orchestrate._Conductor() + items = [] + + def run_in_test_thread(): + conductor.notify() + items.append("test1") + conductor.standby() + items.append("test2") + conductor.notify() + conductor.standby() + items.append("test3") + conductor.notify() + + assert not items + test_thread = threading.Thread(target=run_in_test_thread) + + test_thread.start() + conductor.wait() + assert items == ["test1"] + + conductor.go() + conductor.wait() + assert items == ["test1", "test2"] + + conductor.go() + conductor.wait() + assert items == ["test1", "test2", "test3"] + + +def test__get_syncpoints(): # pragma: SYNCPOINT test_get_syncpoints + lines = enumerate(open(__file__, "r"), start=1) + for expected_lineno, line in lines: # pragma: NO BRANCH COVER + if "# pragma: SYNCPOINT test_get_syncpoints" in line: + break + + orchestrate._get_syncpoints(__file__) + syncpoints = orchestrate._SYNCPOINTS[__file__]["test_get_syncpoints"] + assert syncpoints == {expected_lineno} + + +class Test_TestThread: + @staticmethod + def test__sync(): + test_thread = orchestrate._TestThread(None, None) + test_thread.conductor = mock.Mock() + test_thread._sync() + + test_thread.conductor.notify.assert_called_once_with() + test_thread.conductor.standby.assert_called_once_with() + + @staticmethod + def test__trace_no_source_file(): + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock(f_globals={}, spec=("f_globals",)) + test_thread = orchestrate._TestThread(None, None) + assert test_thread._trace(frame, None, None) is None + assert not orchestrate._SYNCPOINTS + + @staticmethod + def test__trace_this_source_file(): + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock( + f_globals={"__file__": __file__}, + f_lineno=1, + spec=( + "f_globals", + "f_lineno", + ), + ) + test_thread = orchestrate._TestThread(None, None) + assert test_thread._trace(frame, None, None) == test_thread._trace + assert __file__ in orchestrate._SYNCPOINTS + + @staticmethod + def test__trace_reach_syncpoint(): + lines = enumerate(open(__file__, "r"), start=1) + for syncpoint_lineno, line in lines: # pragma: NO BRANCH COVER + if "# pragma: SYNCPOINT test_get_syncpoints" in line: + break + + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock( + f_globals={"__file__": __file__}, + f_lineno=syncpoint_lineno, + spec=( + "f_globals", + "f_lineno", + ), + ) + test_thread = orchestrate._TestThread(None, "test_get_syncpoints") + test_thread._sync = mock.Mock() + assert test_thread._trace(frame, None, None) == test_thread._trace + test_thread._sync.assert_not_called() + + frame = mock.Mock( + f_globals={"__file__": __file__}, + f_lineno=syncpoint_lineno + 1, + spec=( + "f_globals", + "f_lineno", + ), + ) + assert test_thread._trace(frame, None, None) == test_thread._trace + test_thread._sync.assert_called_once_with() + + @staticmethod + def test__trace_other_source_file_with_no_syncpoints(): + filename = orchestrate.__file__ + if filename.endswith(".pyc"): # pragma: NO COVER + filename = filename[:-1] + + orchestrate._SYNCPOINTS.clear() + frame = mock.Mock( + f_globals={"__file__": filename + "c"}, + f_lineno=1, + spec=( + "f_globals", + "f_lineno", + ), + ) + test_thread = orchestrate._TestThread(None, None) + assert test_thread._trace(frame, None, None) is None + syncpoints = orchestrate._SYNCPOINTS[filename] + assert not syncpoints