diff --git a/pandera/__init__.py b/pandera/__init__.py index 8f786ad65..029484d91 100644 --- a/pandera/__init__.py +++ b/pandera/__init__.py @@ -62,6 +62,7 @@ INT64, PANDAS_1_2_0_PLUS, PANDAS_1_3_0_PLUS, + PANDAS_2_0_0_PLUS, STRING, UINT8, UINT16, @@ -136,7 +137,9 @@ "INT16", "INT32", "INT64", + "PANDAS_1_2_0_PLUS", "PANDAS_1_3_0_PLUS", + "PANDAS_2_0_0_PLUS", "STRING", "UINT8", "UINT16", diff --git a/pandera/engines/pandas_engine.py b/pandera/engines/pandas_engine.py index e9473ae31..ed2db9edd 100644 --- a/pandera/engines/pandas_engine.py +++ b/pandera/engines/pandas_engine.py @@ -75,6 +75,7 @@ PANDAS_1_2_0_PLUS = pandas_version().release >= (1, 2, 0) PANDAS_1_3_0_PLUS = pandas_version().release >= (1, 3, 0) +PANDAS_2_0_0_PLUS = pandas_version().release >= (2, 0, 0) # register different TypedDict type depending on python version @@ -101,6 +102,16 @@ def is_extension_dtype( ) +def is_pyarrow_dtype( + pd_dtype: PandasDataType, +) -> Union[bool, Iterable[bool]]: + """Check if a value is a pandas pyarrow type or instance of one.""" + if not PYARROW_INSTALLED: + raise TypeError("pyarrow must be installed to use pyarrow dtypes.") + + return isinstance(pd_dtype, pd.ArrowDtype) + + @immutable(init=True) class DataType(dtypes.DataType): """Base `DataType` for boxing Pandas data types.""" @@ -220,6 +231,8 @@ def dtype(cls, data_type: Any) -> dtypes.DataType: "Usage Tip: Use an instance or a string " "representation." ) from None + elif is_pyarrow_dtype(data_type): + np_or_pd_dtype = data_type.pyarrow_dtype else: # let pandas transform any acceptable value # into a numpy or pandas dtype. @@ -1570,3 +1583,251 @@ def __init__( # pylint:disable=super-init-not-called def __str__(self) -> str: return str(NamedTuple.__name__) + + +############################################################################### +# pyarrow types +############################################################################### + +if PYARROW_INSTALLED and PANDAS_2_0_0_PLUS: + + @Engine.register_dtype( + equivalents=[ + "bool[pyarrow]", + pyarrow.bool_, + pd.ArrowDtype(pyarrow.bool_()), + ] + ) + @immutable + class ArrowBool(BOOL): + """Semantic representation of a :class:`pyarrow.bool_`.""" + + type = pd.ArrowDtype(pyarrow.bool_()) + + @Engine.register_dtype( + equivalents=[ + "int64[pyarrow]", + pyarrow.int64, + pd.ArrowDtype(pyarrow.int64()), + ] + ) + @immutable + class ArrowInt64(DataType, dtypes.Int): + """Semantic representation of a :class:`pyarrow.int64`.""" + + type = pd.ArrowDtype(pyarrow.int64()) + bit_width: int = 64 + + @Engine.register_dtype( + equivalents=[ + "int32[pyarrow]", + pyarrow.int32, + pd.ArrowDtype(pyarrow.int32()), + ] + ) + @immutable + class ArrowInt32(ArrowInt64): + """Semantic representation of a :class:`pyarrow.int32`.""" + + type = pd.ArrowDtype(pyarrow.int32()) + bit_width: int = 32 + + @Engine.register_dtype( + equivalents=[ + "int16[pyarrow]", + pyarrow.int16, + pd.ArrowDtype(pyarrow.int16()), + ] + ) + @immutable + class ArrowInt16(ArrowInt32): + """Semantic representation of a :class:`pyarrow.int16`.""" + + type = pd.ArrowDtype(pyarrow.int16()) + bit_width: int = 16 + + @Engine.register_dtype( + equivalents=[ + "int8[pyarrow]", + pyarrow.int8, + pd.ArrowDtype(pyarrow.int8()), + ] + ) + @immutable + class ArrowInt8(ArrowInt16): + """Semantic representation of a :class:`pyarrow.int8`.""" + + type = pd.ArrowDtype(pyarrow.int8()) + bit_width: int = 8 + + @Engine.register_dtype(equivalents=[pyarrow.string]) + @immutable + class ArrowString(DataType, dtypes.String): + """Semantic representation of a :class:`pyarrow.string`.""" + + type = pd.ArrowDtype(pyarrow.string()) + + @Engine.register_dtype( + equivalents=[ + "uint64[pyarrow]", + pyarrow.uint64, + pd.ArrowDtype(pyarrow.uint64()), + ] + ) + @immutable + class ArrowUInt64(DataType, dtypes.UInt): + """Semantic representation of a :class:`pyarrow.uint64`.""" + + type = pd.ArrowDtype(pyarrow.uint64()) + bit_width: int = 64 + + @Engine.register_dtype( + equivalents=[ + "uint32[pyarrow]", + pyarrow.uint32, + pd.ArrowDtype(pyarrow.uint32()), + ] + ) + @immutable + class ArrowUInt32(ArrowUInt64): + """Semantic representation of a :class:`pyarrow.uint32`.""" + + type = pd.ArrowDtype(pyarrow.uint32()) + bit_width: int = 32 + + @Engine.register_dtype( + equivalents=[ + "uint16[pyarrow]", + pyarrow.uint16, + pd.ArrowDtype(pyarrow.uint16()), + ] + ) + @immutable + class ArrowUInt16(ArrowUInt32): + """Semantic representation of a :class:`pyarrow.uint16`.""" + + type = pd.ArrowDtype(pyarrow.uint16()) + bit_width: int = 16 + + @Engine.register_dtype( + equivalents=[ + "uint8[pyarrow]", + pyarrow.uint8, + pd.ArrowDtype(pyarrow.uint8()), + ] + ) + @immutable + class ArrowUInt8(ArrowUInt16): + """Semantic representation of a :class:`pyarrow.uint8`.""" + + type = pd.ArrowDtype(pyarrow.uint8()) + bit_width: int = 8 + + @Engine.register_dtype( + equivalents=[ + "double[pyarrow]", + pyarrow.float64, + pd.ArrowDtype(pyarrow.float64()), + ] + ) + @immutable + class ArrowFloat64(DataType, dtypes.Float): + """Semantic representation of a :class:`pyarrow.float64`.""" + + type = pd.ArrowDtype(pyarrow.float64()) + bit_width: int = 64 + + @Engine.register_dtype( + equivalents=[ + "float[pyarrow]", + pyarrow.float32, + pd.ArrowDtype(pyarrow.float32()), + ] + ) + @immutable + class ArrowFloat32(ArrowFloat64): + """Semantic representation of a :class:`pyarrow.float32`.""" + + type = pd.ArrowDtype(pyarrow.float32()) + bit_width: int = 32 + + @Engine.register_dtype( + equivalents=[pyarrow.decimal128, pyarrow.Decimal128Type] + ) + @immutable(init=True) + class ArrowDecimal128(DataType, dtypes.Decimal): + """Semantic representation of a :class:`pyarrow.decimal128`.""" + + type: Optional[pd.ArrowDtype] = dataclasses.field( + default=None, init=False + ) + precision: int = 28 + scale: int = 0 + + def __post_init__(self) -> None: + type_ = pd.ArrowDtype( + pyarrow.decimal128(self.precision, self.scale) + ) + object.__setattr__(self, "type", type_) + + @classmethod + def from_parametrized_dtype( + cls, + pyarrow_dtype: pyarrow.Decimal128Type, + ): + return cls(precision=pyarrow_dtype.precision, scale=pyarrow_dtype.scale) # type: ignore + + @Engine.register_dtype( + equivalents=[pyarrow.timestamp, pyarrow.TimestampType] + ) + @immutable(init=True) + class ArrowTimestamp(DataType, dtypes.Timestamp): + """Semantic representation of a :class:`pyarrow.timestamp`.""" + + type: Optional[pd.ArrowDtype] = dataclasses.field( + default=None, init=False + ) + unit: Optional[str] = "ns" + tz: Optional[datetime.tzinfo] = None + + def __post_init__(self): + type_ = pd.ArrowDtype(pyarrow.timestamp(self.unit, self.tz)) + object.__setattr__(self, "type", type_) + + @classmethod + def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType): + return cls(unit=pyarrow_dtype.unit, tz=pyarrow_dtype.tz) # type: ignore + + @Engine.register_dtype( + equivalents=[pyarrow.dictionary, pyarrow.DictionaryType] + ) + @immutable(init=True) + class ArrowDictionary(DataType, dtypes.Category): + """Semantic representation of a :class:`pyarrow.dictionary`.""" + + type: Optional[pd.ArrowDtype] = dataclasses.field( + default=None, init=False + ) + index_type: Optional[pyarrow.DataType] = pyarrow.int64() + value_type: Optional[pyarrow.DataType] = pyarrow.int64() + ordered: bool = False + + def __post_init__(self): + type_ = pd.ArrowDtype( + pyarrow.dictionary( + self.index_type, + self.value_type, + self.ordered, + ) + ) + object.__setattr__(self, "type", type_) + + @classmethod + def from_parametrized_dtype( + cls, pyarrow_dtype: pyarrow.DictionaryType + ): + return cls( + index_type=pyarrow_dtype.index_type, # type: ignore + value_type=pyarrow_dtype.value_type, # type: ignore + ordered=pyarrow_dtype.ordered, # type: ignore + ) diff --git a/tests/core/test_pandas_engine.py b/tests/core/test_pandas_engine.py index 3d6c93aa9..0add7b1c1 100644 --- a/tests/core/test_pandas_engine.py +++ b/tests/core/test_pandas_engine.py @@ -1,6 +1,7 @@ """Test pandas engine.""" from datetime import date +from typing import Any, Set import hypothesis import hypothesis.extra.pandas as pd_st @@ -14,9 +15,20 @@ from pandera.engines import pandas_engine from pandera.errors import ParserError +UNSUPPORTED_DTYPE_CLS: Set[Any] = set() + +# `string[pyarrow]` gets parsed to type `string` by pandas +if pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS: + UNSUPPORTED_DTYPE_CLS.add(pandas_engine.ArrowString) + @pytest.mark.parametrize( - "data_type", list(pandas_engine.Engine.get_registered_dtypes()) + "data_type", + [ + data_type + for data_type in pandas_engine.Engine.get_registered_dtypes() + if data_type not in UNSUPPORTED_DTYPE_CLS + ], ) def test_pandas_data_type(data_type): """Test numpy engine DataType base class.""" diff --git a/tests/strategies/test_strategies.py b/tests/strategies/test_strategies.py index 47abb5d9e..c42820fd8 100644 --- a/tests/strategies/test_strategies.py +++ b/tests/strategies/test_strategies.py @@ -45,6 +45,28 @@ pandas_engine.PythonNamedTuple, ] ) + +if pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS: + UNSUPPORTED_DTYPE_CLS.update( + [ + pandas_engine.ArrowBool, + pandas_engine.ArrowDecimal128, + pandas_engine.ArrowDictionary, + pandas_engine.ArrowFloat32, + pandas_engine.ArrowFloat64, + pandas_engine.ArrowInt8, + pandas_engine.ArrowInt16, + pandas_engine.ArrowInt32, + pandas_engine.ArrowInt64, + pandas_engine.ArrowString, + pandas_engine.ArrowTimestamp, + pandas_engine.ArrowUInt8, + pandas_engine.ArrowUInt16, + pandas_engine.ArrowUInt32, + pandas_engine.ArrowUInt64, + ] + ) + SUPPORTED_DTYPES = set() for data_type in pandas_engine.Engine.get_registered_dtypes(): if (