Skip to content

Commit

Permalink
Fix optimize projections bug (#8960)
Browse files Browse the repository at this point in the history
* Fix optimize projections bug

* Add new dataframe test
  • Loading branch information
mustafasrepo authored Jan 25, 2024
1 parent d6ab343 commit 4ac7de1
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 41 deletions.
63 changes: 59 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,19 @@ use std::sync::Arc;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::execution::context::{SessionContext, SessionState};
use datafusion::prelude::JoinType;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use datafusion::test_util::parquet_test_data;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOptions};
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::{
array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col,
scalar_subquery, sum, wildcard, AggregateFunction, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col,
scalar_subquery, sum, when, wildcard, AggregateFunction, Expr, ExprSchemable,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_physical_expr::var_provider::{VarProvider, VarType};

Expand Down Expand Up @@ -1430,6 +1431,60 @@ async fn unnest_analyze_metrics() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn consecutive_projection_same_schema() -> Result<()> {
let config = SessionConfig::new();
let runtime = Arc::new(RuntimeEnv::default());
let state = SessionState::new_with_config_rt(config, runtime);
let ctx = SessionContext::new_with_state(state);

let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));

let batch =
RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![0, 1]))])
.unwrap();

let df = ctx.read_batch(batch).unwrap();
df.clone().show().await.unwrap();

// Add `t` column full of nulls
let df = df
.with_column("t", cast(Expr::Literal(ScalarValue::Null), DataType::Int32))
.unwrap();
df.clone().show().await.unwrap();

let df = df
// (case when id = 1 then 10 else t) as t
.with_column(
"t",
when(col("id").eq(lit(1)), lit(10))
.otherwise(col("t"))
.unwrap(),
)
.unwrap()
// (case when id = 1 then 10 else t) as t2
.with_column(
"t2",
when(col("id").eq(lit(1)), lit(10))
.otherwise(col("t"))
.unwrap(),
)
.unwrap();

let results = df.collect().await?;
let expected = [
"+----+----+----+",
"| id | t | t2 |",
"+----+----+----+",
"| 0 | | |",
"| 1 | 10 | 10 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

async fn create_test_table(name: &str) -> Result<DataFrame> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Expand Down
27 changes: 25 additions & 2 deletions datafusion/optimizer/src/optimize_projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,9 @@ fn rewrite_projection_given_requirements(
return if let Some(input) =
optimize_projections(&proj.input, config, &required_indices)?
{
if &projection_schema(&input, &exprs_used)? == input.schema() {
if &projection_schema(&input, &exprs_used)? == input.schema()
&& exprs_used.iter().all(is_expr_trivial)
{
Ok(Some(input))
} else {
Projection::try_new(exprs_used, Arc::new(input))
Expand Down Expand Up @@ -899,7 +901,7 @@ mod tests {
use datafusion_common::{Result, TableReference};
use datafusion_expr::{
binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not,
table_scan, try_cast, Expr, Like, LogicalPlan, Operator,
table_scan, try_cast, when, Expr, Like, LogicalPlan, Operator,
};

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -1163,4 +1165,25 @@ mod tests {
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}

// Test outer projection isn't discarded despite the same schema as inner
// https://github.com/apache/arrow-datafusion/issues/8942
#[test]
fn test_derived_column() -> Result<()> {
let table_scan = test_table_scan()?;
let plan = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), lit(0).alias("d")])?
.project(vec![
col("a"),
when(col("a").eq(lit(1)), lit(10))
.otherwise(col("d"))?
.alias("d"),
])?
.build()?;

let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\
\n Projection: test.a, Int32(0) AS d\
\n TableScan: test projection=[a]";
assert_optimized_plan_equal(&plan, expected)
}
}
Loading

0 comments on commit 4ac7de1

Please sign in to comment.