Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing None explicitly to pygmt functions Part 1 #1857

Merged
merged 16 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pygmt/src/grd2cpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def grd2cpt(grid, **kwargs):
``categorical=True``.
{V}
"""
if "W" in kwargs and "Ww" in kwargs:
if kwargs.get("W") is not None and kwargs.get("Ww") is not None:
raise GMTInvalidInput("Set only categorical or cyclic to True, not both.")
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
if "H" not in kwargs: # if no output is set
if kwargs.get("H") is None: # if no output is set
arg_str = build_arg_string(kwargs, infile=infile)
if "H" in kwargs: # if output is set
else: # if output is set
outfile, kwargs["H"] = kwargs["H"], True
if not outfile or not isinstance(outfile, str):
raise GMTInvalidInput("'output' should be a proper file name.")
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/grd2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def grd2xyz(grid, output_type="pandas", outfile=None, **kwargs):
elif outfile is None and output_type == "file":
raise GMTInvalidInput("Must specify 'outfile' for ASCII output.")

if "o" in kwargs and output_type == "pandas":
if kwargs.get("o") is not None and output_type == "pandas":
raise GMTInvalidInput(
"If 'outcols' is specified, 'output_type' must be either 'numpy'"
"or 'file'."
Expand Down
7 changes: 3 additions & 4 deletions pygmt/src/grdgradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def grdgradient(grid, **kwargs):
>>> new_grid = pygmt.grdgradient(grid=grid, azimuth=10)
"""
with GMTTempFile(suffix=".nc") as tmpfile:
if "Q" in kwargs and "N" not in kwargs:
if kwargs.get("Q") is not None and kwargs.get("N") is None:
raise GMTInvalidInput("""Must specify normalize if tiles is specified.""")
if not args_in_kwargs(args=["A", "D", "E"], kwargs=kwargs):
raise GMTInvalidInput(
Expand All @@ -174,9 +174,8 @@ def grdgradient(grid, **kwargs):
with Session() as lib:
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
if "G" not in kwargs: # if outgrid is unset, output to tempfile
kwargs.update({"G": tmpfile.name})
outgrid = kwargs["G"]
if (outgrid := kwargs.get("G")) is None:
kwargs["G"] = outgrid = tmpfile.name # output to tmpfile
lib.call_module("grdgradient", build_arg_string(kwargs, infile=infile))

return load_dataarray(outgrid) if outgrid == tmpfile.name else None
2 changes: 1 addition & 1 deletion pygmt/src/grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def grdimage(self, grid, **kwargs):
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with contextlib.ExitStack() as stack:
# shading using an xr.DataArray
if "I" in kwargs and data_kind(kwargs["I"]) == "grid":
if kwargs.get("I") is not None and data_kind(kwargs["I"]) == "grid":
shading_context = lib.virtualfile_from_grid(kwargs["I"])
kwargs["I"] = stack.enter_context(shading_context)

Expand Down
3 changes: 2 additions & 1 deletion pygmt/src/grdview.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def grdview(self, grid, **kwargs):
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)

with contextlib.ExitStack() as stack:
if "G" in kwargs: # deal with kwargs["G"] if drapegrid is xr.DataArray
if kwargs.get("G") is not None:
# deal with kwargs["G"] if drapegrid is xr.DataArray
drapegrid = kwargs["G"]
if data_kind(drapegrid) in ("file", "grid"):
if data_kind(drapegrid) == "grid":
Expand Down
6 changes: 3 additions & 3 deletions pygmt/src/makecpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ def makecpt(**kwargs):
``categorical=True``.
"""
with Session() as lib:
if "W" in kwargs and "Ww" in kwargs:
if kwargs.get("W") is not None and kwargs.get("Ww") is not None:
raise GMTInvalidInput("Set only categorical or cyclic to True, not both.")
if "H" not in kwargs: # if no output is set
if kwargs.get("H") is None: # if no output is set
arg_str = build_arg_string(kwargs)
elif "H" in kwargs: # if output is set
else: # if output is set
outfile, kwargs["H"] = kwargs.pop("H"), True
if not outfile or not isinstance(outfile, str):
raise GMTInvalidInput("'output' should be a proper file name.")
Expand Down
10 changes: 5 additions & 5 deletions pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,15 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
kind = data_kind(data, x, y)

extra_arrays = []
if "S" in kwargs and kwargs["S"][0] in "vV" and direction is not None:
if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
extra_arrays.extend(direction)
elif (
"S" not in kwargs
kwargs.get("S") is None
and kind == "geojson"
and data.geom_type.isin(["Point", "MultiPoint"]).all()
): # checking if the geometry of a geoDataFrame is Point or MultiPoint
kwargs["S"] = "s0.2c"
elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"):
elif kwargs.get("S") is None and kind == "file" and data.endswith(".gmt"):
# checking that the data is a file path to set default style
try:
with open(which(data), mode="r", encoding="utf8") as file:
Expand All @@ -236,7 +236,7 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
kwargs["S"] = "s0.2c"
except FileNotFoundError:
pass
if "G" in kwargs and is_nonstr_iter(kwargs["G"]):
if kwargs.get("G") is not None and is_nonstr_iter(kwargs["G"]):
if kind != "vectors":
raise GMTInvalidInput(
"Can't use arrays for color if data is matrix or file."
Expand All @@ -251,7 +251,7 @@ def plot(self, data=None, x=None, y=None, size=None, direction=None, **kwargs):
extra_arrays.append(size)

for flag in ["I", "t"]:
if flag in kwargs and is_nonstr_iter(kwargs[flag]):
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
if kind != "vectors":
raise GMTInvalidInput(
f"Can't use arrays for {plot.aliases[flag]} if data is matrix or file."
Expand Down
10 changes: 5 additions & 5 deletions pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,15 @@ def plot3d(
kind = data_kind(data, x, y, z)

extra_arrays = []
if "S" in kwargs and kwargs["S"][0] in "vV" and direction is not None:
if kwargs.get("S") is not None and kwargs["S"][0] in "vV" and direction is not None:
extra_arrays.extend(direction)
elif (
"S" not in kwargs
kwargs.get("S") is None
and kind == "geojson"
and data.geom_type.isin(["Point", "MultiPoint"]).all()
): # checking if the geometry of a geoDataFrame is Point or MultiPoint
kwargs["S"] = "u0.2c"
elif "S" not in kwargs and kind == "file" and data.endswith(".gmt"):
elif kwargs.get("S") is None and kind == "file" and data.endswith(".gmt"):
# checking that the data is a file path to set default style
try:
with open(which(data), mode="r", encoding="utf8") as file:
Expand All @@ -206,7 +206,7 @@ def plot3d(
kwargs["S"] = "u0.2c"
except FileNotFoundError:
pass
if "G" in kwargs and is_nonstr_iter(kwargs["G"]):
if kwargs.get("G") is not None and is_nonstr_iter(kwargs["G"]):
if kind != "vectors":
raise GMTInvalidInput(
"Can't use arrays for color if data is matrix or file."
Expand All @@ -221,7 +221,7 @@ def plot3d(
extra_arrays.append(size)

for flag in ["I", "t"]:
if flag in kwargs and is_nonstr_iter(kwargs[flag]):
if kwargs.get(flag) is not None and is_nonstr_iter(kwargs[flag]):
if kind != "vectors":
raise GMTInvalidInput(
f"Can't use arrays for {plot3d.aliases[flag]} if data is matrix or file."
Expand Down
10 changes: 5 additions & 5 deletions pygmt/src/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
by ``outfile``)
"""

if "C" not in kwargs:
if kwargs.get("C") is None:
raise GMTInvalidInput("The `center` parameter must be specified.")
if "G" not in kwargs and data is None:
if kwargs.get("G") is None and data is None:
raise GMTInvalidInput(
"The `data` parameter must be specified unless `generate` is used."
)
if "G" in kwargs and "F" in kwargs:
if kwargs.get("G") is not None and kwargs.get("F") is not None:
raise GMTInvalidInput(
"The `convention` parameter is not allowed with `generate`."
)
Expand All @@ -225,7 +225,7 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name
with Session() as lib:
if "G" not in kwargs:
if kwargs.get("G") is None:
# Choose how data will be passed into the module
table_context = lib.virtualfile_from_data(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
Expand All @@ -240,7 +240,7 @@ def project(data=None, x=None, y=None, z=None, outfile=None, **kwargs):

# if user did not set outfile, return pd.DataFrame
if outfile == tmpfile.name:
if "G" in kwargs:
if kwargs.get("G") is not None:
column_names = list("rsp")
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
else:
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def solar(self, terminator="d", terminator_datetime=None, **kwargs):
"""

kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
if "T" in kwargs:
if kwargs.get("T") is not None:
raise GMTInvalidInput(
"Use 'terminator' and 'terminator_datetime' instead of 'T'."
)
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def text_(
extra_arrays = []
# If an array of transparency is given, GMT will read it from
# the last numerical column per data record.
if "t" in kwargs and is_nonstr_iter(kwargs["t"]):
if kwargs.get("t") is not None and is_nonstr_iter(kwargs["t"]):
extra_arrays.append(kwargs["t"])
kwargs["t"] = ""

Expand Down
8 changes: 6 additions & 2 deletions pygmt/src/velo.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,12 @@ def velo(self, data=None, **kwargs):
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really related to the PR, but it's something that I found confusing while working on part 2. Why is kwargs passed to figure._preprocess? It doesn't look like it is used at all.

Copy link
Member Author

@weiji14 weiji14 Apr 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, this ensures that velo is plotted to the right instance of the figure (e.g. if someone does fig1, fig2, etc). See also #1072 (comment). The _preprocess code is here:

pygmt/pygmt/figure.py

Lines 109 to 115 in 0aa04d7

def _preprocess(self, **kwargs):
"""
Call the ``figure`` module before each plotting command to ensure we're
plotting to this particular figure.
"""
self._activate_figure()
return kwargs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me why we call the figure method to ensure the correct figure is activated in the session. I still don't really understand why we pass and return kwargs to/from _preprocess, or even why we call _preprocess at all rather than calling _activate_figure directly. I pushed a new branch to demonstrate the point of confusion, in which I replaced all the kwargs = self._preprocess(**kwargs) with self._activate_figure which seems more readable, a tiny bit faster, and still seems to work fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added in #148, I'm not sure exactly why, but maybe @leouieda was thinking that there might be some other common function that needs to be called every time a plotting function happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was setup for any other preprocessing of kwargs that may be needed or other figure plotting methods. At the time, I put this in place to allow things like passing pygmt.Figure(projection=pygmt.Mercator(), width="25c") and then inserting the projection code into the first plot command.

It could also be useful for any bookkeeping operation that is needed, like activating the figure, activating a subplot, etc.


if "S" not in kwargs or ("S" in kwargs and not isinstance(kwargs["S"], str)):
raise GMTInvalidInput("Spec is a required argument and has to be a string.")
if kwargs.get("S") is None or (
kwargs.get("S") is not None and not isinstance(kwargs["S"], str)
):
raise GMTInvalidInput(
"The parameter `spec` is required and has to be a string."
)

if isinstance(data, np.ndarray) and not pd.api.types.is_numeric_dtype(data):
raise GMTInvalidInput(
Expand Down
2 changes: 1 addition & 1 deletion pygmt/tests/test_grd2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_grd2xyz_format(grid):
np.testing.assert_allclose(orig_val, xyz_val)
xyz_array = grd2xyz(grid=grid, output_type="numpy")
assert isinstance(xyz_array, np.ndarray)
xyz_df = grd2xyz(grid=grid, output_type="pandas")
xyz_df = grd2xyz(grid=grid, output_type="pandas", outcols=None)
assert isinstance(xyz_df, pd.DataFrame)
assert list(xyz_df.columns) == ["lon", "lat", "z"]

Expand Down
7 changes: 6 additions & 1 deletion pygmt/tests/test_grdgradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,13 @@ def test_grdgradient_no_outgrid(grid, expected_grid):
"""
Test the azimuth and direction parameters for grdgradient with no set
outgrid.

This is a regression test for
https://github.com/GenericMappingTools/pygmt/issues/1807.
"""
result = grdgradient(grid=grid, azimuth=10, region=[-53, -49, -20, -17])
result = grdgradient(
grid=grid, azimuth=10, region=[-53, -49, -20, -17], outgrid=None
)
# check information of the output grid
assert isinstance(result, xr.DataArray)
assert result.gmt.gtype == 1 # Geographic grid
Expand Down
15 changes: 15 additions & 0 deletions pygmt/tests/test_grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ def test_grdimage_file():
return fig


@pytest.mark.mpl_image_compare(filename="test_grdimage_slice.png")
@pytest.mark.parametrize("shading", [None, False])
def test_grdimage_default_no_shading(grid, shading):
"""
Plot an image with no shading.

This is a regression test for
https://github.com/GenericMappingTools/pygmt/issues/1852
"""
grid_ = grid.sel(lat=slice(-30, 30))
fig = Figure()
fig.grdimage(grid_, cmap="earth", projection="M6i", shading=shading)
return fig


@check_figures_equal()
@pytest.mark.parametrize(
"shading",
Expand Down
14 changes: 14 additions & 0 deletions pygmt/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,20 @@ def test_text_varying_transparency():
return fig


@pytest.mark.mpl_image_compare(filename="test_text_input_single_filename.png")
@pytest.mark.parametrize("transparency", [None, False, 0])
def test_text_no_transparency(transparency):
"""
Add text with no transparency set.

This is a regression test for
https://github.com/GenericMappingTools/pygmt/issues/1852.
"""
fig = Figure()
fig.text(region=[10, 70, -5, 10], textfiles=POINTS_DATA, transparency=transparency)
return fig


@pytest.mark.mpl_image_compare
def test_text_nonstr_text():
"""
Expand Down