Skip to content

Commit

Permalink
BUG: CoW - correctly track references for chained operations (#48996)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick Hoefler <[email protected]>
  • Loading branch information
jorisvandenbossche and phofl authored Nov 1, 2022
1 parent b3bd5ad commit b858de0
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 19 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.5.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Fixed regressions

Bug fixes
~~~~~~~~~
-
- Bug in the Copy-on-Write implementation losing track of views in certain chained indexing cases (:issue:`48996`)
-

.. ---------------------------------------------------------------------------
Expand Down
12 changes: 9 additions & 3 deletions pandas/_libs/internals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,9 @@ cdef class BlockManager:
public bint _known_consolidated, _is_consolidated
public ndarray _blknos, _blklocs
public list refs
public object parent

def __cinit__(self, blocks=None, axes=None, refs=None, verify_integrity=True):
def __cinit__(self, blocks=None, axes=None, refs=None, parent=None, verify_integrity=True):
# None as defaults for unpickling GH#42345
if blocks is None:
# This adds 1-2 microseconds to DataFrame(np.array([]))
Expand All @@ -690,6 +691,7 @@ cdef class BlockManager:
self.blocks = blocks
self.axes = axes.copy() # copy to make sure we are not remotely-mutable
self.refs = refs
self.parent = parent

# Populate known_consolidate, blknos, and blklocs lazily
self._known_consolidated = False
Expand Down Expand Up @@ -805,7 +807,9 @@ cdef class BlockManager:
nrefs.append(weakref.ref(blk))

new_axes = [self.axes[0], self.axes[1]._getitem_slice(slobj)]
mgr = type(self)(tuple(nbs), new_axes, nrefs, verify_integrity=False)
mgr = type(self)(
tuple(nbs), new_axes, nrefs, parent=self, verify_integrity=False
)

# We can avoid having to rebuild blklocs/blknos
blklocs = self._blklocs
Expand All @@ -827,4 +831,6 @@ cdef class BlockManager:
new_axes = list(self.axes)
new_axes[axis] = new_axes[axis]._getitem_slice(slobj)

return type(self)(tuple(new_blocks), new_axes, new_refs, verify_integrity=False)
return type(self)(
tuple(new_blocks), new_axes, new_refs, parent=self, verify_integrity=False
)
55 changes: 43 additions & 12 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import pandas.core.algorithms as algos
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.arrays.sparse import SparseDtype
import pandas.core.common as com
from pandas.core.construction import (
ensure_wrapped_if_datetimelike,
extract_array,
Expand Down Expand Up @@ -148,6 +149,7 @@ class BaseBlockManager(DataManager):
blocks: tuple[Block, ...]
axes: list[Index]
refs: list[weakref.ref | None] | None
parent: object

@property
def ndim(self) -> int:
Expand All @@ -165,6 +167,7 @@ def from_blocks(
blocks: list[Block],
axes: list[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
) -> T:
raise NotImplementedError

Expand Down Expand Up @@ -264,6 +267,8 @@ def _clear_reference_block(self, blkno: int) -> None:
"""
if self.refs is not None:
self.refs[blkno] = None
if com.all_none(*self.refs):
self.parent = None

def get_dtypes(self):
dtypes = np.array([blk.dtype for blk in self.blocks])
Expand Down Expand Up @@ -605,7 +610,9 @@ def _combine(
axes[-1] = index
axes[0] = self.items.take(indexer)

return type(self).from_blocks(new_blocks, axes, new_refs)
return type(self).from_blocks(
new_blocks, axes, new_refs, parent=None if copy else self
)

@property
def nblocks(self) -> int:
Expand Down Expand Up @@ -648,11 +655,14 @@ def copy_func(ax):
new_refs: list[weakref.ref | None] | None
if deep:
new_refs = None
parent = None
else:
new_refs = [weakref.ref(blk) for blk in self.blocks]
parent = self

res.axes = new_axes
res.refs = new_refs
res.parent = parent

if self.ndim > 1:
# Avoid needing to re-compute these
Expand Down Expand Up @@ -744,6 +754,7 @@ def reindex_indexer(
only_slice=only_slice,
use_na_proxy=use_na_proxy,
)
parent = None if com.all_none(*new_refs) else self
else:
new_blocks = [
blk.take_nd(
Expand All @@ -756,11 +767,12 @@ def reindex_indexer(
for blk in self.blocks
]
new_refs = None
parent = None

new_axes = list(self.axes)
new_axes[axis] = new_axis

new_mgr = type(self).from_blocks(new_blocks, new_axes, new_refs)
new_mgr = type(self).from_blocks(new_blocks, new_axes, new_refs, parent=parent)
if axis == 1:
# We can avoid the need to rebuild these
new_mgr._blknos = self.blknos.copy()
Expand Down Expand Up @@ -995,6 +1007,7 @@ def __init__(
blocks: Sequence[Block],
axes: Sequence[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
verify_integrity: bool = True,
) -> None:

Expand Down Expand Up @@ -1059,11 +1072,13 @@ def from_blocks(
blocks: list[Block],
axes: list[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
) -> BlockManager:
"""
Constructor for BlockManager and SingleBlockManager with same signature.
"""
return cls(blocks, axes, refs, verify_integrity=False)
parent = parent if _using_copy_on_write() else None
return cls(blocks, axes, refs, parent, verify_integrity=False)

# ----------------------------------------------------------------
# Indexing
Expand All @@ -1085,7 +1100,7 @@ def fast_xs(self, loc: int) -> SingleBlockManager:
block = new_block(result, placement=slice(0, len(result)), ndim=1)
# in the case of a single block, the new block is a view
ref = weakref.ref(self.blocks[0])
return SingleBlockManager(block, self.axes[0], [ref])
return SingleBlockManager(block, self.axes[0], [ref], parent=self)

dtype = interleaved_dtype([blk.dtype for blk in self.blocks])

Expand Down Expand Up @@ -1119,7 +1134,7 @@ def fast_xs(self, loc: int) -> SingleBlockManager:
block = new_block(result, placement=slice(0, len(result)), ndim=1)
return SingleBlockManager(block, self.axes[0])

def iget(self, i: int) -> SingleBlockManager:
def iget(self, i: int, track_ref: bool = True) -> SingleBlockManager:
"""
Return the data as a SingleBlockManager.
"""
Expand All @@ -1129,7 +1144,9 @@ def iget(self, i: int) -> SingleBlockManager:
# shortcut for select a single-dim from a 2-dim BM
bp = BlockPlacement(slice(0, len(values)))
nb = type(block)(values, placement=bp, ndim=1)
return SingleBlockManager(nb, self.axes[1], [weakref.ref(block)])
ref = weakref.ref(block) if track_ref else None
parent = self if track_ref else None
return SingleBlockManager(nb, self.axes[1], [ref], parent)

def iget_values(self, i: int) -> ArrayLike:
"""
Expand Down Expand Up @@ -1371,7 +1388,9 @@ def column_setitem(self, loc: int, idx: int | slice | np.ndarray, value) -> None
self.blocks = tuple(blocks)
self._clear_reference_block(blkno)

col_mgr = self.iget(loc)
# this manager is only created temporarily to mutate the values in place
# so don't track references, otherwise the `setitem` would perform CoW again
col_mgr = self.iget(loc, track_ref=False)
new_mgr = col_mgr.setitem((idx,), value)
self.iset(loc, new_mgr._block.values, inplace=True)

Expand Down Expand Up @@ -1469,7 +1488,9 @@ def idelete(self, indexer) -> BlockManager:
nbs, new_refs = self._slice_take_blocks_ax0(taker, only_slice=True)
new_columns = self.items[~is_deleted]
axes = [new_columns, self.axes[1]]
return type(self)(tuple(nbs), axes, new_refs, verify_integrity=False)
# TODO this might not be needed (can a delete ever be done in chained manner?)
parent = None if com.all_none(*new_refs) else self
return type(self)(tuple(nbs), axes, new_refs, parent, verify_integrity=False)

# ----------------------------------------------------------------
# Block-wise Operation
Expand Down Expand Up @@ -1875,6 +1896,7 @@ def __init__(
block: Block,
axis: Index,
refs: list[weakref.ref | None] | None = None,
parent: object = None,
verify_integrity: bool = False,
fastpath=lib.no_default,
) -> None:
Expand All @@ -1893,13 +1915,15 @@ def __init__(
self.axes = [axis]
self.blocks = (block,)
self.refs = refs
self.parent = parent if _using_copy_on_write() else None

@classmethod
def from_blocks(
cls,
blocks: list[Block],
axes: list[Index],
refs: list[weakref.ref | None] | None = None,
parent: object = None,
) -> SingleBlockManager:
"""
Constructor for BlockManager and SingleBlockManager with same signature.
Expand All @@ -1908,7 +1932,7 @@ def from_blocks(
assert len(axes) == 1
if refs is not None:
assert len(refs) == 1
return cls(blocks[0], axes[0], refs, verify_integrity=False)
return cls(blocks[0], axes[0], refs, parent, verify_integrity=False)

@classmethod
def from_array(cls, array: ArrayLike, index: Index) -> SingleBlockManager:
Expand All @@ -1928,7 +1952,10 @@ def to_2d_mgr(self, columns: Index) -> BlockManager:
new_blk = type(blk)(arr, placement=bp, ndim=2)
axes = [columns, self.axes[0]]
refs: list[weakref.ref | None] = [weakref.ref(blk)]
return BlockManager([new_blk], axes=axes, refs=refs, verify_integrity=False)
parent = self if _using_copy_on_write() else None
return BlockManager(
[new_blk], axes=axes, refs=refs, parent=parent, verify_integrity=False
)

def _has_no_reference(self, i: int = 0) -> bool:
"""
Expand Down Expand Up @@ -2010,7 +2037,7 @@ def getitem_mgr(self, indexer: slice | npt.NDArray[np.bool_]) -> SingleBlockMana
new_idx = self.index[indexer]
# TODO(CoW) in theory only need to track reference if new_array is a view
ref = weakref.ref(blk)
return type(self)(block, new_idx, [ref])
return type(self)(block, new_idx, [ref], parent=self)

def get_slice(self, slobj: slice, axis: AxisInt = 0) -> SingleBlockManager:
# Assertion disabled for performance
Expand All @@ -2023,7 +2050,9 @@ def get_slice(self, slobj: slice, axis: AxisInt = 0) -> SingleBlockManager:
bp = BlockPlacement(slice(0, len(array)))
block = type(blk)(array, placement=bp, ndim=1)
new_index = self.index._getitem_slice(slobj)
return type(self)(block, new_index, [weakref.ref(blk)])
# TODO this method is only used in groupby SeriesSplitter at the moment,
# so passing refs / parent is not yet covered by the tests
return type(self)(block, new_index, [weakref.ref(blk)], parent=self)

@property
def index(self) -> Index:
Expand Down Expand Up @@ -2070,6 +2099,7 @@ def setitem_inplace(self, indexer, value) -> None:
if _using_copy_on_write() and not self._has_no_reference(0):
self.blocks = (self._block.copy(),)
self.refs = None
self.parent = None
self._cache.clear()

super().setitem_inplace(indexer, value)
Expand All @@ -2086,6 +2116,7 @@ def idelete(self, indexer) -> SingleBlockManager:
self._cache.clear()
# clear reference since delete always results in a new array
self.refs = None
self.parent = None
return self

def fast_xs(self, loc):
Expand Down
Loading

0 comments on commit b858de0

Please sign in to comment.