Skip to content

Commit

Permalink
ENH: Add ArrowDype and .array.ArrowExtensionArray to top level (panda…
Browse files Browse the repository at this point in the history
…s-dev#47818)

* ENH: Add ArrowDype and .array.ArrowExtensionDtype to top level

* ensure string[pyarrow] dispatches to StringDtype for now

* type ignores

* Address availability of Pyarrow

* Address typing
  • Loading branch information
mroeschke authored and noatamir committed Nov 9, 2022
1 parent efd906d commit 01259af
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 127 deletions.
2 changes: 2 additions & 0 deletions pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

from pandas.core.api import (
# dtype
ArrowDtype,
Int8Dtype,
Int16Dtype,
Int32Dtype,
Expand Down Expand Up @@ -308,6 +309,7 @@ def __getattr__(name):
# Pandas is not (yet) a py.typed library: the public API is determined
# based on the documentation.
__all__ = [
"ArrowDtype",
"BooleanDtype",
"Categorical",
"CategoricalDtype",
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
value_counts,
)
from pandas.core.arrays import Categorical
from pandas.core.arrays.arrow import ArrowDtype
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.floating import (
Float32Dtype,
Expand Down Expand Up @@ -85,6 +86,7 @@

__all__ = [
"array",
"ArrowDtype",
"bdate_range",
"BooleanDtype",
"Categorical",
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pandas.core.arrays.arrow import ArrowExtensionArray
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionOpsMixin,
Expand All @@ -21,6 +22,7 @@
from pandas.core.arrays.timedeltas import TimedeltaArray

__all__ = [
"ArrowExtensionArray",
"ExtensionArray",
"ExtensionOpsMixin",
"ExtensionScalarOpsMixin",
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pandas.core.arrays.arrow.array import ArrowExtensionArray
from pandas.core.arrays.arrow.dtype import ArrowDtype

__all__ = ["ArrowExtensionArray"]
__all__ = ["ArrowDtype", "ArrowExtensionArray"]
111 changes: 0 additions & 111 deletions pandas/core/arrays/arrow/_arrow_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from __future__ import annotations

import inspect
import json
import warnings

import numpy as np
import pyarrow

from pandas._typing import IntervalInclusiveType
from pandas.errors import PerformanceWarning
from pandas.util._decorators import deprecate_kwarg
from pandas.util._exceptions import find_stack_level

from pandas.core.arrays.interval import VALID_INCLUSIVE


def fallback_performancewarning(version: str | None = None) -> None:
"""
Expand Down Expand Up @@ -67,109 +62,3 @@ def pyarrow_array_to_numpy_and_mask(
else:
mask = np.ones(len(arr), dtype=bool)
return data, mask


class ArrowPeriodType(pyarrow.ExtensionType):
def __init__(self, freq) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
self._freq = freq
pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period")

@property
def freq(self):
return self._freq

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"freq": self.freq}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType:
metadata = json.loads(serialized.decode())
return ArrowPeriodType(metadata["freq"])

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return type(self) == type(other) and self.freq == other.freq
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), self.freq))

def to_pandas_dtype(self):
import pandas as pd

return pd.PeriodDtype(freq=self.freq)


# register the type with a dummy instance
_period_type = ArrowPeriodType("D")
pyarrow.register_extension_type(_period_type)


class ArrowIntervalType(pyarrow.ExtensionType):
@deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive")
def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
assert inclusive in VALID_INCLUSIVE
self._inclusive: IntervalInclusiveType = inclusive
if not isinstance(subtype, pyarrow.DataType):
subtype = pyarrow.type_for_alias(str(subtype))
self._subtype = subtype

storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])
pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval")

@property
def subtype(self):
return self._subtype

@property
def inclusive(self) -> IntervalInclusiveType:
return self._inclusive

@property
def closed(self) -> IntervalInclusiveType:
warnings.warn(
"Attribute `closed` is deprecated in favor of `inclusive`.",
FutureWarning,
stacklevel=find_stack_level(inspect.currentframe()),
)
return self._inclusive

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType:
metadata = json.loads(serialized.decode())
subtype = pyarrow.type_for_alias(metadata["subtype"])
inclusive = metadata["inclusive"]
return ArrowIntervalType(subtype, inclusive)

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return (
type(self) == type(other)
and self.subtype == other.subtype
and self.inclusive == other.inclusive
)
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), str(self.subtype), self.inclusive))

def to_pandas_dtype(self):
import pandas as pd

return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive)


# register the type with a dummy instance
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
pyarrow.register_extension_type(_interval_type)
10 changes: 9 additions & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
import re

import numpy as np
import pyarrow as pa

from pandas._typing import DtypeObj
from pandas.compat import pa_version_under1p01
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.base import (
StorageExtensionDtype,
register_extension_dtype,
)

if not pa_version_under1p01:
import pyarrow as pa


@register_extension_dtype
class ArrowDtype(StorageExtensionDtype):
Expand All @@ -25,6 +28,8 @@ class ArrowDtype(StorageExtensionDtype):

def __init__(self, pyarrow_dtype: pa.DataType) -> None:
super().__init__("pyarrow")
if pa_version_under1p01:
raise ImportError("pyarrow>=1.0.1 is required for ArrowDtype")
if not isinstance(pyarrow_dtype, pa.DataType):
raise ValueError(
f"pyarrow_dtype ({pyarrow_dtype}) must be an instance "
Expand Down Expand Up @@ -93,6 +98,9 @@ def construct_from_string(cls, string: str) -> ArrowDtype:
)
if not string.endswith("[pyarrow]"):
raise TypeError(f"'{string}' must end with '[pyarrow]'")
if string == "string[pyarrow]":
# Ensure Registry.find skips ArrowDtype to use StringDtype instead
raise TypeError("string[pyarrow] should be constructed by StringDtype")
base_type = string.split("[pyarrow]")[0]
try:
pa_dtype = pa.type_for_alias(base_type)
Expand Down
118 changes: 118 additions & 0 deletions pandas/core/arrays/arrow/extension_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

import json
import warnings

import pyarrow

from pandas._typing import IntervalInclusiveType
from pandas.util._decorators import deprecate_kwarg
from pandas.util._exceptions import find_stack_level

from pandas.core.arrays.interval import VALID_INCLUSIVE


class ArrowPeriodType(pyarrow.ExtensionType):
def __init__(self, freq) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
self._freq = freq
pyarrow.ExtensionType.__init__(self, pyarrow.int64(), "pandas.period")

@property
def freq(self):
return self._freq

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"freq": self.freq}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowPeriodType:
metadata = json.loads(serialized.decode())
return ArrowPeriodType(metadata["freq"])

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return type(self) == type(other) and self.freq == other.freq
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), self.freq))

def to_pandas_dtype(self):
import pandas as pd

return pd.PeriodDtype(freq=self.freq)


# register the type with a dummy instance
_period_type = ArrowPeriodType("D")
pyarrow.register_extension_type(_period_type)


class ArrowIntervalType(pyarrow.ExtensionType):
@deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive")
def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
assert inclusive in VALID_INCLUSIVE
self._inclusive: IntervalInclusiveType = inclusive
if not isinstance(subtype, pyarrow.DataType):
subtype = pyarrow.type_for_alias(str(subtype))
self._subtype = subtype

storage_type = pyarrow.struct([("left", subtype), ("right", subtype)])
pyarrow.ExtensionType.__init__(self, storage_type, "pandas.interval")

@property
def subtype(self):
return self._subtype

@property
def inclusive(self) -> IntervalInclusiveType:
return self._inclusive

@property
def closed(self) -> IntervalInclusiveType:
warnings.warn(
"Attribute `closed` is deprecated in favor of `inclusive`.",
FutureWarning,
stacklevel=find_stack_level(),
)
return self._inclusive

def __arrow_ext_serialize__(self) -> bytes:
metadata = {"subtype": str(self.subtype), "inclusive": self.inclusive}
return json.dumps(metadata).encode()

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized) -> ArrowIntervalType:
metadata = json.loads(serialized.decode())
subtype = pyarrow.type_for_alias(metadata["subtype"])
inclusive = metadata["inclusive"]
return ArrowIntervalType(subtype, inclusive)

def __eq__(self, other):
if isinstance(other, pyarrow.BaseExtensionType):
return (
type(self) == type(other)
and self.subtype == other.subtype
and self.inclusive == other.inclusive
)
else:
return NotImplemented

def __hash__(self) -> int:
return hash((str(self), str(self.subtype), self.inclusive))

def to_pandas_dtype(self):
import pandas as pd

return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.inclusive)


# register the type with a dummy instance
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
pyarrow.register_extension_type(_interval_type)
2 changes: 1 addition & 1 deletion pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,7 @@ def __arrow_array__(self, type=None):
"""
import pyarrow

from pandas.core.arrays.arrow._arrow_utils import ArrowIntervalType
from pandas.core.arrays.arrow.extension_types import ArrowIntervalType

try:
subtype = pyarrow.from_numpy_dtype(self.dtype.subtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def __arrow_array__(self, type=None):
"""
import pyarrow

from pandas.core.arrays.arrow._arrow_utils import ArrowPeriodType
from pandas.core.arrays.arrow.extension_types import ArrowPeriodType

if type is not None:
if pyarrow.types.is_integer(type):
Expand Down
1 change: 0 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMi

def __init__(self, values) -> None:
super().__init__(values)
# TODO: Migrate to ArrowDtype instead
self._dtype = StringDtype(storage="pyarrow")

if not pa.types.is_string(self._data.type):
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self) -> None:
import pyarrow.parquet

# import utils to register the pyarrow extension types
import pandas.core.arrays.arrow._arrow_utils # pyright: ignore # noqa:F401
import pandas.core.arrays.arrow.extension_types # pyright: ignore # noqa:F401

self.api = pyarrow

Expand Down
1 change: 1 addition & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class TestPDApi(Base):

# top-level classes
classes = [
"ArrowDtype",
"Categorical",
"CategoricalIndex",
"DataFrame",
Expand Down
Loading

0 comments on commit 01259af

Please sign in to comment.