Skip to content

Commit

Permalink
CoW: Track references in unstack if there is no copy (#57487)
Browse files Browse the repository at this point in the history
* CoW: Track references in unstack if there is no copy

* Update

* Update

* Update

---------

Co-authored-by: Matthew Roeschke <[email protected]>
  • Loading branch information
phofl and mroeschke authored Apr 1, 2024
1 parent 95f911d commit 07f6c4d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
26 changes: 18 additions & 8 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
factorize,
unique,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.arrays.categorical import factorize_from_iterable
from pandas.core.construction import ensure_wrapped_if_datetimelike
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -231,20 +232,31 @@ def arange_result(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.bool_]]:
return new_values, mask.any(0)
# TODO: in all tests we have mask.any(0).all(); can we rely on that?

def get_result(self, values, value_columns, fill_value) -> DataFrame:
def get_result(self, obj, value_columns, fill_value) -> DataFrame:
values = obj._values
if values.ndim == 1:
values = values[:, np.newaxis]

if value_columns is None and values.shape[1] != 1: # pragma: no cover
raise ValueError("must pass column labels for multi-column data")

values, _ = self.get_new_values(values, fill_value)
new_values, _ = self.get_new_values(values, fill_value)
columns = self.get_new_columns(value_columns)
index = self.new_index

return self.constructor(
values, index=index, columns=columns, dtype=values.dtype
result = self.constructor(
new_values, index=index, columns=columns, dtype=new_values.dtype, copy=False
)
if isinstance(values, np.ndarray):
base, new_base = values.base, new_values.base
elif isinstance(values, NDArrayBackedExtensionArray):
base, new_base = values._ndarray.base, new_values._ndarray.base
else:
base, new_base = 1, 2 # type: ignore[assignment]
if base is new_base:
# We can only get here if one of the dimensions is size 1
result._mgr.add_references(obj._mgr)
return result

def get_new_values(self, values, fill_value=None):
if values.ndim == 1:
Expand Down Expand Up @@ -532,9 +544,7 @@ def unstack(
unstacker = _Unstacker(
obj.index, level=level, constructor=obj._constructor_expanddim, sort=sort
)
return unstacker.get_result(
obj._values, value_columns=None, fill_value=fill_value
)
return unstacker.get_result(obj, value_columns=None, fill_value=fill_value)


def _unstack_frame(
Expand All @@ -550,7 +560,7 @@ def _unstack_frame(
return obj._constructor_from_mgr(mgr, axes=mgr.axes)
else:
return unstacker.get_result(
obj._values, value_columns=obj.columns, fill_value=fill_value
obj, value_columns=obj.columns, fill_value=fill_value
)


Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/reshape/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2703,3 +2703,16 @@ def test_pivot_table_with_margins_and_numeric_column_names(self):
index=Index(["a", "b", "All"], name=0),
)
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("m", [1, 10])
def test_unstack_shares_memory(self, m):
# GH#56633
levels = np.arange(m)
index = MultiIndex.from_product([levels] * 2)
values = np.arange(m * m * 100).reshape(m * m, 100)
df = DataFrame(values, index, np.arange(100))
df_orig = df.copy()
result = df.unstack(sort=False)
assert np.shares_memory(df._values, result._values) is (m == 1)
result.iloc[0, 0] = -1
tm.assert_frame_equal(df, df_orig)

0 comments on commit 07f6c4d

Please sign in to comment.