diff --git a/pytest_arraydiff/plugin.py b/pytest_arraydiff/plugin.py index da78dfb..bb7c07e 100755 --- a/pytest_arraydiff/plugin.py +++ b/pytest_arraydiff/plugin.py @@ -221,6 +221,37 @@ def pytest_configure(config): reference_dir=reference_dir, generate_dir=generate_dir, default_format=default_format)) + else: + config.pluginmanager.register(ArrayInterceptor(config)) + + +def generate_test_name(item): + """ + Generate a unique name for this test. + """ + if item.cls is not None: + name = f"{item.module.__name__}.{item.cls.__name__}.{item.name}" + else: + name = f"{item.module.__name__}.{item.name}" + return name + + +def wrap_array_interceptor(plugin, item): + """ + Intercept and store arrays returned by test functions. + """ + # Only intercept array on marked array tests + if item.get_closest_marker('array_compare') is not None: + + # Use the full test name as a key to ensure correct array is being retrieved + test_name = generate_test_name(item) + + def array_interceptor(store, obj): + def wrapper(*args, **kwargs): + store.return_value[test_name] = obj(*args, **kwargs) + return wrapper + + item.obj = array_interceptor(plugin, item.obj) class ArrayComparison(object): @@ -230,12 +261,15 @@ def __init__(self, config, reference_dir=None, generate_dir=None, default_format self.reference_dir = reference_dir self.generate_dir = generate_dir self.default_format = default_format + self.return_value = {} - def pytest_runtest_setup(self, item): + @pytest.hookimpl(hookwrapper=True) + def pytest_runtest_call(self, item): compare = item.get_closest_marker('array_compare') if compare is None: + yield return file_format = compare.kwargs.get('file_format', self.default_format) @@ -255,85 +289,95 @@ def pytest_runtest_setup(self, item): write_kwargs = compare.kwargs.get('write_kwargs', {}) - original = item.function + reference_dir = compare.kwargs.get('reference_dir', None) + if reference_dir is None: + if self.reference_dir is None: + reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference') + else: + reference_dir = self.reference_dir + else: + if not reference_dir.startswith(('http://', 'https://')): + reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir) - @wraps(item.function) - def item_function_wrapper(*args, **kwargs): + baseline_remote = reference_dir.startswith('http') - reference_dir = compare.kwargs.get('reference_dir', None) - if reference_dir is None: - if self.reference_dir is None: - reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference') - else: - reference_dir = self.reference_dir - else: - if not reference_dir.startswith(('http://', 'https://')): - reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir) - - baseline_remote = reference_dir.startswith('http') - - # Run test and get figure object - import inspect - if inspect.ismethod(original): # method - array = original(*args[1:], **kwargs) - else: # function - array = original(*args, **kwargs) - - # Find test name to use as plot name - filename = compare.kwargs.get('filename', None) - if filename is None: - if single_reference: - filename = original.__name__ + '.' + extension - else: - filename = item.name + '.' + extension - filename = filename.replace('[', '_').replace(']', '_') - filename = filename.replace('_.' + extension, '.' + extension) - - # What we do now depends on whether we are generating the reference - # files or simply running the test. - if self.generate_dir is None: - - # Save the figure - result_dir = tempfile.mkdtemp() - test_array = os.path.abspath(os.path.join(result_dir, filename)) - - FORMATS[file_format].write(test_array, array, **write_kwargs) - - # Find path to baseline array - if baseline_remote: - baseline_file_ref = _download_file(reference_dir + filename) - else: - baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename)) - - if not os.path.exists(baseline_file_ref): - raise Exception("""File not found for comparison test - Generated file: - \t{test} - This is expected for new tests.""".format( - test=test_array)) - - # setuptools may put the baseline arrays in non-accessible places, - # copy to our tmpdir to be sure to keep them in case of failure - baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename)) - shutil.copyfile(baseline_file_ref, baseline_file) - - identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol) - - if identical: - shutil.rmtree(result_dir) - else: - raise Exception(msg) + # Run test and get array object + wrap_array_interceptor(self, item) + yield + test_name = generate_test_name(item) + if test_name not in self.return_value: + # Test function did not complete successfully + return + array = self.return_value[test_name] + + # Find test name to use as plot name + filename = compare.kwargs.get('filename', None) + if filename is None: + filename = item.name + '.' + extension + if not single_reference: + filename = filename.replace('[', '_').replace(']', '_') + filename = filename.replace('_.' + extension, '.' + extension) + + # What we do now depends on whether we are generating the reference + # files or simply running the test. + if self.generate_dir is None: + + # Save the figure + result_dir = tempfile.mkdtemp() + test_array = os.path.abspath(os.path.join(result_dir, filename)) + FORMATS[file_format].write(test_array, array, **write_kwargs) + + # Find path to baseline array + if baseline_remote: + baseline_file_ref = _download_file(reference_dir + filename) else: + baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename)) + + if not os.path.exists(baseline_file_ref): + raise Exception("""File not found for comparison test + Generated file: + \t{test} + This is expected for new tests.""".format( + test=test_array)) - if not os.path.exists(self.generate_dir): - os.makedirs(self.generate_dir) + # setuptools may put the baseline arrays in non-accessible places, + # copy to our tmpdir to be sure to keep them in case of failure + baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename)) + shutil.copyfile(baseline_file_ref, baseline_file) - FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs) + identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol) - pytest.skip("Skipping test, since generating data") + if identical: + shutil.rmtree(result_dir) + else: + raise Exception(msg) - if item.cls is not None: - setattr(item.cls, item.function.__name__, item_function_wrapper) else: - item.obj = item_function_wrapper + + if not os.path.exists(self.generate_dir): + os.makedirs(self.generate_dir) + + FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs) + + pytest.skip("Skipping test, since generating data") + + +class ArrayInterceptor: + """ + This is used in place of ArrayComparison when the array comparison option is not used, + to make sure that we still intercept arrays returned by tests. + """ + + def __init__(self, config): + self.config = config + self.return_value = {} + + @pytest.hookimpl(hookwrapper=True) + def pytest_runtest_call(self, item): + + if item.get_closest_marker('array_compare') is not None: + wrap_array_interceptor(self, item) + + yield + return diff --git a/setup.cfg b/setup.cfg index 674af9c..5bc3a65 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,6 +50,10 @@ testpaths = tests xfail_strict = true markers = array_compare: for functions using array comparison +filterwarnings = + error + # Can be removed when min Python is >=3.8 + ignore:distutils Version classes are deprecated [flake8] max-line-length = 150