From 60d0721cc0cd0f524bdd5c73c8c2e5c55216d925 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sun, 27 Oct 2024 11:42:52 +0400 Subject: [PATCH] fix: Address inadvertent quadratic behaviour in `expand_columns` (#19469) --- crates/polars-plan/src/dsl/selector.rs | 4 +- .../src/plans/conversion/expr_expansion.rs | 100 ++++++------------ 2 files changed, 35 insertions(+), 69 deletions(-) diff --git a/crates/polars-plan/src/dsl/selector.rs b/crates/polars-plan/src/dsl/selector.rs index 16e7d7b374e0..7877edb152df 100644 --- a/crates/polars-plan/src/dsl/selector.rs +++ b/crates/polars-plan/src/dsl/selector.rs @@ -11,7 +11,7 @@ pub enum Selector { Add(Box, Box), Sub(Box, Box), ExclusiveOr(Box, Box), - InterSect(Box, Box), + Intersect(Box, Box), Root(Box), } @@ -34,7 +34,7 @@ impl BitAnd for Selector { #[allow(clippy::suspicious_arithmetic_impl)] fn bitand(self, rhs: Self) -> Self::Output { - Selector::InterSect(Box::new(self), Box::new(rhs)) + Selector::Intersect(Box::new(self), Box::new(rhs)) } } diff --git a/crates/polars-plan/src/plans/conversion/expr_expansion.rs b/crates/polars-plan/src/plans/conversion/expr_expansion.rs index bec3fbe852cd..4709641662f9 100644 --- a/crates/polars-plan/src/plans/conversion/expr_expansion.rs +++ b/crates/polars-plan/src/plans/conversion/expr_expansion.rs @@ -1,5 +1,4 @@ //! this contains code used for rewriting projections, expanding wildcards, regex selection etc. -use std::ops::BitXor; use super::*; @@ -176,26 +175,28 @@ fn expand_columns( schema: &Schema, exclude: &PlHashSet, ) -> PolarsResult<()> { - let mut is_valid = true; + if !expr.into_iter().all(|e| match e { + // check for invalid expansions such as `col([a, b]) + col([c, d])` + Expr::Columns(ref members) => members.as_ref() == names, + _ => true, + }) { + polars_bail!(ComputeError: "expanding more than one `col` is not allowed"); + } for name in names { if !exclude.contains(name) { - let new_expr = expr.clone(); - let (new_expr, new_expr_valid) = replace_columns_with_column(new_expr, names, name); - is_valid &= new_expr_valid; - // we may have regex col in columns. - #[allow(clippy::collapsible_else_if)] + let new_expr = expr.clone().map_expr(|e| match e { + Expr::Columns(_) => Expr::Column((*name).clone()), + Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), + e => e, + }); + #[cfg(feature = "regex")] - { - replace_regex(&new_expr, result, schema, exclude)?; - } + replace_regex(&new_expr, result, schema, exclude)?; + #[cfg(not(feature = "regex"))] - { - let new_expr = rewrite_special_aliases(new_expr)?; - result.push(new_expr) - } + result.push(rewrite_special_aliases(new_expr)?); } } - polars_ensure!(is_valid, ComputeError: "expanding more than one `col` is not allowed"); Ok(()) } @@ -246,30 +247,6 @@ fn replace_dtype_or_index_with_column( }) } -/// This replaces the columns Expr with a Column Expr. It also removes the Exclude Expr from the -/// expression chain. -pub(super) fn replace_columns_with_column( - mut expr: Expr, - names: &[PlSmallStr], - column_name: &PlSmallStr, -) -> (Expr, bool) { - let mut is_valid = true; - expr = expr.map_expr(|e| match e { - Expr::Columns(members) => { - // `col([a, b]) + col([c, d])` - if members.as_ref() == names { - Expr::Column(column_name.clone()) - } else { - is_valid = false; - Expr::Columns(members) - } - }, - Expr::Exclude(input, _) => Arc::unwrap_or_clone(input), - e => e, - }); - (expr, is_valid) -} - fn dtypes_match(d1: &DataType, d2: &DataType) -> bool { match (d1, d2) { // note: allow Datetime "*" wildcard for timezones... @@ -562,7 +539,7 @@ fn expand_function_inputs( }) } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] struct ExpansionFlags { multiple_columns: bool, has_nth: bool, @@ -819,42 +796,31 @@ fn replace_selector_inner( members.extend(scratch.drain(..)) }, Selector::Add(lhs, rhs) => { + let mut tmp_members: PlIndexSet = Default::default(); replace_selector_inner(*lhs, members, scratch, schema, keys)?; - let mut rhs_members: PlIndexSet = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - members.extend(rhs_members) + replace_selector_inner(*rhs, &mut tmp_members, scratch, schema, keys)?; + members.extend(tmp_members) }, Selector::ExclusiveOr(lhs, rhs) => { - let mut lhs_members = Default::default(); - replace_selector_inner(*lhs, &mut lhs_members, scratch, schema, keys)?; + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - let xor_members = lhs_members.bitxor(&rhs_members); - *members = xor_members; + *members = tmp_members.symmetric_difference(members).cloned().collect(); }, - Selector::InterSect(lhs, rhs) => { - replace_selector_inner(*lhs, members, scratch, schema, keys)?; + Selector::Intersect(lhs, rhs) => { + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - *members = members.intersection(&rhs_members).cloned().collect() + *members = tmp_members.intersection(members).cloned().collect(); }, Selector::Sub(lhs, rhs) => { - replace_selector_inner(*lhs, members, scratch, schema, keys)?; + let mut tmp_members = Default::default(); + replace_selector_inner(*lhs, &mut tmp_members, scratch, schema, keys)?; + replace_selector_inner(*rhs, members, scratch, schema, keys)?; - let mut rhs_members = Default::default(); - replace_selector_inner(*rhs, &mut rhs_members, scratch, schema, keys)?; - - let mut new_members = PlIndexSet::with_capacity(members.len()); - for e in members.drain(..) { - if !rhs_members.contains(&e) { - new_members.insert(e); - } - } - *members = new_members; + *members = tmp_members.difference(members).cloned().collect(); }, } Ok(())