Skip to content

Commit

Permalink
Merge pull request #36 from ConorMacBride/update-pytest-integration
Browse files Browse the repository at this point in the history
Test inside `pytest_runtest_call` hook
  • Loading branch information
astrofrog authored Sep 11, 2023
2 parents 1af2acd + 9ffd246 commit e006ea1
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 74 deletions.
192 changes: 118 additions & 74 deletions pytest_arraydiff/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,37 @@ def pytest_configure(config):
reference_dir=reference_dir,
generate_dir=generate_dir,
default_format=default_format))
else:
config.pluginmanager.register(ArrayInterceptor(config))


def generate_test_name(item):
"""
Generate a unique name for this test.
"""
if item.cls is not None:
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
else:
name = f"{item.module.__name__}.{item.name}"
return name


def wrap_array_interceptor(plugin, item):
"""
Intercept and store arrays returned by test functions.
"""
# Only intercept array on marked array tests
if item.get_closest_marker('array_compare') is not None:

# Use the full test name as a key to ensure correct array is being retrieved
test_name = generate_test_name(item)

def array_interceptor(store, obj):
def wrapper(*args, **kwargs):
store.return_value[test_name] = obj(*args, **kwargs)
return wrapper

item.obj = array_interceptor(plugin, item.obj)


class ArrayComparison(object):
Expand All @@ -230,12 +261,15 @@ def __init__(self, config, reference_dir=None, generate_dir=None, default_format
self.reference_dir = reference_dir
self.generate_dir = generate_dir
self.default_format = default_format
self.return_value = {}

def pytest_runtest_setup(self, item):
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):

compare = item.get_closest_marker('array_compare')

if compare is None:
yield
return

file_format = compare.kwargs.get('file_format', self.default_format)
Expand All @@ -255,85 +289,95 @@ def pytest_runtest_setup(self, item):

write_kwargs = compare.kwargs.get('write_kwargs', {})

original = item.function
reference_dir = compare.kwargs.get('reference_dir', None)
if reference_dir is None:
if self.reference_dir is None:
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
else:
reference_dir = self.reference_dir
else:
if not reference_dir.startswith(('http://', 'https://')):
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)

@wraps(item.function)
def item_function_wrapper(*args, **kwargs):
baseline_remote = reference_dir.startswith('http')

reference_dir = compare.kwargs.get('reference_dir', None)
if reference_dir is None:
if self.reference_dir is None:
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
else:
reference_dir = self.reference_dir
else:
if not reference_dir.startswith(('http://', 'https://')):
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)

baseline_remote = reference_dir.startswith('http')

# Run test and get figure object
import inspect
if inspect.ismethod(original): # method
array = original(*args[1:], **kwargs)
else: # function
array = original(*args, **kwargs)

# Find test name to use as plot name
filename = compare.kwargs.get('filename', None)
if filename is None:
if single_reference:
filename = original.__name__ + '.' + extension
else:
filename = item.name + '.' + extension
filename = filename.replace('[', '_').replace(']', '_')
filename = filename.replace('_.' + extension, '.' + extension)

# What we do now depends on whether we are generating the reference
# files or simply running the test.
if self.generate_dir is None:

# Save the figure
result_dir = tempfile.mkdtemp()
test_array = os.path.abspath(os.path.join(result_dir, filename))

FORMATS[file_format].write(test_array, array, **write_kwargs)

# Find path to baseline array
if baseline_remote:
baseline_file_ref = _download_file(reference_dir + filename)
else:
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))

if not os.path.exists(baseline_file_ref):
raise Exception("""File not found for comparison test
Generated file:
\t{test}
This is expected for new tests.""".format(
test=test_array))

# setuptools may put the baseline arrays in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
shutil.copyfile(baseline_file_ref, baseline_file)

identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)

if identical:
shutil.rmtree(result_dir)
else:
raise Exception(msg)
# Run test and get array object
wrap_array_interceptor(self, item)
yield
test_name = generate_test_name(item)
if test_name not in self.return_value:
# Test function did not complete successfully
return
array = self.return_value[test_name]

# Find test name to use as plot name
filename = compare.kwargs.get('filename', None)
if filename is None:
filename = item.name + '.' + extension
if not single_reference:
filename = filename.replace('[', '_').replace(']', '_')
filename = filename.replace('_.' + extension, '.' + extension)

# What we do now depends on whether we are generating the reference
# files or simply running the test.
if self.generate_dir is None:

# Save the figure
result_dir = tempfile.mkdtemp()
test_array = os.path.abspath(os.path.join(result_dir, filename))

FORMATS[file_format].write(test_array, array, **write_kwargs)

# Find path to baseline array
if baseline_remote:
baseline_file_ref = _download_file(reference_dir + filename)
else:
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))

if not os.path.exists(baseline_file_ref):
raise Exception("""File not found for comparison test
Generated file:
\t{test}
This is expected for new tests.""".format(
test=test_array))

if not os.path.exists(self.generate_dir):
os.makedirs(self.generate_dir)
# setuptools may put the baseline arrays in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
shutil.copyfile(baseline_file_ref, baseline_file)

FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)

pytest.skip("Skipping test, since generating data")
if identical:
shutil.rmtree(result_dir)
else:
raise Exception(msg)

if item.cls is not None:
setattr(item.cls, item.function.__name__, item_function_wrapper)
else:
item.obj = item_function_wrapper

if not os.path.exists(self.generate_dir):
os.makedirs(self.generate_dir)

FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)

pytest.skip("Skipping test, since generating data")


class ArrayInterceptor:
"""
This is used in place of ArrayComparison when the array comparison option is not used,
to make sure that we still intercept arrays returned by tests.
"""

def __init__(self, config):
self.config = config
self.return_value = {}

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):

if item.get_closest_marker('array_compare') is not None:
wrap_array_interceptor(self, item)

yield
return
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ testpaths = tests
xfail_strict = true
markers =
array_compare: for functions using array comparison
filterwarnings =
error
# Can be removed when min Python is >=3.8
ignore:distutils Version classes are deprecated

[flake8]
max-line-length = 150

0 comments on commit e006ea1

Please sign in to comment.