Skip to content

Commit

Permalink
Backport PR #51966 on branch 2.0.x (CoW: __array__ not recognizing ea…
Browse files Browse the repository at this point in the history
… dtypes) (#52358)

Backport PR #51966: CoW: __array__ not recognizing ea dtypes

Co-authored-by: Patrick Hoefler <[email protected]>
  • Loading branch information
meeseeksmachine and phofl authored Apr 2, 2023
1 parent af8be1d commit 7c56e05
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
16 changes: 12 additions & 4 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
validate_inclusive,
)

from pandas.core.dtypes.astype import astype_is_view
from pandas.core.dtypes.common import (
ensure_object,
ensure_platform_int,
Expand Down Expand Up @@ -1995,10 +1996,17 @@ def empty(self) -> bool_t:
def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray:
values = self._values
arr = np.asarray(values, dtype=dtype)
if arr is values and using_copy_on_write():
# TODO(CoW) also properly handle extension dtypes
arr = arr.view()
arr.flags.writeable = False
if (
astype_is_view(values.dtype, arr.dtype)
and using_copy_on_write()
and self._mgr.is_single_block
):
# Check if both conversions can be done without a copy
if astype_is_view(self.dtypes.iloc[0], values.dtype) and astype_is_view(
values.dtype, arr.dtype
):
arr = arr.view()
arr.flags.writeable = False
return arr

@final
Expand Down
3 changes: 1 addition & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,7 @@ def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray:
"""
values = self._values
arr = np.asarray(values, dtype=dtype)
if arr is values and using_copy_on_write():
# TODO(CoW) also properly handle extension dtypes
if using_copy_on_write() and astype_is_view(values.dtype, arr.dtype):
arr = arr.view()
arr.flags.writeable = False
return arr
Expand Down
64 changes: 64 additions & 0 deletions pandas/tests/copy_view/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pandas import (
DataFrame,
Series,
date_range,
)
import pandas._testing as tm
from pandas.tests.copy_view.util import get_array
Expand Down Expand Up @@ -119,3 +120,66 @@ def test_ravel_read_only(using_copy_on_write, order):
if using_copy_on_write:
assert arr.flags.writeable is False
assert np.shares_memory(get_array(ser), arr)


def test_series_array_ea_dtypes(using_copy_on_write):
ser = Series([1, 2, 3], dtype="Int64")
arr = np.asarray(ser, dtype="int64")
assert np.shares_memory(arr, get_array(ser))
if using_copy_on_write:
assert arr.flags.writeable is False
else:
assert arr.flags.writeable is True

arr = np.asarray(ser)
assert not np.shares_memory(arr, get_array(ser))
assert arr.flags.writeable is True


def test_dataframe_array_ea_dtypes(using_copy_on_write):
df = DataFrame({"a": [1, 2, 3]}, dtype="Int64")
arr = np.asarray(df, dtype="int64")
# TODO: This should be able to share memory, but we are roundtripping
# through object
assert not np.shares_memory(arr, get_array(df, "a"))
assert arr.flags.writeable is True

arr = np.asarray(df)
if using_copy_on_write:
# TODO(CoW): This should be True
assert arr.flags.writeable is False
else:
assert arr.flags.writeable is True


def test_dataframe_array_string_dtype(using_copy_on_write, using_array_manager):
df = DataFrame({"a": ["a", "b"]}, dtype="string")
arr = np.asarray(df)
if not using_array_manager:
assert np.shares_memory(arr, get_array(df, "a"))
if using_copy_on_write:
assert arr.flags.writeable is False
else:
assert arr.flags.writeable is True


def test_dataframe_multiple_numpy_dtypes():
df = DataFrame({"a": [1, 2, 3], "b": 1.5})
arr = np.asarray(df)
assert not np.shares_memory(arr, get_array(df, "a"))
assert arr.flags.writeable is True


def test_values_is_ea(using_copy_on_write):
df = DataFrame({"a": date_range("2012-01-01", periods=3)})
arr = np.asarray(df)
if using_copy_on_write:
assert arr.flags.writeable is False
else:
assert arr.flags.writeable is True


def test_empty_dataframe():
df = DataFrame()
arr = np.asarray(df)
assert arr.flags.writeable is True

0 comments on commit 7c56e05

Please sign in to comment.