Skip to content

Commit

Permalink
feat: Add fail_on_overflow option to BinaryExpr (#11400)
Browse files Browse the repository at this point in the history
* update tests

* update tests

* add rustdoc

* update PartialEq impl

* fix

* address feedback about improving api
  • Loading branch information
andygrove authored Jul 11, 2024
1 parent 7a23ea9 commit 2413155
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 9 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2312,7 +2312,7 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }";
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
assert!(format!("{exec_plan:?}").contains(expected));
Ok(())
}
Expand Down Expand Up @@ -2551,7 +2551,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }";
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }";

let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
Expand Down
126 changes: 119 additions & 7 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
/// Specifies whether an error is returned on overflow or not
fail_on_overflow: bool,
}

impl BinaryExpr {
Expand All @@ -62,7 +64,22 @@ impl BinaryExpr {
op: Operator,
right: Arc<dyn PhysicalExpr>,
) -> Self {
Self { left, op, right }
Self {
left,
op,
right,
fail_on_overflow: false,
}
}

/// Create new binary expression with explicit fail_on_overflow value
pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
Self {
left: self.left,
op: self.op,
right: self.right,
fail_on_overflow,
}
}

/// Get the left side of the binary expression
Expand Down Expand Up @@ -273,8 +290,11 @@ impl PhysicalExpr for BinaryExpr {
}

match self.op {
Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add),
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub),
Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul),
Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
Operator::Divide => return apply(&lhs, &rhs, div),
Operator::Modulo => return apply(&lhs, &rhs, rem),
Expand Down Expand Up @@ -327,11 +347,10 @@ impl PhysicalExpr for BinaryExpr {
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(BinaryExpr::new(
Arc::clone(&children[0]),
self.op,
Arc::clone(&children[1]),
)))
Ok(Arc::new(
BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1]))
.with_fail_on_overflow(self.fail_on_overflow),
))
}

fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
Expand Down Expand Up @@ -496,7 +515,12 @@ impl PartialEq<dyn Any> for BinaryExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.left.eq(&x.left) && self.op == x.op && self.right.eq(&x.right))
.map(|x| {
self.left.eq(&x.left)
&& self.op == x.op
&& self.right.eq(&x.right)
&& self.fail_on_overflow.eq(&x.fail_on_overflow)
})
.unwrap_or(false)
}
}
Expand Down Expand Up @@ -661,6 +685,7 @@ mod tests {

use datafusion_common::plan_datafusion_err;
use datafusion_expr::type_coercion::binary::get_input_types;
use datafusion_physical_expr_common::expressions::column::Column;

/// Performs a binary operation, applying any type coercion necessary
fn binary_op(
Expand Down Expand Up @@ -4008,4 +4033,91 @@ mod tests {
.unwrap();
assert_eq!(&casted, &dictionary);
}

#[test]
fn test_add_with_overflow() -> Result<()> {
// create test data
let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
let r = Arc::new(Int32Array::from(vec![2, 1]));
let schema = Arc::new(Schema::new(vec![
Field::new("l", DataType::Int32, false),
Field::new("r", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(schema, vec![l, r])?;

// create expression
let expr = BinaryExpr::new(
Arc::new(Column::new("l", 0)),
Operator::Plus,
Arc::new(Column::new("r", 1)),
)
.with_fail_on_overflow(true);

// evaluate expression
let result = expr.evaluate(&batch);
assert!(result
.err()
.unwrap()
.to_string()
.contains("Overflow happened on: 2147483647 + 1"));
Ok(())
}

#[test]
fn test_subtract_with_overflow() -> Result<()> {
// create test data
let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
let r = Arc::new(Int32Array::from(vec![2, 1]));
let schema = Arc::new(Schema::new(vec![
Field::new("l", DataType::Int32, false),
Field::new("r", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(schema, vec![l, r])?;

// create expression
let expr = BinaryExpr::new(
Arc::new(Column::new("l", 0)),
Operator::Minus,
Arc::new(Column::new("r", 1)),
)
.with_fail_on_overflow(true);

// evaluate expression
let result = expr.evaluate(&batch);
assert!(result
.err()
.unwrap()
.to_string()
.contains("Overflow happened on: -2147483648 - 1"));
Ok(())
}

#[test]
fn test_mul_with_overflow() -> Result<()> {
// create test data
let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
let r = Arc::new(Int32Array::from(vec![2, 2]));
let schema = Arc::new(Schema::new(vec![
Field::new("l", DataType::Int32, false),
Field::new("r", DataType::Int32, false),
]));
let batch = RecordBatch::try_new(schema, vec![l, r])?;

// create expression
let expr = BinaryExpr::new(
Arc::new(Column::new("l", 0)),
Operator::Multiply,
Arc::new(Column::new("r", 1)),
)
.with_fail_on_overflow(true);

// evaluate expression
let result = expr.evaluate(&batch);
assert!(result
.err()
.unwrap()
.to_string()
.contains("Overflow happened on: 2147483647 * 2"));
Ok(())
}
}

0 comments on commit 2413155

Please sign in to comment.