Skip to content

Commit

Permalink
refactor: add more equivalents to pyarrow dtypes
Browse files Browse the repository at this point in the history
Signed-off-by: Ajith Aravind <[email protected]>
  • Loading branch information
aaravind100 committed May 10, 2024
1 parent f51e0af commit f06ef84
Showing 1 changed file with 77 additions and 11 deletions.
88 changes: 77 additions & 11 deletions pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,38 +1591,68 @@ def __str__(self) -> str:

if PYARROW_INSTALLED and PANDAS_2_0_0_PLUS:

@Engine.register_dtype(equivalents=["bool[pyarrow]", pyarrow.bool_])
@Engine.register_dtype(
equivalents=[
"bool[pyarrow]",
pyarrow.bool_,
pd.ArrowDtype(pyarrow.bool_()),
]
)
@immutable
class ArrowBool(BOOL):
"""Semantic representation of a :class:`pyarrow.bool_`."""

type = pd.ArrowDtype(pyarrow.bool_())

@Engine.register_dtype(equivalents=["int64[pyarrow]", pyarrow.int64])
@Engine.register_dtype(
equivalents=[
"int64[pyarrow]",
pyarrow.int64,
pd.ArrowDtype(pyarrow.int64()),
]
)
@immutable
class ArrowInt64(DataType, dtypes.Int):
"""Semantic representation of a :class:`pyarrow.int64`."""

type = pd.ArrowDtype(pyarrow.int64())
bit_width: int = 64

@Engine.register_dtype(equivalents=["int32[pyarrow]", pyarrow.int32])
@Engine.register_dtype(
equivalents=[
"int32[pyarrow]",
pyarrow.int32,
pd.ArrowDtype(pyarrow.int32()),
]
)
@immutable
class ArrowInt32(ArrowInt64):
"""Semantic representation of a :class:`pyarrow.int32`."""

type = pd.ArrowDtype(pyarrow.int32())
bit_width: int = 32

@Engine.register_dtype(equivalents=["int16[pyarrow]", pyarrow.int16])
@Engine.register_dtype(
equivalents=[
"int16[pyarrow]",
pyarrow.int16,
pd.ArrowDtype(pyarrow.int16()),
]
)
@immutable
class ArrowInt16(ArrowInt32):
"""Semantic representation of a :class:`pyarrow.int16`."""

type = pd.ArrowDtype(pyarrow.int16())
bit_width: int = 16

@Engine.register_dtype(equivalents=["int8[pyarrow]", pyarrow.int8])
@Engine.register_dtype(
equivalents=[
"int8[pyarrow]",
pyarrow.int8,
pd.ArrowDtype(pyarrow.int8()),
]
)
@immutable
class ArrowInt8(ArrowInt16):
"""Semantic representation of a :class:`pyarrow.int8`."""
Expand All @@ -1637,47 +1667,83 @@ class ArrowString(DataType, dtypes.String):

type = pd.ArrowDtype(pyarrow.string())

@Engine.register_dtype(equivalents=["uint64[pyarrow]", pyarrow.uint64])
@Engine.register_dtype(
equivalents=[
"uint64[pyarrow]",
pyarrow.uint64,
pd.ArrowDtype(pyarrow.uint64()),
]
)
@immutable
class ArrowUInt64(DataType, dtypes.UInt):
"""Semantic representation of a :class:`pyarrow.uint64`."""

type = pd.ArrowDtype(pyarrow.uint64())
bit_width: int = 64

@Engine.register_dtype(equivalents=["uint32[pyarrow]", pyarrow.uint32])
@Engine.register_dtype(
equivalents=[
"uint32[pyarrow]",
pyarrow.uint32,
pd.ArrowDtype(pyarrow.uint32()),
]
)
@immutable
class ArrowUInt32(ArrowUInt64):
"""Semantic representation of a :class:`pyarrow.uint32`."""

type = pd.ArrowDtype(pyarrow.uint32())
bit_width: int = 32

@Engine.register_dtype(equivalents=["uint16[pyarrow]", pyarrow.uint16])
@Engine.register_dtype(
equivalents=[
"uint16[pyarrow]",
pyarrow.uint16,
pd.ArrowDtype(pyarrow.uint16()),
]
)
@immutable
class ArrowUInt16(ArrowUInt32):
"""Semantic representation of a :class:`pyarrow.uint16`."""

type = pd.ArrowDtype(pyarrow.uint16())
bit_width: int = 16

@Engine.register_dtype(equivalents=["uint8[pyarrow]", pyarrow.uint8])
@Engine.register_dtype(
equivalents=[
"uint8[pyarrow]",
pyarrow.uint8,
pd.ArrowDtype(pyarrow.uint8()),
]
)
@immutable
class ArrowUInt8(ArrowUInt16):
"""Semantic representation of a :class:`pyarrow.uint8`."""

type = pd.ArrowDtype(pyarrow.uint8())
bit_width: int = 8

@Engine.register_dtype(equivalents=["double[pyarrow]", pyarrow.float64])
@Engine.register_dtype(
equivalents=[
"double[pyarrow]",
pyarrow.float64,
pd.ArrowDtype(pyarrow.float64()),
]
)
@immutable
class ArrowFloat64(DataType, dtypes.Float):
"""Semantic representation of a :class:`pyarrow.float64`."""

type = pd.ArrowDtype(pyarrow.float64())
bit_width: int = 64

@Engine.register_dtype(equivalents=["float[pyarrow]", pyarrow.float32])
@Engine.register_dtype(
equivalents=[
"float[pyarrow]",
pyarrow.float32,
pd.ArrowDtype(pyarrow.float32()),
]
)
@immutable
class ArrowFloat32(ArrowFloat64):
"""Semantic representation of a :class:`pyarrow.float32`."""
Expand Down

0 comments on commit f06ef84

Please sign in to comment.