diff --git a/pygmt/datasets/tutorial.py b/pygmt/datasets/tutorial.py index 10e394a213a..3d9188e0dfa 100644 --- a/pygmt/datasets/tutorial.py +++ b/pygmt/datasets/tutorial.py @@ -2,7 +2,6 @@ Functions to load sample data from the GMT tutorials. """ import pandas as pd -import xarray as xr from .. import which diff --git a/pygmt/sampling.py b/pygmt/sampling.py index 9c988385089..7df01eb2884 100644 --- a/pygmt/sampling.py +++ b/pygmt/sampling.py @@ -2,7 +2,6 @@ GMT modules for Sampling of 1-D and 2-D Data """ import pandas as pd -import xarray as xr from .clib import Session from .helpers import ( @@ -16,9 +15,7 @@ @fmt_docstring -def grdtrack( - points: pd.DataFrame, grid: xr.DataArray, newcolname: str = None, **kwargs -): +def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): """ Sample grids at specified (x,y) locations. @@ -33,33 +30,40 @@ def grdtrack( Parameters ---------- - points: pandas.DataFrame - Table with (x, y) or (lon, lat) values in the first two columns. More columns - may be present. + points: pandas.DataFrame or file (csv, txt, etc) + Either a table with (x, y) or (lon, lat) values in the first two columns, + or a data file name. More columns may be present. grid: xarray.DataArray or file (netcdf) Gridded array from which to sample values from. newcolname: str - Name for the new column in the table where the sampled values will be placed. + Required if 'points' is a pandas.DataFrame. The name for the new column in the + track pandas.DataFrame table where the sampled values will be placed. + + outfile: str + Required if 'points' is a file. The file name for the output ASCII file. Returns ------- - track: pandas.DataFrame - Table with (x, y, ..., newcolname) or (lon, lat, ..., newcolname) values. + track: pandas.DataFrame or None + Return type depends on whether the outfile parameter is set: + - pandas.DataFrame table with (x, y, ..., newcolname) if outfile is not set + - None if outfile is set (track output will be stored in outfile) """ - try: - assert isinstance(newcolname, str) - except AssertionError: - raise GMTInvalidInput("Please pass in a str to 'newcolname'") - with GMTTempFile(suffix=".csv") as tmpfile: with Session() as lib: # Store the pandas.DataFrame points table in virtualfile if data_kind(points) == "matrix": + if newcolname is None: + raise GMTInvalidInput("Please pass in a str to 'newcolname'") table_context = lib.virtualfile_from_matrix(points.values) + elif data_kind(points) == "file": + if outfile is None: + raise GMTInvalidInput("Please pass in a str to 'outfile'") + table_context = dummy_context(points) else: raise GMTInvalidInput(f"Unrecognized data type {type(points)}") @@ -75,13 +79,18 @@ def grdtrack( with table_context as csvfile: with grid_context as grdfile: kwargs.update({"G": grdfile}) + if outfile is None: # Output to tmpfile if outfile is not set + outfile = tmpfile.name arg_str = " ".join( - [csvfile, build_arg_string(kwargs), "->" + tmpfile.name] + [csvfile, build_arg_string(kwargs), "->" + outfile] ) lib.call_module(module="grdtrack", args=arg_str) # Read temporary csv output to a pandas table - column_names = points.columns.to_list() + [newcolname] - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) + if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame + column_names = points.columns.to_list() + [newcolname] + result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) + elif outfile != tmpfile.name: # return None if outfile set, output in outfile + result = None return result diff --git a/pygmt/tests/test_grdtrack.py b/pygmt/tests/test_grdtrack.py index fbf3b972dce..4801ee0428f 100644 --- a/pygmt/tests/test_grdtrack.py +++ b/pygmt/tests/test_grdtrack.py @@ -1,6 +1,7 @@ """ Tests for grdtrack """ +import os import pandas as pd import pytest @@ -11,6 +12,9 @@ from ..exceptions import GMTInvalidInput from ..helpers import data_kind +TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") +TEMP_TRACK = os.path.join(TEST_DATA_DIR, "tmp_track.txt") + def test_grdtrack_input_dataframe_and_dataarray(): """ @@ -27,6 +31,26 @@ def test_grdtrack_input_dataframe_and_dataarray(): return output +def test_grdtrack_input_csvfile_and_dataarray(): + """ + Run grdtrack by passing in a csvfile and xarray.DataArray as inputs + """ + csvfile = which("@ridge.txt", download="c") + dataarray = load_earth_relief().sel(lat=slice(-49, -42), lon=slice(-118, -107)) + + try: + output = grdtrack(points=csvfile, grid=dataarray, outfile=TEMP_TRACK) + assert output is None # check that output is None since outfile is set + assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path + + track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">") + assert track.iloc[0].to_list() == [-110.9536, -42.2489, -2823.96637605] + finally: + os.remove(path=TEMP_TRACK) + + return output + + def test_grdtrack_input_dataframe_and_ncfile(): """ Run grdtrack by passing in a pandas.DataFrame and netcdf file as inputs @@ -42,9 +66,29 @@ def test_grdtrack_input_dataframe_and_ncfile(): return output +def test_grdtrack_input_csvfile_and_ncfile(): + """ + Run grdtrack by passing in a csvfile and netcdf file as inputs + """ + csvfile = which("@ridge.txt", download="c") + ncfile = which("@earth_relief_60m", download="c") + + try: + output = grdtrack(points=csvfile, grid=ncfile, outfile=TEMP_TRACK) + assert output is None # check that output is None since outfile is set + assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path + + track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">") + assert track.iloc[0].to_list() == [-32.2971, 37.4118, -1697.87197487] + finally: + os.remove(path=TEMP_TRACK) + + return output + + def test_grdtrack_wrong_kind_of_points_input(): """ - Run grdtrack using points input that is not a pandas.DataFrame (matrix) + Run grdtrack using points input that is not a pandas.DataFrame (matrix) or file """ dataframe = load_ocean_ridge_points() invalid_points = dataframe.longitude.to_xarray() @@ -77,3 +121,14 @@ def test_grdtrack_without_newcolname_setting(): with pytest.raises(GMTInvalidInput): grdtrack(points=dataframe, grid=dataarray) + + +def test_grdtrack_without_outfile_setting(): + """ + Run grdtrack by not passing in outfile parameter setting + """ + csvfile = which("@ridge.txt", download="c") + dataarray = load_earth_relief().sel(lat=slice(-49, -42), lon=slice(-118, -107)) + + with pytest.raises(GMTInvalidInput): + grdtrack(points=csvfile, grid=dataarray)