Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Figure.savefig: Add a new test for the show parameter and simplify existing tests. #3568

Merged
merged 6 commits into from
Nov 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 30 additions & 66 deletions pygmt/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pygmt.helpers import GMTTempFile

_HAS_IPYTHON = bool(importlib.util.find_spec("IPython"))
_HAS_RIOXARRAY = bool(importlib.util.find_spec("rioxarray"))


def test_figure_region():
Expand Down Expand Up @@ -100,7 +101,7 @@ def test_figure_savefig_geotiff():
geofname = Path("test_figure_savefig_geotiff.tiff")
fig.savefig(geofname)
assert geofname.exists()
# The .pgw should not exist
# The .pgw file should not exist
assert not geofname.with_suffix(".pgw").exists()

# Save as TIFF
Expand All @@ -109,7 +110,7 @@ def test_figure_savefig_geotiff():
assert fname.exists()

# Check if a TIFF is georeferenced or not
try:
if _HAS_RIOXARRAY:
import rioxarray
from rasterio.errors import NotGeoreferencedWarning
from rasterio.transform import Affine
Expand Down Expand Up @@ -147,8 +148,6 @@ def test_figure_savefig_geotiff():
a=1.0, b=0.0, c=0.0, d=0.0, e=1.0, f=0.0
)
assert len(record) == 1
except ImportError:
pass
geofname.unlink()
fname.unlink()

Expand All @@ -170,9 +169,7 @@ def test_figure_savefig_unknown_extension():
"""
fig = Figure()
fig.basemap(region="10/70/-300/800", projection="X3i/5i", frame="af")
prefix = "test_figure_savefig_unknown_extension"
fmt = "test"
fname = f"{prefix}.{fmt}"
fname = "test_figure_savefig_unknown_extension.test"
with pytest.raises(GMTInvalidInput, match="Unknown extension '.test'."):
fig.savefig(fname)

Expand Down Expand Up @@ -223,69 +220,23 @@ def test_figure_savefig():
"""
Check if the arguments being passed to psconvert are correct.
"""
kwargs_saved = []

def mock_psconvert(*args, **kwargs): # noqa: ARG001
"""
Just record the arguments.
"""
kwargs_saved.append(kwargs)

fig = Figure()
fig.psconvert = mock_psconvert

prefix = "test_figure_savefig"

fname = f"{prefix}.png"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "g",
"crop": True,
"Qt": 2,
"Qg": 2,
}

fname = f"{prefix}.pdf"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "f",
"crop": True,
"Qt": 2,
"Qg": 2,
common_kwargs = {"prefix": prefix, "crop": True, "Qt": 2, "Qg": 2}
expected_kwargs = {
"png": {"fmt": "g", **common_kwargs},
"pdf": {"fmt": "f", **common_kwargs},
"eps": {"fmt": "e", **common_kwargs},
"kml": {"fmt": "g", "W": "+k", **common_kwargs},
}

fname = f"{prefix}.png"
fig.savefig(fname, transparent=True)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "G",
"crop": True,
"Qt": 2,
"Qg": 2,
}

fname = f"{prefix}.eps"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "e",
"crop": True,
"Qt": 2,
"Qg": 2,
}
with patch.object(Figure, "psconvert") as mock_psconvert:
fig = Figure()
for fmt, expected in expected_kwargs.items():
fig.savefig(f"{prefix}.{fmt}")
mock_psconvert.assert_called_with(**expected)

fname = f"{prefix}.kml"
fig.savefig(fname)
assert kwargs_saved[-1] == {
"prefix": prefix,
"fmt": "g",
"crop": True,
"Qt": 2,
"Qg": 2,
"W": "+k",
}
fig.savefig(f"{prefix}.png", transparent=True)
mock_psconvert.assert_called_with(fmt="G", **common_kwargs)


def test_figure_savefig_worldfile():
Expand All @@ -309,6 +260,19 @@ def test_figure_savefig_worldfile():
fig.savefig(fname=imgfile.name, worldfile=True)


def test_figure_savefig_show():
"""
Check if the external viewer is launched when the show parameter is specified.
"""
fig = Figure()
fig.basemap(region=[0, 1, 0, 1], projection="X1c/1c", frame=True)
prefix = "test_figure_savefig_show"
with patch("pygmt.figure.launch_external_viewer") as mock_viewer:
with GMTTempFile(prefix=prefix, suffix=".png") as imgfile:
fig.savefig(imgfile.name, show=True)
assert mock_viewer.call_count == 1


@pytest.mark.skipif(not _HAS_IPYTHON, reason="run when IPython is installed")
def test_figure_show():
"""
Expand Down