Skip to content

Commit

Permalink
fix(rust, python): fix cse_plan invalid projection removal (#9700)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 4, 2023
1 parent 230b443 commit 227c850
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 8 deletions.
11 changes: 8 additions & 3 deletions polars/polars-lazy/polars-plan/src/logical_plan/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,15 @@ impl LogicalPlan {
} else {
"UNION"
};
// 3 levels of indentation
// - 0 => UNION ... END UNION
// - 1 => PLAN 0, PLAN 1, ... PLAN N
// - 2 => actual formatting of plans
let sub_sub_indent = sub_indent + 2;
write!(f, "{:indent$}{}", "", name)?;
for (i, plan) in inputs.iter().enumerate() {
write!(f, "\n{:indent$}PLAN {i}:", "")?;
plan._format(f, sub_indent)?;
write!(f, "\n{:sub_indent$}PLAN {i}:", "")?;
plan._format(f, sub_sub_indent)?;
}
write!(f, "\n{:indent$}END {}", "", name)
}
Expand Down Expand Up @@ -227,7 +232,7 @@ impl LogicalPlan {
} => {
write!(f, "{:indent$}AGGREGATE", "")?;
write!(f, "\n{:indent$}\t{aggs:?} BY {keys:?} FROM", "")?;
write!(f, "\n{:indent$}\t{input:?}", "")
input._format(f, sub_indent)
}
Join {
input_left,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::BTreeSet;

use polars_core::prelude::*;

use super::*;
Expand All @@ -13,7 +15,19 @@ use crate::logical_plan::functions::FunctionNode;
/// It is important that this optimization is ran after projection pushdown.
///
/// The schema reported after this optimization is also
pub(super) struct FastProjectionAndCollapse {}
pub(super) struct FastProjectionAndCollapse {
/// keep track of nodes that are already processed when they
/// can be expensive. Schema materialization can be for instance.
processed: BTreeSet<Node>,
}

impl FastProjectionAndCollapse {
pub(super) fn new() -> Self {
Self {
processed: Default::default(),
}
}
}

fn impl_fast_projection(
input: Node,
Expand Down Expand Up @@ -78,8 +92,13 @@ impl OptimizationRule for FastProjectionAndCollapse {
}),
// cleanup projections set in projection pushdown just above caches
// they are not needed.
cache_lp @ Cache { .. } => {
if cache_lp.schema(lp_arena).len() == columns.len() {
cache_lp @ Cache { .. } if self.processed.insert(node) => {
let cache_schema = cache_lp.schema(lp_arena);
if cache_schema.len() == columns.len()
&& cache_schema.iter_names().zip(columns.iter()).all(
|(left_name, right_name)| left_name.as_str() == right_name.as_ref(),
)
{
Some(cache_lp.clone())
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub fn optimize(

// make sure its before slice pushdown.
if projection_pushdown {
rules.push(Box::new(FastProjectionAndCollapse {}));
rules.push(Box::new(FastProjectionAndCollapse::new()));
}
rules.push(Box::new(DelayRechunk::new()));

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-utils/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn index_of<T>(slice: &[T], item: &T) -> Option<usize> {
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct Node(pub usize);

impl Default for Node {
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,49 @@ def test_cse_schema_6081() -> None:
"value": [1, 2, 2],
"min_value": [1, 1, 2],
}


def test_cse_9630() -> None:
df1 = pl.DataFrame(
{
"key": [1],
"x": [1],
}
).lazy()

df2 = pl.DataFrame(
{
"key": [1],
"y": [2],
}
).lazy()

joined_df2 = df1.join(df2, on="key")

all_subsections = (
pl.concat(
[
df1.select("key", pl.col("x").alias("value")),
joined_df2.select("key", pl.col("y").alias("value")),
]
)
.groupby("key")
.agg(
[
pl.col("value"),
]
)
)

intersected_df1 = all_subsections.join(df1, on="key")
intersected_df2 = all_subsections.join(df2, on="key")

assert intersected_df1.join(intersected_df2, on=["key"], how="left").collect(
common_subplan_elimination=True
).to_dict(False) == {
"key": [1],
"value": [[1, 2]],
"x": [1],
"value_right": [[1, 2]],
"y": [2],
}

0 comments on commit 227c850

Please sign in to comment.