Skip to content

Commit

Permalink
REF (string): de-duplicate str_endswith, startswith (#59568)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jorisvandenbossche committed Oct 10, 2024
1 parent d64b8d8 commit 807d8d5
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 72 deletions.
48 changes: 46 additions & 2 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from __future__ import annotations

from typing import Literal
from typing import (
TYPE_CHECKING,
Literal,
)

import numpy as np

from pandas.compat import pa_version_under10p1

from pandas.core.dtypes.missing import isna

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc

if TYPE_CHECKING:
from collections.abc import Sized

from pandas._typing import Scalar


class ArrowStringArrayMixin:
_pa_array = None
_pa_array: Sized

def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -90,3 +100,37 @@ def _str_removesuffix(self, suffix: str):
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# For empty tuple we return null for missing values and False
# for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.starts_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# For empty tuple we return null for missing values and False
# for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.ends_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
33 changes: 1 addition & 32 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,38 +2311,7 @@ def _str_contains(
result = result.fill_null(na)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na=None):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# For empty tuple, pd.StringDtype() returns null for missing values
# and false for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.starts_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_endswith(self, pat: str | tuple[str, ...], na=None):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# For empty tuple, pd.StringDtype() returns null for missing values
# and false for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.ends_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
def _result_converter(self, result):
return type(self)(result)

def _str_replace(
Expand Down
40 changes: 2 additions & 38 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def _data(self):
# String methods interface

_str_map = BaseStringArray._str_map
_str_startswith = ArrowStringArrayMixin._str_startswith
_str_endswith = ArrowStringArrayMixin._str_endswith

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
Expand All @@ -301,44 +303,6 @@ def _str_contains(
result[isna(result)] = bool(na)
return result

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# mimic existing behaviour of string extension array
# and python string method
result = pa.array(
np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
)
else:
result = pc.starts_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# mimic existing behaviour of string extension array
# and python string method
result = pa.array(
np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
)
else:
result = pc.ends_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_replace(
self,
pat: str | re.Pattern,
Expand Down

0 comments on commit 807d8d5

Please sign in to comment.