From d913c860c0bb35468738b3d789a916ed09f225a4 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 2 Dec 2024 17:18:25 +0800 Subject: [PATCH 01/23] Add pygmt.read to read a dataset/grid/image into pandas.DataFrame/xarray.DataArray --- doc/api/index.rst | 1 + pygmt/__init__.py | 1 + pygmt/datasets/load_remote_dataset.py | 14 +---- pygmt/datasets/samples.py | 6 +- pygmt/helpers/testing.py | 11 ++-- pygmt/src/__init__.py | 1 + pygmt/src/read.py | 86 +++++++++++++++++++++++++++ 7 files changed, 99 insertions(+), 21 deletions(-) create mode 100644 pygmt/src/read.py diff --git a/doc/api/index.rst b/doc/api/index.rst index 01fac7d7a89..4b720a12353 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -172,6 +172,7 @@ Input/output :toctree: generated load_dataarray + read GMT Defaults ------------ diff --git a/pygmt/__init__.py b/pygmt/__init__.py index f6d1040851f..dba5cf38d1c 100644 --- a/pygmt/__init__.py +++ b/pygmt/__init__.py @@ -54,6 +54,7 @@ makecpt, nearneighbor, project, + read, select, sph2grd, sphdistance, diff --git a/pygmt/datasets/load_remote_dataset.py b/pygmt/datasets/load_remote_dataset.py index 168a93583b2..ca5db7d818e 100644 --- a/pygmt/datasets/load_remote_dataset.py +++ b/pygmt/datasets/load_remote_dataset.py @@ -6,10 +6,9 @@ from typing import Any, Literal, NamedTuple import xarray as xr -from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import build_arg_list, kwargs_to_strings -from pygmt.src import which +from pygmt.helpers import kwargs_to_strings +from pygmt.src import read, which class Resolution(NamedTuple): @@ -443,14 +442,7 @@ def _load_remote_dataset( fname = f"@{prefix}_{resolution}_{reg}" kind = "image" if name in {"earth_day", "earth_night"} else "grid" - kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[kind]} - with Session() as lib: - with lib.virtualfile_out(kind=kind) as voutgrd: - lib.call_module( - module="read", - args=[fname, voutgrd, *build_arg_list(kwdict)], - ) - grid = lib.virtualfile_to_raster(kind=kind, outgrid=None, vfname=voutgrd) + grid = read(fname, kind=kind, region=region) # Full path to the grid if not tiled grids. source = which(fname, download="a") if not resinfo.tiled else None diff --git a/pygmt/datasets/samples.py b/pygmt/datasets/samples.py index 3739ee630b8..594eae26a54 100644 --- a/pygmt/datasets/samples.py +++ b/pygmt/datasets/samples.py @@ -8,8 +8,7 @@ import pandas as pd import xarray as xr from pygmt.exceptions import GMTInvalidInput -from pygmt.io import load_dataarray -from pygmt.src import which +from pygmt.src import read, which def _load_japan_quakes() -> pd.DataFrame: @@ -203,8 +202,7 @@ def _load_earth_relief_holes() -> xr.DataArray: The Earth relief grid. Coordinates are latitude and longitude in degrees. Relief is in meters. """ - fname = which("@earth_relief_20m_holes.grd", download="c") - return load_dataarray(fname, engine="netcdf4") + return read("@earth_relief_20m_holes.grd", kind="grid") class GMTSampleData(NamedTuple): diff --git a/pygmt/helpers/testing.py b/pygmt/helpers/testing.py index 29dfd08df19..990ce5aa418 100644 --- a/pygmt/helpers/testing.py +++ b/pygmt/helpers/testing.py @@ -7,9 +7,9 @@ import string from pathlib import Path +import xarray as xr from pygmt.exceptions import GMTImageComparisonFailure -from pygmt.io import load_dataarray -from pygmt.src import which +from pygmt.src import read def check_figures_equal(*, extensions=("png",), tol=0.0, result_dir="result_images"): @@ -144,17 +144,16 @@ def wrapper(*args, ext="png", request=None, **kwargs): return decorator -def load_static_earth_relief(): +def load_static_earth_relief() -> xr.DataArray: """ Load the static_earth_relief file for internal testing. Returns ------- - data : xarray.DataArray + data A grid of Earth relief for internal tests. """ - fname = which("@static_earth_relief.nc", download="c") - return load_dataarray(fname) + return read("@static_earth_relief.nc", kind="grid") # type: ignore[return-value] def skip_if_no(package): diff --git a/pygmt/src/__init__.py b/pygmt/src/__init__.py index e4db7321963..01cea172417 100644 --- a/pygmt/src/__init__.py +++ b/pygmt/src/__init__.py @@ -41,6 +41,7 @@ from pygmt.src.plot3d import plot3d from pygmt.src.project import project from pygmt.src.psconvert import psconvert +from pygmt.src.read import read from pygmt.src.rose import rose from pygmt.src.select import select from pygmt.src.shift_origin import shift_origin diff --git a/pygmt/src/read.py b/pygmt/src/read.py new file mode 100644 index 00000000000..aca92dbf5e3 --- /dev/null +++ b/pygmt/src/read.py @@ -0,0 +1,86 @@ +""" +Read data from files +""" + +from typing import Literal + +import pandas as pd +import xarray as xr +from pygmt.clib import Session +from pygmt.helpers import build_arg_list, fmt_docstring, kwargs_to_strings, use_alias + + +@fmt_docstring +@use_alias(R="region") +@kwargs_to_strings(R="sequence") +def read( + file, + kind: Literal["dataset", "grid", "image"], + **kwargs, +) -> pd.DataFrame | xr.DataArray: + """ + Read a dataset, grid, or image from a file and return the appropriate object. + + For datasets, it returns a :class:`pandas.DataFrame`. For grids and images, it + returns a :class:`xarray.DataArray`. + + Parameters + ---------- + file + The file name to read. + kind + The kind of data to read. Valid values are ``"dataset"``, ``"grid"``, and + ``"image"``. + {region} + + For datasets, the following keyword arguments are supported: + + column_names + A list of column names. + header + Row number containing column names. ``header=None`` means not to parse the + column names from table header. Ignored if the row number is larger than the + number of headers in the table. + dtype + Data type. Can be a single type for all columns or a dictionary mapping column + names to types. + index_col + Column to set as index. + + Returns + ------- + Return type depends on the ``kind`` argument: + + - ``"dataset"``: :class:`pandas.DataFrame` + - ``"grid"`` or ``"image"``: :class:`xarray.DataArray` + + + Examples + -------- + + Read a dataset into a :class:`pandas.DataFrame` object: + + >>> from pygmt import read + >>> df = read("@hotspots.txt", kind="dataset") + >>> type(df) + + + Read a grid into an :class:`xarray.DataArray` object: + >>> dataarray = read("@earth_relief_01d", kind="grid") + >>> type(dataarray) + + """ + kwdict = { + "R": kwargs.get("R"), + "T": {"dataset": "d", "grid": "g", "image": "i"}[kind], + } + + with Session() as lib: + with lib.virtualfile_out(kind=kind) as voutfile: + lib.call_module("read", args=[file, voutfile, *build_arg_list(kwdict)]) + + match kind: + case "dataset": + return lib.virtualfile_to_dataset(vfname=voutfile, **kwargs) + case "grid" | "image": + return lib.virtualfile_to_raster(vfname=voutfile, kind=kind) From f456bf81412f1ee8d9037e61196005c8700f1cf3 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Thu, 5 Dec 2024 23:17:52 +0800 Subject: [PATCH 02/23] Set GMT accessor --- pygmt/src/read.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index aca92dbf5e3..bc528df1db9 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -83,4 +83,6 @@ def read( case "dataset": return lib.virtualfile_to_dataset(vfname=voutfile, **kwargs) case "grid" | "image": - return lib.virtualfile_to_raster(vfname=voutfile, kind=kind) + raster = lib.virtualfile_to_raster(vfname=voutfile, kind=kind) + _ = raster.gmt # Load GMTDataArray accessor information + return raster From c3cbb6ed3e2453a8cc630635e2b1521195bbde6f Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Thu, 5 Dec 2024 23:34:53 +0800 Subject: [PATCH 03/23] Need to set 'source' encoding to make GMT accessor work --- pygmt/src/read.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index bc528df1db9..e7ec650bad2 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -8,13 +8,14 @@ import xarray as xr from pygmt.clib import Session from pygmt.helpers import build_arg_list, fmt_docstring, kwargs_to_strings, use_alias +from pygmt.src.which import which @fmt_docstring @use_alias(R="region") @kwargs_to_strings(R="sequence") def read( - file, + file: str, kind: Literal["dataset", "grid", "image"], **kwargs, ) -> pd.DataFrame | xr.DataArray: @@ -84,5 +85,6 @@ def read( return lib.virtualfile_to_dataset(vfname=voutfile, **kwargs) case "grid" | "image": raster = lib.virtualfile_to_raster(vfname=voutfile, kind=kind) - _ = raster.gmt # Load GMTDataArray accessor information + raster.encoding["source"] = which(fname=file) # Add "source" encoding + _ = raster.gmt # Load GMTDataArray accessor information return raster From 1dd97c6add374c149e749226c326f2c4265ff07b Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Thu, 5 Dec 2024 23:49:39 +0800 Subject: [PATCH 04/23] Fix the source encoding --- pygmt/src/read.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index e7ec650bad2..e7bee0cb9b9 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -85,6 +85,10 @@ def read( return lib.virtualfile_to_dataset(vfname=voutfile, **kwargs) case "grid" | "image": raster = lib.virtualfile_to_raster(vfname=voutfile, kind=kind) - raster.encoding["source"] = which(fname=file) # Add "source" encoding + # Add "source" encoding + source = which(fname=file) + raster.encoding["source"] = ( + source if isinstance(source, str) else source[0] + ) _ = raster.gmt # Load GMTDataArray accessor information return raster From 7790ea3fd7e2c6c6e418b6db4aa1b2f7918cf536 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 6 Dec 2024 00:15:23 +0800 Subject: [PATCH 05/23] No need to set the source encoding in load_remote_dataset.py --- pygmt/datasets/load_remote_dataset.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pygmt/datasets/load_remote_dataset.py b/pygmt/datasets/load_remote_dataset.py index ca5db7d818e..c721d344f85 100644 --- a/pygmt/datasets/load_remote_dataset.py +++ b/pygmt/datasets/load_remote_dataset.py @@ -8,7 +8,7 @@ import xarray as xr from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import kwargs_to_strings -from pygmt.src import read, which +from pygmt.src import read class Resolution(NamedTuple): @@ -444,12 +444,6 @@ def _load_remote_dataset( kind = "image" if name in {"earth_day", "earth_night"} else "grid" grid = read(fname, kind=kind, region=region) - # Full path to the grid if not tiled grids. - source = which(fname, download="a") if not resinfo.tiled else None - # Manually add source to xarray.DataArray encoding to make the GMT accessors work. - if source: - grid.encoding["source"] = source - # Add some metadata to the grid grid.attrs["description"] = dataset.description if dataset.units: From e588008b13bdde835ca1b8de87917c83f829526b Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 6 Dec 2024 23:29:40 +0800 Subject: [PATCH 06/23] Revert changes in pygmt/datasets/load_remote_dataset.py --- pygmt/datasets/load_remote_dataset.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pygmt/datasets/load_remote_dataset.py b/pygmt/datasets/load_remote_dataset.py index c721d344f85..168a93583b2 100644 --- a/pygmt/datasets/load_remote_dataset.py +++ b/pygmt/datasets/load_remote_dataset.py @@ -6,9 +6,10 @@ from typing import Any, Literal, NamedTuple import xarray as xr +from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import kwargs_to_strings -from pygmt.src import read +from pygmt.helpers import build_arg_list, kwargs_to_strings +from pygmt.src import which class Resolution(NamedTuple): @@ -442,7 +443,20 @@ def _load_remote_dataset( fname = f"@{prefix}_{resolution}_{reg}" kind = "image" if name in {"earth_day", "earth_night"} else "grid" - grid = read(fname, kind=kind, region=region) + kwdict = {"R": region, "T": {"grid": "g", "image": "i"}[kind]} + with Session() as lib: + with lib.virtualfile_out(kind=kind) as voutgrd: + lib.call_module( + module="read", + args=[fname, voutgrd, *build_arg_list(kwdict)], + ) + grid = lib.virtualfile_to_raster(kind=kind, outgrid=None, vfname=voutgrd) + + # Full path to the grid if not tiled grids. + source = which(fname, download="a") if not resinfo.tiled else None + # Manually add source to xarray.DataArray encoding to make the GMT accessors work. + if source: + grid.encoding["source"] = source # Add some metadata to the grid grid.attrs["description"] = dataset.description From 40d12eeeca4df13be7f92002652590a65426c08d Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 6 Dec 2024 23:45:03 +0800 Subject: [PATCH 07/23] Improve docstring in pygmt/helpers/testing.py --- pygmt/helpers/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/helpers/testing.py b/pygmt/helpers/testing.py index 990ce5aa418..8be5f4ee78e 100644 --- a/pygmt/helpers/testing.py +++ b/pygmt/helpers/testing.py @@ -146,7 +146,7 @@ def wrapper(*args, ext="png", request=None, **kwargs): def load_static_earth_relief() -> xr.DataArray: """ - Load the static_earth_relief file for internal testing. + Load the static_earth_relief.nc file for internal testing. Returns ------- From fa1021de45cdbfc92caa1da22b5d82e0ecbb91e5 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 6 Dec 2024 23:47:45 +0800 Subject: [PATCH 08/23] Improve docstrinbgs --- pygmt/src/read.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index e7bee0cb9b9..e20de46ba6a 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -34,7 +34,7 @@ def read( ``"image"``. {region} - For datasets, the following keyword arguments are supported: + For datasets, the following keyword arguments are supported: column_names A list of column names. @@ -50,15 +50,15 @@ def read( Returns ------- - Return type depends on the ``kind`` argument: + data + Return type depends on the ``kind`` argument: - - ``"dataset"``: :class:`pandas.DataFrame` - - ``"grid"`` or ``"image"``: :class:`xarray.DataArray` + - ``"dataset"``: :class:`pandas.DataFrame` + - ``"grid"`` or ``"image"``: :class:`xarray.DataArray` Examples -------- - Read a dataset into a :class:`pandas.DataFrame` object: >>> from pygmt import read @@ -67,6 +67,7 @@ def read( Read a grid into an :class:`xarray.DataArray` object: + >>> dataarray = read("@earth_relief_01d", kind="grid") >>> type(dataarray) From c378225a3515ee5b5df53c0a00cc543ad6d7e16a Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Sun, 8 Dec 2024 17:22:16 +0800 Subject: [PATCH 09/23] Get rid of decorators --- pygmt/src/read.py | 53 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index e20de46ba6a..39922fb389d 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -1,29 +1,36 @@ """ -Read data from files +Read a file into an appropriate object. """ -from typing import Literal +from collections.abc import Mapping, Sequence +from typing import Any, Literal import pandas as pd import xarray as xr from pygmt.clib import Session -from pygmt.helpers import build_arg_list, fmt_docstring, kwargs_to_strings, use_alias +from pygmt.helpers import ( + build_arg_list, + fmt_docstring, + is_nonstr_iter, +) from pygmt.src.which import which @fmt_docstring -@use_alias(R="region") -@kwargs_to_strings(R="sequence") def read( file: str, kind: Literal["dataset", "grid", "image"], - **kwargs, + region: Sequence[float] | str | None = None, + header: int | None = None, + column_names: pd.Index | None = None, + dtype: type | Mapping[Any, type] | None = None, + index_col: str | int | None = None, ) -> pd.DataFrame | xr.DataArray: """ Read a dataset, grid, or image from a file and return the appropriate object. - For datasets, it returns a :class:`pandas.DataFrame`. For grids and images, it - returns a :class:`xarray.DataArray`. + The returned object is a :class:`pandas.DataFrame` for datasets, and + :class:`xarray.DataArray` for grids and images. Parameters ---------- @@ -32,10 +39,8 @@ def read( kind The kind of data to read. Valid values are ``"dataset"``, ``"grid"``, and ``"image"``. - {region} - - For datasets, the following keyword arguments are supported: - + region + The region of interest. Only data within this region will be read. column_names A list of column names. header @@ -43,8 +48,8 @@ def read( column names from table header. Ignored if the row number is larger than the number of headers in the table. dtype - Data type. Can be a single type for all columns or a dictionary mapping column - names to types. + Data type. Can be a single type for all columns or a dictionary mapping + column names to types. index_col Column to set as index. @@ -72,8 +77,18 @@ def read( >>> type(dataarray) """ + + if kind != "dataset" and any( + v is not None for v in [column_names, header, dtype, index_col] + ): + msg = ( + "Only the 'dataset' kind supports the 'column_names', 'header', " + "'dtype', and 'index_col' arguments." + ) + raise ValueError(msg) + kwdict = { - "R": kwargs.get("R"), + "R": "/".join(f"{v}" for v in region) if is_nonstr_iter(region) else region, "T": {"dataset": "d", "grid": "g", "image": "i"}[kind], } @@ -83,7 +98,13 @@ def read( match kind: case "dataset": - return lib.virtualfile_to_dataset(vfname=voutfile, **kwargs) + return lib.virtualfile_to_dataset( + vfname=voutfile, + column_names=column_names, + header=header, + dtype=dtype, + index_col=index_col, + ) case "grid" | "image": raster = lib.virtualfile_to_raster(vfname=voutfile, kind=kind) # Add "source" encoding From 7b749e0c45c433f2b423ebc1e6096d902553b046 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Sun, 8 Dec 2024 17:23:01 +0800 Subject: [PATCH 10/23] Improve comment --- pygmt/src/read.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index 39922fb389d..0fa67852bdb 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -32,6 +32,9 @@ def read( The returned object is a :class:`pandas.DataFrame` for datasets, and :class:`xarray.DataArray` for grids and images. + For datasets, keyword arguments ``column_names``, ``header``, ``dtype``, and + ``index_col`` are supported. + Parameters ---------- file From 8befa58ad1b3511316063aa91bfe139803626460 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Sun, 8 Dec 2024 17:23:39 +0800 Subject: [PATCH 11/23] Get rid of the fmt_docstring alias --- pygmt/src/read.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index 0fa67852bdb..eea030cb2e6 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -8,15 +8,10 @@ import pandas as pd import xarray as xr from pygmt.clib import Session -from pygmt.helpers import ( - build_arg_list, - fmt_docstring, - is_nonstr_iter, -) +from pygmt.helpers import build_arg_list, is_nonstr_iter from pygmt.src.which import which -@fmt_docstring def read( file: str, kind: Literal["dataset", "grid", "image"], From a7587528c07445adc721ec75f74edaa98d3e0ead Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 12:06:46 +0800 Subject: [PATCH 12/23] Fix type hints issue with overload --- pygmt/src/read.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index eea030cb2e6..29daa6582e2 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -3,7 +3,7 @@ """ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, overload import pandas as pd import xarray as xr @@ -12,15 +12,29 @@ from pygmt.src.which import which +@overload def read( file: str, - kind: Literal["dataset", "grid", "image"], + kind: Literal["dataset"], region: Sequence[float] | str | None = None, header: int | None = None, column_names: pd.Index | None = None, dtype: type | Mapping[Any, type] | None = None, index_col: str | int | None = None, -) -> pd.DataFrame | xr.DataArray: +) -> pd.DataFrame: ... + + +@overload +def read( + file: str, + kind: Literal["grid", "image"], + region: Sequence[float] | str | None = None, +) -> xr.DataArray: ... + + +def read( + file, kind, region, header=None, column_names=None, dtype=None, index_col=None +): """ Read a dataset, grid, or image from a file and return the appropriate object. From 9d66cf4eec3ad1339ded8708d7f74782f94c3b41 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 12:10:24 +0800 Subject: [PATCH 13/23] Remove the type ignore flag --- pygmt/helpers/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/helpers/testing.py b/pygmt/helpers/testing.py index 8be5f4ee78e..f3d8d1654f5 100644 --- a/pygmt/helpers/testing.py +++ b/pygmt/helpers/testing.py @@ -153,7 +153,7 @@ def load_static_earth_relief() -> xr.DataArray: data A grid of Earth relief for internal tests. """ - return read("@static_earth_relief.nc", kind="grid") # type: ignore[return-value] + return read("@static_earth_relief.nc", kind="grid") def skip_if_no(package): From a05383a28aafd33509c44a30e538803b390bd745 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 12:30:23 +0800 Subject: [PATCH 14/23] region defaults to None --- pygmt/src/read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index 29daa6582e2..e8b27a42e19 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -33,7 +33,7 @@ def read( def read( - file, kind, region, header=None, column_names=None, dtype=None, index_col=None + file, kind, region=None, header=None, column_names=None, dtype=None, index_col=None ): """ Read a dataset, grid, or image from a file and return the appropriate object. From 7851ced26cf59113f90f36b20db45515087fa2e2 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 17:46:20 +0800 Subject: [PATCH 15/23] Improve type hints and add tests --- pygmt/datasets/samples.py | 2 +- pygmt/helpers/testing.py | 2 +- pygmt/src/read.py | 28 +++++++++------------------- pygmt/tests/test_read.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 21 deletions(-) create mode 100644 pygmt/tests/test_read.py diff --git a/pygmt/datasets/samples.py b/pygmt/datasets/samples.py index 594eae26a54..e05eef59c37 100644 --- a/pygmt/datasets/samples.py +++ b/pygmt/datasets/samples.py @@ -202,7 +202,7 @@ def _load_earth_relief_holes() -> xr.DataArray: The Earth relief grid. Coordinates are latitude and longitude in degrees. Relief is in meters. """ - return read("@earth_relief_20m_holes.grd", kind="grid") + return read("@earth_relief_20m_holes.grd", kind="grid") # type: ignore[return-value] class GMTSampleData(NamedTuple): diff --git a/pygmt/helpers/testing.py b/pygmt/helpers/testing.py index f3d8d1654f5..8be5f4ee78e 100644 --- a/pygmt/helpers/testing.py +++ b/pygmt/helpers/testing.py @@ -153,7 +153,7 @@ def load_static_earth_relief() -> xr.DataArray: data A grid of Earth relief for internal tests. """ - return read("@static_earth_relief.nc", kind="grid") + return read("@static_earth_relief.nc", kind="grid") # type: ignore[return-value] def skip_if_no(package): diff --git a/pygmt/src/read.py b/pygmt/src/read.py index e8b27a42e19..6a70ad4ac9e 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -3,7 +3,8 @@ """ from collections.abc import Mapping, Sequence -from typing import Any, Literal, overload +from pathlib import PurePath +from typing import Any, Literal import pandas as pd import xarray as xr @@ -12,29 +13,15 @@ from pygmt.src.which import which -@overload def read( - file: str, - kind: Literal["dataset"], + file: str | PurePath, + kind: Literal["dataset", "grid", "image"], region: Sequence[float] | str | None = None, header: int | None = None, column_names: pd.Index | None = None, dtype: type | Mapping[Any, type] | None = None, index_col: str | int | None = None, -) -> pd.DataFrame: ... - - -@overload -def read( - file: str, - kind: Literal["grid", "image"], - region: Sequence[float] | str | None = None, -) -> xr.DataArray: ... - - -def read( - file, kind, region=None, header=None, column_names=None, dtype=None, index_col=None -): +) -> pd.DataFrame | xr.DataArray: """ Read a dataset, grid, or image from a file and return the appropriate object. @@ -89,6 +76,9 @@ def read( >>> type(dataarray) """ + if kind not in {"dataset", "grid", "image"}: + msg = f"Invalid kind {kind}: must be one of 'dataset', 'grid', or 'image'." + raise ValueError(msg) if kind != "dataset" and any( v is not None for v in [column_names, header, dtype, index_col] @@ -100,7 +90,7 @@ def read( raise ValueError(msg) kwdict = { - "R": "/".join(f"{v}" for v in region) if is_nonstr_iter(region) else region, + "R": "/".join(f"{v}" for v in region) if is_nonstr_iter(region) else region, # type: ignore[union-attr] "T": {"dataset": "d", "grid": "g", "image": "i"}[kind], } diff --git a/pygmt/tests/test_read.py b/pygmt/tests/test_read.py new file mode 100644 index 00000000000..79e1cd63c3c --- /dev/null +++ b/pygmt/tests/test_read.py @@ -0,0 +1,28 @@ +""" +Test the read function. +""" + +import pytest +from pygmt import read + + +def test_read_invalid_kind(): + """ + Test that an invalid kind raises a ValueError. + """ + with pytest.raises(ValueError, match="Invalid kind"): + read("file.cpt", kind="cpt") + + +def test_read_invalid_arguments(): + """ + Test that invalid arguments raise a ValueError for non-'dataset' kind. + """ + with pytest.raises(ValueError, match="Only the 'dataset' kind supports"): + read("file.nc", kind="grid", column_names="foo") + + with pytest.raises(ValueError, match="Only the 'dataset' kind supports"): + read("file.nc", kind="grid", header=1) + + with pytest.raises(ValueError, match="Only the 'dataset' kind supports"): + read("file.nc", kind="grid", dtype="float") From 084b87a17c31b5b9a35b00bba5457cc06dbbfe24 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 17:54:43 +0800 Subject: [PATCH 16/23] Improve the checking of return value of which --- pygmt/src/read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/src/read.py b/pygmt/src/read.py index 6a70ad4ac9e..1a2f0089288 100644 --- a/pygmt/src/read.py +++ b/pygmt/src/read.py @@ -112,7 +112,7 @@ def read( # Add "source" encoding source = which(fname=file) raster.encoding["source"] = ( - source if isinstance(source, str) else source[0] + source[0] if isinstance(source, list) else source ) _ = raster.gmt # Load GMTDataArray accessor information return raster From b21997c155c41f5bf6d0b2553049e92e0f78c9c0 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 17:58:14 +0800 Subject: [PATCH 17/23] Use the read funciton in pygmt/tests/test_datatypes_dataset.py --- pygmt/tests/test_datatypes_dataset.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pygmt/tests/test_datatypes_dataset.py b/pygmt/tests/test_datatypes_dataset.py index 56f18143035..5dd12124edc 100644 --- a/pygmt/tests/test_datatypes_dataset.py +++ b/pygmt/tests/test_datatypes_dataset.py @@ -6,8 +6,7 @@ import pandas as pd import pytest -from pygmt import which -from pygmt.clib import Session +from pygmt import read, which from pygmt.helpers import GMTTempFile @@ -44,11 +43,7 @@ def dataframe_from_gmt(fname, **kwargs): """ Read tabular data as pandas.DataFrame using GMT virtual file. """ - with Session() as lib: - with lib.virtualfile_out(kind="dataset") as vouttbl: - lib.call_module("read", [fname, vouttbl, "-Td"]) - df = lib.virtualfile_to_dataset(vfname=vouttbl, **kwargs) - return df + return read(fname, kind="dataset", **kwargs) @pytest.mark.benchmark From a81231717ec5e63d1d8635cd1a62b8e708c55266 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 18:18:32 +0800 Subject: [PATCH 18/23] Use the read function instead of the load_dataarray method --- pygmt/tests/test_dimfilter.py | 4 ++-- pygmt/tests/test_grdclip.py | 4 ++-- pygmt/tests/test_grdcut.py | 4 ++-- pygmt/tests/test_grdfill.py | 4 ++-- pygmt/tests/test_grdfilter.py | 4 ++-- pygmt/tests/test_grdgradient.py | 4 ++-- pygmt/tests/test_grdhisteq.py | 4 ++-- pygmt/tests/test_grdlandmask.py | 4 ++-- pygmt/tests/test_grdproject.py | 4 ++-- pygmt/tests/test_grdsample.py | 4 ++-- pygmt/tests/test_xyz2grd.py | 4 ++-- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pygmt/tests/test_dimfilter.py b/pygmt/tests/test_dimfilter.py index 9e998ac9980..37dcd9f6a6f 100644 --- a/pygmt/tests/test_dimfilter.py +++ b/pygmt/tests/test_dimfilter.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from pygmt import dimfilter, load_dataarray +from pygmt import dimfilter, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -56,7 +56,7 @@ def test_dimfilter_outgrid(grid, expected_grid): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdclip.py b/pygmt/tests/test_grdclip.py index a0f2e4a8d7c..d63090f2ae6 100644 --- a/pygmt/tests/test_grdclip.py +++ b/pygmt/tests/test_grdclip.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from pygmt import grdclip, load_dataarray +from pygmt import grdclip, read from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -49,7 +49,7 @@ def test_grdclip_outgrid(grid, expected_grid): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") assert temp_grid.dims == ("lat", "lon") assert temp_grid.gmt.gtype == 1 # Geographic grid assert temp_grid.gmt.registration == 1 # Pixel registration diff --git a/pygmt/tests/test_grdcut.py b/pygmt/tests/test_grdcut.py index dbf5dd21f49..c017d2a726d 100644 --- a/pygmt/tests/test_grdcut.py +++ b/pygmt/tests/test_grdcut.py @@ -5,7 +5,7 @@ import numpy as np import pytest import xarray as xr -from pygmt import grdcut, load_dataarray +from pygmt import grdcut, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -50,7 +50,7 @@ def test_grdcut_dataarray_in_file_out(grid, expected_grid, region): with GMTTempFile(suffix=".nc") as tmpfile: result = grdcut(grid, outgrid=tmpfile.name, region=region) assert result is None # grdcut returns None if output to a file - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdfill.py b/pygmt/tests/test_grdfill.py index f7c4730b744..36167ee21bf 100644 --- a/pygmt/tests/test_grdfill.py +++ b/pygmt/tests/test_grdfill.py @@ -7,7 +7,7 @@ import numpy as np import pytest import xarray as xr -from pygmt import grdfill, load_dataarray +from pygmt import grdfill, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -109,7 +109,7 @@ def test_grdfill_file_out(grid, expected_grid): result = grdfill(grid=grid, mode="c20", outgrid=tmpfile.name) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdfilter.py b/pygmt/tests/test_grdfilter.py index 5cbe3574767..e3fefb56a62 100644 --- a/pygmt/tests/test_grdfilter.py +++ b/pygmt/tests/test_grdfilter.py @@ -7,7 +7,7 @@ import numpy as np import pytest import xarray as xr -from pygmt import grdfilter, load_dataarray +from pygmt import grdfilter, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -70,7 +70,7 @@ def test_grdfilter_dataarray_in_file_out(grid, expected_grid): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdgradient.py b/pygmt/tests/test_grdgradient.py index dc082d50d90..844b21418a5 100644 --- a/pygmt/tests/test_grdgradient.py +++ b/pygmt/tests/test_grdgradient.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from pygmt import grdgradient, load_dataarray +from pygmt import grdgradient, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -49,7 +49,7 @@ def test_grdgradient_outgrid(grid, expected_grid): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdhisteq.py b/pygmt/tests/test_grdhisteq.py index 3c5a3df2d8d..d2899606164 100644 --- a/pygmt/tests/test_grdhisteq.py +++ b/pygmt/tests/test_grdhisteq.py @@ -8,7 +8,7 @@ import pandas as pd import pytest import xarray as xr -from pygmt import grdhisteq, load_dataarray +from pygmt import grdhisteq, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -66,7 +66,7 @@ def test_equalize_grid_outgrid_file(grid, expected_grid, region): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdlandmask.py b/pygmt/tests/test_grdlandmask.py index ae51ba2eda4..a47432a9368 100644 --- a/pygmt/tests/test_grdlandmask.py +++ b/pygmt/tests/test_grdlandmask.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from pygmt import grdlandmask, load_dataarray +from pygmt import grdlandmask, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -41,7 +41,7 @@ def test_grdlandmask_outgrid(expected_grid): result = grdlandmask(outgrid=tmpfile.name, spacing=1, region=[125, 130, 30, 35]) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdproject.py b/pygmt/tests/test_grdproject.py index 644ac311c54..33f85d3d6bb 100644 --- a/pygmt/tests/test_grdproject.py +++ b/pygmt/tests/test_grdproject.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from pygmt import grdproject, load_dataarray +from pygmt import grdproject, read from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -55,7 +55,7 @@ def test_grdproject_file_out(grid, expected_grid): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_grdsample.py b/pygmt/tests/test_grdsample.py index 4c9e64139c3..5d17b0e9de9 100644 --- a/pygmt/tests/test_grdsample.py +++ b/pygmt/tests/test_grdsample.py @@ -6,7 +6,7 @@ import pytest import xarray as xr -from pygmt import grdsample, load_dataarray +from pygmt import grdsample, read from pygmt.helpers import GMTTempFile from pygmt.helpers.testing import load_static_earth_relief @@ -66,7 +66,7 @@ def test_grdsample_file_out(grid, expected_grid, region, spacing): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) diff --git a/pygmt/tests/test_xyz2grd.py b/pygmt/tests/test_xyz2grd.py index 56b2d1167e2..10c42bf0dfd 100644 --- a/pygmt/tests/test_xyz2grd.py +++ b/pygmt/tests/test_xyz2grd.py @@ -7,7 +7,7 @@ import numpy as np import pytest import xarray as xr -from pygmt import load_dataarray, xyz2grd +from pygmt import read, xyz2grd from pygmt.datasets import load_sample_data from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -66,7 +66,7 @@ def test_xyz2grd_input_array_file_out(ship_data, expected_grid): ) assert result is None # return value is None assert Path(tmpfile.name).stat().st_size > 0 - temp_grid = load_dataarray(tmpfile.name) + temp_grid = read(tmpfile.name, kind="grid") xr.testing.assert_allclose(a=temp_grid, b=expected_grid) From 1f0f1583ee1eaa48476e88546b3ee7a4b89d62aa Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 18:32:02 +0800 Subject: [PATCH 19/23] Add one test to make sure that read and load_dataarray returns the same DataArray --- pygmt/tests/test_read.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pygmt/tests/test_read.py b/pygmt/tests/test_read.py index 79e1cd63c3c..083f177f61c 100644 --- a/pygmt/tests/test_read.py +++ b/pygmt/tests/test_read.py @@ -2,8 +2,26 @@ Test the read function. """ +import importlib.util + +import numpy as np import pytest -from pygmt import read +import xarray as xr +from pygmt import read, which + +_HAS_NETCDF4 = bool(importlib.util.find_spec("netCDF4")) + + +@pytest.mark.skipif(not _HAS_NETCDF4, reason="netCDF4 is not installed.") +def test_read_grid(): + """ + Test that reading a grid returns an xr.DataArray and the grid is the same as the one + loaded via xarray.load_dataarray. + """ + grid = read("@static_earth_relief.nc", kind="grid") + assert isinstance(grid, xr.DataArray) + expected_grid = xr.load_dataarray(which("@static_earth_relief.nc", download="a")) + assert np.allclose(grid, expected_grid) def test_read_invalid_kind(): From 957c7ebf00c13ff488301b420e9dab487536ee5e Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 18:32:22 +0800 Subject: [PATCH 20/23] Simplify pygmt/tests/test_clib_read_data.py with read --- pygmt/tests/test_clib_read_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pygmt/tests/test_clib_read_data.py b/pygmt/tests/test_clib_read_data.py index 09c4e1f6fe9..c77e4002aa1 100644 --- a/pygmt/tests/test_clib_read_data.py +++ b/pygmt/tests/test_clib_read_data.py @@ -11,8 +11,8 @@ from pygmt.clib import Session from pygmt.exceptions import GMTCLibError from pygmt.helpers import GMTTempFile -from pygmt.io import load_dataarray from pygmt.src import which +from pygmt.tests.helpers import load_static_earth_relief try: import rioxarray @@ -27,7 +27,7 @@ def fixture_expected_xrgrid(): """ The expected xr.DataArray object for the static_earth_relief.nc file. """ - return load_dataarray(which("@static_earth_relief.nc")) + return load_static_earth_relief() @pytest.fixture(scope="module", name="expected_xrimage") From 6aef3caf73b67bc8bb1f0d8eec754f00c42682b3 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 18:42:20 +0800 Subject: [PATCH 21/23] Fix a typo --- pygmt/tests/test_clib_read_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/tests/test_clib_read_data.py b/pygmt/tests/test_clib_read_data.py index c77e4002aa1..c58b1fd9d98 100644 --- a/pygmt/tests/test_clib_read_data.py +++ b/pygmt/tests/test_clib_read_data.py @@ -11,8 +11,8 @@ from pygmt.clib import Session from pygmt.exceptions import GMTCLibError from pygmt.helpers import GMTTempFile +from pygmt.helpers.testing import load_static_earth_relief from pygmt.src import which -from pygmt.tests.helpers import load_static_earth_relief try: import rioxarray From 72afbfe0b6f3a41a68e64ddb1ee628d764093489 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 19:27:54 +0800 Subject: [PATCH 22/23] Replace xr.open_dataarray with read --- pygmt/tests/test_clib_put_matrix.py | 23 +++++++++++------------ pygmt/tests/test_nearneighbor.py | 11 ++++++----- pygmt/tests/test_surface.py | 6 +++--- pygmt/tests/test_triangulate.py | 12 ++++++------ 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/pygmt/tests/test_clib_put_matrix.py b/pygmt/tests/test_clib_put_matrix.py index cf1a71f99ec..b2e2b205947 100644 --- a/pygmt/tests/test_clib_put_matrix.py +++ b/pygmt/tests/test_clib_put_matrix.py @@ -5,8 +5,7 @@ import numpy as np import numpy.testing as npt import pytest -import xarray as xr -from pygmt import clib +from pygmt import clib, read from pygmt.clib.session import DTYPES_NUMERIC from pygmt.exceptions import GMTCLibError from pygmt.helpers import GMTTempFile @@ -101,7 +100,7 @@ def test_put_matrix_grid(dtypes): newdata = tmp_file.loadtxt(dtype=dtype) npt.assert_allclose(newdata, data) - # Save the data to a netCDF grid and check that xarray can load it + # Save the data to a netCDF grid and check it can be read again. with GMTTempFile(suffix=".nc") as tmp_grid: lib.write_data( "GMT_IS_MATRIX", @@ -111,12 +110,12 @@ def test_put_matrix_grid(dtypes): tmp_grid.name, grid, ) - with xr.open_dataarray(tmp_grid.name) as dataarray: - assert dataarray.shape == shape - npt.assert_allclose(dataarray.data, np.flipud(data)) - npt.assert_allclose( - dataarray.coords["x"].actual_range, np.array(wesn[0:2]) - ) - npt.assert_allclose( - dataarray.coords["y"].actual_range, np.array(wesn[2:4]) - ) + dataarray = read(tmp_grid.name, kind="grid") + assert dataarray.shape == shape + npt.assert_allclose(dataarray.data, np.flipud(data)) + npt.assert_allclose( + dataarray.coords["x"].actual_range, np.array(wesn[0:2]) + ) + npt.assert_allclose( + dataarray.coords["y"].actual_range, np.array(wesn[2:4]) + ) diff --git a/pygmt/tests/test_nearneighbor.py b/pygmt/tests/test_nearneighbor.py index 9d08aa0e8d1..f95e9ac1758 100644 --- a/pygmt/tests/test_nearneighbor.py +++ b/pygmt/tests/test_nearneighbor.py @@ -8,7 +8,7 @@ import numpy.testing as npt import pytest import xarray as xr -from pygmt import nearneighbor +from pygmt import nearneighbor, read from pygmt.datasets import load_sample_data from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -81,7 +81,8 @@ def test_nearneighbor_with_outgrid_param(ship_data): ) assert output is None # check that output is None since outgrid is set assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - with xr.open_dataarray(tmpfile.name) as grid: - assert isinstance(grid, xr.DataArray) # ensure netCDF grid loads ok - assert grid.shape == (121, 121) - npt.assert_allclose(grid.mean(), -2378.2385) + + grid = read(tmpfile.name, kind="grid") + assert isinstance(grid, xr.DataArray) # ensure netCDF grid loads ok + assert grid.shape == (121, 121) + npt.assert_allclose(grid.mean(), -2378.2385) diff --git a/pygmt/tests/test_surface.py b/pygmt/tests/test_surface.py index e8ec0cf3445..51358b07377 100644 --- a/pygmt/tests/test_surface.py +++ b/pygmt/tests/test_surface.py @@ -7,7 +7,7 @@ import pandas as pd import pytest import xarray as xr -from pygmt import surface, which +from pygmt import read, surface, which from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -144,5 +144,5 @@ def test_surface_with_outgrid_param(data, region, spacing, expected_grid): ) assert output is None # check that output is None since outgrid is set assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - with xr.open_dataarray(tmpfile.name) as grid: - check_values(grid, expected_grid) + grid = read(tmpfile.name, kind="grid") + check_values(grid, expected_grid) diff --git a/pygmt/tests/test_triangulate.py b/pygmt/tests/test_triangulate.py index f0a47c6e4ae..6af967dd1c4 100644 --- a/pygmt/tests/test_triangulate.py +++ b/pygmt/tests/test_triangulate.py @@ -8,7 +8,7 @@ import pandas as pd import pytest import xarray as xr -from pygmt import triangulate, which +from pygmt import read, triangulate, which from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -155,8 +155,8 @@ def test_regular_grid_with_outgrid_param(dataframe, expected_grid): ) assert output is None # check that output is None since outgrid is set assert Path(tmpfile.name).stat().st_size > 0 # check that outgrid exists - with xr.open_dataarray(tmpfile.name) as grid: - assert isinstance(grid, xr.DataArray) - assert grid.gmt.registration == 0 # Gridline registration - assert grid.gmt.gtype == 0 # Cartesian type - xr.testing.assert_allclose(a=grid, b=expected_grid) + grid = read(tmpfile.name, kind="grid") + assert isinstance(grid, xr.DataArray) + assert grid.gmt.registration == 0 # Gridline registration + assert grid.gmt.gtype == 0 # Cartesian type + xr.testing.assert_allclose(a=grid, b=expected_grid) From 03de9b7a7909584f7125bc69cd1c1c63caa4c9d0 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 9 Dec 2024 22:17:54 +0800 Subject: [PATCH 23/23] Fix a typo --- pygmt/tests/test_read.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pygmt/tests/test_read.py b/pygmt/tests/test_read.py index 083f177f61c..9b338d2b9ef 100644 --- a/pygmt/tests/test_read.py +++ b/pygmt/tests/test_read.py @@ -2,7 +2,7 @@ Test the read function. """ -import importlib.util +import importlib import numpy as np import pytest