From 39d91848c9f616f897cba1418e3f6dc1c30a3730 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 5 Jun 2023 08:23:41 +0200 Subject: [PATCH 1/4] feat(rust, python): building blocks for expression expansion sets --- .../polars-lazy/polars-plan/src/dsl/expr.rs | 250 +----------------- .../polars-plan/src/dsl/expr_dyn_fn.rs | 248 +++++++++++++++++ .../polars-lazy/polars-plan/src/dsl/meta.rs | 34 ++- polars/polars-lazy/polars-plan/src/dsl/mod.rs | 2 + .../polars-plan/src/dsl/selector.rs | 52 ++++ .../src/logical_plan/conversion.rs | 1 + .../polars-plan/src/logical_plan/format.rs | 1 + .../polars-plan/src/logical_plan/iterator.rs | 20 +- .../src/logical_plan/projection.rs | 207 ++++++++++----- py-polars/polars/expr/meta.py | 12 + py-polars/src/expr/meta.rs | 30 +++ py-polars/tests/unit/namespaces/test_meta.py | 16 ++ 12 files changed, 566 insertions(+), 307 deletions(-) create mode 100644 polars/polars-lazy/polars-plan/src/dsl/expr_dyn_fn.rs create mode 100644 polars/polars-lazy/polars-plan/src/dsl/selector.rs diff --git a/polars/polars-lazy/polars-plan/src/dsl/expr.rs b/polars/polars-lazy/polars-plan/src/dsl/expr.rs index b5b91a633872..0cad240c4a84 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/expr.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/expr.rs @@ -1,255 +1,14 @@ use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::ops::Deref; use polars_core::prelude::*; -use polars_core::utils::get_supertype; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "serde")] -use serde::{Deserializer, Serializer}; +pub use super::expr_dyn_fn::*; use crate::dsl::function_expr::FunctionExpr; use crate::prelude::*; -/// A wrapper trait for any closure `Fn(Vec) -> PolarsResult` -pub trait SeriesUdf: Send + Sync { - fn call_udf(&self, s: &mut [Series]) -> PolarsResult>; -} - -impl SeriesUdf for F -where - F: Fn(&mut [Series]) -> PolarsResult> + Send + Sync, -{ - fn call_udf(&self, s: &mut [Series]) -> PolarsResult> { - self(s) - } -} - -impl Debug for dyn SeriesUdf { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "SeriesUdf") - } -} - -/// A wrapper trait for any binary closure `Fn(Series, Series) -> PolarsResult` -pub trait SeriesBinaryUdf: Send + Sync { - fn call_udf(&self, a: Series, b: Series) -> PolarsResult; -} - -impl SeriesBinaryUdf for F -where - F: Fn(Series, Series) -> PolarsResult + Send + Sync, -{ - fn call_udf(&self, a: Series, b: Series) -> PolarsResult { - self(a, b) - } -} - -impl Debug for dyn SeriesBinaryUdf { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "SeriesBinaryUdf") - } -} - -impl Default for SpecialEq> { - fn default() -> Self { - panic!("implementation error"); - } -} - -impl Default for SpecialEq> { - fn default() -> Self { - let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None; - SpecialEq::new(Arc::new(output_field)) - } -} - -pub trait RenameAliasFn: Send + Sync { - fn call(&self, name: &str) -> PolarsResult; -} - -impl PolarsResult + Send + Sync> RenameAliasFn for F { - fn call(&self, name: &str) -> PolarsResult { - self(name) - } -} - -impl Debug for dyn RenameAliasFn { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "RenameAliasFn") - } -} - -#[derive(Clone)] -/// Wrapper type that has special equality properties -/// depending on the inner type specialization -pub struct SpecialEq(T); - -#[cfg(feature = "serde")] -impl Serialize for SpecialEq { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - self.0.serialize(serializer) - } -} - -#[cfg(feature = "serde")] -impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'a>, - { - let t = T::deserialize(deserializer)?; - Ok(SpecialEq(t)) - } -} - -impl SpecialEq { - pub fn new(val: T) -> Self { - SpecialEq(val) - } -} - -impl PartialEq for SpecialEq> { - fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.0, &other.0) - } -} - -impl PartialEq for SpecialEq { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Debug for SpecialEq { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "no_eq") - } -} - -impl Deref for SpecialEq { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -pub trait BinaryUdfOutputField: Send + Sync { - fn get_field( - &self, - input_schema: &Schema, - cntxt: Context, - field_a: &Field, - field_b: &Field, - ) -> Option; -} - -impl BinaryUdfOutputField for F -where - F: Fn(&Schema, Context, &Field, &Field) -> Option + Send + Sync, -{ - fn get_field( - &self, - input_schema: &Schema, - cntxt: Context, - field_a: &Field, - field_b: &Field, - ) -> Option { - self(input_schema, cntxt, field_a, field_b) - } -} - -pub trait FunctionOutputField: Send + Sync { - fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field; -} - -pub type GetOutput = SpecialEq>; - -impl Default for GetOutput { - fn default() -> Self { - SpecialEq::new(Arc::new( - |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| fields[0].clone(), - )) - } -} - -impl GetOutput { - pub fn same_type() -> Self { - Default::default() - } - - pub fn from_type(dt: DataType) -> Self { - SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { - Field::new(flds[0].name(), dt.clone()) - })) - } - - pub fn map_field Field + Send + Sync>(f: F) -> Self { - SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { - f(&flds[0]) - })) - } - - pub fn map_fields Field + Send + Sync>(f: F) -> Self { - SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { - f(flds) - })) - } - - pub fn map_dtype DataType + Send + Sync>(f: F) -> Self { - SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { - let mut fld = flds[0].clone(); - let new_type = f(fld.data_type()); - fld.coerce(new_type); - fld - })) - } - - pub fn float_type() -> Self { - Self::map_dtype(|dt| match dt { - DataType::Float32 => DataType::Float32, - _ => DataType::Float64, - }) - } - - pub fn super_type() -> Self { - Self::map_dtypes(|dtypes| { - let mut st = dtypes[0].clone(); - for dt in &dtypes[1..] { - st = get_supertype(&st, dt).unwrap(); - } - st - }) - } - - pub fn map_dtypes(f: F) -> Self - where - F: 'static + Fn(&[&DataType]) -> DataType + Send + Sync, - { - SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { - let mut fld = flds[0].clone(); - let dtypes = flds.iter().map(|fld| fld.data_type()).collect::>(); - let new_type = f(&dtypes); - fld.coerce(new_type); - fld - })) - } -} - -impl FunctionOutputField for F -where - F: Fn(&Schema, Context, &[Field]) -> Field + Send + Sync, -{ - fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field { - self(input_schema, cntxt, fields) - } -} - #[derive(PartialEq, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum AggExpr { @@ -399,6 +158,13 @@ pub enum Expr { input: Box, id: usize, }, + /// Expressions in this node should only be expanding + /// e.g. + /// `Expr::Columns` + /// `Expr::Dtypes` + /// `Expr::Wildcard` + /// `Expr::Exclude` + Selector(super::selector::Selector), } // TODO! derive. This is only a temporary fix diff --git a/polars/polars-lazy/polars-plan/src/dsl/expr_dyn_fn.rs b/polars/polars-lazy/polars-plan/src/dsl/expr_dyn_fn.rs new file mode 100644 index 000000000000..07d842a3ea74 --- /dev/null +++ b/polars/polars-lazy/polars-plan/src/dsl/expr_dyn_fn.rs @@ -0,0 +1,248 @@ +use std::fmt::Formatter; +use std::ops::Deref; + +use polars_core::utils::get_supertype; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +#[cfg(feature = "serde")] +use serde::{Deserializer, Serializer}; + +use super::*; + +/// A wrapper trait for any closure `Fn(Vec) -> PolarsResult` +pub trait SeriesUdf: Send + Sync { + fn call_udf(&self, s: &mut [Series]) -> PolarsResult>; +} + +impl SeriesUdf for F +where + F: Fn(&mut [Series]) -> PolarsResult> + Send + Sync, +{ + fn call_udf(&self, s: &mut [Series]) -> PolarsResult> { + self(s) + } +} + +impl Debug for dyn SeriesUdf { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SeriesUdf") + } +} + +/// A wrapper trait for any binary closure `Fn(Series, Series) -> PolarsResult` +pub trait SeriesBinaryUdf: Send + Sync { + fn call_udf(&self, a: Series, b: Series) -> PolarsResult; +} + +impl SeriesBinaryUdf for F +where + F: Fn(Series, Series) -> PolarsResult + Send + Sync, +{ + fn call_udf(&self, a: Series, b: Series) -> PolarsResult { + self(a, b) + } +} + +impl Debug for dyn SeriesBinaryUdf { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SeriesBinaryUdf") + } +} + +impl Default for SpecialEq> { + fn default() -> Self { + panic!("implementation error"); + } +} + +impl Default for SpecialEq> { + fn default() -> Self { + let output_field = move |_: &Schema, _: Context, _: &Field, _: &Field| None; + SpecialEq::new(Arc::new(output_field)) + } +} + +pub trait RenameAliasFn: Send + Sync { + fn call(&self, name: &str) -> PolarsResult; +} + +impl PolarsResult + Send + Sync> RenameAliasFn for F { + fn call(&self, name: &str) -> PolarsResult { + self(name) + } +} + +impl Debug for dyn RenameAliasFn { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "RenameAliasFn") + } +} + +#[derive(Clone)] +/// Wrapper type that has special equality properties +/// depending on the inner type specialization +pub struct SpecialEq(T); + +#[cfg(feature = "serde")] +impl Serialize for SpecialEq { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + self.0.serialize(serializer) + } +} + +#[cfg(feature = "serde")] +impl<'a, T: Deserialize<'a>> Deserialize<'a> for SpecialEq { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'a>, + { + let t = T::deserialize(deserializer)?; + Ok(SpecialEq(t)) + } +} + +impl SpecialEq { + pub fn new(val: T) -> Self { + SpecialEq(val) + } +} + +impl PartialEq for SpecialEq> { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.0, &other.0) + } +} + +impl PartialEq for SpecialEq { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Debug for SpecialEq { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "no_eq") + } +} + +impl Deref for SpecialEq { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub trait BinaryUdfOutputField: Send + Sync { + fn get_field( + &self, + input_schema: &Schema, + cntxt: Context, + field_a: &Field, + field_b: &Field, + ) -> Option; +} + +impl BinaryUdfOutputField for F +where + F: Fn(&Schema, Context, &Field, &Field) -> Option + Send + Sync, +{ + fn get_field( + &self, + input_schema: &Schema, + cntxt: Context, + field_a: &Field, + field_b: &Field, + ) -> Option { + self(input_schema, cntxt, field_a, field_b) + } +} + +pub trait FunctionOutputField: Send + Sync { + fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field; +} + +pub type GetOutput = SpecialEq>; + +impl Default for GetOutput { + fn default() -> Self { + SpecialEq::new(Arc::new( + |_input_schema: &Schema, _cntxt: Context, fields: &[Field]| fields[0].clone(), + )) + } +} + +impl GetOutput { + pub fn same_type() -> Self { + Default::default() + } + + pub fn from_type(dt: DataType) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + Field::new(flds[0].name(), dt.clone()) + })) + } + + pub fn map_field Field + Send + Sync>(f: F) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + f(&flds[0]) + })) + } + + pub fn map_fields Field + Send + Sync>(f: F) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + f(flds) + })) + } + + pub fn map_dtype DataType + Send + Sync>(f: F) -> Self { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + let mut fld = flds[0].clone(); + let new_type = f(fld.data_type()); + fld.coerce(new_type); + fld + })) + } + + pub fn float_type() -> Self { + Self::map_dtype(|dt| match dt { + DataType::Float32 => DataType::Float32, + _ => DataType::Float64, + }) + } + + pub fn super_type() -> Self { + Self::map_dtypes(|dtypes| { + let mut st = dtypes[0].clone(); + for dt in &dtypes[1..] { + st = get_supertype(&st, dt).unwrap(); + } + st + }) + } + + pub fn map_dtypes(f: F) -> Self + where + F: 'static + Fn(&[&DataType]) -> DataType + Send + Sync, + { + SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| { + let mut fld = flds[0].clone(); + let dtypes = flds.iter().map(|fld| fld.data_type()).collect::>(); + let new_type = f(&dtypes); + fld.coerce(new_type); + fld + })) + } +} + +impl FunctionOutputField for F +where + F: Fn(&Schema, Context, &[Field]) -> Field + Send + Sync, +{ + fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field { + self(input_schema, cntxt, fields) + } +} diff --git a/polars/polars-lazy/polars-plan/src/dsl/meta.rs b/polars/polars-lazy/polars-plan/src/dsl/meta.rs index 96b3960f4439..723df12f3a08 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/meta.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/meta.rs @@ -1,4 +1,5 @@ use super::*; +use crate::dsl::selector::Selector; use crate::logical_plan::projection::is_regex_projection; /// Specialized expressions for Categorical dtypes. @@ -55,7 +56,7 @@ impl MetaNameSpace { pub fn has_multiple_outputs(&self) -> bool { self.0.into_iter().any(|e| match e { - Expr::Wildcard | Expr::Columns(_) | Expr::DtypeColumn(_) => true, + Expr::Selector(_) | Expr::Wildcard | Expr::Columns(_) | Expr::DtypeColumn(_) => true, Expr::Column(name) => is_regex_projection(name), _ => false, }) @@ -67,4 +68,35 @@ impl MetaNameSpace { _ => false, }) } + + pub fn _selector_add(self, other: Expr) -> PolarsResult { + if let Expr::Selector(mut s) = self.0 { + if let Expr::Selector(s_other) = other { + s = &s + &s_other; + } else { + s.add.push(other); + } + Ok(Expr::Selector(s)) + } else { + polars_bail!(ComputeError: "expected selector, got {}", self.0) + } + } + + pub fn _selector_sub(self, other: Expr) -> PolarsResult { + if let Expr::Selector(mut s) = self.0 { + if let Expr::Selector(s_other) = other { + s = &s - &s_other; + } else { + s.subtract.push(other); + } + Ok(Expr::Selector(s)) + } else { + polars_bail!(ComputeError: "expected selector, got {}", self.0) + } + } + + pub fn _into_selector(self) -> PolarsResult { + polars_ensure!(!matches!(self.0, Expr::Selector(_)), ComputeError: "nested selectors not allowed"); + Ok(Expr::Selector(Selector::new(self.0))) + } } diff --git a/polars/polars-lazy/polars-plan/src/dsl/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/mod.rs index e4112b5a5d53..fea74a0aa41b 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/mod.rs @@ -12,6 +12,7 @@ pub mod binary; #[cfg(feature = "temporal")] pub mod dt; mod expr; +mod expr_dyn_fn; mod from; pub(crate) mod function_expr; #[cfg(feature = "compile")] @@ -21,6 +22,7 @@ mod list; mod meta; pub(crate) mod names; mod options; +mod selector; #[cfg(feature = "strings")] pub mod string; #[cfg(feature = "dtype-struct")] diff --git a/polars/polars-lazy/polars-plan/src/dsl/selector.rs b/polars/polars-lazy/polars-plan/src/dsl/selector.rs new file mode 100644 index 000000000000..8e8c46c057a5 --- /dev/null +++ b/polars/polars-lazy/polars-plan/src/dsl/selector.rs @@ -0,0 +1,52 @@ +use std::ops::{Add, Sub}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use super::*; + +#[derive(Clone, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct Selector { + pub(crate) add: Vec, + pub(crate) subtract: Vec, +} + +impl Selector { + pub(crate) fn new(e: Expr) -> Self { + Self { + add: vec![e], + subtract: vec![], + } + } +} + +impl Add for &Selector { + type Output = Selector; + + fn add(self, rhs: Self) -> Self::Output { + let mut add = Vec::with_capacity(self.add.len() + rhs.add.len()); + add.extend_from_slice(&self.add); + add.extend_from_slice(&rhs.add); + + let mut subtract = Vec::with_capacity(self.subtract.len() + rhs.subtract.len()); + subtract.extend_from_slice(&self.subtract); + subtract.extend_from_slice(&rhs.subtract); + Selector { add, subtract } + } +} + +impl Sub for &Selector { + type Output = Selector; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: Self) -> Self::Output { + let mut subtract = Vec::with_capacity(self.subtract.len() + rhs.subtract.len()); + subtract.extend_from_slice(&self.subtract); + subtract.extend_from_slice(&rhs.subtract); + Selector { + add: self.add.clone(), + subtract, + } + } +} diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/conversion.rs b/polars/polars-lazy/polars-plan/src/logical_plan/conversion.rs index b1d40e7d775b..60041dcfc40e 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/conversion.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/conversion.rs @@ -157,6 +157,7 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> Node { Expr::RenameAlias { .. } => panic!("no `rename_alias` expected at this point"), Expr::Columns { .. } => panic!("no `columns` expected at this point"), Expr::DtypeColumn { .. } => panic!("no `dtype-columns` expected at this point"), + Expr::Selector(_) => panic!("no `selector` expected at this point"), }; arena.add(v) } diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/format.rs b/polars/polars-lazy/polars-plan/src/logical_plan/format.rs index 0351e2ff65fc..d3e2df2a5f04 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/format.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/format.rs @@ -417,6 +417,7 @@ impl Debug for Expr { Columns(names) => write!(f, "COLUMNS({names:?})"), DtypeColumn(dt) => write!(f, "COLUMN OF DTYPE: {dt:?}"), Cache { input, .. } => write!(f, "CACHE {input:?}"), + Selector(s) => write!(f, "SET({:?}) - SET({:?})", s.add, s.subtract), } } } diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/iterator.rs b/polars/polars-lazy/polars-plan/src/logical_plan/iterator.rs index f643f1680f1d..e1950831b060 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/iterator.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/iterator.rs @@ -1,3 +1,5 @@ +use polars_arrow::error::PolarsResult; + use crate::prelude::*; macro_rules! push_expr { ($current_expr:expr, $push:ident, $iter:ident) => {{ @@ -91,6 +93,14 @@ macro_rules! push_expr { KeepName(e) => $push(e), Cache { input, .. } => $push(input), RenameAlias { expr, .. } => $push(expr), + Selector(s) => { + for e in s.add.$iter() { + $push(e); + } + for e in s.subtract.$iter() { + $push(e) + } + } } }}; } @@ -116,17 +126,25 @@ impl<'a> ExprMut<'a> { pub fn apply(&mut self, mut f: F) where F: FnMut(&mut Expr) -> bool, + { + let _ = self.try_apply(|e| Ok(f(e))); + } + + pub fn try_apply(&mut self, mut f: F) -> PolarsResult<()> + where + F: FnMut(&mut Expr) -> PolarsResult, { while let Some(current_expr) = self.stack.pop() { // the order is important, we first modify the Expr // before we push its children on the stack. // The modification can make the children invalid. - if !f(current_expr) { + if !f(current_expr)? { break; } let mut push = |e: &'a mut Expr| self.stack.push(e); push_expr!(current_expr, push, iter_mut); } + Ok(()) } } diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs b/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs index 21d33c097a28..7484f4b46f5d 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs @@ -97,8 +97,7 @@ fn expand_regex( schema: &Schema, pattern: &str, ) -> PolarsResult<()> { - let re = regex::Regex::new(pattern) - .unwrap_or_else(|_| panic!("invalid regular expression in column: {pattern}")); + let re = regex::Regex::new(pattern)?; for name in schema.iter_names() { if re.is_match(name) { let mut new_expr = expr.clone(); @@ -334,6 +333,46 @@ fn early_supertype(inputs: &[Expr], schema: &Schema) -> Option { st } +#[derive(Copy, Clone)] +struct ExpansionFlags { + multiple_columns: bool, + has_nth: bool, + has_wildcard: bool, + replace_fill_null_type: bool, + has_selector: bool, +} + +fn find_flags(expr: &Expr) -> ExpansionFlags { + let mut multiple_columns = false; + let mut has_nth = false; + let mut has_wildcard = false; + let mut replace_fill_null_type = false; + let mut has_selector = false; + + // do a single pass and collect all flags at once. + // supertypes/modification that can be done in place are also don e in that pass + for expr in expr { + match expr { + Expr::Columns(_) | Expr::DtypeColumn(_) => multiple_columns = true, + Expr::Nth(_) => has_nth = true, + Expr::Wildcard => has_wildcard = true, + Expr::Selector(_) => has_selector = true, + Expr::Function { + function: FunctionExpr::FillNull { .. }, + .. + } => replace_fill_null_type = true, + _ => {} + } + } + ExpansionFlags { + multiple_columns, + has_nth, + has_wildcard, + replace_fill_null_type, + has_selector, + } +} + /// In case of single col(*) -> do nothing, no selection is the same as select all /// In other cases replace the wildcard with an expression with all columns pub(crate) fn rewrite_projections( @@ -349,68 +388,14 @@ pub(crate) fn rewrite_projections( // functions can have col(["a", "b"]) or col(Utf8) as inputs expr = expand_function_inputs(expr, schema); - let mut multiple_columns = false; - let mut has_nth = false; - let mut has_wildcard = false; - let mut replace_fill_null_type = false; - - // do a single pass and collect all flags at once. - // supertypes/modification that can be done in place are also don e in that pass - for expr in &expr { - match expr { - Expr::Columns(_) | Expr::DtypeColumn(_) => multiple_columns = true, - Expr::Nth(_) => has_nth = true, - Expr::Wildcard => has_wildcard = true, - Expr::Function { - function: FunctionExpr::FillNull { .. }, - .. - } => replace_fill_null_type = true, - _ => {} - } - } - - if has_nth { - replace_nth(&mut expr, schema); + let mut flags = find_flags(&expr); + if flags.has_selector { + replace_selector(&mut expr, schema, keys)?; + // the selector is replaced with Expr::Columns + flags.multiple_columns = true; } - // has multiple column names - // the expanded columns are added to the result - if multiple_columns { - if let Some(e) = expr - .into_iter() - .find(|e| matches!(e, Expr::Columns(_) | Expr::DtypeColumn(_))) - { - match &e { - Expr::Columns(names) => expand_columns(&expr, &mut result, names)?, - Expr::DtypeColumn(dtypes) => { - // keep track of column excluded from the dtypes - let exclude = prepare_excluded(&expr, schema, keys)?; - expand_dtypes(&expr, &mut result, schema, dtypes, &exclude)? - } - _ => {} - } - } - } - // has multiple column names due to wildcards - else if has_wildcard { - // keep track of column excluded from the wildcard - let exclude = prepare_excluded(&expr, schema, keys)?; - // this path prepares the wildcard as input for the Function Expr - replace_wildcard(&expr, &mut result, &exclude, schema)?; - } - // can have multiple column names due to a regex - else { - #[allow(clippy::collapsible_else_if)] - #[cfg(feature = "regex")] - { - replace_regex(&expr, &mut result, schema)? - } - #[cfg(not(feature = "regex"))] - { - let expr = rewrite_special_aliases(expr)?; - result.push(expr) - } - } + replace_and_add_to_results(expr, flags, &mut result, schema, keys)?; // this is done after all expansion (wildcard, column, dtypes) // have been done. This will ensure the conversion to aexpr does @@ -418,7 +403,7 @@ pub(crate) fn rewrite_projections( // the expanded expressions are written to result, so we pick // them up there. - if replace_fill_null_type { + if flags.replace_fill_null_type { for e in &mut result[result_offset..] { e.mutate().apply(|e| { if let Expr::Function { @@ -440,3 +425,99 @@ pub(crate) fn rewrite_projections( } Ok(result) } + +fn replace_and_add_to_results( + mut expr: Expr, + flags: ExpansionFlags, + result: &mut Vec, + schema: &Schema, + keys: &[Expr], +) -> PolarsResult<()> { + if flags.has_nth { + replace_nth(&mut expr, schema); + } + + // has multiple column names + // the expanded columns are added to the result + if flags.multiple_columns { + if let Some(e) = expr + .into_iter() + .find(|e| matches!(e, Expr::Columns(_) | Expr::DtypeColumn(_))) + { + match &e { + Expr::Columns(names) => expand_columns(&expr, result, names)?, + Expr::DtypeColumn(dtypes) => { + // keep track of column excluded from the dtypes + let exclude = prepare_excluded(&expr, schema, keys)?; + expand_dtypes(&expr, result, schema, dtypes, &exclude)? + } + _ => {} + } + } + } + // has multiple column names due to wildcards + else if flags.has_wildcard { + // keep track of column excluded from the wildcard + let exclude = prepare_excluded(&expr, schema, keys)?; + // this path prepares the wildcard as input for the Function Expr + replace_wildcard(&expr, result, &exclude, schema)?; + } + // can have multiple column names due to a regex + else { + #[allow(clippy::collapsible_else_if)] + #[cfg(feature = "regex")] + { + replace_regex(&expr, result, schema)? + } + #[cfg(not(feature = "regex"))] + { + let expr = rewrite_special_aliases(expr)?; + result.push(expr) + } + } + Ok(()) +} + +fn replace_selector(expr: &mut Expr, schema: &Schema, keys: &[Expr]) -> PolarsResult<()> { + // first pass we replace the selectors + // with Expr::Columns + // we expand the `to_add` columns + // and then subtract the `to_subtract` columns + expr.mutate().try_apply(|e| match e { + Expr::Selector(s) => { + let mut to_add = Vec::with_capacity(s.add.len() * 3); + + for add_e in std::mem::take(&mut s.add) { + let local_flags = find_flags(&add_e); + replace_and_add_to_results(add_e, local_flags, &mut to_add, schema, keys)?; + } + + let mut to_subtract = Vec::with_capacity(s.subtract.len() * 3); + for sub_e in std::mem::take(&mut s.subtract) { + let local_flags = find_flags(&sub_e); + replace_and_add_to_results(sub_e, local_flags, &mut to_subtract, schema, keys)?; + } + + let to_subtract = to_subtract + .iter() + .map(|e| { + let Expr::Column(name) = e else {unreachable!()}; + name.as_ref() + }) + .collect::>(); + + let mut final_names = Vec::with_capacity(to_add.len()); + for e in to_add { + let Expr::Column(name) = e else {unreachable!()}; + if !to_subtract.contains(name.as_ref()) { + final_names.push(name.to_string()) + } + } + *e = Expr::Columns(final_names); + + Ok(true) + } + _ => Ok(true), + })?; + Ok(()) +} diff --git a/py-polars/polars/expr/meta.py b/py-polars/polars/expr/meta.py index 5465d47f3514..5b0aab399a9c 100644 --- a/py-polars/polars/expr/meta.py +++ b/py-polars/polars/expr/meta.py @@ -68,3 +68,15 @@ def root_names(self) -> list[str]: def undo_aliases(self) -> Expr: """Undo any renaming operation like ``alias`` or ``keep_name``.""" return wrap_expr(self._pyexpr.meta_undo_aliases()) + + def _as_selector(self) -> Expr: + """Turn this expression in a selector.""" + return wrap_expr(self._pyexpr._meta_as_selector()) + + def _selector_add(self, other: Expr) -> Expr: + """Add selectors.""" + return wrap_expr(self._pyexpr._meta_selector_add(other._pyexpr)) + + def _selector_sub(self, other: Expr) -> Expr: + """Subtract selectors.""" + return wrap_expr(self._pyexpr._meta_selector_sub(other._pyexpr)) diff --git a/py-polars/src/expr/meta.rs b/py-polars/src/expr/meta.rs index 46029c26c435..913ef0035768 100644 --- a/py-polars/src/expr/meta.rs +++ b/py-polars/src/expr/meta.rs @@ -44,4 +44,34 @@ impl PyExpr { fn meta_is_regex_projection(&self) -> bool { self.inner.clone().meta().is_regex_projection() } + + fn _meta_selector_add(&self, other: PyExpr) -> PyResult { + let out = self + .inner + .clone() + .meta() + ._selector_add(other.inner) + .map_err(PyPolarsErr::from)?; + Ok(out.into()) + } + + fn _meta_selector_sub(&self, other: PyExpr) -> PyResult { + let out = self + .inner + .clone() + .meta() + ._selector_sub(other.inner) + .map_err(PyPolarsErr::from)?; + Ok(out.into()) + } + + fn _meta_as_selector(&self) -> PyResult { + let out = self + .inner + .clone() + .meta() + ._into_selector() + .map_err(PyPolarsErr::from)?; + Ok(out.into()) + } } diff --git a/py-polars/tests/unit/namespaces/test_meta.py b/py-polars/tests/unit/namespaces/test_meta.py index 88ebd8fbb50b..1b2f21916895 100644 --- a/py-polars/tests/unit/namespaces/test_meta.py +++ b/py-polars/tests/unit/namespaces/test_meta.py @@ -64,3 +64,19 @@ def test_meta_is_regex_projection() -> None: e = pl.col("^.*$").alias("bar") assert e.meta.is_regex_projection() assert e.meta.has_multiple_outputs() + + +def test_selector_expansion() -> None: + df = pl.DataFrame({name: [] for name in "abcde"}) + + s1 = pl.all().meta._as_selector() + s2 = pl.col(["a", "b"]) + s = s1.meta._selector_sub(s2) + assert df.select(s).columns == ["c", "d", "e"] + + s1 = pl.col("^a|b$").meta._as_selector() + s = s1.meta._selector_add(pl.col(["d", "e"])) + assert df.select(s).columns == ["a", "b", "d", "e"] + + s = s.meta._selector_sub(pl.col("d")) + assert df.select(s).columns == ["a", "b", "e"] From 79a5ebf53c0147f1c63aa38e3a1923d2ac0e3477 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 5 Jun 2023 11:02:07 +0200 Subject: [PATCH 2/4] features --- polars/polars-lazy/polars-plan/src/logical_plan/projection.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs b/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs index 7484f4b46f5d..dd1e945c7ab0 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs @@ -97,7 +97,8 @@ fn expand_regex( schema: &Schema, pattern: &str, ) -> PolarsResult<()> { - let re = regex::Regex::new(pattern)?; + let re = + regex::Regex::new(pattern).map_err(|e| polars_err!(ComputeError: "invalid regex {}", e))?; for name in schema.iter_names() { if re.is_match(name) { let mut new_expr = expr.clone(); From db58cb297d8cf8e8980c651f8fbf52c112a179e9 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 5 Jun 2023 11:27:16 +0200 Subject: [PATCH 3/4] ensure duplicates are handled --- polars/polars-lazy/polars-plan/src/logical_plan/projection.rs | 4 +++- py-polars/tests/unit/namespaces/test_meta.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs b/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs index dd1e945c7ab0..a2d76957da2a 100644 --- a/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs +++ b/polars/polars-lazy/polars-plan/src/logical_plan/projection.rs @@ -508,9 +508,11 @@ fn replace_selector(expr: &mut Expr, schema: &Schema, keys: &[Expr]) -> PolarsRe .collect::>(); let mut final_names = Vec::with_capacity(to_add.len()); + // keep a set to ensure we don't create duplicates + let mut added = PlHashSet::with_capacity(to_add.len()); for e in to_add { let Expr::Column(name) = e else {unreachable!()}; - if !to_subtract.contains(name.as_ref()) { + if !to_subtract.contains(name.as_ref()) && added.insert(name.clone()) { final_names.push(name.to_string()) } } diff --git a/py-polars/tests/unit/namespaces/test_meta.py b/py-polars/tests/unit/namespaces/test_meta.py index 1b2f21916895..469bbce12821 100644 --- a/py-polars/tests/unit/namespaces/test_meta.py +++ b/py-polars/tests/unit/namespaces/test_meta.py @@ -80,3 +80,7 @@ def test_selector_expansion() -> None: s = s.meta._selector_sub(pl.col("d")) assert df.select(s).columns == ["a", "b", "e"] + + # add a duplicate, this tests if they are pruned + s = s.meta._selector_add(pl.col("a")) + assert df.select(s).columns == ["a", "b", "e"] \ No newline at end of file From a1a56e9f0b83857d71aac16fc72b3551dcea9eb7 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 5 Jun 2023 11:30:58 +0200 Subject: [PATCH 4/4] fmt --- py-polars/tests/unit/namespaces/test_meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/tests/unit/namespaces/test_meta.py b/py-polars/tests/unit/namespaces/test_meta.py index 469bbce12821..822215ec0086 100644 --- a/py-polars/tests/unit/namespaces/test_meta.py +++ b/py-polars/tests/unit/namespaces/test_meta.py @@ -83,4 +83,4 @@ def test_selector_expansion() -> None: # add a duplicate, this tests if they are pruned s = s.meta._selector_add(pl.col("a")) - assert df.select(s).columns == ["a", "b", "e"] \ No newline at end of file + assert df.select(s).columns == ["a", "b", "e"]