Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust,python,cli): add LENGTH and OCTET_LENGTH string functions for SQL #9860

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
match func {
#[cfg(feature = "regex")]
Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict),
CountMatch(pat) => {
map!(strings::count_match, &pat)
}
EndsWith { .. } => map_as_slice!(strings::ends_with),
StartsWith { .. } => map_as_slice!(strings::starts_with),
Extract { pat, group_index } => {
Expand All @@ -626,9 +629,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
ExtractAll => {
map_as_slice!(strings::extract_all)
}
CountMatch(pat) => {
map!(strings::count_match, &pat)
}
NChars => map!(strings::n_chars),
Length => map!(strings::lengths),
#[cfg(feature = "string_justify")]
Zfill(alignment) => {
map!(strings::zfill, alignment)
Expand Down
149 changes: 82 additions & 67 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,94 +17,98 @@ use super::*;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum StringFunction {
#[cfg(feature = "concat_str")]
ConcatHorizontal(String),
#[cfg(feature = "concat_str")]
ConcatVertical(String),
#[cfg(feature = "regex")]
Contains {
literal: bool,
strict: bool,
},
StartsWith,
CountMatch(String),
EndsWith,
Explode,
Extract {
pat: String,
group_index: usize,
},
#[cfg(feature = "string_justify")]
Zfill(usize),
ExtractAll,
#[cfg(feature = "string_from_radix")]
FromRadix(u32, bool),
NChars,
Length,
#[cfg(feature = "string_justify")]
LJust {
width: usize,
fillchar: char,
},
#[cfg(feature = "string_justify")]
RJust {
width: usize,
fillchar: char,
Lowercase,
LStrip(Option<String>),
#[cfg(feature = "extract_jsonpath")]
JsonExtract {
dtype: Option<DataType>,
infer_schema_len: Option<usize>,
},
ExtractAll,
CountMatch(String),
#[cfg(feature = "temporal")]
Strptime(DataType, StrptimeOptions),
#[cfg(feature = "concat_str")]
ConcatVertical(String),
#[cfg(feature = "concat_str")]
ConcatHorizontal(String),
#[cfg(feature = "regex")]
Replace {
// negative is replace all
// how many matches to replace
n: i64,
literal: bool,
},
Uppercase,
Lowercase,
#[cfg(feature = "nightly")]
Titlecase,
Strip(Option<String>),
#[cfg(feature = "string_justify")]
RJust {
width: usize,
fillchar: char,
},
RStrip(Option<String>),
LStrip(Option<String>),
#[cfg(feature = "string_from_radix")]
FromRadix(u32, bool),
Slice(i64, Option<u64>),
Explode,
StartsWith,
Strip(Option<String>),
#[cfg(feature = "temporal")]
Strptime(DataType, StrptimeOptions),
#[cfg(feature = "dtype-decimal")]
ToDecimal(usize),
#[cfg(feature = "extract_jsonpath")]
JsonExtract {
dtype: Option<DataType>,
infer_schema_len: Option<usize>,
},
#[cfg(feature = "nightly")]
Titlecase,
Uppercase,
#[cfg(feature = "string_justify")]
Zfill(usize),
}

impl StringFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use StringFunction::*;
match self {
#[cfg(feature = "concat_str")]
ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(),
#[cfg(feature = "regex")]
Contains { .. } => mapper.with_dtype(DataType::Boolean),
CountMatch(_) => mapper.with_dtype(DataType::UInt32),
EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
Explode => mapper.with_same_dtype(),
Extract { .. } => mapper.with_same_dtype(),
ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))),
CountMatch(_) => mapper.with_dtype(DataType::UInt32),
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(),
#[cfg(feature = "temporal")]
Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
#[cfg(feature = "concat_str")]
ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(),
#[cfg(feature = "string_from_radix")]
FromRadix { .. } => mapper.with_dtype(DataType::Int32),
#[cfg(feature = "extract_jsonpath")]
JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
Length => mapper.with_dtype(DataType::UInt32),
NChars => mapper.with_dtype(DataType::UInt32),
#[cfg(feature = "regex")]
Replace { .. } => mapper.with_same_dtype(),
Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => {
mapper.with_same_dtype()
}
#[cfg(feature = "temporal")]
Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
#[cfg(feature = "nightly")]
Titlecase => mapper.with_same_dtype(),
#[cfg(feature = "string_from_radix")]
FromRadix { .. } => mapper.with_dtype(DataType::Int32),
Explode => mapper.with_same_dtype(),
#[cfg(feature = "dtype-decimal")]
ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
#[cfg(feature = "extract_jsonpath")]
JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => {
mapper.with_same_dtype()
}
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(),
}
}
}
Expand All @@ -114,42 +118,43 @@ impl Display for StringFunction {
let s = match self {
#[cfg(feature = "regex")]
StringFunction::Contains { .. } => "contains",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::CountMatch(_) => "count_match",
StringFunction::EndsWith { .. } => "ends_with",
StringFunction::Extract { .. } => "extract",
#[cfg(feature = "string_justify")]
StringFunction::Zfill(_) => "zfill",
#[cfg(feature = "concat_str")]
StringFunction::ConcatHorizontal(_) => "concat_horizontal",
#[cfg(feature = "concat_str")]
StringFunction::ConcatVertical(_) => "concat_vertical",
StringFunction::Explode => "explode",
StringFunction::ExtractAll => "extract_all",
#[cfg(feature = "string_from_radix")]
StringFunction::FromRadix { .. } => "from_radix",
#[cfg(feature = "extract_jsonpath")]
StringFunction::JsonExtract { .. } => "json_extract",
#[cfg(feature = "string_justify")]
StringFunction::LJust { .. } => "str.ljust",
StringFunction::LStrip(_) => "lstrip",
StringFunction::Length => "str_lengths",
StringFunction::Lowercase => "lowercase",
StringFunction::NChars => "n_chars",
#[cfg(feature = "string_justify")]
StringFunction::RJust { .. } => "rjust",
StringFunction::ExtractAll => "extract_all",
StringFunction::CountMatch(_) => "count_match",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_, _) => "strptime",
#[cfg(feature = "concat_str")]
StringFunction::ConcatVertical(_) => "concat_vertical",
#[cfg(feature = "concat_str")]
StringFunction::ConcatHorizontal(_) => "concat_horizontal",
StringFunction::RStrip(_) => "rstrip",
#[cfg(feature = "regex")]
StringFunction::Replace { .. } => "replace",
StringFunction::Uppercase => "uppercase",
StringFunction::Lowercase => "lowercase",
StringFunction::Slice(_, _) => "str_slice",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::Strip(_) => "strip",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_, _) => "strptime",
#[cfg(feature = "nightly")]
StringFunction::Titlecase => "titlecase",
StringFunction::Strip(_) => "strip",
StringFunction::LStrip(_) => "lstrip",
StringFunction::RStrip(_) => "rstrip",
#[cfg(feature = "string_from_radix")]
StringFunction::FromRadix { .. } => "from_radix",
StringFunction::Slice(_, _) => "str_slice",
StringFunction::Explode => "explode",
#[cfg(feature = "dtype-decimal")]
StringFunction::ToDecimal(_) => "to_decimal",
#[cfg(feature = "extract_jsonpath")]
StringFunction::JsonExtract { .. } => "json_extract",
StringFunction::Uppercase => "uppercase",
#[cfg(feature = "string_justify")]
StringFunction::Zfill(_) => "zfill",
};

write!(f, "str.{s}")
}
}
Expand All @@ -170,6 +175,16 @@ pub(super) fn titlecase(s: &Series) -> PolarsResult<Series> {
Ok(ca.to_titlecase().into_series())
}

pub(super) fn n_chars(s: &Series) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca.str_n_chars().into_series())
}

pub(super) fn lengths(s: &Series) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca.str_lengths().into_series())
}

#[cfg(feature = "regex")]
pub(super) fn contains(s: &[Series], literal: bool, strict: bool) -> PolarsResult<Series> {
let ca = &s[0].utf8()?;
Expand Down
14 changes: 13 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,26 @@ impl StringNameSpace {
}

#[cfg(feature = "string_from_radix")]
/// Parse string in base radix into decimal
/// Parse string in base radix into decimal.
pub fn from_radix(self, radix: u32, strict: bool) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::FromRadix(
radix, strict,
)))
}

/// Return the number of characters in the string (not bytes).
pub fn n_chars(self) -> Expr {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, can we now properly dispatch the python side to those?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! Missed the final link in the chain 😅 Should be good now - jumping on a plane in about an hour (at the airport now), so if there's anything else then I'll pick it up when I land and get home this evening. (In the meantime, let's see if I can get anything done from the emergency exit row at 10km up :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good! You will land in a world with this PR merged. ;)

self.0
.map_private(FunctionExpr::StringExpr(StringFunction::NChars))
}

/// Return the number of bytes in the string (not characters).
pub fn lengths(self) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Length))
}

/// Slice the string values.
pub fn str_slice(self, start: i64, length: Option<u64>) -> Expr {
self.0
Expand Down
15 changes: 15 additions & 0 deletions polars/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ pub(crate) enum PolarsSqlFunctions {
/// SELECT LEFT(column_1, 3) from df;
/// ```
Left,
/// SQL 'length' function (characters)
/// ```sql
/// SELECT LENGTH(column_1) from df;
/// ```
Length,
/// SQL 'lower' function
/// ```sql
/// SELECT LOWER(column_1) from df;
Expand All @@ -171,6 +176,11 @@ pub(crate) enum PolarsSqlFunctions {
/// SELECT LTRIM(column_1) from df;
/// ```
LTrim,
/// SQL 'octet_length' function (bytes)
/// ```sql
/// SELECT OCTET_LENGTH(column_1) from df;
/// ```
OctetLength,
/// SQL 'regexp_like' function
/// ```sql
/// SELECT REGEXP_LIKE(column_1,'xyz', 'i') from df;
Expand Down Expand Up @@ -368,6 +378,7 @@ impl PolarsSqlFunctions {
"ltrim",
"max",
"min",
"octet_length",
"pow",
"radians",
"round",
Expand Down Expand Up @@ -428,9 +439,11 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions {
// String functions
// ----
"ends_with" => Self::EndsWith,
"length" => Self::Length,
"left" => Self::Left,
"lower" => Self::Lower,
"ltrim" => Self::LTrim,
"octet_length" => Self::OctetLength,
"regexp_like" => Self::RegexpLike,
"rtrim" => Self::RTrim,
"starts_with" => Self::StartsWith,
Expand Down Expand Up @@ -532,6 +545,7 @@ impl SqlFunctionVisitor<'_> {
}
}))
}),
Length => self.visit_unary(|e| e.str().n_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
LTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().lstrip(None)),
Expand All @@ -541,6 +555,7 @@ impl SqlFunctionVisitor<'_> {
function.args.len()
),
},
OctetLength => self.visit_unary(|e| e.str().lengths()),
RegexpLike => match function.args.len() {
2 => self.visit_binary(|e, s| e.str().contains(s, true)),
3 => self.try_visit_ternary(|e, pat, flags| {
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,27 @@ def test_sql_round_ndigits_errors() -> None:
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_sql_string_lengths() -> None:
df = pl.DataFrame({"words": ["Café", None, "東京"]})

with pl.SQLContext(frame=df) as ctx:
res = ctx.execute(
"""
SELECT
words,
LENGTH(words) AS n_chars,
OCTET_LENGTH(words) AS n_bytes
FROM frame
"""
).collect()

assert res.to_dict(False) == {
"words": ["Café", None, "東京"],
"n_chars": [4, None, 2],
"n_bytes": [5, None, 6],
}


def test_sql_substr() -> None:
df = pl.DataFrame(
{
Expand Down