Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grdtrack: Fix the bug when profile is given #1867

Merged
merged 20 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 58 additions & 15 deletions pygmt/src/grdtrack.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""
grdtrack - Sample grids at specified (x,y) locations.
"""
import warnings
weiji14 marked this conversation as resolved.
Show resolved Hide resolved

import pandas as pd
import xarray as xr
weiji14 marked this conversation as resolved.
Show resolved Hide resolved
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
Expand All @@ -11,6 +14,7 @@
kwargs_to_strings,
use_alias,
)
from pygmt.src.which import which

__doctest_skip__ = ["grdtrack"]

Expand Down Expand Up @@ -43,7 +47,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
r"""
Sample grids at specified (x,y) locations.

Expand All @@ -67,14 +71,14 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):

Parameters
----------
points : str or {table-like}
Pass in either a file name to an ASCII data table, a 2D
{table-classes}.

grid : xarray.DataArray or str
Gridded array from which to sample values from, or a filename (netcdf
format).

points : str or {table-like}
Pass in either a file name to an ASCII data table, a 2D
{table-classes}.

newcolname : str
Required if ``points`` is a :class:`pandas.DataFrame`. The name for the
new column in the track :class:`pandas.DataFrame` table where the
Expand Down Expand Up @@ -283,26 +287,65 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
... points=points, grid=grid, newcolname="bathymetry"
... )
"""
# pylint: disable=too-many-branches
if points is not None and kwargs.get("E") is not None:
raise GMTInvalidInput("Can't set both 'points' and 'profile'.")

if points is None and kwargs.get("E") is None:
raise GMTInvalidInput("Must give 'points' or set 'profile'.")
seisman marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(points, "columns") and newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")

# Backward compatibility with old parameter order "points, grid".
# deprecated_version="0.7.0", remove_version="v0.9.0"
is_a_grid = True
if not isinstance(grid, (xr.DataArray, str)):
is_a_grid = False
elif isinstance(grid, str):
try:
xr.open_dataarray(which(grid, download="a"), engine="netcdf4").close()
is_a_grid = True
except (ValueError, OSError):
is_a_grid = False
if not is_a_grid:
msg = (
"Positional parameters 'points, grid' of pygmt.grdtrack() has changed "
"to 'grid, points=None' since v0.7.0. It's likely that you're NOT "
"passing a valid grid as the first positional argument or "
"are passing an invalid grid to the 'grid' parameter. "
"Please check the order of arguments with the latest documentation. "
"This warning will be removed in v0.9.0."
)
grid, points = points, grid
warnings.warn(msg, category=FutureWarning, stacklevel=1)

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
# Choose how data will be passed into the module
table_context = lib.virtualfile_from_data(check_kind="vector", data=points)
# Store the xarray.DataArray grid in virtualfile
grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid)

# Run grdtrack on the temporary (csv) points table
# and (netcdf) grid virtualfile
with table_context as csvfile:
with grid_context as grdfile:
kwargs.update({"G": grdfile})
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name
with grid_context as grdfile:
kwargs.update({"G": grdfile})
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name

if points is not None:
# Choose how data will be passed into the module
table_context = lib.virtualfile_from_data(
check_kind="vector", data=points
)
with table_context as csvfile:
lib.call_module(
module="grdtrack",
args=build_arg_string(
kwargs, infile=csvfile, outfile=outfile
),
)
else:
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=csvfile, outfile=outfile),
args=build_arg_string(kwargs, outfile=outfile),
)

# Read temporary csv output to a pandas table
Expand Down
54 changes: 54 additions & 0 deletions pygmt/tests/test_grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,29 @@ def test_grdtrack_input_csvfile_and_ncfile_to_dataframe(expected_array):
npt.assert_allclose(np.array(output), expected_array)


def test_grdtrack_profile(dataarray):
"""
Run grdtrack by passing a profile.
"""
output = grdtrack(grid=dataarray, profile="-51/-17/-54/-19")
assert isinstance(output, pd.DataFrame)
npt.assert_allclose(
np.array(output),
np.array(
[
[-51.0, -17.0, 669.671875],
[-51.42430204, -17.28838525, 847.40745877],
[-51.85009439, -17.57598444, 885.30534844],
[-52.27733766, -17.86273467, 829.85423488],
[-52.70599151, -18.14857333, 776.83702212],
[-53.13601473, -18.43343819, 631.07867839],
[-53.56736521, -18.7172675, 504.28037216],
[-54.0, -19.0, 486.10351562],
]
),
)


def test_grdtrack_wrong_kind_of_points_input(dataarray, dataframe):
"""
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or
Expand Down Expand Up @@ -137,3 +160,34 @@ def test_grdtrack_without_outfile_setting(dataarray, dataframe):
"""
with pytest.raises(GMTInvalidInput):
grdtrack(points=dataframe, grid=dataarray)


def test_grdtrack_no_points_and_profile(dataarray):
"""
Run grdtrack but don't set 'points' and 'profile'.
"""
with pytest.raises(GMTInvalidInput):
grdtrack(grid=dataarray)


def test_grdtrack_set_points_and_profile(dataarray, dataframe):
"""
Run grdtrack but set both 'points' and 'profile'.
"""
with pytest.raises(GMTInvalidInput):
grdtrack(grid=dataarray, points=dataframe, profile="BL/TR")


def test_grdtrack_old_parameter_order(dataframe, dataarray, expected_array):
"""
Run grdtrack with the old parameter order 'points, grid'.

This test should be removed in v0.9.0.
"""
for points in (POINTS_DATA, dataframe):
for grid in ("@static_earth_relief.nc", dataarray):
with pytest.warns(expected_warning=FutureWarning) as record:
output = grdtrack(points, grid)
assert len(record) == 1
assert isinstance(output, pd.DataFrame)
npt.assert_allclose(np.array(output), expected_array)