diff --git a/pygmt/src/grdtrack.py b/pygmt/src/grdtrack.py index d9b005883b9..066d7464ebe 100644 --- a/pygmt/src/grdtrack.py +++ b/pygmt/src/grdtrack.py @@ -1,16 +1,13 @@ """ grdtrack - Sample grids at specified (x,y) locations. """ +import warnings + +import numpy as np import pandas as pd from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import ( - GMTTempFile, - build_arg_string, - fmt_docstring, - kwargs_to_strings, - use_alias, -) +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias __doctest_skip__ = ["grdtrack"] @@ -43,7 +40,9 @@ w="wrap", ) @kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma") -def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): +def grdtrack( + grid, points=None, output_type="pandas", outfile=None, newcolname=None, **kwargs +): r""" Sample grids at specified (x,y) locations. @@ -292,29 +291,44 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs): if hasattr(points, "columns") and newcolname is None: raise GMTInvalidInput("Please pass in a str to 'newcolname'") - with GMTTempFile(suffix=".csv") as tmpfile: - with Session() as lib: - with lib.virtualfile_from_data( - check_kind="raster", data=grid - ) as grdfile, lib.virtualfile_from_data( - check_kind="vector", data=points, required_data=False - ) as csvfile: - kwargs["G"] = grdfile - if outfile is None: # Output to tmpfile if outfile is not set - outfile = tmpfile.name - lib.call_module( - module="grdtrack", - args=build_arg_string(kwargs, infile=csvfile, outfile=outfile), - ) + if output_type not in ["numpy", "pandas", "file"]: + raise GMTInvalidInput( + "Must specify 'output_type' either as 'numpy', 'pandas' or 'file'." + ) + + if outfile is not None and output_type != "file": + msg = ( + f"Changing 'output_type' from '{output_type}' to 'file' " + "since 'outfile' parameter is set. Please use output_type='file' " + "to silence this warning." + ) + warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2) + output_type = "file" + elif outfile is None and output_type == "file": + raise GMTInvalidInput("Must specify 'outfile' for ASCII output.") + + if isinstance(points, pd.DataFrame): + column_names = points.columns.to_list() + [newcolname] + else: + column_names = None - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - try: - column_names = points.columns.to_list() + [newcolname] - result = pd.read_csv(tmpfile.name, sep="\t", names=column_names) - except AttributeError: # 'str' object has no attribute 'columns' - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") - elif outfile != tmpfile.name: # return None if outfile set, output in outfile - result = None + with Session() as lib: + with lib.virtualfile_from_data( + check_kind="raster", data=grid + ) as ingrid, lib.virtualfile_from_data( + check_kind="vector", data=points, required_data=False + ) as infile, lib.virtualfile_to_data( + kind="dataset", fname=outfile + ) as outvfile: + kwargs["G"] = ingrid + lib.call_module( + module="grdtrack", + args=build_arg_string(kwargs, infile=infile, outfile=outvfile), + ) - return result + if output_type == "file": + return None + vectors = lib.read_virtualfile(outvfile, kind="dataset").contents.to_vectors() + if output_type == "numpy": + return np.array(vectors).T + return pd.DataFrame(np.array(vectors).T, columns=column_names)