diff --git a/doc/api/index.rst b/doc/api/index.rst index 3fcabeedd25..69685060086 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -82,6 +82,7 @@ Operations on tabular data: blockmean blockmedian blockmode + nearneighbor surface Operations on grids: diff --git a/pygmt/__init__.py b/pygmt/__init__.py index 656c730aefc..7a1a1079f37 100644 --- a/pygmt/__init__.py +++ b/pygmt/__init__.py @@ -48,6 +48,7 @@ grdtrack, info, makecpt, + nearneighbor, sphdistance, surface, which, diff --git a/pygmt/src/__init__.py b/pygmt/src/__init__.py index 09ce03fe886..a3dded681eb 100644 --- a/pygmt/src/__init__.py +++ b/pygmt/src/__init__.py @@ -32,6 +32,7 @@ from pygmt.src.logo import logo from pygmt.src.makecpt import makecpt from pygmt.src.meca import meca +from pygmt.src.nearneighbor import nearneighbor from pygmt.src.plot import plot from pygmt.src.plot3d import plot3d from pygmt.src.rose import rose diff --git a/pygmt/src/nearneighbor.py b/pygmt/src/nearneighbor.py new file mode 100644 index 00000000000..496c9bef87a --- /dev/null +++ b/pygmt/src/nearneighbor.py @@ -0,0 +1,148 @@ +""" +nearneighbor - Grid table data using a "Nearest neighbor" algorithm +""" + +from pygmt.clib import Session +from pygmt.helpers import ( + GMTTempFile, + build_arg_string, + fmt_docstring, + kwargs_to_strings, + use_alias, +) +from pygmt.io import load_dataarray + + +@fmt_docstring +@use_alias( + E="empty", + G="outgrid", + I="spacing", + N="sectors", + R="region", + S="search_radius", + V="verbose", + a="aspatial", + b="binary", + d="nodata", + e="find", + f="coltypes", + g="gap", + h="header", + i="incols", + r="registration", + w="wrap", +) +@kwargs_to_strings(R="sequence", i="sequence_comma") +def nearneighbor(data=None, x=None, y=None, z=None, **kwargs): + r""" + Grid table data using a "Nearest neighbor" algorithm + + **nearneighbor** reads arbitrarily located (*x,y,z*\ [,\ *w*]) triples + [quadruplets] and uses a nearest neighbor algorithm to assign a weighted + average value to each node that has one or more data points within a search + radius centered on the node with adequate coverage across a subset of the + chosen sectors. The node value is computed as a weighted mean of the + nearest point from each sector inside the search radius. The weighting + function and the averaging used is given by: + + .. math:: + w(r_i) = \frac{{w_i}}{{1 + d(r_i) ^ 2}}, + \quad d(r) = \frac {{3r}}{{R}}, + \quad \bar{{z}} = \frac{{\sum_i^n w(r_i) z_i}}{{\sum_i^n w(r_i)}} + + where :math:`n` is the number of data points that satisfy the selection + criteria and :math:`r_i` is the distance from the node to the *i*'th data + point. If no data weights are supplied then :math:`w_i = 1`. + + .. figure:: https://docs.generic-mapping-tools.org/dev/_images/GMT_nearneighbor.png # noqa: W505 + :width: 300 px + :align: center + + Search geometry includes the search radius (R) which limits the points + considered and the number of sectors (here 4), which restricts how + points inside the search radius contribute to the value at the node. + Only the closest point in each sector (red circles) contribute to the + weighted estimate. + + Takes a matrix, xyz triples, or a file name as input. + + Must provide either ``data`` or ``x``, ``y``, and ``z``. + + Full option list at :gmt-docs:`nearneighbor.html` + + {aliases} + + Parameters + ---------- + data : str or {table-like} + Pass in (x, y, z) or (longitude, latitude, elevation) values by + providing a file name to an ASCII data table, a 2D + {table-classes}. + x/y/z : 1d arrays + Arrays of x and y coordinates and values z of the data points. + + {I} + + {R} + + search_radius : str + Sets the search radius that determines which data points are considered + close to a node. + + outgrid : str + Optional. The file name for the output netcdf file with extension .nc + to store the grid in. + + empty : str + Optional. Set the value assigned to empty nodes. Defaults to NaN. + + sectors : str + *sectors*\ [**+m**\ *min_sectors*]\|\ **n**. + Optional. The circular search area centered on each node is divided + into *sectors* sectors. Average values will only be computed if there + is *at least* one value inside each of at least *min_sectors* of the + sectors for a given node. Nodes that fail this test are assigned the + value NaN (but see ``empty``). If **+m** is omitted then *min_sectors* + is set to be at least 50% of *sectors* (i.e., rounded up to next + integer) [Default is a quadrant search with 100% coverage, i.e., + *sectors* = *min_sectors* = 4]. Note that only the nearest value per + sector enters into the averaging; the more distant points are ignored. + Alternatively, use ``sectors="n"`` to call GDAL's nearest neighbor + algorithm instead. + + {V} + {a} + {b} + {d} + {e} + {f} + {g} + {h} + {i} + {r} + {w} + + Returns + ------- + ret: xarray.DataArray or None + Return type depends on whether the ``outgrid`` parameter is set: + + - :class:`xarray.DataArray`: if ``outgrid`` is not set + - None if ``outgrid`` is set (grid output will be stored in file set by + ``outgrid``) + """ + with GMTTempFile(suffix=".nc") as tmpfile: + with Session() as lib: + # Choose how data will be passed into the module + table_context = lib.virtualfile_from_data( + check_kind="vector", data=data, x=x, y=y, z=z, required_z=True + ) + with table_context as infile: + if "G" not in kwargs.keys(): # if outgrid is unset, output to tmpfile + kwargs.update({"G": tmpfile.name}) + outgrid = kwargs["G"] + arg_str = " ".join([infile, build_arg_string(kwargs)]) + lib.call_module(module="nearneighbor", args=arg_str) + + return load_dataarray(outgrid) if outgrid == tmpfile.name else None diff --git a/pygmt/tests/test_nearneighbor.py b/pygmt/tests/test_nearneighbor.py new file mode 100644 index 00000000000..5722aa5e71a --- /dev/null +++ b/pygmt/tests/test_nearneighbor.py @@ -0,0 +1,86 @@ +""" +Tests for nearneighbor. +""" +import os + +import numpy as np +import numpy.testing as npt +import pytest +import xarray as xr +from pygmt import nearneighbor +from pygmt.datasets import load_sample_bathymetry +from pygmt.exceptions import GMTInvalidInput +from pygmt.helpers import GMTTempFile, data_kind + + +@pytest.fixture(scope="module", name="ship_data") +def fixture_ship_data(): + """ + Load the data from the sample bathymetry dataset. + """ + return load_sample_bathymetry() + + +@pytest.mark.parametrize("array_func", [np.array, xr.Dataset]) +def test_nearneighbor_input_data(array_func, ship_data): + """ + Run nearneighbor by passing in a numpy.array or xarray.Dataset. + """ + data = array_func(ship_data) + output = nearneighbor( + data=data, spacing="5m", region=[245, 255, 20, 30], search_radius="10m" + ) + assert isinstance(output, xr.DataArray) + assert output.gmt.registration == 0 # Gridline registration + assert output.gmt.gtype == 1 # Geographic type + assert output.shape == (121, 121) + npt.assert_allclose(output.mean(), -2378.2385) + + +def test_nearneighbor_input_xyz(ship_data): + """ + Run nearneighbor by passing in x, y, z numpy.ndarrays individually. + """ + output = nearneighbor( + x=ship_data.longitude, + y=ship_data.latitude, + z=ship_data.bathymetry, + spacing="5m", + region=[245, 255, 20, 30], + search_radius="10m", + ) + assert isinstance(output, xr.DataArray) + assert output.shape == (121, 121) + npt.assert_allclose(output.mean(), -2378.2385) + + +def test_nearneighbor_wrong_kind_of_input(ship_data): + """ + Run nearneighbor using grid input that is not file/matrix/vectors. + """ + data = ship_data.bathymetry.to_xarray() # convert pandas.Series to xarray.DataArray + assert data_kind(data) == "grid" + with pytest.raises(GMTInvalidInput): + nearneighbor( + data=data, spacing="5m", region=[245, 255, 20, 30], search_radius="10m" + ) + + +def test_nearneighbor_with_outgrid_param(ship_data): + """ + Run nearneighbor with the 'outgrid' parameter. + """ + with GMTTempFile() as tmpfile: + output = nearneighbor( + data=ship_data, + spacing="5m", + region=[245, 255, 20, 30], + outgrid=tmpfile.name, + search_radius="10m", + ) + assert output is None # check that output is None since outgrid is set + assert os.path.exists(path=tmpfile.name) # check that outgrid exists at path + with xr.open_dataarray(tmpfile.name) as grid: + assert isinstance(grid, xr.DataArray) # ensure netcdf grid loads ok + assert grid.shape == (121, 121) + npt.assert_allclose(grid.mean(), -2378.2385)