Skip to content

Commit

Permalink
refactor(rust): split up dsl::functions module (#9213)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 4, 2023
1 parent f800164 commit dcd6113
Show file tree
Hide file tree
Showing 13 changed files with 1,489 additions and 1,459 deletions.
1,457 changes: 0 additions & 1,457 deletions polars/polars-lazy/polars-plan/src/dsl/functions.rs

This file was deleted.

34 changes: 34 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/functions/arity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use super::*;

macro_rules! prepare_binary_function {
($f:ident) => {
move |s: &mut [Series]| {
let s0 = std::mem::take(&mut s[0]);
let s1 = std::mem::take(&mut s[1]);

$f(s0, s1)
}
};
}

/// Apply a closure on the two columns that are evaluated from `Expr` a and `Expr` b.
///
/// The closure takes two arguments, each a `Series`. `output_type` must be the output dtype of the resulting `Series`.
pub fn map_binary<F: 'static>(a: Expr, b: Expr, f: F, output_type: GetOutput) -> Expr
where
F: Fn(Series, Series) -> PolarsResult<Option<Series>> + Send + Sync,
{
let function = prepare_binary_function!(f);
a.map_many(function, &[b], output_type)
}

/// Like [`map_binary`], but used in a groupby-aggregation context.
///
/// See [`Expr::apply`] for the difference between [`map`](Expr::map) and [`apply`](Expr::apply).
pub fn apply_binary<F: 'static>(a: Expr, b: Expr, f: F, output_type: GetOutput) -> Expr
where
F: Fn(Series, Series) -> PolarsResult<Option<Series>> + Send + Sync,
{
let function = prepare_binary_function!(f);
a.apply_many(function, &[b], output_type)
}
18 changes: 18 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/functions/coerce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#[cfg(feature = "dtype-struct")]
use super::*;

/// Take several expressions and collect them into a [`StructChunked`].
#[cfg(feature = "dtype-struct")]
pub fn as_struct(exprs: &[Expr]) -> Expr {
map_multiple(
|s| StructChunked::new(s[0].name(), s).map(|ca| Some(ca.into_series())),
exprs,
GetOutput::map_fields(|fld| Field::new(fld[0].name(), DataType::Struct(fld.to_vec()))),
)
.with_function_options(|mut options| {
options.input_wildcard_expansion = true;
options.fmt_str = "as_struct";
options.pass_name_to_apply = true;
options
})
}
67 changes: 67 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/functions/concat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use super::*;

#[cfg(all(feature = "concat_str", feature = "strings"))]
/// Horizontally concat string columns in linear time
pub fn concat_str<E: AsRef<[Expr]>>(s: E, separator: &str) -> Expr {
let input = s.as_ref().to_vec();
let separator = separator.to_string();

Expr::Function {
input,
function: StringFunction::ConcatHorizontal(separator).into(),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: true,
..Default::default()
},
}
}

#[cfg(all(feature = "concat_str", feature = "strings"))]
/// Format the results of an array of expressions using a format string
pub fn format_str<E: AsRef<[Expr]>>(format: &str, args: E) -> PolarsResult<Expr> {
let mut args: std::collections::VecDeque<Expr> = args.as_ref().to_vec().into();

// Parse the format string, and separate substrings between placeholders
let segments: Vec<&str> = format.split("{}").collect();

polars_ensure!(
segments.len() - 1 == args.len(),
ShapeMismatch: "number of placeholders should equal the number of arguments"
);

let mut exprs: Vec<Expr> = Vec::new();

for (i, s) in segments.iter().enumerate() {
if i > 0 {
if let Some(arg) = args.pop_front() {
exprs.push(arg);
}
}

if !s.is_empty() {
exprs.push(lit(s.to_string()))
}
}

Ok(concat_str(exprs, ""))
}

/// Concat lists entries.
pub fn concat_list<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> PolarsResult<Expr> {
let s: Vec<_> = s.as_ref().iter().map(|e| e.clone().into()).collect();

polars_ensure!(!s.is_empty(), ComputeError: "`concat_list` needs one or more expressions");

Ok(Expr::Function {
input: s,
function: FunctionExpr::ListExpr(ListFunction::Concat),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
fmt_str: "concat_list",
..Default::default()
},
})
}
Loading

0 comments on commit dcd6113

Please sign in to comment.