From fa3e0e709daa5bb4289cc5a96834e0f4df2e33d0 Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Fri, 15 Mar 2024 21:20:01 +0800 Subject: [PATCH] pygmt.grdhisteq.compute_bins: Refactor to store output in virtual files instead of temporary files (#3109) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Yvonne Fröhlich <94163266+yvonnefroehlich@users.noreply.github.com> --- pygmt/src/grdhisteq.py | 76 +++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/pygmt/src/grdhisteq.py b/pygmt/src/grdhisteq.py index 0e2c8c9ea60..880368f1a7c 100644 --- a/pygmt/src/grdhisteq.py +++ b/pygmt/src/grdhisteq.py @@ -2,6 +2,8 @@ grdhisteq - Perform histogram equalization for a grid. """ +from typing import Literal + import numpy as np import pandas as pd from pygmt.clib import Session @@ -135,7 +137,6 @@ def equalize_grid(grid, **kwargs): @fmt_docstring @use_alias( C="divisions", - D="outfile", R="region", N="gaussian", Q="quadratic", @@ -143,7 +144,12 @@ def equalize_grid(grid, **kwargs): h="header", ) @kwargs_to_strings(R="sequence") - def compute_bins(grid, output_type="pandas", **kwargs): + def compute_bins( + grid, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, + ) -> pd.DataFrame | np.ndarray | None: r""" Perform histogram equalization for a grid. @@ -168,16 +174,8 @@ def compute_bins(grid, output_type="pandas", **kwargs): Parameters ---------- {grid} - outfile : str or bool or None - The name of the output ASCII file to store the results of the - histogram equalization in. - output_type : str - Determine the format the xyz data will be returned in [Default is - ``pandas``]: - - - ``numpy`` - :class:`numpy.ndarray` - - ``pandas``- :class:`pandas.DataFrame` - - ``file`` - ASCII file (requires ``outfile``) + {output_type} + {outfile} divisions : int Set the number of divisions of the data range. quadratic : bool @@ -188,13 +186,13 @@ def compute_bins(grid, output_type="pandas", **kwargs): Returns ------- - ret : pandas.DataFrame or numpy.ndarray or None + ret Return type depends on ``outfile`` and ``output_type``: - - None if ``outfile`` is set (output will be stored in file set by + - ``None`` if ``outfile`` is set (output will be stored in file set by ``outfile``) - - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if - ``outfile`` is not set (depends on ``output_type``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not + set (depends on ``output_type``) Example ------- @@ -225,39 +223,33 @@ def compute_bins(grid, output_type="pandas", **kwargs): This method does a weighted histogram equalization for geographic grids to account for node area varying with latitude. """ - outfile = kwargs.get("D") output_type = validate_output_table_type(output_type, outfile=outfile) if kwargs.get("h") is not None and output_type != "file": raise GMTInvalidInput("'header' is only allowed with output_type='file'.") - with GMTTempFile(suffix=".txt") as tmpfile: - with Session() as lib: - with lib.virtualfile_in(check_kind="raster", data=grid) as vingrd: - if outfile is None: - kwargs["D"] = outfile = tmpfile.name # output to tmpfile - lib.call_module( - module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd) - ) + with Session() as lib: + with ( + lib.virtualfile_in(check_kind="raster", data=grid) as vingrd, + lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl, + ): + kwargs["D"] = vouttbl # -D for output file name + lib.call_module( + module="grdhisteq", args=build_arg_string(kwargs, infile=vingrd) + ) - if outfile == tmpfile.name: - # if user did not set outfile, return pd.DataFrame - result = pd.read_csv( - filepath_or_buffer=outfile, - sep="\t", - header=None, - names=["start", "stop", "bin_id"], - dtype={ + result = lib.virtualfile_to_dataset( + output_type=output_type, + vfname=vouttbl, + column_names=["start", "stop", "bin_id"], + ) + if output_type == "pandas": + result = result.astype( + { "start": np.float32, "stop": np.float32, "bin_id": np.uint32, - }, + } ) - elif outfile != tmpfile.name: - # return None if outfile set, output in outfile - return None - - if output_type == "numpy": - return result.to_numpy() - - return result.set_index("bin_id") + return result.set_index("bin_id") + return result