diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9ccb25bc1c1..f4e54db00cf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -310,8 +310,38 @@ Leave a comment in the PR and we'll help you out. ### Testing plots -We use the [pytest-mpl](https://github.com/matplotlib/pytest-mpl) plug-in to test plot -generating code. +Writing an image-based test is only slightly more difficult than a simple test. +The main consideration is that you must specify the "baseline" or reference +image, and compare it with a "generated" or test image. This is handled using +the *decorator* functions `@check_figures_equal` and +`@pytest.mark.mpl_image_compare` whose usage are further described below. + +#### Using check_figures_equal + +This approach draws the same figure using two different methods (the reference +method and the tested method), and checks that both of them are the same. +It takes two `pygmt.Figure` objects ('fig_ref' and 'fig_test'), generates a png +image, and checks for the Root Mean Square (RMS) error between the two. +Here's an example: + +```python +@check_figures_equal() +def test_my_plotting_case(fig_ref, fig_test): + "Test that my plotting function works" + fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo") + fig_test.grdimage(grid, projection="W120/15c", cmap="geo") +``` + +Note: This is the recommended way to test plots whenever possible, such as when +we want to compare a reference GMT plot created from NetCDF files with one +generated by PyGMT that passes through several layers of virtualfile machinery. +Using this method will help save space in the git repository by not having to +store baseline images as with the other method below. + +#### Using mpl_image_compare + +This method uses the [pytest-mpl](https://github.com/matplotlib/pytest-mpl) +plug-in to test plot generating code. Every time the tests are run, `pytest-mpl` compares the generated plots with known correct ones stored in `pygmt/tests/baseline`. If your test created a `pygmt.Figure` object, you can test it by adding a *decorator* and diff --git a/pygmt/exceptions.py b/pygmt/exceptions.py index d5b2c9584ef..6b72b8cb919 100644 --- a/pygmt/exceptions.py +++ b/pygmt/exceptions.py @@ -44,3 +44,9 @@ class GMTVersionError(GMTError): """ Raised when an incompatible version of GMT is being used. """ + + +class GMTImageComparisonFailure(AssertionError): + """ + Raised when a comparison between two images fails. + """ diff --git a/pygmt/helpers/testing.py b/pygmt/helpers/testing.py new file mode 100644 index 00000000000..889e7f61efd --- /dev/null +++ b/pygmt/helpers/testing.py @@ -0,0 +1,113 @@ +""" +Helper functions for testing. +""" + +import inspect +import os + +from matplotlib.testing.compare import compare_images + +from ..exceptions import GMTImageComparisonFailure +from ..figure import Figure + + +def check_figures_equal(*, tol=0.0, result_dir="result_images"): + """ + Decorator for test cases that generate and compare two figures. + + The decorated function must take two arguments, *fig_ref* and *fig_test*, + and draw the reference and test images on them. After the function + returns, the figures are saved and compared. + + This decorator is practically identical to matplotlib's check_figures_equal + function, but adapted for PyGMT figures. See also the original code at + https://matplotlib.org/3.3.1/api/testing_api.html# + matplotlib.testing.decorators.check_figures_equal + + Parameters + ---------- + tol : float + The RMS threshold above which the test is considered failed. + result_dir : str + The directory where the figures will be stored. + + Examples + -------- + + >>> import pytest + >>> import shutil + + >>> @check_figures_equal(result_dir="tmp_result_images") + ... def test_check_figures_equal(fig_ref, fig_test): + ... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True) + ... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af") + >>> test_check_figures_equal() + >>> assert len(os.listdir("tmp_result_images")) == 0 + >>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass + + >>> @check_figures_equal(result_dir="tmp_result_images") + ... def test_check_figures_unequal(fig_ref, fig_test): + ... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True) + ... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True) + >>> with pytest.raises(GMTImageComparisonFailure): + ... test_check_figures_unequal() + >>> for suffix in ["", "-expected", "-failed-diff"]: + ... assert os.path.exists( + ... os.path.join( + ... "tmp_result_images", + ... f"test_check_figures_unequal{suffix}.png", + ... ) + ... ) + >>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass + """ + + def decorator(func): + + os.makedirs(result_dir, exist_ok=True) + old_sig = inspect.signature(func) + + def wrapper(*args, **kwargs): + try: + fig_ref = Figure() + fig_test = Figure() + func(*args, fig_ref=fig_ref, fig_test=fig_test, **kwargs) + ref_image_path = os.path.join( + result_dir, func.__name__ + "-expected.png" + ) + test_image_path = os.path.join(result_dir, func.__name__ + ".png") + fig_ref.savefig(ref_image_path) + fig_test.savefig(test_image_path) + + # Code below is adapted for PyGMT, and is originally based on + # matplotlib.testing.decorators._raise_on_image_difference + err = compare_images( + expected=ref_image_path, + actual=test_image_path, + tol=tol, + in_decorator=True, + ) + if err is None: # Images are the same + os.remove(ref_image_path) + os.remove(test_image_path) + else: # Images are not the same + for key in ["actual", "expected", "diff"]: + err[key] = os.path.relpath(err[key]) + raise GMTImageComparisonFailure( + "images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s " + % err + ) + finally: + del fig_ref + del fig_test + + parameters = [ + param + for param in old_sig.parameters.values() + if param.name not in {"fig_test", "fig_ref"} + ] + new_sig = old_sig.replace(parameters=parameters) + wrapper.__signature__ = new_sig + + return wrapper + + return decorator diff --git a/pygmt/tests/test_grdimage.py b/pygmt/tests/test_grdimage.py index 45be76f6d01..37b5fca822a 100644 --- a/pygmt/tests/test_grdimage.py +++ b/pygmt/tests/test_grdimage.py @@ -2,12 +2,13 @@ Test Figure.grdimage """ import numpy as np -import xarray as xr import pytest +import xarray as xr from .. import Figure -from ..exceptions import GMTInvalidInput from ..datasets import load_earth_relief +from ..exceptions import GMTInvalidInput +from ..helpers.testing import check_figures_equal @pytest.fixture(scope="module", name="grid") @@ -93,3 +94,12 @@ def test_grdimage_over_dateline(xrgrid): xrgrid.gmt.gtype = 1 # geographic coordinate system fig.grdimage(grid=xrgrid, region="g", projection="A0/0/1c", V="i") return fig + + +@check_figures_equal() +def test_grdimage_central_longitude(grid, fig_ref, fig_test): + """ + Test that plotting a grid centred at different longitudes/meridians work. + """ + fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo") + fig_test.grdimage(grid, projection="W120/15c", cmap="geo")