Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test inside pytest_runtest_call hook #36

Merged
merged 4 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 118 additions & 74 deletions pytest_arraydiff/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,37 @@ def pytest_configure(config):
reference_dir=reference_dir,
generate_dir=generate_dir,
default_format=default_format))
else:
config.pluginmanager.register(ArrayInterceptor(config))


def generate_test_name(item):
"""
Generate a unique name for this test.
"""
if item.cls is not None:
name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}"
else:
name = f"{item.module.__name__}.{item.name}"
return name


def wrap_array_interceptor(plugin, item):
"""
Intercept and store arrays returned by test functions.
"""
# Only intercept array on marked array tests
if item.get_closest_marker('array_compare') is not None:

# Use the full test name as a key to ensure correct array is being retrieved
test_name = generate_test_name(item)

def array_interceptor(store, obj):
def wrapper(*args, **kwargs):
store.return_value[test_name] = obj(*args, **kwargs)
return wrapper

item.obj = array_interceptor(plugin, item.obj)


class ArrayComparison(object):
Expand All @@ -230,12 +261,15 @@ def __init__(self, config, reference_dir=None, generate_dir=None, default_format
self.reference_dir = reference_dir
self.generate_dir = generate_dir
self.default_format = default_format
self.return_value = {}

def pytest_runtest_setup(self, item):
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):

compare = item.get_closest_marker('array_compare')

if compare is None:
yield
return

file_format = compare.kwargs.get('file_format', self.default_format)
Expand All @@ -255,85 +289,95 @@ def pytest_runtest_setup(self, item):

write_kwargs = compare.kwargs.get('write_kwargs', {})

original = item.function
reference_dir = compare.kwargs.get('reference_dir', None)
if reference_dir is None:
if self.reference_dir is None:
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
else:
reference_dir = self.reference_dir
else:
if not reference_dir.startswith(('http://', 'https://')):
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)

@wraps(item.function)
def item_function_wrapper(*args, **kwargs):
baseline_remote = reference_dir.startswith('http')

reference_dir = compare.kwargs.get('reference_dir', None)
if reference_dir is None:
if self.reference_dir is None:
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
else:
reference_dir = self.reference_dir
else:
if not reference_dir.startswith(('http://', 'https://')):
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)

baseline_remote = reference_dir.startswith('http')

# Run test and get figure object
import inspect
if inspect.ismethod(original): # method
array = original(*args[1:], **kwargs)
else: # function
array = original(*args, **kwargs)

# Find test name to use as plot name
filename = compare.kwargs.get('filename', None)
if filename is None:
if single_reference:
filename = original.__name__ + '.' + extension
else:
filename = item.name + '.' + extension
filename = filename.replace('[', '_').replace(']', '_')
filename = filename.replace('_.' + extension, '.' + extension)

# What we do now depends on whether we are generating the reference
# files or simply running the test.
if self.generate_dir is None:

# Save the figure
result_dir = tempfile.mkdtemp()
test_array = os.path.abspath(os.path.join(result_dir, filename))

FORMATS[file_format].write(test_array, array, **write_kwargs)

# Find path to baseline array
if baseline_remote:
baseline_file_ref = _download_file(reference_dir + filename)
else:
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))

if not os.path.exists(baseline_file_ref):
raise Exception("""File not found for comparison test
Generated file:
\t{test}
This is expected for new tests.""".format(
test=test_array))

# setuptools may put the baseline arrays in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
shutil.copyfile(baseline_file_ref, baseline_file)

identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)

if identical:
shutil.rmtree(result_dir)
else:
raise Exception(msg)
# Run test and get array object
wrap_array_interceptor(self, item)
yield
test_name = generate_test_name(item)
if test_name not in self.return_value:
# Test function did not complete successfully
return
array = self.return_value[test_name]

# Find test name to use as plot name
filename = compare.kwargs.get('filename', None)
if filename is None:
filename = item.name + '.' + extension
if not single_reference:
filename = filename.replace('[', '_').replace(']', '_')
filename = filename.replace('_.' + extension, '.' + extension)

# What we do now depends on whether we are generating the reference
# files or simply running the test.
if self.generate_dir is None:

# Save the figure
result_dir = tempfile.mkdtemp()
test_array = os.path.abspath(os.path.join(result_dir, filename))

FORMATS[file_format].write(test_array, array, **write_kwargs)

# Find path to baseline array
if baseline_remote:
baseline_file_ref = _download_file(reference_dir + filename)
else:
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))

if not os.path.exists(baseline_file_ref):
raise Exception("""File not found for comparison test
Generated file:
\t{test}
This is expected for new tests.""".format(
test=test_array))

if not os.path.exists(self.generate_dir):
os.makedirs(self.generate_dir)
# setuptools may put the baseline arrays in non-accessible places,
# copy to our tmpdir to be sure to keep them in case of failure
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
shutil.copyfile(baseline_file_ref, baseline_file)

FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)

pytest.skip("Skipping test, since generating data")
if identical:
shutil.rmtree(result_dir)
else:
raise Exception(msg)

if item.cls is not None:
setattr(item.cls, item.function.__name__, item_function_wrapper)
else:
item.obj = item_function_wrapper

if not os.path.exists(self.generate_dir):
os.makedirs(self.generate_dir)

FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)

pytest.skip("Skipping test, since generating data")


class ArrayInterceptor:
"""
This is used in place of ArrayComparison when the array comparison option is not used,
to make sure that we still intercept arrays returned by tests.
"""

def __init__(self, config):
self.config = config
self.return_value = {}

@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):

if item.get_closest_marker('array_compare') is not None:
wrap_array_interceptor(self, item)

yield
return
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ testpaths = tests
xfail_strict = true
markers =
array_compare: for functions using array comparison
filterwarnings =
error
# Can be removed when min Python is >=3.8
ignore:distutils Version classes are deprecated
ConorMacBride marked this conversation as resolved.
Show resolved Hide resolved

[flake8]
max-line-length = 150