Skip to content

Commit

Permalink
pygmt.grdtrack: Support consistent table-like outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Oct 10, 2023
1 parent 8546048 commit 452ce5e
Showing 1 changed file with 48 additions and 32 deletions.
80 changes: 48 additions & 32 deletions pygmt/src/grdtrack.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""
grdtrack - Sample grids at specified (x,y) locations.
"""
import warnings

import numpy as np
import pandas as pd
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias

__doctest_skip__ = ["grdtrack"]

Expand Down Expand Up @@ -43,7 +40,9 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
def grdtrack(
grid, points=None, output_type="pandas", outfile=None, newcolname=None, **kwargs
):
r"""
Sample grids at specified (x,y) locations.
Expand Down Expand Up @@ -291,30 +290,47 @@ def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):

if hasattr(points, "columns") and newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")

if output_type not in ["numpy", "pandas", "file"]:
raise GMTInvalidInput(
"Must specify 'output_type' either as 'numpy', 'pandas' or 'file'."
)

if outfile is not None and output_type != "file":
msg = (
f"Changing 'output_type' from '{output_type}' to 'file' "
"since 'outfile' parameter is set. Please use output_type='file' "
"to silence this warning."
)
warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2)
output_type = "file"
elif outfile is None and output_type == "file":
raise GMTInvalidInput("Must specify 'outfile' for ASCII output.")


with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as ingrid, lib.virtualfile_from_data(
check_kind="vector", data=points, required_data=False
) as infile, lib.virtualfile_to_gmtdataset() as outvfile:
kwargs["G"] = ingrid
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=infile, outfile=outvfile),
)

if output_type == "file":
lib.call_module("write", f"{outvfile} {outfile} -Td")
return None

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
with lib.virtualfile_from_data(
check_kind="raster", data=grid
) as grdfile, lib.virtualfile_from_data(
check_kind="vector", data=points, required_data=False
) as csvfile:
kwargs["G"] = grdfile
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=csvfile, outfile=outfile),
)
vectors = lib.gmtdataset_to_vectors(outvfile)
if output_type == "numpy":
return np.array(vectors).T

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
try:
column_names = points.columns.to_list() + [newcolname]
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
except AttributeError: # 'str' object has no attribute 'columns'
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None
if isinstance(points, pd.DataFrame):
column_names = points.columns.to_list() + [newcolname]
else:
column_names = None

return result
return pd.DataFrame(np.array(vectors).T, columns=column_names)

0 comments on commit 452ce5e

Please sign in to comment.