Skip to content

Commit

Permalink
BUG: MultiIndex.difference not keeping ea dtype (#48606)
Browse files Browse the repository at this point in the history
* BUG: MultiIndex.difference not keeping ea dtype

* Add asv

* Add whatsnew

* Reduce asv

* Ad ea asv

* Fix mypy

* Add whatsnew

* Fix mypy
  • Loading branch information
phofl authored Sep 23, 2022
1 parent 1209160 commit 44a4f16
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 9 deletions.
39 changes: 39 additions & 0 deletions asv_bench/benchmarks/multiindex_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,45 @@ def time_operation(self, index_structure, dtype, method):
getattr(self.left, method)(self.right)


class Difference:

params = [
("datetime", "int", "string", "ea_int"),
]
param_names = ["dtype"]

def setup(self, dtype):
N = 10**4 * 2
level1 = range(1000)

level2 = date_range(start="1/1/2000", periods=N // 1000)
dates_left = MultiIndex.from_product([level1, level2])

level2 = range(N // 1000)
int_left = MultiIndex.from_product([level1, level2])

level2 = Series(range(N // 1000), dtype="Int64")
level2[0] = NA
ea_int_left = MultiIndex.from_product([level1, level2])

level2 = tm.makeStringIndex(N // 1000).values
str_left = MultiIndex.from_product([level1, level2])

data = {
"datetime": dates_left,
"int": int_left,
"ea_int": ea_int_left,
"string": str_left,
}

data = {k: {"left": mi, "right": mi[:5]} for k, mi in data.items()}
self.left = data[dtype]["left"]
self.right = data[dtype]["right"]

def time_difference(self, dtype):
self.left.difference(self.right)


class Unique:
params = [
(("Int64", NA), ("int64", 0)),
Expand Down
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ Performance improvements
- Performance improvement in :meth:`MultiIndex.argsort` and :meth:`MultiIndex.sort_values` (:issue:`48406`)
- Performance improvement in :meth:`MultiIndex.size` (:issue:`48723`)
- Performance improvement in :meth:`MultiIndex.union` without missing values and without duplicates (:issue:`48505`)
- Performance improvement in :meth:`MultiIndex.difference` (:issue:`48606`)
- Performance improvement in :meth:`.DataFrameGroupBy.mean`, :meth:`.SeriesGroupBy.mean`, :meth:`.DataFrameGroupBy.var`, and :meth:`.SeriesGroupBy.var` for extension array dtypes (:issue:`37493`)
- Performance improvement in :meth:`MultiIndex.isin` when ``level=None`` (:issue:`48622`)
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
Expand Down Expand Up @@ -210,6 +211,7 @@ Missing

MultiIndex
^^^^^^^^^^
- Bug in :meth:`MultiIndex.difference` losing extension array dtype (:issue:`48606`)
- Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`)
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
- Bug in :meth:`MultiIndex.intersection` losing extension array (:issue:`48604`)
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3684,7 +3684,12 @@ def _difference(self, other, sort):
indexer = indexer.take((indexer != -1).nonzero()[0])

label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True)
the_diff = this._values.take(label_diff)

the_diff: MultiIndex | ArrayLike
if isinstance(this, ABCMultiIndex):
the_diff = this.take(label_diff)
else:
the_diff = this._values.take(label_diff)
the_diff = _maybe_try_sort(the_diff, sort)

return the_diff
Expand Down
11 changes: 3 additions & 8 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3746,18 +3746,13 @@ def _wrap_intersection_result(self, other, result) -> MultiIndex:
_, result_names = self._convert_can_do_setop(other)
return result.set_names(result_names)

def _wrap_difference_result(self, other, result) -> MultiIndex:
def _wrap_difference_result(self, other, result: MultiIndex) -> MultiIndex:
_, result_names = self._convert_can_do_setop(other)

if len(result) == 0:
return MultiIndex(
levels=[[]] * self.nlevels,
codes=[[]] * self.nlevels,
names=result_names,
verify_integrity=False,
)
return result.remove_unused_levels().set_names(result_names)
else:
return MultiIndex.from_tuples(result, sortorder=0, names=result_names)
return result.set_names(result_names)

def _convert_can_do_setop(self, other):
result_names = self.names
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,27 @@ def test_setops_disallow_true(method):
getattr(idx1, method)(idx2, sort=True)


@pytest.mark.parametrize("val", [pd.NA, 100])
def test_difference_keep_ea_dtypes(any_numeric_ea_dtype, val):
# GH#48606
midx = MultiIndex.from_arrays(
[Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
)
midx2 = MultiIndex.from_arrays(
[Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
)
result = midx.difference(midx2)
expected = MultiIndex.from_arrays([Series([1], dtype=any_numeric_ea_dtype), [2]])
tm.assert_index_equal(result, expected)

result = midx.difference(midx.sort_values(ascending=False))
expected = MultiIndex.from_arrays(
[Series([], dtype=any_numeric_ea_dtype), Series([], dtype=int)],
names=["a", None],
)
tm.assert_index_equal(result, expected)


@pytest.mark.parametrize("val", [pd.NA, 5])
def test_symmetric_difference_keeping_ea_dtype(any_numeric_ea_dtype, val):
# GH#48607
Expand Down

0 comments on commit 44a4f16

Please sign in to comment.