Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize check figures equal to work with pytest.marks #600

Merged
merged 5 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,8 @@ Here's an example:
@check_figures_equal()
def test_my_plotting_case():
"Test that my plotting function works"
fig_ref = Figure()
fig_ref, fig_test = Figure(), Figure()
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
fig_test = Figure()
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")
return fig_ref, fig_test
```
Expand Down
42 changes: 31 additions & 11 deletions pygmt/helpers/testing.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""
Helper functions for testing.
"""

import inspect
import os
import string

from matplotlib.testing.compare import compare_images

from ..exceptions import GMTImageComparisonFailure


def check_figures_equal(*, tol=0.0, result_dir="result_images"):
def check_figures_equal(*, extensions=("png",), tol=0.0, result_dir="result_images"):
"""
Decorator for test cases that generate and compare two figures.

The decorated function must take two arguments, *fig_ref* and *fig_test*,
and draw the reference and test images on them. After the function
returns, the figures are saved and compared.
The decorated function must return two arguments, *fig_ref* and *fig_test*,
these two figures will then be saved and compared against each other.

This decorator is practically identical to matplotlib's check_figures_equal
function, but adapted for PyGMT figures. See also the original code at
Expand All @@ -25,6 +23,8 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):

Parameters
----------
extensions : list
The extensions to test. Default is ["png"].
tol : float
The RMS threshold above which the test is considered failed.
result_dir : str
Expand Down Expand Up @@ -66,19 +66,30 @@ def check_figures_equal(*, tol=0.0, result_dir="result_images"):
... )
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
"""
# pylint: disable=invalid-name
ALLOWED_CHARS = set(string.digits + string.ascii_letters + "_-[]()")
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY

def decorator(func):
import pytest

os.makedirs(result_dir, exist_ok=True)
old_sig = inspect.signature(func)

def wrapper(*args, **kwargs):
@pytest.mark.parametrize("ext", extensions)
def wrapper(*args, ext="png", request=None, **kwargs):
if "ext" in old_sig.parameters:
kwargs["ext"] = ext
if "request" in old_sig.parameters:
kwargs["request"] = request
try:
file_name = "".join(c for c in request.node.name if c in ALLOWED_CHARS)
except AttributeError: # 'NoneType' object has no attribute 'node'
file_name = func.__name__
try:
fig_ref, fig_test = func(*args, **kwargs)
ref_image_path = os.path.join(
result_dir, func.__name__ + "-expected.png"
)
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
ref_image_path = os.path.join(result_dir, f"{file_name}-expected.{ext}")
test_image_path = os.path.join(result_dir, f"{file_name}.{ext}")
fig_ref.savefig(ref_image_path)
fig_test.savefig(test_image_path)

Expand Down Expand Up @@ -109,9 +120,18 @@ def wrapper(*args, **kwargs):
for param in old_sig.parameters.values()
if param.name not in {"fig_test", "fig_ref"}
]
if "ext" not in old_sig.parameters:
parameters += [inspect.Parameter("ext", KEYWORD_ONLY)]
if "request" not in old_sig.parameters:
parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
new_sig = old_sig.replace(parameters=parameters)
wrapper.__signature__ = new_sig

# reach a bit into pytest internals to hoist the marks from
# our wrapped function
new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark
wrapper.pytestmark = new_marks

return wrapper

return decorator