Skip to content

Commit

Permalink
Use spawn in _compat_test.py to avoid fork problems (#6374)
Browse files Browse the repository at this point in the history
Review: @dstrain115
  • Loading branch information
maffoo authored Dec 5, 2023
1 parent 30b6c39 commit 7578110
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 43 deletions.
83 changes: 45 additions & 38 deletions cirq-core/cirq/_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_wrap_module():


def test_deprecate_attributes_assert_attributes_in_sys_modules():
subprocess_context(_test_deprecate_attributes_assert_attributes_in_sys_modules)()
run_in_subprocess(_test_deprecate_attributes_assert_attributes_in_sys_modules)


def _test_deprecate_attributes_assert_attributes_in_sys_modules():
Expand Down Expand Up @@ -635,42 +635,49 @@ def _type_repr_in_deprecated_module():
] + _deprecation_origin


def _trace_unhandled_exceptions(*args, queue: 'multiprocessing.Queue', func: Callable, **kwargs):
def _trace_unhandled_exceptions(*args, queue: 'multiprocessing.Queue', func: Callable):
try:
func(*args, **kwargs)
func(*args)
queue.put(None)
except BaseException as ex:
msg = str(ex)
queue.put((type(ex).__name__, msg, traceback.format_exc()))


def subprocess_context(test_func):
"""Ensures that sys.modules changes in subprocesses won't impact the parent process."""
def run_in_subprocess(test_func, *args):
"""Run a function in a subprocess.
This ensures that sys.modules changes in subprocesses won't impact the parent process.
Args:
test_func: The function to be run in a subprocess.
*args: Positional args to pass to the function.
"""

assert callable(test_func), (
"subprocess_context expects a function. Did you call the function instead of passing "
"run_in_subprocess expects a function. Did you call the function instead of passing "
"it to this method?"
)

ctx = multiprocessing.get_context('spawn' if os.name == 'nt' else 'fork')

exception = ctx.Queue()
# Use spawn to ensure subprocesses are isolated.
# See https://github.com/quantumlib/Cirq/issues/6373
ctx = multiprocessing.get_context('spawn')

def isolated_func(*args, **kwargs):
kwargs['queue'] = exception
kwargs['func'] = test_func
p = ctx.Process(target=_trace_unhandled_exceptions, args=args, kwargs=kwargs)
p.start()
p.join()
result = exception.get()
if result: # pragma: no cover
ex_type, msg, ex_trace = result
if ex_type == "Skipped":
warnings.warn(f"Skipping: {ex_type}: {msg}\n{ex_trace}")
pytest.skip(f'{ex_type}: {msg}\n{ex_trace}')
else:
pytest.fail(f'{ex_type}: {msg}\n{ex_trace}')
queue = ctx.Queue()

return isolated_func
p = ctx.Process(
target=_trace_unhandled_exceptions, args=args, kwargs={'queue': queue, 'func': test_func}
)
p.start()
p.join()
result = queue.get()
if result: # pragma: no cover
ex_type, msg, ex_trace = result
if ex_type == "Skipped":
warnings.warn(f"Skipping: {ex_type}: {msg}\n{ex_trace}")
pytest.skip(f'{ex_type}: {msg}\n{ex_trace}')
else:
pytest.fail(f'{ex_type}: {msg}\n{ex_trace}')


@mock.patch.dict(os.environ, {"CIRQ_FORCE_DEDUPE_MODULE_DEPRECATION": "1"})
Expand Down Expand Up @@ -698,7 +705,7 @@ def isolated_func(*args, **kwargs):
],
)
def test_deprecated_module(outdated_method, deprecation_messages):
subprocess_context(_test_deprecated_module_inner)(outdated_method, deprecation_messages)
run_in_subprocess(_test_deprecated_module_inner, outdated_method, deprecation_messages)


def _test_deprecated_module_inner(outdated_method, deprecation_messages):
Expand Down Expand Up @@ -736,7 +743,7 @@ def test_same_name_submodule_earlier_in_subtree():
cirq.ops.engine.calibration packages. The wrong resolution resulted in false circular
imports!
"""
subprocess_context(_test_same_name_submodule_earlier_in_subtree_inner)()
run_in_subprocess(_test_same_name_submodule_earlier_in_subtree_inner)


def _test_same_name_submodule_earlier_in_subtree_inner():
Expand All @@ -748,7 +755,7 @@ def _test_same_name_submodule_earlier_in_subtree_inner():
def test_metadata_search_path():
# to cater for metadata path finders
# https://docs.python.org/3/library/importlib.metadata.html#extending-the-search-algorithm
subprocess_context(_test_metadata_search_path_inner)()
run_in_subprocess(_test_metadata_search_path_inner)


def _test_metadata_search_path_inner(): # pragma: no cover
Expand All @@ -760,7 +767,7 @@ def _test_metadata_search_path_inner(): # pragma: no cover


def test_metadata_distributions_after_deprecated_submodule():
subprocess_context(_test_metadata_distributions_after_deprecated_submodule)()
run_in_subprocess(_test_metadata_distributions_after_deprecated_submodule)


def _test_metadata_distributions_after_deprecated_submodule():
Expand All @@ -779,7 +786,7 @@ def _test_metadata_distributions_after_deprecated_submodule():


def test_parent_spec_after_deprecated_submodule():
subprocess_context(_test_parent_spec_after_deprecated_submodule)()
run_in_subprocess(_test_parent_spec_after_deprecated_submodule)


def _test_parent_spec_after_deprecated_submodule():
Expand All @@ -791,7 +798,7 @@ def _test_parent_spec_after_deprecated_submodule():
def test_type_repr_in_new_module():
# to cater for metadata path finders
# https://docs.python.org/3/library/importlib.metadata.html#extending-the-search-algorithm
subprocess_context(_test_type_repr_in_new_module_inner)()
run_in_subprocess(_test_type_repr_in_new_module_inner)


def _test_type_repr_in_new_module_inner():
Expand Down Expand Up @@ -849,19 +856,19 @@ def _test_broken_module_3_inner():


def test_deprecated_module_error_handling_1():
subprocess_context(_test_broken_module_1_inner)()
run_in_subprocess(_test_broken_module_1_inner)


def test_deprecated_module_error_handling_2():
subprocess_context(_test_broken_module_2_inner)()
run_in_subprocess(_test_broken_module_2_inner)


def test_deprecated_module_error_handling_3():
subprocess_context(_test_broken_module_3_inner)()
run_in_subprocess(_test_broken_module_3_inner)


def test_new_module_is_top_level():
subprocess_context(_test_new_module_is_top_level_inner)()
run_in_subprocess(_test_new_module_is_top_level_inner)


def _test_new_module_is_top_level_inner():
Expand All @@ -877,7 +884,7 @@ def _test_new_module_is_top_level_inner():


def test_import_deprecated_with_no_attribute():
subprocess_context(_test_import_deprecated_with_no_attribute_inner)()
run_in_subprocess(_test_import_deprecated_with_no_attribute_inner)


def _test_import_deprecated_with_no_attribute_inner():
Expand Down Expand Up @@ -970,23 +977,23 @@ def module_repr(self, module: ModuleType) -> str:

def test_subprocess_test_failure():
with pytest.raises(Failed, match='ValueError.*this fails'):
subprocess_context(_test_subprocess_test_failure_inner)()
run_in_subprocess(_test_subprocess_test_failure_inner)


def _test_subprocess_test_failure_inner():
raise ValueError('this fails')


def test_dir_is_still_valid():
subprocess_context(_dir_is_still_valid_inner)()
run_in_subprocess(_dir_is_still_valid_inner)


def _dir_is_still_valid_inner():
"""to ensure that create_attribute=True keeps the dir(module) intact"""

import cirq.testing._compat_test_data as mod

for m in ['fake_a', 'info', 'module_a', 'sys']:
for m in ['fake_a', 'logging', 'module_a']:
assert m in dir(mod)


Expand Down
6 changes: 3 additions & 3 deletions cirq-core/cirq/testing/_compat_test_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
See cirq/_compat_test.py for the tests.
This module contains example deprecations for modules.
"""
import sys
from logging import info
import logging

from cirq import _compat

info("init:compat_test_data")
logging.info("init:compat_test_data")

# simulates a rename of a child module
# fake_a -> module_a
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/testing/_compat_test_data/module_a/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=wrong-or-nonexistent-copyright-notice
"""module_a for module deprecation tests"""

from logging import info
import logging

from cirq.testing._compat_test_data.module_a import module_b

Expand All @@ -11,4 +11,4 @@

MODULE_A_ATTRIBUTE = "module_a"

info("init:module_a")
logging.info("init:module_a")

0 comments on commit 7578110

Please sign in to comment.