diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 1b8b5483a28..1a77659ac75 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -34,6 +34,7 @@ GMTVersionError, ) from pygmt.helpers import ( + _validate_data_input, data_kind, tempfile_from_geojson, tempfile_from_image, @@ -1684,8 +1685,15 @@ def virtualfile_in( # noqa: PLR0912 ... print(fout.read().strip()) : N = 3 <7/9> <4/6> <1/3> """ - kind = data_kind( - data, x, y, z, required_z=required_z, required_data=required_data + kind = data_kind(data, required=required_data) + _validate_data_input( + data=data, + x=x, + y=y, + z=z, + required_z=required_z, + required_data=required_data, + kind=kind, ) if check_kind: diff --git a/pygmt/helpers/__init__.py b/pygmt/helpers/__init__.py index 128b1e31a18..862abbbdd64 100644 --- a/pygmt/helpers/__init__.py +++ b/pygmt/helpers/__init__.py @@ -15,6 +15,7 @@ unique_name, ) from pygmt.helpers.utils import ( + _validate_data_input, args_in_kwargs, build_arg_list, build_arg_string, diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index dd202d2e840..2e981266575 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -12,7 +12,7 @@ import warnings import webbrowser from collections.abc import Iterable, Sequence -from typing import Any +from typing import Any, Literal import xarray as xr from pygmt.encodings import charset @@ -79,6 +79,10 @@ def _validate_data_input( Traceback (most recent call last): ... pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. + >>> _validate_data_input(data="infile", x=[1, 2, 3], y=[4, 5, 6]) + Traceback (most recent call last): + ... + pygmt.exceptions.GMTInvalidInput: Too much data. Use either data or x/y/z. >>> _validate_data_input(data="infile", z=[7, 8, 9]) Traceback (most recent call last): ... @@ -111,21 +115,21 @@ def _validate_data_input( raise GMTInvalidInput("data must provide x, y, and z columns.") -def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data=True): +def data_kind( + data: Any = None, required: bool = True +) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]: """ - Check what kind of data is provided to a module. + Check the kind of data that is provided to a module. - Possible types: + The ``data`` argument can be in any type, but only following types are supported: - * a file name provided as 'data' - * a pathlib.PurePath object provided as 'data' - * an xarray.DataArray object provided as 'data' - * a 2-D matrix provided as 'data' - * 1-D arrays x and y (and z, optionally) - * an optional argument (None, bool, int or float) provided as 'data' - - Arguments should be ``None`` if not used. If doesn't fit any of these - categories (or fits more than one), will raise an exception. + - a string or a :class:`pathlib.PurePath` object or a sequence of them, representing + a file name or a list of file names + - a 2-D or 3-D :class:`xarray.DataArray` object + - a 2-D matrix + - None, bool, int or float type representing an optional arguments + - a geo-like Python object that implements ``__geo_interface__`` (e.g., + geopandas.GeoDataFrame or shapely.geometry) Parameters ---------- @@ -133,55 +137,47 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data Pass in either a file name or :class:`pathlib.Path` to an ASCII data table, an :class:`xarray.DataArray`, a 1-D/2-D {table-classes} or an option argument. - x/y : 1-D arrays or None - x and y columns as numpy arrays. - z : 1-D array or None - z column as numpy array. To be used optionally when x and y are given. - required_z : bool - State whether the 'z' column is required. - required_data : bool + required Set to True when 'data' is required, or False when dealing with optional virtual files. [Default is True]. Returns ------- - kind : str - One of ``'arg'``, ``'file'``, ``'grid'``, ``image``, ``'geojson'``, - ``'matrix'``, or ``'vectors'``. + kind + The data kind. Examples -------- - >>> import numpy as np >>> import xarray as xr >>> import pathlib - >>> data_kind(data=None, x=np.array([1, 2, 3]), y=np.array([4, 5, 6])) + >>> data_kind(data=None) 'vectors' - >>> data_kind(data=np.arange(10).reshape((5, 2)), x=None, y=None) + >>> data_kind(data=np.arange(10).reshape((5, 2))) 'matrix' - >>> data_kind(data="my-data-file.txt", x=None, y=None) + >>> data_kind(data="my-data-file.txt") 'file' - >>> data_kind(data=pathlib.Path("my-data-file.txt"), x=None, y=None) + >>> data_kind(data=pathlib.Path("my-data-file.txt")) 'file' - >>> data_kind(data=None, x=None, y=None, required_data=False) + >>> data_kind(data=None, required=False) 'arg' - >>> data_kind(data=2.0, x=None, y=None, required_data=False) + >>> data_kind(data=2.0, required=False) 'arg' - >>> data_kind(data=True, x=None, y=None, required_data=False) + >>> data_kind(data=True, required=False) 'arg' >>> data_kind(data=xr.DataArray(np.random.rand(4, 3))) 'grid' >>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5))) 'image' """ - # determine the data kind + kind: Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"] if isinstance(data, str | pathlib.PurePath) or ( isinstance(data, list | tuple) and all(isinstance(_file, str | pathlib.PurePath) for _file in data) ): # One or more files kind = "file" - elif isinstance(data, bool | int | float) or (data is None and not required_data): + elif isinstance(data, bool | int | float) or (data is None and not required): kind = "arg" elif isinstance(data, xr.DataArray): kind = "image" if len(data.dims) == 3 else "grid" @@ -193,15 +189,6 @@ def data_kind(data=None, x=None, y=None, z=None, required_z=False, required_data kind = "matrix" else: kind = "vectors" - _validate_data_input( - data=data, - x=x, - y=y, - z=z, - required_z=required_z, - required_data=required_data, - kind=kind, - ) return kind diff --git a/pygmt/src/plot.py b/pygmt/src/plot.py index 43b26232871..e66f08438e5 100644 --- a/pygmt/src/plot.py +++ b/pygmt/src/plot.py @@ -208,7 +208,7 @@ def plot( # noqa: PLR0912 """ kwargs = self._preprocess(**kwargs) - kind = data_kind(data, x, y) + kind = data_kind(data) extra_arrays = [] if kind == "vectors": # Add more columns for vectors input # Parameters for vector styles diff --git a/pygmt/src/plot3d.py b/pygmt/src/plot3d.py index 65d87761d5c..c86e5e259f1 100644 --- a/pygmt/src/plot3d.py +++ b/pygmt/src/plot3d.py @@ -183,7 +183,7 @@ def plot3d( # noqa: PLR0912 """ kwargs = self._preprocess(**kwargs) - kind = data_kind(data, x, y, z) + kind = data_kind(data) extra_arrays = [] if kind == "vectors": # Add more columns for vectors input diff --git a/pygmt/src/text.py b/pygmt/src/text.py index 04abf12ea3b..484f885997a 100644 --- a/pygmt/src/text.py +++ b/pygmt/src/text.py @@ -180,11 +180,11 @@ def text_( # noqa: PLR0912 # Ensure inputs are either textfiles, x/y/text, or position/text if position is None: - if (x is not None or y is not None) and textfiles is not None: + if any(v is not None for v in (x, y, text)) and textfiles is not None: raise GMTInvalidInput( "Provide either position only, or x/y pairs, or textfiles." ) - kind = data_kind(textfiles, x, y, text) + kind = data_kind(textfiles) if kind == "vectors" and text is None: raise GMTInvalidInput("Must provide text with x/y pairs") else: diff --git a/pygmt/tests/test_helpers.py b/pygmt/tests/test_helpers.py index ea966753535..98cc4c16d25 100644 --- a/pygmt/tests/test_helpers.py +++ b/pygmt/tests/test_helpers.py @@ -4,7 +4,6 @@ from pathlib import Path -import numpy as np import pytest import xarray as xr from pygmt import Figure @@ -13,7 +12,6 @@ GMTTempFile, args_in_kwargs, build_arg_list, - data_kind, kwargs_to_strings, unique_name, ) @@ -33,25 +31,6 @@ def test_load_static_earth_relief(): assert isinstance(data, xr.DataArray) -@pytest.mark.parametrize( - ("data", "x", "y"), - [ - (None, None, None), - ("data.txt", np.array([1, 2]), np.array([4, 5])), - ("data.txt", np.array([1, 2]), None), - ("data.txt", None, np.array([4, 5])), - (None, np.array([1, 2]), None), - (None, None, np.array([4, 5])), - ], -) -def test_data_kind_fails(data, x, y): - """ - Make sure data_kind raises exceptions when it should. - """ - with pytest.raises(GMTInvalidInput): - data_kind(data=data, x=x, y=y) - - def test_unique_name(): """ Make sure the names are really unique.