From 147dac91644a59a0085bf655e9743413505b9629 Mon Sep 17 00:00:00 2001 From: Wei Ji Date: Wed, 2 Sep 2020 13:14:20 +1200 Subject: [PATCH 1/2] Allow pandas.DataFrame table inputs into pygmt.info Also renamed 'fname' argument to 'table' since `info` supports both file name inputs and pandas.DataFrame tables now. --- pygmt/modules.py | 30 ++++++++++++++++++++---------- pygmt/tests/test_info.py | 33 ++++++++++++++++++++++++--------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/pygmt/modules.py b/pygmt/modules.py index 7339241002a..22abc34f0da 100644 --- a/pygmt/modules.py +++ b/pygmt/modules.py @@ -55,7 +55,7 @@ def grdinfo(grid, **kwargs): @fmt_docstring @use_alias(C="per_column", I="spacing", T="nearest_multiple") -def info(fname, **kwargs): +def info(table, **kwargs): """ Get information about data tables. @@ -74,8 +74,8 @@ def info(fname, **kwargs): Parameters ---------- - fname : str - The file name of the input data table file. + table : pandas.DataFrame or str + Either a pandas dataframe or a file name to an ASCII data table. per_column : bool Report the min/max values per column in separate columns. spacing : str @@ -88,14 +88,24 @@ def info(fname, **kwargs): Report the min/max of the first (0'th) column to the nearest multiple of dz and output this as the string *-Tzmin/zmax/dz*. """ - if not isinstance(fname, str): - raise GMTInvalidInput("'info' only accepts file names.") + kind = data_kind(table) + with Session() as lib: + if kind == "file": + file_context = dummy_context(table) + elif kind == "matrix": + if not hasattr(table, "values"): + raise GMTInvalidInput(f"Unrecognized data type: {type(table)}") + file_context = lib.virtualfile_from_matrix(table.values) + else: + raise GMTInvalidInput(f"Unrecognized data type: {type(table)}") - with GMTTempFile() as tmpfile: - arg_str = " ".join([fname, build_arg_string(kwargs), "->" + tmpfile.name]) - with Session() as lib: - lib.call_module("info", arg_str) - return tmpfile.read() + with GMTTempFile() as tmpfile: + with file_context as fname: + arg_str = " ".join( + [fname, build_arg_string(kwargs), "->" + tmpfile.name] + ) + lib.call_module("info", arg_str) + return tmpfile.read() @fmt_docstring diff --git a/pygmt/tests/test_info.py b/pygmt/tests/test_info.py index 3e9da3abf81..755dad3b9ff 100644 --- a/pygmt/tests/test_info.py +++ b/pygmt/tests/test_info.py @@ -4,7 +4,9 @@ import os import numpy as np +import pandas as pd import pytest +import xarray as xr from .. import info from ..exceptions import GMTInvalidInput @@ -14,8 +16,8 @@ def test_info(): - "Make sure info works" - output = info(fname=POINTS_DATA) + "Make sure info works on file name inputs" + output = info(table=POINTS_DATA) expected_output = ( f"{POINTS_DATA}: N = 20 " "<11.5309/61.7074> " @@ -25,33 +27,46 @@ def test_info(): assert output == expected_output +def test_info_dataframe(): + "Make sure info works on pandas.DataFrame inputs" + table = pd.read_csv(POINTS_DATA, sep=" ", header=None) + output = info(table=table) + expected_output = ( + ": N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n" + ) + assert output == expected_output + + def test_info_per_column(): "Make sure the per_column option works" - output = info(fname=POINTS_DATA, per_column=True) + output = info(table=POINTS_DATA, per_column=True) assert output == "11.5309 61.7074 -2.9289 7.8648 0.1412 0.9338\n" def test_info_spacing(): "Make sure the spacing option works" - output = info(fname=POINTS_DATA, spacing=0.1) + output = info(table=POINTS_DATA, spacing=0.1) assert output == "-R11.5/61.8/-3/7.9\n" def test_info_per_column_spacing(): "Make sure the per_column and spacing options work together" - output = info(fname=POINTS_DATA, per_column=True, spacing=0.1) + output = info(table=POINTS_DATA, per_column=True, spacing=0.1) assert output == "11.5 61.8 -3 7.9 0.1412 0.9338\n" def test_info_nearest_multiple(): "Make sure the nearest_multiple option works" - output = info(fname=POINTS_DATA, nearest_multiple=0.1) + output = info(table=POINTS_DATA, nearest_multiple=0.1) assert output == "-T11.5/61.8/0.1\n" def test_info_fails(): - "Make sure info raises an exception if not given a file name" + """ + Make sure info raises an exception if not given a file name or pandas + DataFrame + """ with pytest.raises(GMTInvalidInput): - info(fname=21) + info(table=xr.DataArray(21)) with pytest.raises(GMTInvalidInput): - info(fname=np.arange(20)) + info(table=np.arange(20)) From e264d56a7d27448438a5d8ac968e359911cdd4b4 Mon Sep 17 00:00:00 2001 From: Wei Ji Date: Mon, 7 Sep 2020 16:20:38 +1200 Subject: [PATCH 2/2] Handle 1D and 2D numpy.ndarray inputs to pygmt info --- pygmt/modules.py | 13 ++++++++----- pygmt/tests/test_info.py | 23 +++++++++++++++++++---- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pygmt/modules.py b/pygmt/modules.py index 22abc34f0da..94512ec325d 100644 --- a/pygmt/modules.py +++ b/pygmt/modules.py @@ -1,6 +1,7 @@ """ Non-plot GMT modules. """ +import numpy as np import xarray as xr from .clib import Session @@ -74,8 +75,9 @@ def info(table, **kwargs): Parameters ---------- - table : pandas.DataFrame or str - Either a pandas dataframe or a file name to an ASCII data table. + table : pandas.DataFrame or np.ndarray or str + Either a pandas dataframe, a 1D/2D numpy.ndarray or a file name to an + ASCII data table. per_column : bool Report the min/max values per column in separate columns. spacing : str @@ -93,9 +95,10 @@ def info(table, **kwargs): if kind == "file": file_context = dummy_context(table) elif kind == "matrix": - if not hasattr(table, "values"): - raise GMTInvalidInput(f"Unrecognized data type: {type(table)}") - file_context = lib.virtualfile_from_matrix(table.values) + _table = np.asanyarray(table) + if table.ndim == 1: # 1D arrays need to be 2D and transposed + _table = np.transpose(np.atleast_2d(_table)) + file_context = lib.virtualfile_from_matrix(_table) else: raise GMTInvalidInput(f"Unrecognized data type: {type(table)}") diff --git a/pygmt/tests/test_info.py b/pygmt/tests/test_info.py index 755dad3b9ff..b7eadc53649 100644 --- a/pygmt/tests/test_info.py +++ b/pygmt/tests/test_info.py @@ -37,6 +37,23 @@ def test_info_dataframe(): assert output == expected_output +def test_info_2d_array(): + "Make sure info works on 2D numpy.ndarray inputs" + table = np.loadtxt(POINTS_DATA) + output = info(table=table) + expected_output = ( + ": N = 20 <11.5309/61.7074> <-2.9289/7.8648> <0.1412/0.9338>\n" + ) + assert output == expected_output + + +def test_info_1d_array(): + "Make sure info works on 1D numpy.ndarray inputs" + output = info(table=np.arange(20)) + expected_output = ": N = 20 <0/19>\n" + assert output == expected_output + + def test_info_per_column(): "Make sure the per_column option works" output = info(table=POINTS_DATA, per_column=True) @@ -63,10 +80,8 @@ def test_info_nearest_multiple(): def test_info_fails(): """ - Make sure info raises an exception if not given a file name or pandas - DataFrame + Make sure info raises an exception if not given either a file name, pandas + DataFrame, or numpy ndarray """ with pytest.raises(GMTInvalidInput): info(table=xr.DataArray(21)) - with pytest.raises(GMTInvalidInput): - info(table=np.arange(20))