Skip to content

Commit

Permalink
feat(rust,python,cli): add LENGTH and OCTET_LENGTH string functio…
Browse files Browse the repository at this point in the history
…ns for SQL (pola-rs#9860)
  • Loading branch information
alexander-beedie authored and c-peters committed Jul 14, 2023
1 parent bdbe8a0 commit 07d6ba9
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 89 deletions.
8 changes: 5 additions & 3 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
match func {
#[cfg(feature = "regex")]
Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict),
CountMatch(pat) => {
map!(strings::count_match, &pat)
}
EndsWith { .. } => map_as_slice!(strings::ends_with),
StartsWith { .. } => map_as_slice!(strings::starts_with),
Extract { pat, group_index } => {
Expand All @@ -626,9 +629,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
ExtractAll => {
map_as_slice!(strings::extract_all)
}
CountMatch(pat) => {
map!(strings::count_match, &pat)
}
NChars => map!(strings::n_chars),
Length => map!(strings::lengths),
#[cfg(feature = "string_justify")]
Zfill(alignment) => {
map!(strings::zfill, alignment)
Expand Down
149 changes: 82 additions & 67 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,94 +17,98 @@ use super::*;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
pub enum StringFunction {
#[cfg(feature = "concat_str")]
ConcatHorizontal(String),
#[cfg(feature = "concat_str")]
ConcatVertical(String),
#[cfg(feature = "regex")]
Contains {
literal: bool,
strict: bool,
},
StartsWith,
CountMatch(String),
EndsWith,
Explode,
Extract {
pat: String,
group_index: usize,
},
#[cfg(feature = "string_justify")]
Zfill(usize),
ExtractAll,
#[cfg(feature = "string_from_radix")]
FromRadix(u32, bool),
NChars,
Length,
#[cfg(feature = "string_justify")]
LJust {
width: usize,
fillchar: char,
},
#[cfg(feature = "string_justify")]
RJust {
width: usize,
fillchar: char,
Lowercase,
LStrip(Option<String>),
#[cfg(feature = "extract_jsonpath")]
JsonExtract {
dtype: Option<DataType>,
infer_schema_len: Option<usize>,
},
ExtractAll,
CountMatch(String),
#[cfg(feature = "temporal")]
Strptime(DataType, StrptimeOptions),
#[cfg(feature = "concat_str")]
ConcatVertical(String),
#[cfg(feature = "concat_str")]
ConcatHorizontal(String),
#[cfg(feature = "regex")]
Replace {
// negative is replace all
// how many matches to replace
n: i64,
literal: bool,
},
Uppercase,
Lowercase,
#[cfg(feature = "nightly")]
Titlecase,
Strip(Option<String>),
#[cfg(feature = "string_justify")]
RJust {
width: usize,
fillchar: char,
},
RStrip(Option<String>),
LStrip(Option<String>),
#[cfg(feature = "string_from_radix")]
FromRadix(u32, bool),
Slice(i64, Option<u64>),
Explode,
StartsWith,
Strip(Option<String>),
#[cfg(feature = "temporal")]
Strptime(DataType, StrptimeOptions),
#[cfg(feature = "dtype-decimal")]
ToDecimal(usize),
#[cfg(feature = "extract_jsonpath")]
JsonExtract {
dtype: Option<DataType>,
infer_schema_len: Option<usize>,
},
#[cfg(feature = "nightly")]
Titlecase,
Uppercase,
#[cfg(feature = "string_justify")]
Zfill(usize),
}

impl StringFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use StringFunction::*;
match self {
#[cfg(feature = "concat_str")]
ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(),
#[cfg(feature = "regex")]
Contains { .. } => mapper.with_dtype(DataType::Boolean),
CountMatch(_) => mapper.with_dtype(DataType::UInt32),
EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
Explode => mapper.with_same_dtype(),
Extract { .. } => mapper.with_same_dtype(),
ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))),
CountMatch(_) => mapper.with_dtype(DataType::UInt32),
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(),
#[cfg(feature = "temporal")]
Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
#[cfg(feature = "concat_str")]
ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(),
#[cfg(feature = "string_from_radix")]
FromRadix { .. } => mapper.with_dtype(DataType::Int32),
#[cfg(feature = "extract_jsonpath")]
JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
Length => mapper.with_dtype(DataType::UInt32),
NChars => mapper.with_dtype(DataType::UInt32),
#[cfg(feature = "regex")]
Replace { .. } => mapper.with_same_dtype(),
Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => {
mapper.with_same_dtype()
}
#[cfg(feature = "temporal")]
Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
#[cfg(feature = "nightly")]
Titlecase => mapper.with_same_dtype(),
#[cfg(feature = "string_from_radix")]
FromRadix { .. } => mapper.with_dtype(DataType::Int32),
Explode => mapper.with_same_dtype(),
#[cfg(feature = "dtype-decimal")]
ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
#[cfg(feature = "extract_jsonpath")]
JsonExtract { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
Uppercase | Lowercase | Strip(_) | LStrip(_) | RStrip(_) | Slice(_, _) => {
mapper.with_same_dtype()
}
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => mapper.with_same_dtype(),
}
}
}
Expand All @@ -114,42 +118,43 @@ impl Display for StringFunction {
let s = match self {
#[cfg(feature = "regex")]
StringFunction::Contains { .. } => "contains",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::CountMatch(_) => "count_match",
StringFunction::EndsWith { .. } => "ends_with",
StringFunction::Extract { .. } => "extract",
#[cfg(feature = "string_justify")]
StringFunction::Zfill(_) => "zfill",
#[cfg(feature = "concat_str")]
StringFunction::ConcatHorizontal(_) => "concat_horizontal",
#[cfg(feature = "concat_str")]
StringFunction::ConcatVertical(_) => "concat_vertical",
StringFunction::Explode => "explode",
StringFunction::ExtractAll => "extract_all",
#[cfg(feature = "string_from_radix")]
StringFunction::FromRadix { .. } => "from_radix",
#[cfg(feature = "extract_jsonpath")]
StringFunction::JsonExtract { .. } => "json_extract",
#[cfg(feature = "string_justify")]
StringFunction::LJust { .. } => "str.ljust",
StringFunction::LStrip(_) => "lstrip",
StringFunction::Length => "str_lengths",
StringFunction::Lowercase => "lowercase",
StringFunction::NChars => "n_chars",
#[cfg(feature = "string_justify")]
StringFunction::RJust { .. } => "rjust",
StringFunction::ExtractAll => "extract_all",
StringFunction::CountMatch(_) => "count_match",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_, _) => "strptime",
#[cfg(feature = "concat_str")]
StringFunction::ConcatVertical(_) => "concat_vertical",
#[cfg(feature = "concat_str")]
StringFunction::ConcatHorizontal(_) => "concat_horizontal",
StringFunction::RStrip(_) => "rstrip",
#[cfg(feature = "regex")]
StringFunction::Replace { .. } => "replace",
StringFunction::Uppercase => "uppercase",
StringFunction::Lowercase => "lowercase",
StringFunction::Slice(_, _) => "str_slice",
StringFunction::StartsWith { .. } => "starts_with",
StringFunction::Strip(_) => "strip",
#[cfg(feature = "temporal")]
StringFunction::Strptime(_, _) => "strptime",
#[cfg(feature = "nightly")]
StringFunction::Titlecase => "titlecase",
StringFunction::Strip(_) => "strip",
StringFunction::LStrip(_) => "lstrip",
StringFunction::RStrip(_) => "rstrip",
#[cfg(feature = "string_from_radix")]
StringFunction::FromRadix { .. } => "from_radix",
StringFunction::Slice(_, _) => "str_slice",
StringFunction::Explode => "explode",
#[cfg(feature = "dtype-decimal")]
StringFunction::ToDecimal(_) => "to_decimal",
#[cfg(feature = "extract_jsonpath")]
StringFunction::JsonExtract { .. } => "json_extract",
StringFunction::Uppercase => "uppercase",
#[cfg(feature = "string_justify")]
StringFunction::Zfill(_) => "zfill",
};

write!(f, "str.{s}")
}
}
Expand All @@ -170,6 +175,16 @@ pub(super) fn titlecase(s: &Series) -> PolarsResult<Series> {
Ok(ca.to_titlecase().into_series())
}

pub(super) fn n_chars(s: &Series) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca.str_n_chars().into_series())
}

pub(super) fn lengths(s: &Series) -> PolarsResult<Series> {
let ca = s.utf8()?;
Ok(ca.str_lengths().into_series())
}

#[cfg(feature = "regex")]
pub(super) fn contains(s: &[Series], literal: bool, strict: bool) -> PolarsResult<Series> {
let ca = &s[0].utf8()?;
Expand Down
14 changes: 13 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,14 +455,26 @@ impl StringNameSpace {
}

#[cfg(feature = "string_from_radix")]
/// Parse string in base radix into decimal
/// Parse string in base radix into decimal.
pub fn from_radix(self, radix: u32, strict: bool) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::FromRadix(
radix, strict,
)))
}

/// Return the number of characters in the string (not bytes).
pub fn n_chars(self) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::NChars))
}

/// Return the number of bytes in the string (not characters).
pub fn lengths(self) -> Expr {
self.0
.map_private(FunctionExpr::StringExpr(StringFunction::Length))
}

/// Slice the string values.
pub fn str_slice(self, start: i64, length: Option<u64>) -> Expr {
self.0
Expand Down
15 changes: 15 additions & 0 deletions polars/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ pub(crate) enum PolarsSqlFunctions {
/// SELECT LEFT(column_1, 3) from df;
/// ```
Left,
/// SQL 'length' function (characters)
/// ```sql
/// SELECT LENGTH(column_1) from df;
/// ```
Length,
/// SQL 'lower' function
/// ```sql
/// SELECT LOWER(column_1) from df;
Expand All @@ -171,6 +176,11 @@ pub(crate) enum PolarsSqlFunctions {
/// SELECT LTRIM(column_1) from df;
/// ```
LTrim,
/// SQL 'octet_length' function (bytes)
/// ```sql
/// SELECT OCTET_LENGTH(column_1) from df;
/// ```
OctetLength,
/// SQL 'regexp_like' function
/// ```sql
/// SELECT REGEXP_LIKE(column_1,'xyz', 'i') from df;
Expand Down Expand Up @@ -368,6 +378,7 @@ impl PolarsSqlFunctions {
"ltrim",
"max",
"min",
"octet_length",
"pow",
"radians",
"round",
Expand Down Expand Up @@ -428,9 +439,11 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions {
// String functions
// ----
"ends_with" => Self::EndsWith,
"length" => Self::Length,
"left" => Self::Left,
"lower" => Self::Lower,
"ltrim" => Self::LTrim,
"octet_length" => Self::OctetLength,
"regexp_like" => Self::RegexpLike,
"rtrim" => Self::RTrim,
"starts_with" => Self::StartsWith,
Expand Down Expand Up @@ -532,6 +545,7 @@ impl SqlFunctionVisitor<'_> {
}
}))
}),
Length => self.visit_unary(|e| e.str().n_chars()),
Lower => self.visit_unary(|e| e.str().to_lowercase()),
LTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().lstrip(None)),
Expand All @@ -541,6 +555,7 @@ impl SqlFunctionVisitor<'_> {
function.args.len()
),
},
OctetLength => self.visit_unary(|e| e.str().lengths()),
RegexpLike => match function.args.len() {
2 => self.visit_binary(|e, s| e.str().contains(s, true)),
3 => self.try_visit_ternary(|e, pat, flags| {
Expand Down
20 changes: 2 additions & 18 deletions py-polars/src/expr/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,11 @@ impl PyExpr {
}

fn str_lengths(&self) -> Self {
let function = |s: Series| {
let ca = s.utf8()?;
Ok(Some(ca.str_lengths().into_series()))
};
self.clone()
.inner
.map(function, GetOutput::from_type(DataType::UInt32))
.with_fmt("str.lengths")
.into()
self.inner.clone().str().lengths().into()
}

fn str_n_chars(&self) -> Self {
let function = |s: Series| {
let ca = s.utf8()?;
Ok(Some(ca.str_n_chars().into_series()))
};
self.clone()
.inner
.map(function, GetOutput::from_type(DataType::UInt32))
.with_fmt("str.n_chars")
.into()
self.inner.clone().str().n_chars().into()
}

#[cfg(feature = "lazy_regex")]
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,27 @@ def test_sql_round_ndigits_errors() -> None:
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_sql_string_lengths() -> None:
df = pl.DataFrame({"words": ["Café", None, "東京"]})

with pl.SQLContext(frame=df) as ctx:
res = ctx.execute(
"""
SELECT
words,
LENGTH(words) AS n_chars,
OCTET_LENGTH(words) AS n_bytes
FROM frame
"""
).collect()

assert res.to_dict(False) == {
"words": ["Café", None, "東京"],
"n_chars": [4, None, 2],
"n_bytes": [5, None, 6],
}


def test_sql_substr() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 07d6ba9

Please sign in to comment.