Skip to content

Commit

Permalink
feat(python): Warn on inefficient use of map_elements for additiona…
Browse files Browse the repository at this point in the history
…l string functions (#14565)
  • Loading branch information
alexander-beedie authored Feb 18, 2024
1 parent 7adab46 commit 7698c31
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 25 deletions.
73 changes: 51 additions & 22 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
Union,
)

from polars.utils.various import re_escape

if TYPE_CHECKING:
from dis import Instruction

Expand Down Expand Up @@ -147,12 +149,17 @@ class OpNames:
_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"}
_PYTHON_METHODS_MAP = {
# string
"endswith": "str.ends_with",
"lower": "str.to_lowercase",
"lstrip": "str.strip_chars_start",
"rstrip": "str.strip_chars_end",
"startswith": "str.starts_with",
"strip": "str.strip_chars",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
# temporal
"isoweekday": "dt.weekday",
"date": "dt.date",
"isoweekday": "dt.weekday",
"time": "dt.time",
}

Expand Down Expand Up @@ -576,7 +583,7 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str
# But, if e1 << e2 was valid, then e2 must have been positive.
# Hence, the output of 2**e2 can be safely cast to Int64, which
# may be necessary if chaining operations which assume Int64 output.
return f"({e1}*2**{e2}).cast(pl.Int64)"
return f"({e1} * 2**{e2}).cast(pl.Int64)"
elif op == ">>":
# Motivation for the cast is the same as in the '<<' case above.
return f"({e1} / 2**{e2}).cast(pl.Int64)"
Expand Down Expand Up @@ -685,7 +692,7 @@ def _matches(
argvals
Associated argvals that must also match (in same position as opnames).
is_attr
Indicate if the match is expected to represent attribute access.
Indicate if the match represents pure attribute access (cannot be called).
"""
n_required_ops, argvals = len(opnames), argvals or []
idx_offset = idx + n_required_ops
Expand Down Expand Up @@ -744,10 +751,10 @@ def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> i
):
inst = matching_instructions[1]
expr_name = _PYTHON_ATTRS_MAP[inst.argval]
synthetic_call = inst._replace(
px = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.extend([matching_instructions[0], synthetic_call])
updated_instructions.extend([matching_instructions[0], px])

return len(matching_instructions)

Expand All @@ -765,15 +772,15 @@ def _rewrite_builtins(
dtype = _PYTHON_CASTS_MAP[argval]
argval = f"cast(pl.{dtype})"

synthetic_call = inst1._replace(
px = inst1._replace(
opname="POLARS_EXPRESSION",
argval=argval,
argrepr=argval,
offset=inst2.offset,
)
# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst2._replace(offset=inst1.offset)
updated_instructions.extend((operand, synthetic_call))
updated_instructions.extend((operand, px))

return len(matching_instructions)

Expand Down Expand Up @@ -818,22 +825,24 @@ def _rewrite_functions(
return 0
else:
expr_name = inst2.argval
synthetic_call = inst1._replace(

px = inst1._replace(
opname="POLARS_EXPRESSION",
argval=expr_name,
argrepr=expr_name,
offset=inst3.offset,
)

# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst3._replace(offset=inst1.offset)
updated_instructions.extend(
(
operand,
matching_instructions[3 + attribute_count],
synthetic_call,
px,
)
if function_kind["argument_1_unary_opname"]
else (operand, synthetic_call)
else (operand, px)
)
return len(matching_instructions)

Expand All @@ -843,20 +852,40 @@ def _rewrite_methods(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace python method calls with synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[
OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"},
OpNames.CALL,
],
argvals=[_PYTHON_METHODS_MAP],
LOAD_METHOD = OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"}
if matching_instructions := (
# method call with one basic arg, eg: "s.endswith('!')"
self._matches(
idx,
opnames=[LOAD_METHOD, {"LOAD_CONST"}, OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
)
or
# method call with no arg, eg: "s.lower()"
self._matches(
idx,
opnames=[LOAD_METHOD, OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
)
):
inst = matching_instructions[0]
expr_name = _PYTHON_METHODS_MAP[inst.argval]
synthetic_call = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
updated_instructions.append(synthetic_call)
expr = _PYTHON_METHODS_MAP[inst.argval]

if matching_instructions[1].opname == "LOAD_CONST":
param_value = matching_instructions[1].argval
if isinstance(param_value, tuple) and expr in (
"str.starts_with",
"str.ends_with",
):
starts, ends = ("^", "") if "starts" in expr else ("", "$")
rx = "|".join(re_escape(v) for v in param_value)
q = '"' if "'" in param_value else "'"
expr = f"str.contains(r{q}{starts}({rx}){ends}{q})"
else:
expr += f"({param_value!r})"

px = inst._replace(opname="POLARS_EXPRESSION", argval=expr, argrepr=expr)
updated_instructions.append(px)

return len(matching_instructions)

Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,11 @@ def parse_percentiles(
at_or_above_50_percentiles = [0.5, *at_or_above_50_percentiles]

return [*sub_50_percentiles, *at_or_above_50_percentiles]


def re_escape(s: str) -> str:
"""Escape a string for use in a Polars (Rust) regex."""
# note: almost the same as the standard python 're.escape' function, but
# escapes _only_ those metachars with meaning to the rust regex crate
re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,29 @@
'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',
),
# ---------------------------------------------
# string expr: case/cast ops
# string exprs
# ---------------------------------------------
("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.String).str.to_titlecase()'),
(
"b",
'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',
'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',
),
(
"b",
"lambda x: x.strip().startswith('#')",
"""pl.col("b").str.strip_chars().str.starts_with('#')""",
),
(
"b",
"""lambda x: x.rstrip().endswith(('!','#','?','"'))""",
"""pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""",
),
(
"b",
"""lambda x: x.lstrip().startswith(('!','#','?',"'"))""",
"""pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""",
),
# ---------------------------------------------
# json expr: load/extract
# ---------------------------------------------
Expand Down Expand Up @@ -186,12 +201,12 @@
(
"a",
"lambda x: (3 << (32-x)) & 3",
'(3*2**(32 - pl.col("a"))).cast(pl.Int64) & 3',
'(3 * 2**(32 - pl.col("a"))).cast(pl.Int64) & 3',
),
(
"a",
"lambda x: (x << 32) & 3",
'(pl.col("a")*2**32).cast(pl.Int64) & 3',
'(pl.col("a") * 2**32).cast(pl.Int64) & 3',
),
(
"a",
Expand Down

0 comments on commit 7698c31

Please sign in to comment.