Skip to content

Commit

Permalink
Enable ascii file input for grdtrack
Browse files Browse the repository at this point in the history
Enable passing in ascii file inputs (csv, txt, etc) into grdtrack's 'points' parameter instead of just pandas.DataFrame. This requires a new 'outfile' parameter to be set. The type of 'points' input determines the type of 'track' returned, i.e. pd.DataFrame in, pd.DataFrame out; filename in, filename out. Extra unit tests created to test the various new input combinations and associated outputs.
  • Loading branch information
weiji14 committed May 27, 2019
1 parent 3776214 commit a37fb0a
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 20 deletions.
1 change: 0 additions & 1 deletion pygmt/datasets/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Functions to load sample data from the GMT tutorials.
"""
import pandas as pd
import xarray as xr

from .. import which

Expand Down
45 changes: 27 additions & 18 deletions pygmt/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
GMT modules for Sampling of 1-D and 2-D Data
"""
import pandas as pd
import xarray as xr

from .clib import Session
from .helpers import (
Expand All @@ -16,9 +15,7 @@


@fmt_docstring
def grdtrack(
points: pd.DataFrame, grid: xr.DataArray, newcolname: str = None, **kwargs
):
def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
"""
Sample grids at specified (x,y) locations.
Expand All @@ -33,33 +30,40 @@ def grdtrack(
Parameters
----------
points: pandas.DataFrame
Table with (x, y) or (lon, lat) values in the first two columns. More columns
may be present.
points: pandas.DataFrame or file (csv, txt, etc)
Either a table with (x, y) or (lon, lat) values in the first two columns,
or a data file name. More columns may be present.
grid: xarray.DataArray or file (netcdf)
Gridded array from which to sample values from.
newcolname: str
Name for the new column in the table where the sampled values will be placed.
Required if 'points' is a pandas.DataFrame. The name for the new column in the
track pandas.DataFrame table where the sampled values will be placed.
outfile: str
Required if 'points' is a file. The file name for the output ASCII file.
Returns
-------
track: pandas.DataFrame
Table with (x, y, ..., newcolname) or (lon, lat, ..., newcolname) values.
track: pandas.DataFrame or None
Return type depends on whether the outfile parameter is set:
- pandas.DataFrame table with (x, y, ..., newcolname) if outfile is not set
- None if outfile is set (track output will be stored in outfile)
"""

try:
assert isinstance(newcolname, str)
except AssertionError:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
# Store the pandas.DataFrame points table in virtualfile
if data_kind(points) == "matrix":
if newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")
table_context = lib.virtualfile_from_matrix(points.values)
elif data_kind(points) == "file":
if outfile is None:
raise GMTInvalidInput("Please pass in a str to 'outfile'")
table_context = dummy_context(points)
else:
raise GMTInvalidInput(f"Unrecognized data type {type(points)}")

Expand All @@ -75,13 +79,18 @@ def grdtrack(
with table_context as csvfile:
with grid_context as grdfile:
kwargs.update({"G": grdfile})
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name
arg_str = " ".join(
[csvfile, build_arg_string(kwargs), "->" + tmpfile.name]
[csvfile, build_arg_string(kwargs), "->" + outfile]
)
lib.call_module(module="grdtrack", args=arg_str)

# Read temporary csv output to a pandas table
column_names = points.columns.to_list() + [newcolname]
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
column_names = points.columns.to_list() + [newcolname]
result = pd.read_csv(tmpfile.name, sep="\t", names=column_names)
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None

return result
57 changes: 56 additions & 1 deletion pygmt/tests/test_grdtrack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for grdtrack
"""
import os

import pandas as pd
import pytest
Expand All @@ -11,6 +12,9 @@
from ..exceptions import GMTInvalidInput
from ..helpers import data_kind

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
TEMP_TRACK = os.path.join(TEST_DATA_DIR, "tmp_track.txt")


def test_grdtrack_input_dataframe_and_dataarray():
"""
Expand All @@ -27,6 +31,26 @@ def test_grdtrack_input_dataframe_and_dataarray():
return output


def test_grdtrack_input_csvfile_and_dataarray():
"""
Run grdtrack by passing in a csvfile and xarray.DataArray as inputs
"""
csvfile = which("@ridge.txt", download="c")
dataarray = load_earth_relief().sel(lat=slice(-49, -42), lon=slice(-118, -107))

try:
output = grdtrack(points=csvfile, grid=dataarray, outfile=TEMP_TRACK)
assert output is None # check that output is None since outfile is set
assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path

track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">")
assert track.iloc[0].to_list() == [-110.9536, -42.2489, -2823.96637605]
finally:
os.remove(path=TEMP_TRACK)

return output


def test_grdtrack_input_dataframe_and_ncfile():
"""
Run grdtrack by passing in a pandas.DataFrame and netcdf file as inputs
Expand All @@ -42,9 +66,29 @@ def test_grdtrack_input_dataframe_and_ncfile():
return output


def test_grdtrack_input_csvfile_and_ncfile():
"""
Run grdtrack by passing in a csvfile and netcdf file as inputs
"""
csvfile = which("@ridge.txt", download="c")
ncfile = which("@earth_relief_60m", download="c")

try:
output = grdtrack(points=csvfile, grid=ncfile, outfile=TEMP_TRACK)
assert output is None # check that output is None since outfile is set
assert os.path.exists(path=TEMP_TRACK) # check that outfile exists at path

track = pd.read_csv(TEMP_TRACK, sep="\t", header=None, comment=">")
assert track.iloc[0].to_list() == [-32.2971, 37.4118, -1697.87197487]
finally:
os.remove(path=TEMP_TRACK)

return output


def test_grdtrack_wrong_kind_of_points_input():
"""
Run grdtrack using points input that is not a pandas.DataFrame (matrix)
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or file
"""
dataframe = load_ocean_ridge_points()
invalid_points = dataframe.longitude.to_xarray()
Expand Down Expand Up @@ -77,3 +121,14 @@ def test_grdtrack_without_newcolname_setting():

with pytest.raises(GMTInvalidInput):
grdtrack(points=dataframe, grid=dataarray)


def test_grdtrack_without_outfile_setting():
"""
Run grdtrack by not passing in outfile parameter setting
"""
csvfile = which("@ridge.txt", download="c")
dataarray = load_earth_relief().sel(lat=slice(-49, -42), lon=slice(-118, -107))

with pytest.raises(GMTInvalidInput):
grdtrack(points=csvfile, grid=dataarray)

0 comments on commit a37fb0a

Please sign in to comment.