Skip to content

Commit

Permalink
Figure.savefig: Add a new test for the show parameter and simplify ex…
Browse files Browse the repository at this point in the history
…isting tests. (#3568)

* Get rid of try-except in test_figure_savefig_geotiff
* Simplify test_figure_savefig_unknown_extension
* Simplify test_figure_savefig with unittest.mock
* Add one more test for Figure.savefig(show=True)
  • Loading branch information
seisman authored Nov 5, 2024
1 parent f41974d commit 3ad94d9
Showing 1 changed file with 30 additions and 66 deletions.
96 changes: 30 additions & 66 deletions pygmt/tests/test_figure.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
from pygmt.helpers import GMTTempFile

_HAS_IPYTHON = bool(importlib.util.find_spec("IPython"))
_HAS_RIOXARRAY = bool(importlib.util.find_spec("rioxarray"))


def test_figure_region():
@@ -100,7 +101,7 @@ def test_figure_savefig_geotiff():
geofname = Path("test_figure_savefig_geotiff.tiff")
fig.savefig(geofname)
assert geofname.exists()
# The .pgw should not exist
# The .pgw file should not exist
assert not geofname.with_suffix(".pgw").exists()

# Save as TIFF
@@ -109,7 +110,7 @@ def test_figure_savefig_geotiff():
assert fname.exists()

# Check if a TIFF is georeferenced or not
try:
if _HAS_RIOXARRAY:
import rioxarray
from rasterio.errors import NotGeoreferencedWarning
from rasterio.transform import Affine
@@ -147,8 +148,6 @@ def test_figure_savefig_geotiff():
a=1.0, b=0.0, c=0.0, d=0.0, e=1.0, f=0.0
)
assert len(record) == 1
except ImportError:
pass
geofname.unlink()
fname.unlink()

@@ -170,9 +169,7 @@ def test_figure_savefig_unknown_extension():
"""
fig = Figure()
fig.basemap(region="10/70/-300/800", projection="X3i/5i", frame="af")
prefix = "test_figure_savefig_unknown_extension"
fmt = "test"
fname = f"{prefix}.{fmt}"
fname = "test_figure_savefig_unknown_extension.test"
with pytest.raises(GMTInvalidInput, match="Unknown extension '.test'."):
fig.savefig(fname)

@@ -223,69 +220,23 @@ def test_figure_savefig():
"""
Check if the arguments being passed to psconvert are correct.
"""
kwargs_saved = []

def mock_psconvert(*args, **kwargs): # noqa: ARG001
"""
Just record the arguments.
"""
kwargs_saved.append(kwargs)

fig = Figure()
fig.psconvert = mock_psconvert

prefix = "test_figure_savefig"

fname = f"{prefix}.png"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "g",
"crop": True,
"Qt": 2,
"Qg": 2,
}

fname = f"{prefix}.pdf"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "f",
"crop": True,
"Qt": 2,
"Qg": 2,
common_kwargs = {"prefix": prefix, "crop": True, "Qt": 2, "Qg": 2}
expected_kwargs = {
"png": {"fmt": "g", **common_kwargs},
"pdf": {"fmt": "f", **common_kwargs},
"eps": {"fmt": "e", **common_kwargs},
"kml": {"fmt": "g", "W": "+k", **common_kwargs},
}

fname = f"{prefix}.png"
fig.savefig(fname, transparent=True)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "G",
"crop": True,
"Qt": 2,
"Qg": 2,
}

fname = f"{prefix}.eps"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "e",
"crop": True,
"Qt": 2,
"Qg": 2,
}
with patch.object(Figure, "psconvert") as mock_psconvert:
fig = Figure()
for fmt, expected in expected_kwargs.items():
fig.savefig(f"{prefix}.{fmt}")
mock_psconvert.assert_called_with(**expected)

fname = f"{prefix}.kml"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "g",
"crop": True,
"Qt": 2,
"Qg": 2,
"W": "+k",
}
fig.savefig(f"{prefix}.png", transparent=True)
mock_psconvert.assert_called_with(fmt="G", **common_kwargs)


def test_figure_savefig_worldfile():
@@ -309,6 +260,19 @@ def test_figure_savefig_worldfile():
fig.savefig(fname=imgfile.name, worldfile=True)


def test_figure_savefig_show():
"""
Check if the external viewer is launched when the show parameter is specified.
"""
fig = Figure()
fig.basemap(region=[0, 1, 0, 1], projection="X1c/1c", frame=True)
prefix = "test_figure_savefig_show"
with patch("pygmt.figure.launch_external_viewer") as mock_viewer:
with GMTTempFile(prefix=prefix, suffix=".png") as imgfile:
fig.savefig(imgfile.name, show=True)
assert mock_viewer.call_count == 1


@pytest.mark.skipif(not _HAS_IPYTHON, reason="run when IPython is installed")
def test_figure_show():
"""

0 comments on commit 3ad94d9

Please sign in to comment.