Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pygmt.filter1d: Improve performance by storing output in virtual files #3085

Merged
merged 7 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 31 additions & 39 deletions pygmt/src/filter1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
filter1d - Time domain filtering of 1-D data tables
"""

import pandas as pd
from typing import Literal

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
fmt_docstring,
use_alias,
Expand All @@ -20,7 +20,12 @@
F="filter_type",
N="time_col",
)
def filter1d(data, output_type="pandas", outfile=None, **kwargs):
def filter1d(
data,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
**kwargs,
):
r"""
Time domain filtering of 1-D data tables.

Expand All @@ -38,6 +43,15 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs):

Parameters
----------
output_type
Desired output type of the result data.

- ``pandas`` will return a :class:`pandas.DataFrame` object.
- ``numpy`` will return a :class:`numpy.ndarray` object.
- ``file`` will save the result to the file given by the ``outfile`` parameter.
outfile
File name for saving the result data. Required if ``output_type`` is ``"file"``.
If specified, ``output_type`` will be forced to be ``"file"``.
filter_type : str
**type**\ *width*\ [**+h**].
Set the filter **type**. Choose among convolution and non-convolution
Expand Down Expand Up @@ -91,48 +105,26 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs):
left-most column is 0, while the right-most is (*n_cols* - 1)
[Default is ``0``].

output_type : str
Determine the format the xyz data will be returned in [Default is
``pandas``]:

- ``numpy`` - :class:`numpy.ndarray`
- ``pandas``- :class:`pandas.DataFrame`
- ``file`` - ASCII file (requires ``outfile``)
outfile : str
The file name for the output ASCII file.

Returns
-------
ret : pandas.DataFrame or numpy.ndarray or None
ret
Return type depends on ``outfile`` and ``output_type``:

- None if ``outfile`` is set (output will be stored in file set by
``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is
not set (depends on ``output_type`` [Default is
:class:`pandas.DataFrame`])
- None if ``outfile`` is set (output will be stored in file set by ``outfile``)
- :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set
(depends on ``output_type``)
"""
if kwargs.get("F") is None:
raise GMTInvalidInput("Pass a required argument to 'filter_type'.")

output_type = validate_output_table_type(output_type, outfile=outfile)

with GMTTempFile() as tmpfile:
with Session() as lib:
with lib.virtualfile_in(check_kind="vector", data=data) as vintbl:
if outfile is None:
outfile = tmpfile.name
lib.call_module(
module="filter1d",
args=build_arg_string(kwargs, infile=vintbl, outfile=outfile),
)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame
result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">")
elif outfile != tmpfile.name: # return None if outfile set, output in outfile
result = None

if output_type == "numpy":
result = result.to_numpy()
return result
with Session() as lib:
with (
lib.virtualfile_in(check_kind="vector", data=data) as vintbl,
lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl,
):
lib.call_module(
module="filter1d",
args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl),
)
return lib.virtualfile_to_dataset(output_type=output_type, vfname=vouttbl)
78 changes: 4 additions & 74 deletions pygmt/tests/test_filter1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
Test pygmt.filter1d.
"""

from pathlib import Path

import numpy as np
import pandas as pd
import pytest
from pygmt import filter1d
from pygmt.datasets import load_sample_data
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile


@pytest.fixture(scope="module", name="data")
Expand All @@ -21,76 +16,11 @@ def fixture_data():
return load_sample_data(name="maunaloa_co2")


def test_filter1d_no_outfile(data):
@pytest.mark.benchmark
def test_filter1d(data):
"""
Test filter1d with no set outfile.
Test the basic functionality of filter1d.
"""
result = filter1d(data=data, filter_type="g5")
assert isinstance(result, pd.DataFrame)
assert result.shape == (671, 2)


def test_filter1d_file_output(data):
"""
Test that filter1d returns a file output when it is specified.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
result = filter1d(
data=data, filter_type="g5", outfile=tmpfile.name, output_type="file"
)
assert result is None # return value is None
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists


def test_filter1d_invalid_format(data):
"""
Test that filter1d fails with an incorrect format for output_type.
"""
with pytest.raises(GMTInvalidInput):
filter1d(data=data, filter_type="g5", output_type="a")


def test_filter1d_no_filter(data):
"""
Test that filter1d fails with an argument is missing for filter.
"""
with pytest.raises(GMTInvalidInput):
filter1d(data=data)


def test_filter1d_no_outfile_specified(data):
"""
Test that filter1d fails when outpput_type is set to 'file' but no output file name
is specified.
"""
with pytest.raises(GMTInvalidInput):
filter1d(data=data, filter_type="g5", output_type="file")


def test_filter1d_outfile_incorrect_output_type(data):
"""
Test that filter1d raises a warning when an outfile filename is set but the
output_type is not set to 'file'.
"""
with pytest.warns(RuntimeWarning):
with GMTTempFile(suffix=".txt") as tmpfile:
result = filter1d(
data=data, filter_type="g5", outfile=tmpfile.name, output_type="numpy"
)
assert result is None # return value is None
assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists


@pytest.mark.benchmark
def test_filter1d_format(data):
"""
Test that correct formats are returned.
"""
time_series_default = filter1d(data=data, filter_type="g5")
assert isinstance(time_series_default, pd.DataFrame)
assert time_series_default.shape == (671, 2)
time_series_array = filter1d(data=data, filter_type="g5", output_type="numpy")
assert isinstance(time_series_array, np.ndarray)
assert time_series_array.shape == (671, 2)
time_series_df = filter1d(data=data, filter_type="g5", output_type="pandas")
assert isinstance(time_series_df, pd.DataFrame)
assert time_series_df.shape == (671, 2)
Loading