Skip to content

Commit

Permalink
String dtype: use 'str' string alias and representation for NaN-varia…
Browse files Browse the repository at this point in the history
…nt of the dtype (pandas-dev#59388)
  • Loading branch information
jorisvandenbossche authored and WillAyd committed Aug 22, 2024
1 parent 1dab487 commit 8e6b5c2
Show file tree
Hide file tree
Showing 79 changed files with 306 additions and 192 deletions.
6 changes: 5 additions & 1 deletion pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config.localization import (
can_set_locale,
get_locales,
Expand Down Expand Up @@ -110,7 +111,10 @@
ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]

COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
STRING_DTYPES: list[Dtype] = [str, "str", "U"]
if using_string_dtype():
STRING_DTYPES: list[Dtype] = [str, "U"]
else:
STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]

DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,10 @@ def __getitem__(self, item: PositionalIndexer):
if isinstance(item, np.ndarray):
if not len(item):
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
if (
isinstance(self._dtype, StringDtype)
and self._dtype.storage == "pyarrow"
):
# TODO(infer_string) should this be large_string?
pa_dtype = pa.string()
else:
Expand Down
24 changes: 18 additions & 6 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
cast,
)
Expand Down Expand Up @@ -114,9 +113,12 @@ class StringDtype(StorageExtensionDtype):
string[pyarrow]
"""

# error: Cannot override instance variable (previously declared on
# base class "StorageExtensionDtype") with class variable
name: ClassVar[str] = "string" # type: ignore[misc]
@property
def name(self) -> str: # type: ignore[override]
if self._na_value is libmissing.NA:
return "string"
else:
return "str"

#: StringDtype().na_value uses pandas.NA except the implementation that
# follows NumPy semantics, which uses nan.
Expand All @@ -133,7 +135,7 @@ def __init__(
) -> None:
# infer defaults
if storage is None:
if using_string_dtype() and na_value is not libmissing.NA:
if na_value is not libmissing.NA:
if HAS_PYARROW:
storage = "pyarrow"
else:
Expand Down Expand Up @@ -166,11 +168,19 @@ def __init__(
self.storage = storage
self._na_value = na_value

def __repr__(self) -> str:
if self._na_value is libmissing.NA:
return f"{self.name}[{self.storage}]"
else:
# TODO add more informative repr
return self.name

def __eq__(self, other: object) -> bool:
# we need to override the base class __eq__ because na_value (NA or NaN)
# cannot be checked with normal `==`
if isinstance(other, str):
if other == self.name:
# TODO should dtype == "string" work for the NaN variant?
if other == "string" or other == self.name: # noqa: PLR1714
return True
try:
other = self.construct_from_string(other)
Expand Down Expand Up @@ -227,6 +237,8 @@ def construct_from_string(cls, string) -> Self:
)
if string == "string":
return cls()
elif string == "str" and using_string_dtype():
return cls(na_value=np.nan)
elif string == "string[python]":
return cls(storage="python")
elif string == "string[pyarrow]":
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4979,7 +4979,9 @@ def select_dtypes(self, include=None, exclude=None) -> Self:
-----
* To select all *numeric* types, use ``np.number`` or ``'number'``
* To select strings you must use the ``object`` dtype, but note that
this will return *all* object dtype columns
this will return *all* object dtype columns. With
``pd.options.future.infer_string`` enabled, using ``"str"`` will
work to select all string columns.
* See the `numpy dtype hierarchy
<https://numpy.org/doc/stable/reference/arrays.scalars.html>`__
* To select datetimes, use ``np.datetime64``, ``'datetime'`` or
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/interchange/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ def dtype_to_arrow_c_fmt(dtype: DtypeObj) -> str:
if format_str is not None:
return format_str

if lib.is_np_dtype(dtype, "M"):
if isinstance(dtype, pd.StringDtype):
# TODO(infer_string) this should be LARGE_STRING for pyarrow storage,
# but current tests don't cover this distinction
return ArrowCTypes.STRING

elif lib.is_np_dtype(dtype, "M"):
# Selecting the first char of resolution string:
# dtype.str -> '<M8[ns]' -> 'n'
resolution = np.datetime_data(dtype)[0][0]
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/apply/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_numba_unsupported_dtypes(apply_axis):

with pytest.raises(
ValueError,
match="Column b must have a numeric dtype. Found 'object|string' instead",
match="Column b must have a numeric dtype. Found 'object|str' instead",
):
df.apply(f, engine="numba", axis=apply_axis)

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/apply/test_series_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_apply_categorical(by_row, using_infer_string):
result = ser.apply(lambda x: "A")
exp = Series(["A"] * 7, name="XX", index=list("abcdefg"))
tm.assert_series_equal(result, exp)
assert result.dtype == object if not using_infer_string else "string[pyarrow_numpy]"
assert result.dtype == object if not using_infer_string else "str"


@pytest.mark.parametrize("series", [["1-1", "1-1", np.nan], ["1-1", "1-2", np.nan]])
Expand Down
12 changes: 9 additions & 3 deletions pandas/tests/arrays/boolean/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas._testing as tm


def test_astype():
def test_astype(using_infer_string):
# with missing values
arr = pd.array([True, False, None], dtype="boolean")

Expand All @@ -20,8 +20,14 @@ def test_astype():
tm.assert_numpy_array_equal(result, expected)

result = arr.astype("str")
expected = np.array(["True", "False", "<NA>"], dtype=f"{tm.ENDIAN}U5")
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(
["True", "False", None], dtype=pd.StringDtype(na_value=np.nan)
)
tm.assert_extension_array_equal(result, expected)
else:
expected = np.array(["True", "False", "<NA>"], dtype=f"{tm.ENDIAN}U5")
tm.assert_numpy_array_equal(result, expected)

# no missing values
arr = pd.array([True, False, True], dtype="boolean")
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/categorical/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_astype(self, ordered):
expected = np.array(cat)
tm.assert_numpy_array_equal(result, expected)

msg = r"Cannot cast object|string dtype to float64"
msg = r"Cannot cast object|str dtype to float64"
with pytest.raises(ValueError, match=msg):
cat.astype(float)

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/categorical/test_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_print(self, using_infer_string):
if using_infer_string:
expected = [
"['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c']",
"Categories (3, string): [a < b < c]",
"Categories (3, str): [a < b < c]",
]
else:
expected = [
Expand Down
17 changes: 13 additions & 4 deletions pandas/tests/arrays/floating/test_astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,21 @@ def test_astype_to_integer_array():
tm.assert_extension_array_equal(result, expected)


def test_astype_str():
def test_astype_str(using_infer_string):
a = pd.array([0.1, 0.2, None], dtype="Float64")
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")

tm.assert_numpy_array_equal(a.astype(str), expected)
tm.assert_numpy_array_equal(a.astype("str"), expected)
if using_infer_string:
expected = pd.array(["0.1", "0.2", None], dtype=pd.StringDtype(na_value=np.nan))
tm.assert_extension_array_equal(a.astype("str"), expected)

# TODO(infer_string) this should also be a string array like above
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
tm.assert_numpy_array_equal(a.astype(str), expected)
else:
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")

tm.assert_numpy_array_equal(a.astype(str), expected)
tm.assert_numpy_array_equal(a.astype("str"), expected)


def test_astype_copy():
Expand Down
17 changes: 13 additions & 4 deletions pandas/tests/arrays/integer/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,21 @@ def test_to_numpy_na_raises(dtype):
a.to_numpy(dtype=dtype)


def test_astype_str():
def test_astype_str(using_infer_string):
a = pd.array([1, 2, None], dtype="Int64")
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")

tm.assert_numpy_array_equal(a.astype(str), expected)
tm.assert_numpy_array_equal(a.astype("str"), expected)
if using_infer_string:
expected = pd.array(["1", "2", None], dtype=pd.StringDtype(na_value=np.nan))
tm.assert_extension_array_equal(a.astype("str"), expected)

# TODO(infer_string) this should also be a string array like above
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
tm.assert_numpy_array_equal(a.astype(str), expected)
else:
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")

tm.assert_numpy_array_equal(a.astype(str), expected)
tm.assert_numpy_array_equal(a.astype("str"), expected)


def test_astype_boolean():
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/arrays/interval/test_interval_pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays import IntervalArray
Expand Down Expand Up @@ -82,7 +80,6 @@ def test_arrow_array_missing():
assert result.storage.equals(expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.filterwarnings(
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
)
Expand Down
4 changes: 0 additions & 4 deletions pandas/tests/arrays/period/test_arrow_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pytest

from pandas._config import using_string_dtype

from pandas.compat.pyarrow import pa_version_under10p1

from pandas.core.dtypes.dtypes import PeriodDtype
Expand Down Expand Up @@ -79,7 +77,6 @@ def test_arrow_array_missing():
assert result.storage.equals(expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_arrow_table_roundtrip():
from pandas.core.arrays.arrow.extension_types import ArrowPeriodType

Expand All @@ -99,7 +96,6 @@ def test_arrow_table_roundtrip():
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_arrow_load_from_zero_chunks():
# GH-41040

Expand Down
35 changes: 21 additions & 14 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_repr(dtype):
assert repr(df) == expected

if dtype.na_value is np.nan:
expected = "0 a\n1 NaN\n2 b\nName: A, dtype: string"
expected = "0 a\n1 NaN\n2 b\nName: A, dtype: str"
else:
expected = "0 a\n1 <NA>\n2 b\nName: A, dtype: string"
assert repr(df.A) == expected
Expand All @@ -75,10 +75,10 @@ def test_repr(dtype):
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
elif dtype.storage == "pyarrow" and dtype.na_value is np.nan:
arr_name = "ArrowStringArrayNumpySemantics"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
elif dtype.storage == "python" and dtype.na_value is np.nan:
arr_name = "StringArrayNumpySemantics"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: string"
expected = f"<{arr_name}>\n['a', nan, 'b']\nLength: 3, dtype: str"
else:
arr_name = "StringArray"
expected = f"<{arr_name}>\n['a', <NA>, 'b']\nLength: 3, dtype: string"
Expand Down Expand Up @@ -502,7 +502,7 @@ def test_fillna_args(dtype):
tm.assert_extension_array_equal(res, expected)

if dtype.storage == "pyarrow":
msg = "Invalid value '1' for dtype string"
msg = "Invalid value '1' for dtype str"
else:
msg = "Cannot set non-string value '1' into a StringArray."
with pytest.raises(TypeError, match=msg):
Expand All @@ -524,7 +524,7 @@ def test_arrow_array(dtype):
assert arr.equals(expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
def test_arrow_roundtrip(dtype, string_storage, using_infer_string):
# roundtrip possible from arrow 1.0.0
Expand All @@ -539,14 +539,17 @@ def test_arrow_roundtrip(dtype, string_storage, using_infer_string):
assert table.field("a").type == "large_string"
with pd.option_context("string_storage", string_storage):
result = table.to_pandas()
assert isinstance(result["a"].dtype, pd.StringDtype)
expected = df.astype(f"string[{string_storage}]")
tm.assert_frame_equal(result, expected)
# ensure the missing value is represented by NA and not np.nan or None
assert result.loc[2, "a"] is result["a"].dtype.na_value
if dtype.na_value is np.nan and not using_string_dtype():
assert result["a"].dtype == "object"
else:
assert isinstance(result["a"].dtype, pd.StringDtype)
expected = df.astype(f"string[{string_storage}]")
tm.assert_frame_equal(result, expected)
# ensure the missing value is represented by NA and not np.nan or None
assert result.loc[2, "a"] is result["a"].dtype.na_value


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning")
def test_arrow_load_from_zero_chunks(dtype, string_storage, using_infer_string):
# GH-41040
Expand All @@ -563,9 +566,13 @@ def test_arrow_load_from_zero_chunks(dtype, string_storage, using_infer_string):
table = pa.table([pa.chunked_array([], type=pa.string())], schema=table.schema)
with pd.option_context("string_storage", string_storage):
result = table.to_pandas()
assert isinstance(result["a"].dtype, pd.StringDtype)
expected = df.astype(f"string[{string_storage}]")
tm.assert_frame_equal(result, expected)

if dtype.na_value is np.nan and not using_string_dtype():
assert result["a"].dtype == "object"
else:
assert isinstance(result["a"].dtype, pd.StringDtype)
expected = df.astype(f"string[{string_storage}]")
tm.assert_frame_equal(result, expected)


def test_value_counts_na(dtype):
Expand Down
6 changes: 4 additions & 2 deletions pandas/tests/arrays/string_/test_string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

from pandas.compat import HAS_PYARROW
import pandas.util._test_decorators as td

import pandas as pd
Expand All @@ -27,8 +28,9 @@ def test_eq_all_na():


def test_config(string_storage, request, using_infer_string):
if using_infer_string and string_storage == "python":
# python string storage with na_value=NaN is not yet implemented
if using_infer_string and string_storage == "python" and HAS_PYARROW:
# string storage with na_value=NaN always uses pyarrow if available
# -> does not yet honor the option
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))

with pd.option_context("string_storage", string_storage):
Expand Down
7 changes: 5 additions & 2 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ def test_searchsorted(self):
assert result == 10

@pytest.mark.parametrize("box", [None, "index", "series"])
def test_searchsorted_castable_strings(self, arr1d, box, string_storage):
def test_searchsorted_castable_strings(
self, arr1d, box, string_storage, using_infer_string
):
arr = arr1d
if box is None:
pass
Expand Down Expand Up @@ -331,7 +333,8 @@ def test_searchsorted_castable_strings(self, arr1d, box, string_storage):
TypeError,
match=re.escape(
f"value should be a '{arr1d._scalar_type.__name__}', 'NaT', "
"or array of those. Got string array instead."
"or array of those. Got "
f"{'str' if using_infer_string else 'string'} array instead."
),
):
arr.searchsorted([str(arr[1]), "baz"])
Expand Down
Loading

0 comments on commit 8e6b5c2

Please sign in to comment.