Skip to content

Commit

Permalink
REF: standardize get_indexer/get_indexer_non_unique (#39343)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Jan 23, 2021
1 parent 9ed521e commit 309cf3a
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 15 deletions.
33 changes: 19 additions & 14 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3289,11 +3289,10 @@ def get_indexer(
if not self._index_as_unique:
raise InvalidIndexError(self._requires_unique_msg)

# Treat boolean labels passed to a numeric index as not found. Without
# this fix False and True would be treated as 0 and 1 respectively.
# (GH #16877)
if target.is_boolean() and self.is_numeric():
return ensure_platform_int(np.repeat(-1, target.size))
if not self._should_compare(target) and not is_interval_dtype(self.dtype):
# IntervalIndex get special treatment bc numeric scalars can be
# matched to Interval scalars
return self._get_indexer_non_comparable(target, method=method, unique=True)

pself, ptarget = self._maybe_promote(target)
if pself is not self or ptarget is not target:
Expand All @@ -3310,8 +3309,9 @@ def _get_indexer(
tolerance = self._convert_tolerance(tolerance, target)

if not is_dtype_equal(self.dtype, target.dtype):
this = self.astype(object)
target = target.astype(object)
dtype = find_common_type([self.dtype, target.dtype])
this = self.astype(dtype, copy=False)
target = target.astype(dtype, copy=False)
return this.get_indexer(
target, method=method, limit=limit, tolerance=tolerance
)
Expand Down Expand Up @@ -5060,19 +5060,15 @@ def set_value(self, arr, key, value):
def get_indexer_non_unique(self, target):
target = ensure_index(target)

if target.is_boolean() and self.is_numeric():
# Treat boolean labels passed to a numeric index as not found. Without
# this fix False and True would be treated as 0 and 1 respectively.
# (GH #16877)
if not self._should_compare(target) and not is_interval_dtype(self.dtype):
# IntervalIndex get special treatment bc numeric scalars can be
# matched to Interval scalars
return self._get_indexer_non_comparable(target, method=None, unique=False)

pself, ptarget = self._maybe_promote(target)
if pself is not self or ptarget is not target:
return pself.get_indexer_non_unique(ptarget)

if not self._should_compare(target):
return self._get_indexer_non_comparable(target, method=None, unique=False)

if not is_dtype_equal(self.dtype, target.dtype):
# TODO: if object, could use infer_dtype to preempt costly
# conversion if still non-comparable?
Expand Down Expand Up @@ -5193,6 +5189,15 @@ def _should_compare(self, other: Index) -> bool:
"""
Check if `self == other` can ever have non-False entries.
"""

if (other.is_boolean() and self.is_numeric()) or (
self.is_boolean() and other.is_numeric()
):
# GH#16877 Treat boolean labels passed to a numeric index as not
# found. Without this fix False and True would be treated as 0 and 1
# respectively.
return False

other = unpack_nested_dtype(other)
dtype = other.dtype
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
Expand Down
31 changes: 31 additions & 0 deletions pandas/tests/indexes/numeric/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,37 @@ def test_get_loc_missing_nan(self):


class TestGetIndexer:
@pytest.mark.parametrize("method", ["pad", "backfill", "nearest"])
def test_get_indexer_with_method_numeric_vs_bool(self, method):
left = Index([1, 2, 3])
right = Index([True, False])

with pytest.raises(TypeError, match="Cannot compare"):
left.get_indexer(right, method=method)

with pytest.raises(TypeError, match="Cannot compare"):
right.get_indexer(left, method=method)

def test_get_indexer_numeric_vs_bool(self):
left = Index([1, 2, 3])
right = Index([True, False])

res = left.get_indexer(right)
expected = -1 * np.ones(len(right), dtype=np.intp)
tm.assert_numpy_array_equal(res, expected)

res = right.get_indexer(left)
expected = -1 * np.ones(len(left), dtype=np.intp)
tm.assert_numpy_array_equal(res, expected)

res = left.get_indexer_non_unique(right)[0]
expected = -1 * np.ones(len(right), dtype=np.intp)
tm.assert_numpy_array_equal(res, expected)

res = right.get_indexer_non_unique(left)[0]
expected = -1 * np.ones(len(left), dtype=np.intp)
tm.assert_numpy_array_equal(res, expected)

def test_get_indexer_float64(self):
idx = Float64Index([0.0, 1.0, 2.0])
tm.assert_numpy_array_equal(
Expand Down
11 changes: 11 additions & 0 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,17 @@ def test_asof(self, index):
d = index[0].to_pydatetime()
assert isinstance(index.asof(d), Timestamp)

def test_asof_numeric_vs_bool_raises(self):
left = Index([1, 2, 3])
right = Index([True, False])

msg = "'<' not supported between instances"
with pytest.raises(TypeError, match=msg):
left.asof(right)

with pytest.raises(TypeError, match=msg):
right.asof(left)

def test_asof_datetime_partial(self):
index = date_range("2010-01-01", periods=2, freq="m")
expected = Timestamp("2010-02-28")
Expand Down
5 changes: 4 additions & 1 deletion pandas/tests/series/methods/test_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,10 @@ def test_reindex_datetimeindexes_tz_naive_and_aware():
idx = date_range("20131101", tz="America/Chicago", periods=7)
newidx = date_range("20131103", periods=10, freq="H")
s = Series(range(7), index=idx)
msg = "Cannot compare tz-naive and tz-aware timestamps"
msg = (
r"Cannot compare dtypes datetime64\[ns, America/Chicago\] "
r"and datetime64\[ns\]"
)
with pytest.raises(TypeError, match=msg):
s.reindex(newidx, method="ffill")

Expand Down

0 comments on commit 309cf3a

Please sign in to comment.