From 07d6ba9c08d8de4bf8586e1f9dc546c08ffa4e8a Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 14 Jul 2023 10:27:38 +0200 Subject: [PATCH] feat(rust,python,cli): add `LENGTH` and `OCTET_LENGTH` string functions for SQL (#9860) --- .../polars-plan/src/dsl/function_expr/mod.rs | 8 +- .../src/dsl/function_expr/strings.rs | 149 ++++++++++-------- .../polars-lazy/polars-plan/src/dsl/string.rs | 14 +- polars/polars-sql/src/functions.rs | 15 ++ py-polars/src/expr/string.rs | 20 +-- py-polars/tests/unit/test_sql.py | 21 +++ 6 files changed, 138 insertions(+), 89 deletions(-) diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs index 52b1922bffb4..3ffc1abe0380 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs @@ -618,6 +618,9 @@ impl From for SpecialEq> { match func { #[cfg(feature = "regex")] Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict), + CountMatch(pat) => { + map!(strings::count_match, &pat) + } EndsWith { .. } => map_as_slice!(strings::ends_with), StartsWith { .. } => map_as_slice!(strings::starts_with), Extract { pat, group_index } => { @@ -626,9 +629,8 @@ impl From for SpecialEq> { ExtractAll => { map_as_slice!(strings::extract_all) } - CountMatch(pat) => { - map!(strings::count_match, &pat) - } + NChars => map!(strings::n_chars), + Length => map!(strings::lengths), #[cfg(feature = "string_justify")] Zfill(alignment) => { map!(strings::zfill, alignment) diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs index da1d7350a76a..60486b1cc073 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs @@ -17,37 +17,39 @@ use super::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, PartialEq, Debug, Eq, Hash)] pub enum StringFunction { + #[cfg(feature = "concat_str")] + ConcatHorizontal(String), + #[cfg(feature = "concat_str")] + ConcatVertical(String), #[cfg(feature = "regex")] Contains { literal: bool, strict: bool, }, - StartsWith, + CountMatch(String), EndsWith, + Explode, Extract { pat: String, group_index: usize, }, - #[cfg(feature = "string_justify")] - Zfill(usize), + ExtractAll, + #[cfg(feature = "string_from_radix")] + FromRadix(u32, bool), + NChars, + Length, #[cfg(feature = "string_justify")] LJust { width: usize, fillchar: char, }, - #[cfg(feature = "string_justify")] - RJust { - width: usize, - fillchar: char, + Lowercase, + LStrip(Option), + #[cfg(feature = "extract_jsonpath")] + JsonExtract { + dtype: Option, + infer_schema_len: Option, }, - ExtractAll, - CountMatch(String), - #[cfg(feature = "temporal")] - Strptime(DataType, StrptimeOptions), - #[cfg(feature = "concat_str")] - ConcatVertical(String), - #[cfg(feature = "concat_str")] - ConcatHorizontal(String), #[cfg(feature = "regex")] Replace { // negative is replace all @@ -55,56 +57,58 @@ pub enum StringFunction { n: i64, literal: bool, }, - Uppercase, - Lowercase, - #[cfg(feature = "nightly")] - Titlecase, - Strip(Option), + #[cfg(feature = "string_justify")] + RJust { + width: usize, + fillchar: char, + }, RStrip(Option), - LStrip(Option), - #[cfg(feature = "string_from_radix")] - FromRadix(u32, bool), Slice(i64, Option), - Explode, + StartsWith, + Strip(Option), + #[cfg(feature = "temporal")] + Strptime(DataType, StrptimeOptions), #[cfg(feature = "dtype-decimal")] ToDecimal(usize), - #[cfg(feature = "extract_jsonpath")] - JsonExtract { - dtype: Option, - infer_schema_len: Option, - }, + #[cfg(feature = "nightly")] + Titlecase, + Uppercase, + #[cfg(feature = "string_justify")] + Zfill(usize), } impl StringFunction { pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult { use StringFunction::*; match self { + #[cfg(feature = "concat_str")] + ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(), #[cfg(feature = "regex")] Contains { .. } => mapper.with_dtype(DataType::Boolean), + CountMatch(_) => mapper.with_dtype(DataType::UInt32), EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean), + Explode => mapper.with_same_dtype(), Extract { .. } => mapper.with_same_dtype(), ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))), - CountMatch(_) => mapper.with_dtype(DataType::UInt32), - #[cfg(feature = "string_justify")] - Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(), - #[cfg(feature = "temporal")] - Strptime(dtype, _) => mapper.with_dtype(dtype.clone()), - #[cfg(feature = "concat_str")] - ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(), + #[cfg(feature = "string_from_radix")] + FromRadix { .. } => mapper.with_dtype(DataType::Int32), + #[cfg(feature = "extract_jsonpath")] + JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()), + Length => mapper.with_dtype(DataType::UInt32), + NChars => mapper.with_dtype(DataType::UInt32), #[cfg(feature = "regex")] Replace { .. } => mapper.with_same_dtype(), - Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => { - mapper.with_same_dtype() - } + #[cfg(feature = "temporal")] + Strptime(dtype, _) => mapper.with_dtype(dtype.clone()), #[cfg(feature = "nightly")] Titlecase => mapper.with_same_dtype(), - #[cfg(feature = "string_from_radix")] - FromRadix { .. } => mapper.with_dtype(DataType::Int32), - Explode => mapper.with_same_dtype(), #[cfg(feature = "dtype-decimal")] ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)), - #[cfg(feature = "extract_jsonpath")] - JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()), + Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => { + mapper.with_same_dtype() + } + #[cfg(feature = "string_justify")] + Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(), } } } @@ -114,42 +118,43 @@ impl Display for StringFunction { let s = match self { #[cfg(feature = "regex")] StringFunction::Contains { .. } => "contains", - StringFunction::StartsWith { .. } => "starts_with", + StringFunction::CountMatch(_) => "count_match", StringFunction::EndsWith { .. } => "ends_with", StringFunction::Extract { .. } => "extract", - #[cfg(feature = "string_justify")] - StringFunction::Zfill(_) => "zfill", + #[cfg(feature = "concat_str")] + StringFunction::ConcatHorizontal(_) => "concat_horizontal", + #[cfg(feature = "concat_str")] + StringFunction::ConcatVertical(_) => "concat_vertical", + StringFunction::Explode => "explode", + StringFunction::ExtractAll => "extract_all", + #[cfg(feature = "string_from_radix")] + StringFunction::FromRadix { .. } => "from_radix", + #[cfg(feature = "extract_jsonpath")] + StringFunction::JsonExtract { .. } => "json_extract", #[cfg(feature = "string_justify")] StringFunction::LJust { .. } => "str.ljust", + StringFunction::LStrip(_) => "lstrip", + StringFunction::Length => "str_lengths", + StringFunction::Lowercase => "lowercase", + StringFunction::NChars => "n_chars", #[cfg(feature = "string_justify")] StringFunction::RJust { .. } => "rjust", - StringFunction::ExtractAll => "extract_all", - StringFunction::CountMatch(_) => "count_match", - #[cfg(feature = "temporal")] - StringFunction::Strptime(_, _) => "strptime", - #[cfg(feature = "concat_str")] - StringFunction::ConcatVertical(_) => "concat_vertical", - #[cfg(feature = "concat_str")] - StringFunction::ConcatHorizontal(_) => "concat_horizontal", + StringFunction::RStrip(_) => "rstrip", #[cfg(feature = "regex")] StringFunction::Replace { .. } => "replace", - StringFunction::Uppercase => "uppercase", - StringFunction::Lowercase => "lowercase", + StringFunction::Slice(_, _) => "str_slice", + StringFunction::StartsWith { .. } => "starts_with", + StringFunction::Strip(_) => "strip", + #[cfg(feature = "temporal")] + StringFunction::Strptime(_, _) => "strptime", #[cfg(feature = "nightly")] StringFunction::Titlecase => "titlecase", - StringFunction::Strip(_) => "strip", - StringFunction::LStrip(_) => "lstrip", - StringFunction::RStrip(_) => "rstrip", - #[cfg(feature = "string_from_radix")] - StringFunction::FromRadix { .. } => "from_radix", - StringFunction::Slice(_, _) => "str_slice", - StringFunction::Explode => "explode", #[cfg(feature = "dtype-decimal")] StringFunction::ToDecimal(_) => "to_decimal", - #[cfg(feature = "extract_jsonpath")] - StringFunction::JsonExtract { .. } => "json_extract", + StringFunction::Uppercase => "uppercase", + #[cfg(feature = "string_justify")] + StringFunction::Zfill(_) => "zfill", }; - write!(f, "str.{s}") } } @@ -170,6 +175,16 @@ pub(super) fn titlecase(s: &Series) -> PolarsResult { Ok(ca.to_titlecase().into_series()) } +pub(super) fn n_chars(s: &Series) -> PolarsResult { + let ca = s.utf8()?; + Ok(ca.str_n_chars().into_series()) +} + +pub(super) fn lengths(s: &Series) -> PolarsResult { + let ca = s.utf8()?; + Ok(ca.str_lengths().into_series()) +} + #[cfg(feature = "regex")] pub(super) fn contains(s: &[Series], literal: bool, strict: bool) -> PolarsResult { let ca = &s[0].utf8()?; diff --git a/polars/polars-lazy/polars-plan/src/dsl/string.rs b/polars/polars-lazy/polars-plan/src/dsl/string.rs index 8d878d4be5a9..fe5409a3ef86 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/string.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/string.rs @@ -455,7 +455,7 @@ impl StringNameSpace { } #[cfg(feature = "string_from_radix")] - /// Parse string in base radix into decimal + /// Parse string in base radix into decimal. pub fn from_radix(self, radix: u32, strict: bool) -> Expr { self.0 .map_private(FunctionExpr::StringExpr(StringFunction::FromRadix( @@ -463,6 +463,18 @@ impl StringNameSpace { ))) } + /// Return the number of characters in the string (not bytes). + pub fn n_chars(self) -> Expr { + self.0 + .map_private(FunctionExpr::StringExpr(StringFunction::NChars)) + } + + /// Return the number of bytes in the string (not characters). + pub fn lengths(self) -> Expr { + self.0 + .map_private(FunctionExpr::StringExpr(StringFunction::Length)) + } + /// Slice the string values. pub fn str_slice(self, start: i64, length: Option) -> Expr { self.0 diff --git a/polars/polars-sql/src/functions.rs b/polars/polars-sql/src/functions.rs index 34a73e9faaa1..3507becbac7b 100644 --- a/polars/polars-sql/src/functions.rs +++ b/polars/polars-sql/src/functions.rs @@ -161,6 +161,11 @@ pub(crate) enum PolarsSqlFunctions { /// SELECT LEFT(column_1, 3) from df; /// ``` Left, + /// SQL 'length' function (characters) + /// ```sql + /// SELECT LENGTH(column_1) from df; + /// ``` + Length, /// SQL 'lower' function /// ```sql /// SELECT LOWER(column_1) from df; @@ -171,6 +176,11 @@ pub(crate) enum PolarsSqlFunctions { /// SELECT LTRIM(column_1) from df; /// ``` LTrim, + /// SQL 'octet_length' function (bytes) + /// ```sql + /// SELECT OCTET_LENGTH(column_1) from df; + /// ``` + OctetLength, /// SQL 'regexp_like' function /// ```sql /// SELECT REGEXP_LIKE(column_1,'xyz', 'i') from df; @@ -368,6 +378,7 @@ impl PolarsSqlFunctions { "ltrim", "max", "min", + "octet_length", "pow", "radians", "round", @@ -428,9 +439,11 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions { // String functions // ---- "ends_with" => Self::EndsWith, + "length" => Self::Length, "left" => Self::Left, "lower" => Self::Lower, "ltrim" => Self::LTrim, + "octet_length" => Self::OctetLength, "regexp_like" => Self::RegexpLike, "rtrim" => Self::RTrim, "starts_with" => Self::StartsWith, @@ -532,6 +545,7 @@ impl SqlFunctionVisitor<'_> { } })) }), + Length => self.visit_unary(|e| e.str().n_chars()), Lower => self.visit_unary(|e| e.str().to_lowercase()), LTrim => match function.args.len() { 1 => self.visit_unary(|e| e.str().lstrip(None)), @@ -541,6 +555,7 @@ impl SqlFunctionVisitor<'_> { function.args.len() ), }, + OctetLength => self.visit_unary(|e| e.str().lengths()), RegexpLike => match function.args.len() { 2 => self.visit_binary(|e, s| e.str().contains(s, true)), 3 => self.try_visit_ternary(|e, pat, flags| { diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index 033d686f6b1e..3dc76dc2c1a9 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -89,27 +89,11 @@ impl PyExpr { } fn str_lengths(&self) -> Self { - let function = |s: Series| { - let ca = s.utf8()?; - Ok(Some(ca.str_lengths().into_series())) - }; - self.clone() - .inner - .map(function, GetOutput::from_type(DataType::UInt32)) - .with_fmt("str.lengths") - .into() + self.inner.clone().str().lengths().into() } fn str_n_chars(&self) -> Self { - let function = |s: Series| { - let ca = s.utf8()?; - Ok(Some(ca.str_n_chars().into_series())) - }; - self.clone() - .inner - .map(function, GetOutput::from_type(DataType::UInt32)) - .with_fmt("str.n_chars") - .into() + self.inner.clone().str().n_chars().into() } #[cfg(feature = "lazy_regex")] diff --git a/py-polars/tests/unit/test_sql.py b/py-polars/tests/unit/test_sql.py index 3e2fefe9a1d0..27347cf8edd2 100644 --- a/py-polars/tests/unit/test_sql.py +++ b/py-polars/tests/unit/test_sql.py @@ -682,6 +682,27 @@ def test_sql_round_ndigits_errors() -> None: ctx.execute("SELECT ROUND(n,-1) AS n FROM df") +def test_sql_string_lengths() -> None: + df = pl.DataFrame({"words": ["Café", None, "東京"]}) + + with pl.SQLContext(frame=df) as ctx: + res = ctx.execute( + """ + SELECT + words, + LENGTH(words) AS n_chars, + OCTET_LENGTH(words) AS n_bytes + FROM frame + """ + ).collect() + + assert res.to_dict(False) == { + "words": ["Café", None, "東京"], + "n_chars": [4, None, 2], + "n_bytes": [5, None, 6], + } + + def test_sql_substr() -> None: df = pl.DataFrame( {