Skip to content

Commit

Permalink
Replace the list of support numpy dtypes with DTYPES_NUMERIC in tests (
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman authored Nov 4, 2024
1 parent 82b0c73 commit a0e4226
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 23 deletions.
3 changes: 2 additions & 1 deletion pygmt/tests/test_clib_put_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import xarray as xr
from pygmt import clib
from pygmt.clib.session import DTYPES_NUMERIC
from pygmt.exceptions import GMTCLibError
from pygmt.helpers import GMTTempFile
from pygmt.tests.test_clib import mock
Expand All @@ -17,7 +18,7 @@ def fixture_dtypes():
"""
List of supported numpy dtypes.
"""
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()
return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64]


@pytest.mark.benchmark
Expand Down
3 changes: 2 additions & 1 deletion pygmt/tests/test_clib_put_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy.testing as npt
import pytest
from pygmt import clib
from pygmt.clib.session import DTYPES_NUMERIC
from pygmt.exceptions import GMTCLibError, GMTInvalidInput
from pygmt.helpers import GMTTempFile

Expand All @@ -18,7 +19,7 @@ def fixture_dtypes():
"""
List of supported numpy dtypes.
"""
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()
return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64]


@pytest.mark.benchmark
Expand Down
3 changes: 2 additions & 1 deletion pygmt/tests/test_clib_virtualfile_from_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pytest
from pygmt import clib
from pygmt.clib.session import DTYPES_NUMERIC
from pygmt.helpers import GMTTempFile


Expand All @@ -13,7 +14,7 @@ def fixture_dtypes():
"""
List of supported numpy dtypes.
"""
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()
return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64]


@pytest.mark.benchmark
Expand Down
7 changes: 3 additions & 4 deletions pygmt/tests/test_clib_virtualfile_from_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import pytest
from pygmt import clib
from pygmt.clib.session import DTYPES_NUMERIC
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import GMTTempFile

Expand All @@ -17,7 +18,7 @@ def fixture_dtypes():
"""
List of supported numpy dtypes.
"""
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()
return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64]


@pytest.fixture(scope="module", name="dtypes_pandas")
Expand All @@ -26,10 +27,8 @@ def fixture_dtypes_pandas(dtypes):
List of supported pandas dtypes.
"""
dtypes_pandas = dtypes.copy()

if find_spec("pyarrow") is not None:
dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas])

dtypes_pandas.extend([f"{np.dtype(dtype).name}[pyarrow]" for dtype in dtypes])
return tuple(dtypes_pandas)


Expand Down
17 changes: 2 additions & 15 deletions pygmt/tests/test_clib_virtualfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
Test the Session.open_virtualfile method.
"""

from importlib.util import find_spec
from pathlib import Path

import numpy as np
import pytest
from pygmt import clib
from pygmt.clib.session import DTYPES_NUMERIC
from pygmt.exceptions import GMTCLibError, GMTInvalidInput
from pygmt.helpers import GMTTempFile
from pygmt.tests.test_clib import mock
Expand All @@ -28,20 +28,7 @@ def fixture_dtypes():
"""
List of supported numpy dtypes.
"""
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()


@pytest.fixture(scope="module", name="dtypes_pandas")
def fixture_dtypes_pandas(dtypes):
"""
List of supported pandas dtypes.
"""
dtypes_pandas = dtypes.copy()

if find_spec("pyarrow") is not None:
dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas])

return tuple(dtypes_pandas)
return [dtype for dtype in DTYPES_NUMERIC if dtype != np.timedelta64]


@pytest.mark.benchmark
Expand Down
4 changes: 3 additions & 1 deletion pygmt/tests/test_grdimage_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Test Figure.grdimage on 3-band RGB images.
"""

import numpy as np
import pytest
from pygmt import Figure
from pygmt.clib.session import DTYPES_NUMERIC
from pygmt.datasets import load_blue_marble

rioxarray = pytest.importorskip("rioxarray")
Expand Down Expand Up @@ -43,7 +45,7 @@ def test_grdimage_image_dataarray(xr_image):

@pytest.mark.parametrize(
"dtype",
["int8", "uint16", "int16", "uint32", "int32", "float32", "float64"],
[dtype for dtype in DTYPES_NUMERIC if dtype not in {np.uint8, np.timedelta64}],
)
def test_grdimage_image_dataarray_unsupported_dtype(dtype, xr_image):
"""
Expand Down

0 comments on commit a0e4226

Please sign in to comment.