diff --git a/pygmt/tests/test_clib_put_matrix.py b/pygmt/tests/test_clib_put_matrix.py index ff85a7430e8..cf1a71f99ec 100644 --- a/pygmt/tests/test_clib_put_matrix.py +++ b/pygmt/tests/test_clib_put_matrix.py @@ -7,6 +7,7 @@ import pytest import xarray as xr from pygmt import clib +from pygmt.clib.session import DTYPES_NUMERIC from pygmt.exceptions import GMTCLibError from pygmt.helpers import GMTTempFile from pygmt.tests.test_clib import mock @@ -17,7 +18,7 @@ def fixture_dtypes(): """ List of supported numpy dtypes. """ - return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split() + return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64] @pytest.mark.benchmark diff --git a/pygmt/tests/test_clib_put_vector.py b/pygmt/tests/test_clib_put_vector.py index d42ee9afbba..7e1dd0d9310 100644 --- a/pygmt/tests/test_clib_put_vector.py +++ b/pygmt/tests/test_clib_put_vector.py @@ -9,6 +9,7 @@ import numpy.testing as npt import pytest from pygmt import clib +from pygmt.clib.session import DTYPES_NUMERIC from pygmt.exceptions import GMTCLibError, GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -18,7 +19,7 @@ def fixture_dtypes(): """ List of supported numpy dtypes. """ - return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split() + return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64] @pytest.mark.benchmark diff --git a/pygmt/tests/test_clib_virtualfile_from_matrix.py b/pygmt/tests/test_clib_virtualfile_from_matrix.py index 6d1a42fa14b..6d4c855ac05 100644 --- a/pygmt/tests/test_clib_virtualfile_from_matrix.py +++ b/pygmt/tests/test_clib_virtualfile_from_matrix.py @@ -5,6 +5,7 @@ import numpy as np import pytest from pygmt import clib +from pygmt.clib.session import DTYPES_NUMERIC from pygmt.helpers import GMTTempFile @@ -13,7 +14,7 @@ def fixture_dtypes(): """ List of supported numpy dtypes. """ - return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split() + return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64] @pytest.mark.benchmark diff --git a/pygmt/tests/test_clib_virtualfile_from_vectors.py b/pygmt/tests/test_clib_virtualfile_from_vectors.py index 18cb4053f37..041bc7a803c 100644 --- a/pygmt/tests/test_clib_virtualfile_from_vectors.py +++ b/pygmt/tests/test_clib_virtualfile_from_vectors.py @@ -8,6 +8,7 @@ import pandas as pd import pytest from pygmt import clib +from pygmt.clib.session import DTYPES_NUMERIC from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import GMTTempFile @@ -17,7 +18,7 @@ def fixture_dtypes(): """ List of supported numpy dtypes. """ - return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split() + return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64] @pytest.fixture(scope="module", name="dtypes_pandas") @@ -26,10 +27,8 @@ def fixture_dtypes_pandas(dtypes): List of supported pandas dtypes. """ dtypes_pandas = dtypes.copy() - if find_spec("pyarrow") is not None: - dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas]) - + dtypes_pandas.extend([f"{np.dtype(dtype).name}[pyarrow]" for dtype in dtypes]) return tuple(dtypes_pandas) diff --git a/pygmt/tests/test_clib_virtualfiles.py b/pygmt/tests/test_clib_virtualfiles.py index a05790894d0..a45a662de71 100644 --- a/pygmt/tests/test_clib_virtualfiles.py +++ b/pygmt/tests/test_clib_virtualfiles.py @@ -2,12 +2,12 @@ Test the Session.open_virtualfile method. """ -from importlib.util import find_spec from pathlib import Path import numpy as np import pytest from pygmt import clib +from pygmt.clib.session import DTYPES_NUMERIC from pygmt.exceptions import GMTCLibError, GMTInvalidInput from pygmt.helpers import GMTTempFile from pygmt.tests.test_clib import mock @@ -28,20 +28,7 @@ def fixture_dtypes(): """ List of supported numpy dtypes. """ - return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split() - - -@pytest.fixture(scope="module", name="dtypes_pandas") -def fixture_dtypes_pandas(dtypes): - """ - List of supported pandas dtypes. - """ - dtypes_pandas = dtypes.copy() - - if find_spec("pyarrow") is not None: - dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas]) - - return tuple(dtypes_pandas) + return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64] @pytest.mark.benchmark diff --git a/pygmt/tests/test_grdimage_image.py b/pygmt/tests/test_grdimage_image.py index e2ec8980d18..5314c3882a2 100644 --- a/pygmt/tests/test_grdimage_image.py +++ b/pygmt/tests/test_grdimage_image.py @@ -2,8 +2,10 @@ Test Figure.grdimage on 3-band RGB images. """ +import numpy as np import pytest from pygmt import Figure +from pygmt.clib.session import DTYPES_NUMERIC from pygmt.datasets import load_blue_marble rioxarray = pytest.importorskip("rioxarray") @@ -43,7 +45,7 @@ def test_grdimage_image_dataarray(xr_image): @pytest.mark.parametrize( "dtype", - ["int8", "uint16", "int16", "uint32", "int32", "float32", "float64"], + [dtype for dtype in DTYPES_NUMERIC if dtype not in {np.uint8, np.timedelta64}], ) def test_grdimage_image_dataarray_unsupported_dtype(dtype, xr_image): """