From 0d24a3508b467236717ce5f54c4bbf41ffe95456 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Wed, 13 Mar 2024 21:53:25 +0800 Subject: [PATCH] pygmt.triangulate.delaunay_triples: Add 'output_type' parameter for output in pandas/numpy/file formats --- pygmt/src/triangulate.py | 59 ++++++++++++--------------------- pygmt/tests/test_triangulate.py | 22 ------------ 2 files changed, 22 insertions(+), 59 deletions(-) diff --git a/pygmt/src/triangulate.py b/pygmt/src/triangulate.py index e73ab92fe5e..ba93800ece9 100644 --- a/pygmt/src/triangulate.py +++ b/pygmt/src/triangulate.py @@ -3,6 +3,9 @@ Cartesian data. """ +from typing import Literal + +import numpy as np import pandas as pd from pygmt.clib import Session from pygmt.helpers import ( @@ -172,10 +175,10 @@ def delaunay_triples( y=None, z=None, *, - output_type="pandas", - outfile=None, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, **kwargs, - ): + ) -> pd.DataFrame | np.ndarray | None: """ Delaunay triangle based gridding of Cartesian data. @@ -204,16 +207,8 @@ def delaunay_triples( {table-classes}. {projection} {region} - outfile : str or None - The name of the output ASCII file to store the results of the - histogram equalization in. - output_type : str - Determine the format the xyz data will be returned in [Default is - ``pandas``]: - - - ``numpy`` - :class:`numpy.ndarray` - - ``pandas``- :class:`pandas.DataFrame` - - ``file`` - ASCII file (requires ``outfile``) + {output_type} + {outfile} {verbose} {binary} {nodata} @@ -226,13 +221,13 @@ def delaunay_triples( Returns ------- - ret : pandas.DataFrame or numpy.ndarray or None + ret Return type depends on ``outfile`` and ``output_type``: - None if ``outfile`` is set (output will be stored in file set by ``outfile``) - - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if - ``outfile`` is not set (depends on ``output_type``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not + set (depends on ``output_type``) Note ---- @@ -243,25 +238,15 @@ def delaunay_triples( """ output_type = validate_output_table_type(output_type, outfile) - with GMTTempFile(suffix=".txt") as tmpfile: - with Session() as lib: - with lib.virtualfile_in( + with Session() as lib: + with ( + lib.virtualfile_in( check_kind="vector", data=data, x=x, y=y, z=z, required_z=False - ) as vintbl: - if outfile is None: - outfile = tmpfile.name - lib.call_module( - module="triangulate", - args=build_arg_string(kwargs, infile=vintbl, outfile=outfile), - ) - - if outfile == tmpfile.name: - # if user did not set outfile, return pd.DataFrame - result = pd.read_csv(outfile, sep="\t", header=None) - elif outfile != tmpfile.name: - # return None if outfile set, output in outfile - result = None - - if output_type == "numpy": - result = result.to_numpy() - return result + ) as vintbl, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + lib.call_module( + module="triangulate", + args=build_arg_string(kwargs, infile=vintbl, outfile=vouttbl), + ) + return lib.virtualfile_to_dataset(output_type=output_type, vfile=vouttbl) diff --git a/pygmt/tests/test_triangulate.py b/pygmt/tests/test_triangulate.py index 154bc82b09f..c4fbf371721 100644 --- a/pygmt/tests/test_triangulate.py +++ b/pygmt/tests/test_triangulate.py @@ -106,28 +106,6 @@ def test_delaunay_triples_ndarray_output(dataframe, expected_dataframe): np.testing.assert_allclose(actual=output, desired=expected_dataframe.to_numpy()) -def test_delaunay_triples_outfile(dataframe, expected_dataframe): - """ - Test triangulate.delaunay_triples with ``outfile``. - """ - with GMTTempFile(suffix=".txt") as tmpfile: - with pytest.warns(RuntimeWarning) as record: - result = triangulate.delaunay_triples(data=dataframe, outfile=tmpfile.name) - assert len(record) == 1 # check that only one warning was raised - assert result is None # return value is None - assert Path(tmpfile.name).stat().st_size > 0 - temp_df = pd.read_csv(filepath_or_buffer=tmpfile.name, sep="\t", header=None) - pd.testing.assert_frame_equal(left=temp_df, right=expected_dataframe) - - -def test_delaunay_triples_invalid_format(dataframe): - """ - Test that triangulate.delaunay_triples fails with incorrect format. - """ - with pytest.raises(GMTInvalidInput): - triangulate.delaunay_triples(data=dataframe, output_type=1) - - @pytest.mark.benchmark def test_regular_grid_no_outgrid(dataframe, expected_grid): """