From 92dc36a19262090fab93b1a163fd87514bc69326 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 4 Mar 2024 19:38:32 +0800 Subject: [PATCH] pygmt.filter1d: Improve performance by getting rid of temporary files --- pygmt/src/filter1d.py | 69 ++++++++++++++++-------------------- pygmt/tests/test_filter1d.py | 55 ---------------------------- 2 files changed, 30 insertions(+), 94 deletions(-) diff --git a/pygmt/src/filter1d.py b/pygmt/src/filter1d.py index 79163e2b0dd..cd07561b258 100644 --- a/pygmt/src/filter1d.py +++ b/pygmt/src/filter1d.py @@ -2,11 +2,11 @@ filter1d - Time domain filtering of 1-D data tables """ -import pandas as pd +from typing import Literal + from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_string, fmt_docstring, use_alias, @@ -20,7 +20,12 @@ F="filter_type", N="time_col", ) -def filter1d(data, output_type="pandas", outfile=None, **kwargs): +def filter1d( + data, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, +): r""" Time domain filtering of 1-D data tables. @@ -38,6 +43,14 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs): Parameters ---------- + output_type + Desired output type of the result data. + - ``pandas`` will return a :class:`pandas.DataFrame` object. + - ``numpy`` will return a :class:`numpy.ndarray` object. + - ``file`` will save the result to the file given by the ``outfile`` parameter. + outfile + File name for saving the result data. Required if ``output_type`` is ``"file"``. + If specified, ``output_type`` will be forced to be ``"file"``. filter_type : str **type**\ *width*\ [**+h**]. Set the filter **type**. Choose among convolution and non-convolution @@ -91,48 +104,26 @@ def filter1d(data, output_type="pandas", outfile=None, **kwargs): left-most column is 0, while the right-most is (*n_cols* - 1) [Default is ``0``]. - output_type : str - Determine the format the xyz data will be returned in [Default is - ``pandas``]: - - - ``numpy`` - :class:`numpy.ndarray` - - ``pandas``- :class:`pandas.DataFrame` - - ``file`` - ASCII file (requires ``outfile``) - outfile : str - The file name for the output ASCII file. - Returns ------- - ret : pandas.DataFrame or numpy.ndarray or None + ret Return type depends on ``outfile`` and ``output_type``: - - - None if ``outfile`` is set (output will be stored in file set by - ``outfile``) - - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is - not set (depends on ``output_type`` [Default is - :class:`pandas.DataFrame`]) + - None if ``outfile`` is set (output will be stored in file set by ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) """ if kwargs.get("F") is None: raise GMTInvalidInput("Pass a required argument to 'filter_type'.") output_type = validate_output_table_type(output_type, outfile=outfile) - with GMTTempFile() as tmpfile: - with Session() as lib: - with lib.virtualfile_in(check_kind="vector", data=data) as vintbl: - if outfile is None: - outfile = tmpfile.name - lib.call_module( - module="filter1d", - args=build_arg_string(kwargs, infile=vintbl, outfile=outfile), - ) - - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if user did not set outfile, return pd.DataFrame - result = pd.read_csv(tmpfile.name, sep="\t", header=None, comment=">") - elif outfile != tmpfile.name: # return None if outfile set, output in outfile - result = None - - if output_type == "numpy": - result = result.to_numpy() - return result + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="vector", data=data) as vintbl, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + lib.call_module( + module="filter1d", + args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset(output_type=output_type, vfname=vouttbl) diff --git a/pygmt/tests/test_filter1d.py b/pygmt/tests/test_filter1d.py index c98c39f7ac4..9e59c361469 100644 --- a/pygmt/tests/test_filter1d.py +++ b/pygmt/tests/test_filter1d.py @@ -2,15 +2,11 @@ Test pygmt.filter1d. """ -from pathlib import Path - import numpy as np import pandas as pd import pytest from pygmt import filter1d from pygmt.datasets import load_sample_data -from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import GMTTempFile @pytest.fixture(scope="module", name="data") @@ -29,57 +25,6 @@ def test_filter1d_no_outfile(data): assert result.shape == (671, 2) -def test_filter1d_file_output(data): - """ - Test that filter1d returns a file output when it is specified. - """ - with GMTTempFile(suffix=".txt") as tmpfile: - result = filter1d( - data=data, filter_type="g5", outfile=tmpfile.name, output_type="file" - ) - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists - - -def test_filter1d_invalid_format(data): - """ - Test that filter1d fails with an incorrect format for output_type. - """ - with pytest.raises(GMTInvalidInput): - filter1d(data=data, filter_type="g5", output_type="a") - - -def test_filter1d_no_filter(data): - """ - Test that filter1d fails with an argument is missing for filter. - """ - with pytest.raises(GMTInvalidInput): - filter1d(data=data) - - -def test_filter1d_no_outfile_specified(data): - """ - Test that filter1d fails when outpput_type is set to 'file' but no output file name - is specified. - """ - with pytest.raises(GMTInvalidInput): - filter1d(data=data, filter_type="g5", output_type="file") - - -def test_filter1d_outfile_incorrect_output_type(data): - """ - Test that filter1d raises a warning when an outfile filename is set but the - output_type is not set to 'file'. - """ - with pytest.warns(RuntimeWarning): - with GMTTempFile(suffix=".txt") as tmpfile: - result = filter1d( - data=data, filter_type="g5", outfile=tmpfile.name, output_type="numpy" - ) - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 # check that outfile exists - - @pytest.mark.benchmark def test_filter1d_format(data): """