Skip to content

Commit

Permalink
pygmt.filter1d: Improve performance by storing output in virtual files (
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman authored Mar 13, 2024
1 parent 2f598c5 commit bf7b9a1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 111 deletions.
62 changes: 25 additions & 37 deletions pygmt/src/filter1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
filter1d - Time domain filtering of 1-D data tables
"""

from typing import Literal

import pandas as pd
import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
use_alias,
Expand All @@ -20,7 +22,12 @@
F="filter_type",
N="time_col",
)
def filter1d(data, output_type="pandas", outfile=None, **kwargs):
def filter1d(
data,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
**kwargs,
) -> pd.DataFrame | xr.DataArray | None:
r"""
Time domain filtering of 1-D data tables.
Expand All @@ -38,6 +45,8 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs):
Parameters
----------
{output_type}
{outfile}
filter_type : str
**type**\ *width*\ [**+h**].
Set the filter **type**. Choose among convolution and non-convolution
Expand Down Expand Up @@ -91,48 +100,27 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs):
left-most column is 0, while the right-most is (*n_cols* - 1)
[Default is ``0``].
output_type : str
Determine the format the xyz data will be returned in [Default is
``pandas``]:
- ``numpy`` - :class:`numpy.ndarray`
- ``pandas``- :class:`pandas.DataFrame`
- ``file`` - ASCII file (requires ``outfile``)
outfile : str
The file name for the output ASCII file.
Returns
-------
ret : pandas.DataFrame or numpy.ndarray or None
ret
Return type depends on ``outfile`` and ``output_type``:
- None if ``outfile`` is set (output will be stored in file set by
``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is
not set (depends on ``output_type`` [Default is
:class:`pandas.DataFrame`])
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)
"""
if kwargs.get("F") is None:
raise GMTInvalidInput("Pass a required argument to 'filter_type'.")

output_type = validate_output_table_type(output_type, outfile=outfile)

with GMTTempFile() as tmpfile:
with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
if outfile is None:
outfile = tmpfile.name
lib.call_module(
module="filter1d",
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None

if output_type == "numpy":
result = result.to_numpy()
return result
with Session() as lib:
with (
lib.virtualfile_in(check_kind="vector", data=data) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
lib.call_module(
module="filter1d",
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
)
return lib.virtualfile_to_dataset(output_type=output_type, vfname=vouttbl)
78 changes: 4 additions & 74 deletions pygmt/tests/test_filter1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
Test pygmt.filter1d.
"""

from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from pygmt import filter1d
from pygmt.datasets import load_sample_data
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile


@pytest.fixture(scope="module", name="data")
Expand All @@ -21,76 +16,11 @@ def fixture_data():
return load_sample_data(name="maunaloa_co2")


def test_filter1d_no_outfile(data):
@pytest.mark.benchmark
def test_filter1d(data):
"""
Test filter1d with no set outfile.
Test the basic functionality of filter1d.
"""
result = filter1d(data=data, filter_type="g5")
assert isinstance(result, pd.DataFrame)
assert result.shape == (671, 2)


def test_filter1d_file_output(data):
"""
Test that filter1d returns a file output when it is specified.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
result = filter1d(
data=data, filter_type="g5", outfile=tmpfile.name, output_type="file"
)
assert result is None # return value is None
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists


def test_filter1d_invalid_format(data):
"""
Test that filter1d fails with an incorrect format for output_type.
"""
with pytest.raises(GMTInvalidInput):
filter1d(data=data, filter_type="g5", output_type="a")


def test_filter1d_no_filter(data):
"""
Test that filter1d fails with an argument is missing for filter.
"""
with pytest.raises(GMTInvalidInput):
filter1d(data=data)


def test_filter1d_no_outfile_specified(data):
"""
Test that filter1d fails when outpput_type is set to 'file' but no output file name
is specified.
"""
with pytest.raises(GMTInvalidInput):
filter1d(data=data, filter_type="g5", output_type="file")


def test_filter1d_outfile_incorrect_output_type(data):
"""
Test that filter1d raises a warning when an outfile filename is set but the
output_type is not set to 'file'.
"""
with pytest.warns(RuntimeWarning):
with GMTTempFile(suffix=".txt") as tmpfile:
result = filter1d(
data=data, filter_type="g5", outfile=tmpfile.name, output_type="numpy"
)
assert result is None # return value is None
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists


@pytest.mark.benchmark
def test_filter1d_format(data):
"""
Test that correct formats are returned.
"""
time_series_default = filter1d(data=data, filter_type="g5")
assert isinstance(time_series_default, pd.DataFrame)
assert time_series_default.shape == (671, 2)
time_series_array = filter1d(data=data, filter_type="g5", output_type="numpy")
assert isinstance(time_series_array, np.ndarray)
assert time_series_array.shape == (671, 2)
time_series_df = filter1d(data=data, filter_type="g5", output_type="pandas")
assert isinstance(time_series_df, pd.DataFrame)
assert time_series_df.shape == (671, 2)

0 comments on commit bf7b9a1

Please sign in to comment.