diff --git a/pygmt/src/filter1d.py b/pygmt/src/filter1d.py index 79163e2b0dd..4bd0cdf8344 100644 --- a/pygmt/src/filter1d.py +++ b/pygmt/src/filter1d.py @@ -2,11 +2,13 @@ filter1d - Time domain filtering of 1-D data tables """ +from typing import Literal + import pandas as pd +import xarray as xr from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_string, fmt_docstring, use_alias, @@ -20,7 +22,12 @@ F="filter_type", N="time_col", ) -def filter1d(data, output_type="pandas", outfile=None, **kwargs): +def filter1d( + data, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, +) -> pd.DataFrame | xr.DataArray | None: r""" Time domain filtering of 1-D data tables. @@ -38,6 +45,8 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs): Parameters ---------- + {output_type} + {outfile} filter_type : str **type**\ *width*\ [**+h**]. Set the filter **type**. Choose among convolution and non-convolution @@ -91,48 +100,27 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs): left-most column is 0, while the right-most is (*n_cols* - 1) [Default is ``0``]. - output_type : str - Determine the format the xyz data will be returned in [Default is - ``pandas``]: - - - ``numpy`` - :class:`numpy.ndarray` - - ``pandas``- :class:`pandas.DataFrame` - - ``file`` - ASCII file (requires ``outfile``) - outfile : str - The file name for the output ASCII file. - Returns ------- - ret : pandas.DataFrame or numpy.ndarray or None + ret Return type depends on ``outfile`` and ``output_type``: - - None if ``outfile`` is set (output will be stored in file set by - ``outfile``) - - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is - not set (depends on ``output_type`` [Default is - :class:`pandas.DataFrame`]) + - None if ``outfile`` is set (output will be stored in file set by ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) """ if kwargs.get("F") is None: raise GMTInvalidInput("Pass a required argument to 'filter_type'.") output_type = validate_output_table_type(output_type, outfile=outfile) - with GMTTempFile() as tmpfile: - with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: - if outfile is None: - outfile = tmpfile.name - lib.call_module( - module="filter1d", - args=build_arg_string(kwargs, infile=vintbl, outfile=outfile), - ) - - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") - elif outfile != tmpfile.name: # return None if outfile set, output in outfile - result = None - - if output_type == "numpy": - result = result.to_numpy() - return result + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="vector", data=data) as vintbl, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + lib.call_module( + module="filter1d", + args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset(output_type=output_type, vfname=vouttbl) diff --git a/pygmt/tests/test_filter1d.py b/pygmt/tests/test_filter1d.py index c98c39f7ac4..9fa6f1c50db 100644 --- a/pygmt/tests/test_filter1d.py +++ b/pygmt/tests/test_filter1d.py @@ -2,15 +2,10 @@ Test pygmt.filter1d. """ -from pathlib import Path - -import numpy as np import pandas as pd import pytest from pygmt import filter1d from pygmt.datasets import load_sample_data -from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import GMTTempFile @pytest.fixture(scope="module", name="data") @@ -21,76 +16,11 @@ def fixture_data(): return load_sample_data(name="maunaloa_co2") -def test_filter1d_no_outfile(data): +@pytest.mark.benchmark +def test_filter1d(data): """ - Test filter1d with no set outfile. + Test the basic functionality of filter1d. """ result = filter1d(data=data, filter_type="g5") + assert isinstance(result, pd.DataFrame) assert result.shape == (671, 2) - - -def test_filter1d_file_output(data): - """ - Test that filter1d returns a file output when it is specified. - """ - with GMTTempFile(suffix=".txt") as tmpfile: - result = filter1d( - data=data, filter_type="g5", outfile=tmpfile.name, output_type="file" - ) - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists - - -def test_filter1d_invalid_format(data): - """ - Test that filter1d fails with an incorrect format for output_type. - """ - with pytest.raises(GMTInvalidInput): - filter1d(data=data, filter_type="g5", output_type="a") - - -def test_filter1d_no_filter(data): - """ - Test that filter1d fails with an argument is missing for filter. - """ - with pytest.raises(GMTInvalidInput): - filter1d(data=data) - - -def test_filter1d_no_outfile_specified(data): - """ - Test that filter1d fails when outpput_type is set to 'file' but no output file name - is specified. - """ - with pytest.raises(GMTInvalidInput): - filter1d(data=data, filter_type="g5", output_type="file") - - -def test_filter1d_outfile_incorrect_output_type(data): - """ - Test that filter1d raises a warning when an outfile filename is set but the - output_type is not set to 'file'. - """ - with pytest.warns(RuntimeWarning): - with GMTTempFile(suffix=".txt") as tmpfile: - result = filter1d( - data=data, filter_type="g5", outfile=tmpfile.name, output_type="numpy" - ) - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists - - -@pytest.mark.benchmark -def test_filter1d_format(data): - """ - Test that correct formats are returned. - """ - time_series_default = filter1d(data=data, filter_type="g5") - assert isinstance(time_series_default, pd.DataFrame) - assert time_series_default.shape == (671, 2) - time_series_array = filter1d(data=data, filter_type="g5", output_type="numpy") - assert isinstance(time_series_array, np.ndarray) - assert time_series_array.shape == (671, 2) - time_series_df = filter1d(data=data, filter_type="g5", output_type="pandas") - assert isinstance(time_series_df, pd.DataFrame) - assert time_series_df.shape == (671, 2)