Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't extract common sub expr in CASE WHEN clause #8833

Merged
merged 3 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

use crate::utils::is_volatile_expression;
use crate::{utils, OptimizerConfig, OptimizerRule};

use arrow::datatypes::DataType;
Expand All @@ -29,7 +30,7 @@ use datafusion_common::tree_node::{
use datafusion_common::{
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::{is_volatile, Alias};
use datafusion_expr::expr::Alias;
use datafusion_expr::logical_plan::{
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
};
Expand Down Expand Up @@ -518,7 +519,7 @@ enum ExprMask {
}

impl ExprMask {
fn ignores(&self, expr: &Expr) -> Result<bool> {
fn ignores(&self, expr: &Expr) -> bool {
let is_normal_minus_aggregates = matches!(
expr,
Expr::Literal(..)
Expand All @@ -529,14 +530,12 @@ impl ExprMask {
| Expr::Wildcard { .. }
);

let is_volatile = is_volatile(expr)?;

let is_aggr = matches!(expr, Expr::AggregateFunction(..));

Ok(match self {
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
})
match self {
Self::Normal => is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_normal_minus_aggregates,
}
}
}

Expand Down Expand Up @@ -614,7 +613,12 @@ impl ExprIdentifierVisitor<'_> {
impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
type N = Expr;

fn pre_visit(&mut self, _expr: &Expr) -> Result<VisitRecursion> {
fn pre_visit(&mut self, expr: &Expr) -> Result<VisitRecursion> {
// related to https://github.com/apache/arrow-datafusion/issues/8814
// If the expr contain volatile expression or is a case expression, skip it.
if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? {
Copy link
Contributor Author

@haohuaijin haohuaijin Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check volatile expression is this place is because the previous method only can deal with a single random() function, the below query wil return true in main branch

SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this problem could exist in other function/expression like COALESCE | OR

It can be tracked as a future ticket.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inner expressions in these expressions may be short-circuited.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

track in #8874

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I comment in the #8874, Do we have any method for this rule to make sure the new function/expr or other cases we can't find will not bring the same bug

return Ok(VisitRecursion::Skip);
}
self.visit_stack
.push(VisitRecord::EnterMark(self.node_count));
self.node_count += 1;
Expand All @@ -628,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {

let (idx, sub_expr_desc) = self.pop_enter_mark();
// skip exprs should not be recognize.
if self.expr_mask.ignores(expr)? {
if self.expr_mask.ignores(expr) {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
Expand Down
39 changes: 5 additions & 34 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::optimizer::ApplyOrder;
use crate::utils::is_volatile_expression;
use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
Expand All @@ -34,7 +35,7 @@ use datafusion_expr::logical_plan::{
use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned};
use datafusion_expr::{
and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator,
ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility,
ScalarFunctionDefinition, TableProviderFilterPushDown,
};

use itertools::Itertools;
Expand Down Expand Up @@ -739,7 +740,9 @@ impl OptimizerRule for PushDownFilter {

(field.qualified_name(), expr)
})
.partition(|(_, value)| is_volatile_expression(value));
.partition(|(_, value)| {
is_volatile_expression(value).unwrap_or(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use unwrap_or(true)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't know whether the scalar function is volatile, default set it to a volatile function.

});

let mut push_predicates = vec![];
let mut keep_predicates = vec![];
Expand Down Expand Up @@ -1028,38 +1031,6 @@ pub fn replace_cols_by_name(
})
}

/// check whether the expression is volatile predicates
fn is_volatile_expression(e: &Expr) -> bool {
let mut is_volatile = false;
e.apply(&mut |expr| {
Ok(match expr {
Expr::ScalarFunction(f) => match &f.func_def {
ScalarFunctionDefinition::BuiltIn(fun)
if fun.volatility() == Volatility::Volatile =>
{
is_volatile = true;
VisitRecursion::Stop
}
ScalarFunctionDefinition::UDF(fun)
if fun.signature().volatility == Volatility::Volatile =>
{
is_volatile = true;
VisitRecursion::Stop
}
ScalarFunctionDefinition::Name(_) => {
return internal_err!(
"Function `Expr` with name should be resolved."
);
}
_ => VisitRecursion::Continue,
},
_ => VisitRecursion::Continue,
})
})
.unwrap();
is_volatile
}

/// check whether the expression uses the columns in `check_map`.
fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
let mut is_contain = false;
Expand Down
16 changes: 16 additions & 0 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
//! Collection of utility functions that are leveraged by the query optimizer rules

use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::{Column, DFSchemaRef};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::is_volatile;
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::utils as expr_utils;
use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator};
Expand Down Expand Up @@ -92,6 +94,20 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) {
trace!("{description}::\n{}\n", plan.display_indent_schema());
}

/// check whether the expression is volatile predicates
pub(crate) fn is_volatile_expression(e: &Expr) -> Result<bool> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move is_volatile_expression to utils for reuse.

let mut is_volatile_expr = false;
e.apply(&mut |expr| {
Ok(if is_volatile(expr)? {
is_volatile_expr = true;
VisitRecursion::Stop
} else {
VisitRecursion::Continue
})
})?;
Ok(is_volatile_expr)
}

/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
///
/// See [`split_conjunction_owned`] for more details and an example.
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,6 @@ NULL

# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
query B
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0)
----
false
19 changes: 19 additions & 0 deletions datafusion/sqllogictest/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1112,3 +1112,22 @@ SELECT abs(x), abs(x) + abs(y) FROM t;

statement ok
DROP TABLE t;

# related to https://github.com/apache/arrow-datafusion/issues/8814
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

statement ok
create table t(x int, y int) as values (1,1), (2,2), (3,3), (0,0), (4,0);

query II
SELECT
CASE WHEN B.x > 0 THEN A.x / B.x ELSE 0 END AS value1,
CASE WHEN B.x > 0 AND B.y > 0 THEN A.x / B.x ELSE 0 END AS value3
FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B;
----
0 0
0 0
0 0
0 0
0 0

statement ok
DROP TABLE t;
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/tpch/q14.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ where
----
logical_plan
Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue
--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, part.p_type
--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
----Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_type
------Inner Join: lineitem.l_partkey = part.p_partkey
--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount
----------Filter: lineitem.l_shipdate >= Date32("9374") AND lineitem.l_shipdate < Date32("9404")
Expand All @@ -45,7 +45,7 @@ ProjectionExec: expr=[100 * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%")
--AggregateExec: mode=Final, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
----CoalescePartitionsExec
------AggregateExec: mode=Partial, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
--------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, p_type@4 as p_type]
--------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, p_type@4 as p_type]
----------CoalesceBatchesExec: target_batch_size=8192
------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)]
--------------CoalesceBatchesExec: target_batch_size=8192
Expand Down
Loading