Skip to content

Commit

Permalink
Session.virtualfile_to_dataset: Add 'header' parameter to parse colum…
Browse files Browse the repository at this point in the history
…n names from table header (#3117)
  • Loading branch information
seisman authored Apr 18, 2024
1 parent fd286fb commit 1746c04
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 8 deletions.
7 changes: 6 additions & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,7 @@ def virtualfile_to_dataset(
self,
vfname: str,
output_type: Literal["pandas", "numpy", "file", "strings"] = "pandas",
header: int | None = None,
column_names: list[str] | None = None,
dtype: type | dict[str, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -1831,6 +1832,10 @@ def virtualfile_to_dataset(
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
- ``"file"`` means the result was saved to a file and will return ``None``.
- ``"strings"`` will return the trailing text only as an array of strings.
header
Row number containing column names for the :class:`pandas.DataFrame` output.
``header=None`` means not to parse the column names from table header.
Ignored if the row number is larger than the number of headers in the table.
column_names
The column names for the :class:`pandas.DataFrame` output.
dtype
Expand Down Expand Up @@ -1945,7 +1950,7 @@ def virtualfile_to_dataset(
return result.to_strings()

result = result.to_dataframe(
column_names=column_names, dtype=dtype, index_col=index_col
header=header, column_names=column_names, dtype=dtype, index_col=index_col
)
if output_type == "numpy": # numpy.ndarray output
return result.to_numpy()
Expand Down
24 changes: 19 additions & 5 deletions pygmt/datatypes/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
>>> with GMTTempFile(suffix=".txt") as tmpfile:
... # Prepare the sample data file
... with Path(tmpfile.name).open(mode="w") as fp:
... print("# x y z name", file=fp)
... print(">", file=fp)
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
Expand All @@ -43,7 +44,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
... print(ds.min[: ds.n_columns], ds.max[: ds.n_columns])
... # The table
... tbl = ds.table[0].contents
... print(tbl.n_columns, tbl.n_segments, tbl.n_records)
... print(tbl.n_columns, tbl.n_segments, tbl.n_records, tbl.n_headers)
... print(tbl.header[: tbl.n_headers])
... print(tbl.min[: tbl.n_columns], ds.max[: tbl.n_columns])
... for i in range(tbl.n_segments):
... seg = tbl.segment[i].contents
Expand All @@ -52,7 +54,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
... print(seg.text[: seg.n_rows])
1 3 2
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
3 2 4
3 2 4 1
[b'x y z name']
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
[1.0, 4.0]
[2.0, 5.0]
Expand Down Expand Up @@ -169,6 +172,7 @@ def to_strings(self) -> np.ndarray[Any, np.dtype[np.str_]]:

def to_dataframe(
self,
header: int | None = None,
column_names: pd.Index | None = None,
dtype: type | Mapping[Any, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -187,6 +191,10 @@ def to_dataframe(
----------
column_names
A list of column names.
header
Row number containing column names. ``header=None`` means not to parse the
column names from table header. Ignored if the row number is larger than the
number of headers in the table.
dtype
Data type. Can be a single type for all columns or a dictionary mapping
column names to types.
Expand All @@ -207,6 +215,7 @@ def to_dataframe(
>>> with GMTTempFile(suffix=".txt") as tmpfile:
... # prepare the sample data file
... with Path(tmpfile.name).open(mode="w") as fp:
... print("# col1 col2 col3 colstr", file=fp)
... print(">", file=fp)
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
Expand All @@ -218,12 +227,12 @@ def to_dataframe(
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
... text = ds.contents.to_strings()
... df = ds.contents.to_dataframe()
... df = ds.contents.to_dataframe(header=0)
>>> text
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
'TEXT123 TEXT456789'], dtype='<U18')
>>> df
0 1 2 3
col1 col2 col3 colstr
0 1.0 2.0 3.0 TEXT1 TEXT23
1 4.0 5.0 6.0 TEXT4 TEXT567
2 7.0 8.0 9.0 TEXT8 TEXT90
Expand All @@ -248,14 +257,19 @@ def to_dataframe(
if len(textvector) != 0:
vectors.append(pd.Series(data=textvector, dtype=pd.StringDtype()))

if header is not None:
tbl = self.table[0].contents # Use the first table!
if header < tbl.n_headers:
column_names = tbl.header[header].decode().split()

if len(vectors) == 0:
# Return an empty DataFrame if no columns are found.
df = pd.DataFrame(columns=column_names)
else:
# Create a DataFrame object by concatenating multiple columns
df = pd.concat(objs=vectors, axis="columns")
if column_names is not None: # Assign column names
df.columns = column_names
df.columns = column_names[: df.shape[1]]
if dtype is not None: # Set dtype for the whole dataset or individual columns
df = df.astype(dtype)
if index_col is not None: # Use a specific column as index
Expand Down
61 changes: 59 additions & 2 deletions pygmt/tests/test_datatypes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def dataframe_from_pandas(filepath_or_buffer, sep=r"\s+", comment="#", header=No
return df


def dataframe_from_gmt(fname):
def dataframe_from_gmt(fname, **kwargs):
"""
Read tabular data as pandas.DataFrame using GMT virtual file.
"""
with Session() as lib:
with lib.virtualfile_out(kind="dataset") as vouttbl:
lib.call_module("read", f"{fname} {vouttbl} -Td")
df = lib.virtualfile_to_dataset(vfname=vouttbl)
df = lib.virtualfile_to_dataset(vfname=vouttbl, **kwargs)
return df


Expand Down Expand Up @@ -84,6 +84,63 @@ def test_dataset_empty():
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header():
"""
Test parsing column names from dataset header.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

# Parse columne names from the first header line
df = dataframe_from_gmt(tmpfile.name, header=0)
assert df.columns.tolist() == ["lon", "lat", "z", "text"]
# pd.read_csv() can't parse the header line with a leading '#'.
# So, we need to skip the header line and manually set the column names.
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
expected_df.columns = df.columns.tolist()
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header_greater_than_nheaders():
"""
Test passing a header line number that is greater than the number of header lines.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

# Parse column names from the second header line.
df = dataframe_from_gmt(tmpfile.name, header=1)
# There is only one header line, so the column names should be default.
assert df.columns.tolist() == [0, 1, 2, 3]
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header_too_many_names():
"""
Test passing a header line with more column names than the number of columns.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text1 text2", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

df = dataframe_from_gmt(tmpfile.name, header=0)
assert df.columns.tolist() == ["lon", "lat", "z", "text1"]
# pd.read_csv() can't parse the header line with a leading '#'.
# So, we need to skip the header line and manually set the column names.
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
expected_df.columns = df.columns.tolist()
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_to_strings_with_none_values():
"""
Test that None values in the trailing text doesn't raise an exception.
Expand Down

0 comments on commit 1746c04

Please sign in to comment.