Skip to content

Commit

Permalink
BUG: eval and query not working with ea dtypes (#50764)
Browse files Browse the repository at this point in the history
* BUG: eval and query not working with ea dtypes

* Fix windows build

* Fix

* Fix another bug

* Fix eval

* Fix

* Fix

* Add arrow tests

* Fix pyarrow-less ci

* Add try except

* Add warning

* Adjust warning

* Fix warning

* Fix

* Update test_query_eval.py

* Update pandas/core/computation/eval.py

---------

Co-authored-by: Matthew Roeschke <[email protected]>
  • Loading branch information
phofl and mroeschke authored Feb 9, 2023
1 parent 4c8b2ea commit b9a4335
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,7 @@ Conversion
- Bug in :meth:`DataFrame.astype` not copying data when converting to pyarrow dtype (:issue:`50984`)
- Bug in :func:`to_datetime` was not respecting ``exact`` argument when ``format`` was an ISO8601 format (:issue:`12649`)
- Bug in :meth:`TimedeltaArray.astype` raising ``TypeError`` when converting to a pyarrow duration type (:issue:`49795`)
- Bug in :meth:`DataFrame.eval` and :meth:`DataFrame.query` raising for extension array dtypes (:issue:`29618`, :issue:`50261`, :issue:`31913`)
-

Strings
Expand Down
20 changes: 20 additions & 0 deletions pandas/core/computation/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,23 @@ def result_type_many(*arrays_and_dtypes):
except ValueError:
# we have > NPY_MAXARGS terms in our expression
return reduce(np.result_type, arrays_and_dtypes)
except TypeError:
from pandas.core.dtypes.cast import find_common_type
from pandas.core.dtypes.common import is_extension_array_dtype

arr_and_dtypes = list(arrays_and_dtypes)
ea_dtypes, non_ea_dtypes = [], []
for arr_or_dtype in arr_and_dtypes:
if is_extension_array_dtype(arr_or_dtype):
ea_dtypes.append(arr_or_dtype)
else:
non_ea_dtypes.append(arr_or_dtype)

if non_ea_dtypes:
try:
np_dtype = np.result_type(*non_ea_dtypes)
except ValueError:
np_dtype = reduce(np.result_type, arrays_and_dtypes)
return find_common_type(ea_dtypes + [np_dtype])

return find_common_type(ea_dtypes)
19 changes: 19 additions & 0 deletions pandas/core/computation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from typing import TYPE_CHECKING
import warnings

from pandas.util._exceptions import find_stack_level
from pandas.util._validators import validate_bool_kwarg

from pandas.core.dtypes.common import is_extension_array_dtype

from pandas.core.computation.engines import ENGINES
from pandas.core.computation.expr import (
PARSERS,
Expand Down Expand Up @@ -333,6 +336,22 @@ def eval(

parsed_expr = Expr(expr, engine=engine, parser=parser, env=env)

if engine == "numexpr" and (
is_extension_array_dtype(parsed_expr.terms.return_type)
or getattr(parsed_expr.terms, "operand_types", None) is not None
and any(
is_extension_array_dtype(elem)
for elem in parsed_expr.terms.operand_types
)
):
warnings.warn(
"Engine has switched to 'python' because numexpr does not support "
"extension array dtypes. Please set your engine to python manually.",
RuntimeWarning,
stacklevel=find_stack_level(),
)
engine = "python"

# construct the engine and evaluate the parsed expression
eng = ENGINES[engine]
eng_inst = eng(parsed_expr)
Expand Down
77 changes: 77 additions & 0 deletions pandas/tests/frame/test_query_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest

from pandas.compat import is_platform_windows
from pandas.errors import (
NumExprClobberingError,
UndefinedVariableError,
Expand Down Expand Up @@ -1291,3 +1292,79 @@ def func(*_):

with pytest.raises(TypeError, match="Only named functions are supported"):
df.eval("@funcs[0].__call__()")

def test_ea_dtypes(self, any_numeric_ea_and_arrow_dtype):
# GH#29618
df = DataFrame(
[[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype
)
warning = RuntimeWarning if NUMEXPR_INSTALLED else None
with tm.assert_produces_warning(warning):
result = df.eval("c = b - a")
expected = DataFrame(
[[1, 2, 1], [3, 4, 1]],
columns=["a", "b", "c"],
dtype=any_numeric_ea_and_arrow_dtype,
)
tm.assert_frame_equal(result, expected)

def test_ea_dtypes_and_scalar(self):
# GH#29618
df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"], dtype="Float64")
warning = RuntimeWarning if NUMEXPR_INSTALLED else None
with tm.assert_produces_warning(warning):
result = df.eval("c = b - 1")
expected = DataFrame(
[[1, 2, 1], [3, 4, 3]], columns=["a", "b", "c"], dtype="Float64"
)
tm.assert_frame_equal(result, expected)

def test_ea_dtypes_and_scalar_operation(self, any_numeric_ea_and_arrow_dtype):
# GH#29618
df = DataFrame(
[[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype
)
result = df.eval("c = 2 - 1")
expected = DataFrame(
{
"a": Series([1, 3], dtype=any_numeric_ea_and_arrow_dtype),
"b": Series([2, 4], dtype=any_numeric_ea_and_arrow_dtype),
"c": Series(
[1, 1], dtype="int64" if not is_platform_windows() else "int32"
),
}
)
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"])
def test_query_ea_dtypes(self, dtype):
if dtype == "int64[pyarrow]":
pytest.importorskip("pyarrow")
# GH#50261
df = DataFrame({"a": Series([1, 2], dtype=dtype)})
ref = {2} # noqa:F841
result = df.query("a in @ref")
expected = DataFrame({"a": Series([2], dtype=dtype, index=[1])})
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize("engine", ["python", "numexpr"])
@pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"])
def test_query_ea_equality_comparison(self, dtype, engine):
# GH#50261
warning = RuntimeWarning if engine == "numexpr" else None
if engine == "numexpr" and not NUMEXPR_INSTALLED:
pytest.skip("numexpr not installed")
if dtype == "int64[pyarrow]":
pytest.importorskip("pyarrow")
df = DataFrame(
{"A": Series([1, 1, 2], dtype="Int64"), "B": Series([1, 2, 2], dtype=dtype)}
)
with tm.assert_produces_warning(warning):
result = df.query("A == B", engine=engine)
expected = DataFrame(
{
"A": Series([1, 2], dtype="Int64", index=[0, 2]),
"B": Series([1, 2], dtype=dtype, index=[0, 2]),
}
)
tm.assert_frame_equal(result, expected)

0 comments on commit b9a4335

Please sign in to comment.