diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index dace62201e7f..7249160792ac 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -135,6 +135,94 @@ impl<'a> Simplifier<'a> { false } + + fn boolean_folding_for_or( + const_bool: &Option, + bool_expr: Box, + left_right_order: bool, + ) -> Expr { + // See if we can fold 'const_bool OR bool_expr' to a constant boolean + match const_bool { + // TRUE or expr (including NULL) = TRUE + Some(true) => Expr::Literal(ScalarValue::Boolean(Some(true))), + // FALSE or expr (including NULL) = expr + Some(false) => *bool_expr, + None => match *bool_expr { + // NULL or TRUE = TRUE + Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(ScalarValue::Boolean(Some(true))) + } + // NULL or FALSE = NULL + Expr::Literal(ScalarValue::Boolean(Some(false))) => { + Expr::Literal(ScalarValue::Boolean(None)) + } + // NULL or NULL = NULL + Expr::Literal(ScalarValue::Boolean(None)) => { + Expr::Literal(ScalarValue::Boolean(None)) + } + // NULL or expr can be either NULL or TRUE + // So let us not rewrite it + _ => { + let mut left = + Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool))); + let mut right = bool_expr; + if !left_right_order { + std::mem::swap(&mut left, &mut right); + } + + Expr::BinaryExpr { + left, + op: Operator::Or, + right, + } + } + }, + } + } + + fn boolean_folding_for_and( + const_bool: &Option, + bool_expr: Box, + left_right_order: bool, + ) -> Expr { + // See if we can fold 'const_bool AND bool_expr' to a constant boolean + match const_bool { + // TRUE and expr (including NULL) = expr + Some(true) => *bool_expr, + // FALSE and expr (including NULL) = FALSE + Some(false) => Expr::Literal(ScalarValue::Boolean(Some(false))), + None => match *bool_expr { + // NULL and TRUE = NULL + Expr::Literal(ScalarValue::Boolean(Some(true))) => { + Expr::Literal(ScalarValue::Boolean(None)) + } + // NULL and FALSE = FALSE + Expr::Literal(ScalarValue::Boolean(Some(false))) => { + Expr::Literal(ScalarValue::Boolean(Some(false))) + } + // NULL and NULL = NULL + Expr::Literal(ScalarValue::Boolean(None)) => { + Expr::Literal(ScalarValue::Boolean(None)) + } + // NULL and expr can either be NULL or FALSE + // So let us not rewrite it + _ => { + let mut left = + Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool))); + let mut right = bool_expr; + if !left_right_order { + std::mem::swap(&mut left, &mut right); + } + + Expr::BinaryExpr { + left, + op: Operator::And, + right, + } + } + }, + } + } } impl<'a> ExprRewriter for Simplifier<'a> { @@ -214,46 +302,12 @@ impl<'a> ExprRewriter for Simplifier<'a> { (Expr::Literal(ScalarValue::Boolean(b)), _) if self.is_boolean_type(&right) => { - match b { - Some(true) => Expr::Literal(ScalarValue::Boolean(Some(true))), - Some(false) => match *right { - Expr::Literal(ScalarValue::Boolean(None)) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - _ => *right, - }, - None => match *right { - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - Expr::Literal(ScalarValue::Boolean(Some(true))) - } - Expr::Literal(ScalarValue::Boolean(Some(false))) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - _ => *right, - }, - } + Self::boolean_folding_for_or(b, right, true) } (_, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&left) => { - match b { - Some(true) => Expr::Literal(ScalarValue::Boolean(Some(true))), - Some(false) => match *left { - Expr::Literal(ScalarValue::Boolean(None)) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - _ => *left, - }, - None => match *left { - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - Expr::Literal(ScalarValue::Boolean(Some(true))) - } - Expr::Literal(ScalarValue::Boolean(Some(false))) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - _ => *left, - }, - } + Self::boolean_folding_for_or(b, left, false) } _ => Expr::BinaryExpr { left, @@ -265,42 +319,12 @@ impl<'a> ExprRewriter for Simplifier<'a> { (Expr::Literal(ScalarValue::Boolean(b)), _) if self.is_boolean_type(&right) => { - // match b { - // Some(false) => { - // Expr::Literal(ScalarValue::Boolean(Some(false))) - // } - // _ => *right, - // } - match b { - Some(true) => match *right { - Expr::Literal(ScalarValue::Boolean(None)) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - _ => *right, - }, - Some(false) => { - Expr::Literal(ScalarValue::Boolean(Some(false))) - } - None => match *right { - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - Expr::Literal(ScalarValue::Boolean(Some(false))) => { - Expr::Literal(ScalarValue::Boolean(Some(false))) - } - _ => *right, - }, - } + Self::boolean_folding_for_and(b, right, true) } (_, Expr::Literal(ScalarValue::Boolean(b))) if self.is_boolean_type(&left) => { - match b { - Some(false) => { - Expr::Literal(ScalarValue::Boolean(Some(false))) - } - _ => *left, - } + Self::boolean_folding_for_and(b, left, false) } _ => Expr::BinaryExpr { left, @@ -962,6 +986,19 @@ mod tests { lit(ScalarValue::Boolean(None)), ); + // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL) + // it can be either NULL or TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)` + // and should not be rewritten + let expr = Expr::Between { + expr: Box::new(col("c1")), + negated: false, + low: Box::new(lit(0)), + high: Box::new(lit(10)), + }; + let expr = expr.or(Expr::Literal(ScalarValue::Boolean(None))); + let result = expr.clone().rewrite(&mut rewriter)?; + assert_eq!(expr, result); + Ok(()) } #[test] @@ -1016,6 +1053,19 @@ mod tests { lit(ScalarValue::Boolean(Some(false))), ); + // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL) + // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10` + // and should not be rewritten + let expr = Expr::Between { + expr: Box::new(col("c1")), + negated: false, + low: Box::new(lit(0)), + high: Box::new(lit(10)), + }; + let expr = expr.and(Expr::Literal(ScalarValue::Boolean(None))); + let result = expr.clone().rewrite(&mut rewriter)?; + assert_eq!(expr, result); + Ok(()) } }