Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

String dtype: avoid surfacing pyarrow exception in binary operations #59610

Merged
38 changes: 32 additions & 6 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,12 @@ def __invert__(self) -> Self:
return type(self)(pc.invert(self._pa_array))

def __neg__(self) -> Self:
return type(self)(pc.negate_checked(self._pa_array))
try:
return type(self)(pc.negate_checked(self._pa_array))
except pa.ArrowNotImplementedError as err:
raise TypeError(
f"unary '-' not supported for dtype '{self.dtype}'"
) from err

def __pos__(self) -> Self:
return type(self)(self._pa_array)
Expand Down Expand Up @@ -736,8 +741,19 @@ def _cmp_method(self, other, op) -> ArrowExtensionArray:
)
return ArrowExtensionArray(result)

def _op_method_error_message(self, other, op) -> str:
if hasattr(other, "dtype"):
other_type = f"dtype '{other.dtype}'"
else:
other_type = f"object of type {type(other)}"
return (
f"operation '{op.__name__}' not supported for "
f"dtype '{self.dtype}' with {other_type}"
)

def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:
pa_type = self._pa_array.type
other_original = other
other = self._box_pa(other)

if (
Expand All @@ -747,10 +763,15 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:
):
if op in [operator.add, roperator.radd]:
sep = pa.scalar("", type=pa_type)
if op is operator.add:
result = pc.binary_join_element_wise(self._pa_array, other, sep)
elif op is roperator.radd:
result = pc.binary_join_element_wise(other, self._pa_array, sep)
try:
if op is operator.add:
result = pc.binary_join_element_wise(self._pa_array, other, sep)
elif op is roperator.radd:
result = pc.binary_join_element_wise(other, self._pa_array, sep)
except pa.ArrowNotImplementedError as err:
raise TypeError(
self._op_method_error_message(other_original, op)
) from err
return type(self)(result)
elif op in [operator.mul, roperator.rmul]:
binary = self._pa_array
Expand Down Expand Up @@ -782,9 +803,14 @@ def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:

pc_func = arrow_funcs[op.__name__]
if pc_func is NotImplemented:
if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
raise TypeError(self._op_method_error_message(other_original, op))
raise NotImplementedError(f"{op.__name__} not implemented.")

result = pc_func(self._pa_array, other)
try:
result = pc_func(self._pa_array, other)
except pa.ArrowNotImplementedError as err:
raise TypeError(self._op_method_error_message(other_original, op)) from err
return type(self)(result)

def _logical_method(self, other, op) -> Self:
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,11 @@ def _cmp_method(self, other, op):
f"Lengths of operands do not match: {len(self)} != {len(other)}"
)

other = np.asarray(other)
# for array-likes, first filter out NAs before converting to numpy
if not is_array_like(other):
other = np.asarray(other)
other = other[valid]
other = np.asarray(other)

if op.__name__ in ops.ARITHMETIC_BINOPS:
result = np.empty_like(self._ndarray, dtype="object")
Expand Down
25 changes: 6 additions & 19 deletions pandas/tests/arithmetic/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW
import pandas.util._test_decorators as td

import pandas as pd
Expand Down Expand Up @@ -318,27 +315,17 @@ def test_add(self):
expected = pd.Index(["1a", "1b", "1c"])
tm.assert_index_equal("1" + index, expected)

@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_sub_fail(self, using_infer_string):
def test_sub_fail(self):
index = pd.Index([str(i) for i in range(10)])

if using_infer_string:
import pyarrow as pa

err = pa.lib.ArrowNotImplementedError
msg = "has no kernel"
else:
err = TypeError
msg = "unsupported operand type|Cannot broadcast"
with pytest.raises(err, match=msg):
msg = "unsupported operand type|Cannot broadcast|sub' not supported"
with pytest.raises(TypeError, match=msg):
index - "a"
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
index - index
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
index - index.tolist()
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
index.tolist() - index

def test_sub_object(self):
Expand Down
26 changes: 7 additions & 19 deletions pandas/tests/arrays/boolean/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW

import pandas as pd
import pandas._testing as tm

Expand Down Expand Up @@ -94,19 +90,8 @@ def test_op_int8(left_array, right_array, opname):
# -----------------------------------------------------------------------------


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
def test_error_invalid_values(data, all_arithmetic_operators):
# invalid ops

if using_infer_string:
import pyarrow as pa

err = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
else:
err = TypeError

op = all_arithmetic_operators
s = pd.Series(data)
ops = getattr(s, op)
Expand All @@ -116,7 +101,8 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"did not contain a loop with signature matching types|"
"BooleanArray cannot perform the operation|"
"not supported for the input types, and the inputs could not be safely coerced "
"to any supported types according to the casting rule ''safe''"
"to any supported types according to the casting rule ''safe''|"
"not supported for dtype"
)
with pytest.raises(TypeError, match=msg):
ops("foo")
Expand All @@ -125,9 +111,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
r"unsupported operand type\(s\) for",
"Concatenation operation is not implemented for NumPy arrays",
"has no kernel",
"not supported for dtype",
]
)
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Timestamp("20180101"))

# invalid array-likes
Expand All @@ -140,7 +127,8 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"not all arguments converted during string formatting",
"has no kernel",
"not implemented",
"not supported for dtype",
]
)
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series("foo", index=s.index))
23 changes: 8 additions & 15 deletions pandas/tests/arrays/floating/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays import FloatingArray
Expand Down Expand Up @@ -124,19 +122,11 @@ def test_arith_zero_dim_ndarray(other):
# -----------------------------------------------------------------------------


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
def test_error_invalid_values(data, all_arithmetic_operators):
op = all_arithmetic_operators
s = pd.Series(data)
ops = getattr(s, op)

if using_infer_string:
import pyarrow as pa

errs = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
else:
errs = TypeError

# invalid scalars
msg = "|".join(
[
Expand All @@ -152,15 +142,17 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"Concatenation operation is not implemented for NumPy arrays",
"has no kernel",
"not implemented",
"not supported for dtype",
"Can only string multiply by an integer",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops("foo")
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Timestamp("20180101"))

# invalid array-likes
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series("foo", index=s.index))

msg = "|".join(
Expand All @@ -181,9 +173,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"cannot subtract DatetimeArray from ndarray",
"has no kernel",
"not implemented",
"not supported for dtype",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series(pd.date_range("20180101", periods=len(s))))


Expand Down
34 changes: 11 additions & 23 deletions pandas/tests/arrays/integer/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas as pd
import pandas._testing as tm
from pandas.core import ops
Expand Down Expand Up @@ -174,19 +172,11 @@ def test_numpy_zero_dim_ndarray(other):
# -----------------------------------------------------------------------------


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
def test_error_invalid_values(data, all_arithmetic_operators):
op = all_arithmetic_operators
s = pd.Series(data)
ops = getattr(s, op)

if using_infer_string:
import pyarrow as pa

errs = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
else:
errs = TypeError

# invalid scalars
msg = "|".join(
[
Expand All @@ -201,24 +191,21 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"has no kernel",
"not implemented",
"The 'out' kwarg is necessary. Use numpy.strings.multiply without it.",
"not supported for dtype",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops("foo")
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Timestamp("20180101"))

# invalid array-likes
str_ser = pd.Series("foo", index=s.index)
# with pytest.raises(TypeError, match=msg):
if (
all_arithmetic_operators
in [
"__mul__",
"__rmul__",
]
and not using_infer_string
): # (data[~data.isna()] >= 0).all():
if all_arithmetic_operators in [
"__mul__",
"__rmul__",
]: # (data[~data.isna()] >= 0).all():
res = ops(str_ser)
expected = pd.Series(["foo" * x for x in data], index=s.index)
expected = expected.fillna(np.nan)
Expand All @@ -227,7 +214,7 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
# more-correct than np.nan here.
tm.assert_series_equal(res, expected)
else:
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(str_ser)

msg = "|".join(
Expand All @@ -242,9 +229,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"cannot subtract DatetimeArray from ndarray",
"has no kernel",
"not implemented",
"not supported for dtype",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series(pd.date_range("20180101", periods=len(s))))


Expand Down
10 changes: 1 addition & 9 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BaseOpsUtil:

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
) -> type[Exception] | tuple[type[Exception], ...] | None:
# Find the Exception, if any we expect to raise calling
# obj.__op_name__(other)

Expand All @@ -39,14 +39,6 @@ def _get_expected_exception(
else:
result = self.frame_scalar_exc

if using_string_dtype() and result is not None:
import pyarrow as pa

result = ( # type: ignore[assignment]
result,
pa.lib.ArrowNotImplementedError,
NotImplementedError,
)
return result

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def data_for_grouping():
class TestDecimalArray(base.ExtensionTests):
def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
) -> type[Exception] | tuple[type[Exception], ...] | None:
return None

def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
Expand Down
Loading