Skip to content

Commit

Permalink
Add new parameter 'required_cols' and remove the parameter 'required_z'
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Aug 4, 2024
1 parent 9e78da0 commit a64046a
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 39 deletions.
10 changes: 5 additions & 5 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,8 +1605,8 @@ def virtualfile_in( # noqa: PLR0912
x=None,
y=None,
z=None,
required_z=False,
required_data=True,
required_cols: int = 2,
):
"""
Store any data inside a virtual file.
Expand All @@ -1626,11 +1626,11 @@ def virtualfile_in( # noqa: PLR0912
data input.
x/y/z : 1-D arrays or None
x, y, and z columns as numpy arrays.
required_z : bool
State whether the 'z' column is required.
required_data : bool
Set to True when 'data' is required, or False when dealing with
optional virtual files. [Default is True].
required_cols
Number of required columns.
Returns
-------
Expand Down Expand Up @@ -1664,8 +1664,8 @@ def virtualfile_in( # noqa: PLR0912
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
required_cols=required_cols,
kind=kind,
)

Expand Down Expand Up @@ -1775,8 +1775,8 @@ def virtualfile_from_data(
x=x,
y=y,
z=z,
required_z=required_z,
required_data=required_data,
required_cols=3 if required_z else 2,
)

@contextlib.contextmanager
Expand Down
60 changes: 36 additions & 24 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def _validate_data_input(
data=None, x=None, y=None, z=None, required_z=False, required_data=True, kind=None
data=None, x=None, y=None, z=None, required_data=True, required_cols=2, kind=None
):
"""
Check if the combination of data/x/y/z is valid.
Expand All @@ -44,29 +44,29 @@ def _validate_data_input(
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide both x and y.
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_z=True)
>>> _validate_data_input(x=[1, 2, 3], y=[4, 5, 6], required_cols=3)
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: Must provide x, y, and z.
>>> import numpy as np
>>> import pandas as pd
>>> import xarray as xr
>>> data = np.arange(8).reshape((4, 2))
>>> _validate_data_input(data=data, required_z=True, kind="matrix")
>>> _validate_data_input(data=data, required_cols=3, kind="matrix")
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
... data=pd.DataFrame(data, columns=["x", "y"]),
... required_z=True,
... required_cols=3,
... kind="matrix",
... )
Traceback (most recent call last):
...
pygmt.exceptions.GMTInvalidInput: data must provide x, y, and z columns.
>>> _validate_data_input(
... data=xr.Dataset(pd.DataFrame(data, columns=["x", "y"])),
... required_z=True,
... required_cols=3,
... kind="matrix",
... )
Traceback (most recent call last):
Expand Down Expand Up @@ -94,26 +94,38 @@ def _validate_data_input(
GMTInvalidInput
If the data input is not valid.
"""
if data is None: # data is None
if x is None and y is None: # both x and y are None
if required_data: # data is not optional
if kind is None:
kind = data_kind(data, required=required_data)

if data is not None and any(v is not None for v in (x, y, z)):
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")

match kind:
case "none":
if x is None and y is None: # both x and y are None
raise GMTInvalidInput("No input data provided.")
elif x is None or y is None: # either x or y is None
raise GMTInvalidInput("Must provide both x and y.")
if required_z and z is None: # both x and y are not None, now check z
raise GMTInvalidInput("Must provide x, y, and z.")
else: # data is not None
if x is not None or y is not None or z is not None:
raise GMTInvalidInput("Too much data. Use either data or x/y/z.")
# For 'matrix' kind, check if data has the required z column
if kind == "matrix" and required_z:
if hasattr(data, "shape"): # np.ndarray or pd.DataFrame
if len(data.shape) == 1 and data.shape[0] < 3:
raise GMTInvalidInput("data must provide x, y, and z columns.")
if len(data.shape) > 1 and data.shape[1] < 3:
raise GMTInvalidInput("data must provide x, y, and z columns.")
if hasattr(data, "data_vars") and len(data.data_vars) < 3: # xr.Dataset
raise GMTInvalidInput("data must provide x, y, and z columns.")
if x is None or y is None: # either x or y is None
raise GMTInvalidInput("Must provide both x and y.")
if required_cols >= 3 and z is None:
# both x and y are not None, now check z
raise GMTInvalidInput("Must provide x, y, and z.")
case "matrix": # 2-D numpy.ndarray
if (actual_cols := data.shape[1]) < required_cols:
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)
case "vectors":
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# Dict, pd.DataFrame, xr.Dataset
arrays = [array for _, array in data.items()]
if (actual_cols := len(arrays)) < required_cols:
msg = f"data needs {required_cols} columns but {actual_cols} column(s) are given."
raise GMTInvalidInput(msg)

# Loop over columns to make sure they're not None
for idx, array in enumerate(arrays[:required_cols]):
if array is None:
msg = f"data needs {required_cols} columns but the {idx} column is None."
raise GMTInvalidInput(msg)


def _check_encoding(
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/blockm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _blockm(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def contour(self, data=None, x=None, y=None, z=None, **kwargs):

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl:
lib.call_module(
module="contour", args=build_arg_list(kwargs, infile=vintbl)
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/nearneighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def nearneighbor(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down
3 changes: 2 additions & 1 deletion pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def plot( # noqa: PLR0912
kind = data_kind(data)
if kind == "none": # Vectors input
data = {"x": x, "y": y}
x, y = None, None
# Parameters for vector styles
if (
kwargs.get("S") is not None
Expand Down Expand Up @@ -255,5 +256,5 @@ def plot( # noqa: PLR0912
pass

with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
with lib.virtualfile_in(check_kind="vector", data=data, x=x, y=y) as vintbl:
lib.call_module(module="plot", args=build_arg_list(kwargs, infile=vintbl))
3 changes: 2 additions & 1 deletion pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def plot3d( # noqa: PLR0912
kind = data_kind(data)
if kind == "none": # Vectors input
data = {"x": x, "y": y, "z": z}
x, y, z = None, None, None
# Parameters for vector styles
if (
kwargs.get("S") is not None
Expand Down Expand Up @@ -231,6 +232,6 @@ def plot3d( # noqa: PLR0912

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl:
lib.call_module(module="plot3d", args=build_arg_list(kwargs, infile=vintbl))
2 changes: 1 addition & 1 deletion pygmt/src/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def project(
x=x,
y=y,
z=z,
required_z=False,
required_cols=2,
required_data=False,
) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def surface(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/triangulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def regular_grid(
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=False
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=2
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/wiggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,6 @@ def wiggle(

with Session() as lib:
with lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl:
lib.call_module(module="wiggle", args=build_arg_list(kwargs, infile=vintbl))
2 changes: 1 addition & 1 deletion pygmt/src/xyz2grd.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def xyz2grd(data=None, x=None, y=None, z=None, outgrid: str | None = None, **kwa
with Session() as lib:
with (
lib.virtualfile_in(
check_kind="vector", data=data, x=x, y=y, z=z, required_z=True
check_kind="vector", data=data, x=x, y=y, z=z, required_cols=3
) as vintbl,
lib.virtualfile_out(kind="grid", fname=outgrid) as voutgrd,
):
Expand Down

0 comments on commit a64046a

Please sign in to comment.