Skip to content

Commit

Permalink
fix: volatile expressions should not be target of common subexpt elim…
Browse files Browse the repository at this point in the history
…ination (#8520)

* fix: volatile expressions should not be target of common subexpt elimination

* Fix clippy

* For review

* Return error for unresolved scalar function

* Improve error message
  • Loading branch information
viirya authored Dec 14, 2023
1 parent 79c17e3 commit 5909866
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 8 deletions.
75 changes: 74 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,24 @@ impl ScalarFunctionDefinition {
ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
}
}

/// Whether this function is volatile, i.e. whether it can return different results
/// when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> Result<bool> {
match self {
ScalarFunctionDefinition::BuiltIn(fun) => {
Ok(fun.volatility() == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::UDF(udf) => {
Ok(udf.signature().volatility == crate::Volatility::Volatile)
}
ScalarFunctionDefinition::Name(func) => {
internal_err!(
"Cannot determine volatility of unresolved function: {func}"
)
}
}
}
}

impl ScalarFunction {
Expand Down Expand Up @@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
.join(", "))
}

/// Whether the given expression is volatile, i.e. whether it can return different results
/// when evaluated multiple times with the same input.
pub fn is_volatile(expr: &Expr) -> Result<bool> {
match expr {
Expr::ScalarFunction(func) => func.func_def.is_volatile(),
_ => Ok(false),
}
}

#[cfg(test)]
mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
use crate::{case, lit, Expr};
use crate::{
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction,
ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature,
Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};
use std::sync::Arc;

#[test]
fn format_case_when() -> Result<()> {
Expand Down Expand Up @@ -1800,4 +1832,45 @@ mod test {
"UInt32(1) OR UInt32(2)"
);
}

#[test]
fn test_is_volatile_scalar_func_definition() {
// BuiltIn
assert!(
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
.is_volatile()
.unwrap()
);
assert!(
!ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
.is_volatile()
.unwrap()
);

// UDF
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation =
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
let udf = Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
&return_type,
&fun,
));
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());

let udf = Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile),
&return_type,
&fun,
));
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());

// Unresolved function
ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
.is_volatile()
.expect_err("Shouldn't determine volatility of unresolved function");
}
}
18 changes: 11 additions & 7 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{
use datafusion_common::{
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::expr::{is_volatile, Alias};
use datafusion_expr::logical_plan::{
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
};
Expand Down Expand Up @@ -113,6 +113,8 @@ impl CommonSubexprEliminate {
let Projection { expr, input, .. } = projection;
let input_schema = Arc::clone(input.schema());
let mut expr_set = ExprSet::new();

// Visit expr list and build expr identifier to occuring count map (`expr_set`).
let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;

let (mut new_expr, new_input) =
Expand Down Expand Up @@ -516,7 +518,7 @@ enum ExprMask {
}

impl ExprMask {
fn ignores(&self, expr: &Expr) -> bool {
fn ignores(&self, expr: &Expr) -> Result<bool> {
let is_normal_minus_aggregates = matches!(
expr,
Expr::Literal(..)
Expand All @@ -527,12 +529,14 @@ impl ExprMask {
| Expr::Wildcard { .. }
);

let is_volatile = is_volatile(expr)?;

let is_aggr = matches!(expr, Expr::AggregateFunction(..));

match self {
Self::Normal => is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_normal_minus_aggregates,
}
Ok(match self {
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
})
}
}

Expand Down Expand Up @@ -624,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {

let (idx, sub_expr_desc) = self.pop_enter_mark();
// skip exprs should not be recognize.
if self.expr_mask.ignores(expr) {
if self.expr_mask.ignores(expr)? {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
Expand Down
6 changes: 6 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,9 @@ query ?
SELECT find_in_set(NULL, NULL)
----
NULL

# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
query B
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
----
false

0 comments on commit 5909866

Please sign in to comment.