diff --git a/pygmt/tests/test_figure.py b/pygmt/tests/test_figure.py index aadaaacfad3..8c1c03dcbe6 100644 --- a/pygmt/tests/test_figure.py +++ b/pygmt/tests/test_figure.py @@ -11,7 +11,7 @@ import pytest from pygmt import Figure, set_display from pygmt.exceptions import GMTError, GMTInvalidInput -from pygmt.figure import _get_default_display_method +from pygmt.figure import SHOW_CONFIG, _get_default_display_method from pygmt.helpers import GMTTempFile try: @@ -373,12 +373,28 @@ def test_figure_display_external(): fig.show(method="external") -def test_figure_set_display_invalid(): +class TestSetDisplay: """ - Test to check if an error is raised when an invalid method is passed to set_display. + Test the pygmt.set_display method. """ - with pytest.raises(GMTInvalidInput): - set_display(method="invalid") + + def test_set_display(self): + """ + Test pygmt.set_display. + """ + current_method = SHOW_CONFIG["method"] + for method in ("notebook", "external", "none"): + set_display(method=method) + assert SHOW_CONFIG["method"] == method + set_display(method=None) + assert SHOW_CONFIG["method"] == current_method + + def test_invalid_method(self): + """ + Test if an error is raised when an invalid method is passed. + """ + with pytest.raises(GMTInvalidInput): + set_display(method="invalid") def test_figure_unsupported_xshift_yshift():