diff --git a/doc/source/whatsnew/v2.3.0.rst b/doc/source/whatsnew/v2.3.0.rst index d57d86f4a1476..da0fcfc2b3f64 100644 --- a/doc/source/whatsnew/v2.3.0.rst +++ b/doc/source/whatsnew/v2.3.0.rst @@ -106,10 +106,10 @@ Conversion Strings ^^^^^^^ - Bug in :meth:`Series.rank` for :class:`StringDtype` with ``storage="pyarrow"`` incorrectly returning integer results in case of ``method="average"`` and raising an error if it would truncate results (:issue:`59768`) +- Bug in :meth:`Series.replace` with :class:`StringDtype` when replacing with a non-string value was not upcasting to ``object`` dtype (:issue:`60282`) - Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`StringDtype` with ``storage="pyarrow"`` (:issue:`59628`) - Bug in ``ser.str.slice`` with negative ``step`` with :class:`ArrowDtype` and :class:`StringDtype` with ``storage="pyarrow"`` giving incorrect results (:issue:`59710`) - Bug in the ``center`` method on :class:`Series` and :class:`Index` object ``str`` accessors with pyarrow-backed dtype not matching the python behavior in corner cases with an odd number of fill characters (:issue:`54792`) -- Interval ^^^^^^^^ diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 9b1f986a7158e..3b881cfd2df2f 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -730,20 +730,9 @@ def _values_for_factorize(self) -> tuple[np.ndarray, libmissing.NAType | float]: return arr, self.dtype.na_value - def __setitem__(self, key, value) -> None: - value = extract_array(value, extract_numpy=True) - if isinstance(value, type(self)): - # extract_array doesn't extract NumpyExtensionArray subclasses - value = value._ndarray - - key = check_array_indexer(self, key) - scalar_key = lib.is_scalar(key) - scalar_value = lib.is_scalar(value) - if scalar_key and not scalar_value: - raise ValueError("setting an array element with a sequence.") - - # validate new items - if scalar_value: + def _maybe_convert_setitem_value(self, value): + """Maybe convert value to be pyarrow compatible.""" + if lib.is_scalar(value): if isna(value): value = self.dtype.na_value elif not isinstance(value, str): @@ -753,8 +742,11 @@ def __setitem__(self, key, value) -> None: "instead." ) else: + value = extract_array(value, extract_numpy=True) if not is_array_like(value): value = np.asarray(value, dtype=object) + elif isinstance(value.dtype, type(self.dtype)): + return value else: # cast categories and friends to arrays to see if values are # compatible, compatibility with arrow backed strings @@ -764,11 +756,26 @@ def __setitem__(self, key, value) -> None: "Invalid value for dtype 'str'. Value should be a " "string or missing value (or array of those)." ) + return value - mask = isna(value) - if mask.any(): - value = value.copy() - value[isna(value)] = self.dtype.na_value + def __setitem__(self, key, value) -> None: + value = self._maybe_convert_setitem_value(value) + + key = check_array_indexer(self, key) + scalar_key = lib.is_scalar(key) + scalar_value = lib.is_scalar(value) + if scalar_key and not scalar_value: + raise ValueError("setting an array element with a sequence.") + + if not scalar_value: + if value.dtype == self.dtype: + value = value._ndarray + else: + value = np.asarray(value) + mask = isna(value) + if mask.any(): + value = value.copy() + value[isna(value)] = self.dtype.na_value super().__setitem__(key, value) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 8850b75323d68..830b84852c704 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1749,6 +1749,13 @@ def can_hold_element(arr: ArrayLike, element: Any) -> bool: except (ValueError, TypeError): return False + if dtype == "string": + try: + arr._maybe_convert_setitem_value(element) # type: ignore[union-attr] + return True + except (ValueError, TypeError): + return False + # This is technically incorrect, but maintains the behavior of # ExtensionBlock._can_hold_element return True diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index a3ff577966a6d..3264676771d5d 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -77,6 +77,7 @@ ABCNumpyExtensionArray, ABCSeries, ) +from pandas.core.dtypes.inference import is_re from pandas.core.dtypes.missing import ( is_valid_na_for_dtype, isna, @@ -706,7 +707,7 @@ def replace( # bc _can_hold_element is incorrect. return [self.copy(deep=False)] - elif self._can_hold_element(value): + elif self._can_hold_element(value) or (self.dtype == "string" and is_re(value)): # TODO(CoW): Maybe split here as well into columns where mask has True # and rest? blk = self._maybe_copy(inplace) @@ -766,14 +767,24 @@ def _replace_regex( ------- List[Block] """ - if not self._can_hold_element(to_replace): + if not is_re(to_replace) and not self._can_hold_element(to_replace): # i.e. only if self.is_object is True, but could in principle include a # String ExtensionBlock return [self.copy(deep=False)] - rx = re.compile(to_replace) + if is_re(to_replace) and self.dtype not in [object, "string"]: + # only object or string dtype can hold strings, and a regex object + # will only match strings + return [self.copy(deep=False)] - block = self._maybe_copy(inplace) + if not ( + self._can_hold_element(value) or (self.dtype == "string" and is_re(value)) + ): + block = self.astype(np.dtype(object)) + else: + block = self._maybe_copy(inplace) + + rx = re.compile(to_replace) replace_regex(block.values, rx, value, mask) return [block] @@ -793,7 +804,9 @@ def replace_list( # Exclude anything that we know we won't contain pairs = [ - (x, y) for x, y in zip(src_list, dest_list) if self._can_hold_element(x) + (x, y) + for x, y in zip(src_list, dest_list) + if (self._can_hold_element(x) or (self.dtype == "string" and is_re(x))) ] if not len(pairs): return [self.copy(deep=False)] diff --git a/pandas/tests/frame/methods/test_replace.py b/pandas/tests/frame/methods/test_replace.py index 6b872bf48d550..73f44bcc6657e 100644 --- a/pandas/tests/frame/methods/test_replace.py +++ b/pandas/tests/frame/methods/test_replace.py @@ -889,7 +889,6 @@ def test_replace_input_formats_listlike(self): with pytest.raises(ValueError, match=msg): df.replace(to_rep, values[1:]) - @pytest.mark.xfail(using_string_dtype(), reason="can't set float into string") def test_replace_input_formats_scalar(self): df = DataFrame( {"A": [np.nan, 0, np.inf], "B": [0, 2, 5], "C": ["", "asdf", "fd"]} @@ -940,7 +939,6 @@ def test_replace_dict_no_regex(self): result = answer.replace(weights) tm.assert_series_equal(result, expected) - @pytest.mark.xfail(using_string_dtype(), reason="can't set float into string") def test_replace_series_no_regex(self): answer = Series( { @@ -1176,7 +1174,6 @@ def test_replace_commutative(self, df, to_replace, exp): result = df.replace(to_replace) tm.assert_frame_equal(result, expected) - @pytest.mark.xfail(using_string_dtype(), reason="can't set float into string") @pytest.mark.parametrize( "replacer", [ diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index 82c616132456b..0d62317893326 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -860,24 +860,16 @@ def test_index_where(self, obj, key, expected, raises, val): mask = np.zeros(obj.shape, dtype=bool) mask[key] = True - if raises and obj.dtype == "string": - with pytest.raises(TypeError, match="Invalid value"): - Index(obj).where(~mask, val) - else: - res = Index(obj).where(~mask, val) - expected_idx = Index(expected, dtype=expected.dtype) - tm.assert_index_equal(res, expected_idx) + res = Index(obj).where(~mask, val) + expected_idx = Index(expected, dtype=expected.dtype) + tm.assert_index_equal(res, expected_idx) def test_index_putmask(self, obj, key, expected, raises, val): mask = np.zeros(obj.shape, dtype=bool) mask[key] = True - if raises and obj.dtype == "string": - with pytest.raises(TypeError, match="Invalid value"): - Index(obj).putmask(mask, val) - else: - res = Index(obj).putmask(mask, val) - tm.assert_index_equal(res, Index(expected, dtype=expected.dtype)) + res = Index(obj).putmask(mask, val) + tm.assert_index_equal(res, Index(expected, dtype=expected.dtype)) @pytest.mark.parametrize( diff --git a/pandas/tests/series/methods/test_replace.py b/pandas/tests/series/methods/test_replace.py index 1ebef333f054a..ecfe3d1b39d31 100644 --- a/pandas/tests/series/methods/test_replace.py +++ b/pandas/tests/series/methods/test_replace.py @@ -635,13 +635,11 @@ def test_replace_regex_dtype_series(self, regex): tm.assert_series_equal(result, expected) @pytest.mark.parametrize("regex", [False, True]) - def test_replace_regex_dtype_series_string(self, regex, using_infer_string): - if not using_infer_string: - # then this is object dtype which is already tested above - return + def test_replace_regex_dtype_series_string(self, regex): series = pd.Series(["0"], dtype="str") - with pytest.raises(TypeError, match="Invalid value"): - series.replace(to_replace="0", value=1, regex=regex) + expected = pd.Series([1], dtype=object) + result = series.replace(to_replace="0", value=1, regex=regex) + tm.assert_series_equal(result, expected) def test_replace_different_int_types(self, any_int_numpy_dtype): # GH#45311