Skip to content

Commit

Permalink
ARROW-6669: [Rust] [DataFusion] Implement binary expression for physi…
Browse files Browse the repository at this point in the history
…cal plan

This PR adds the binary expression to the new physical execution plan, with support for comparison operators (`<`, `<=`, `>`, `>=`, `==`, `!=`) and boolean operators `AND` and `OR`.

Other binary expressions, such as math expressions will be added in a future PR.

Closes #5478 from andygrove/ARROW-6669 and squashes the following commits:

83bfa77 <Andy Grove> formatting
af8d298 <Andy Grove> address PR feedback
9ad3b7f <Andy Grove> formatting
bb82a24 <Andy Grove> use expect() instead of unwrap() when downcasting arrays
9b94cc8 <Andy Grove> Implement binary expression with support for comparison and boolean operators

Authored-by: Andy Grove <[email protected]>
Signed-off-by: Paddy Horan <[email protected]>
  • Loading branch information
andygrove authored and paddyhoran committed Sep 24, 2019
1 parent 5a918ce commit c6faaed
Showing 1 changed file with 204 additions and 3 deletions.
207 changes: 204 additions & 3 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ use std::sync::Arc;

use crate::error::{ExecutionError, Result};
use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::logicalplan::ScalarValue;
use crate::logicalplan::{Operator, ScalarValue};
use arrow::array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::array::{
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder,
Int8Builder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
};
use arrow::compute::kernels::boolean::{and, or};
use arrow::compute::kernels::cast::cast;
use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;

Expand Down Expand Up @@ -197,6 +199,140 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(expr))
}

/// Invoke a compute kernel on a pair of arrays
macro_rules! compute_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
}

/// Invoke a compute kernel on a pair of arrays
macro_rules! comparison_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
other => Err(ExecutionError::General(format!(
"Unsupported data type {:?}",
other
))),
}
}};
}

/// Invoke a boolean kernel on a pair of arrays
macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
let rr = $RIGHT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
}
/// Binary expression
pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
}

impl BinaryExpr {
/// Create new binary expression
pub fn new(
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
) -> Self {
Self { left, op, right }
}
}

impl PhysicalExpr for BinaryExpr {
fn name(&self) -> String {
format!("{:?}", self.op)
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
self.left.data_type(input_schema)
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let left = self.left.evaluate(batch)?;
let right = self.right.evaluate(batch)?;
if left.data_type() != right.data_type() {
return Err(ExecutionError::General(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op,
left.data_type(),
right.data_type()
)));
}
match &self.op {
Operator::Lt => comparison_op!(left, right, lt),
Operator::LtEq => comparison_op!(left, right, lt_eq),
Operator::Gt => comparison_op!(left, right, gt),
Operator::GtEq => comparison_op!(left, right, gt_eq),
Operator::Eq => comparison_op!(left, right, eq),
Operator::NotEq => comparison_op!(left, right, neq),
Operator::And => {
if left.data_type() == &DataType::Boolean {
boolean_op!(left, right, and)
} else {
return Err(ExecutionError::General(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op,
left.data_type(),
right.data_type()
)));
}
}
Operator::Or => {
if left.data_type() == &DataType::Boolean {
boolean_op!(left, right, or)
} else {
return Err(ExecutionError::General(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op,
left.data_type(),
right.data_type()
)));
}
}
_ => Err(ExecutionError::General("Unsupported operator".to_string())),
}
}
}

/// Create a binary expression
pub fn binary(
l: Arc<dyn PhysicalExpr>,
op: Operator,
r: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(l, op, r))
}

/// CAST expression casts an expression to a specific data type
pub struct CastExpr {
/// The expression to cast
Expand Down Expand Up @@ -335,6 +471,71 @@ mod tests {
use arrow::array::BinaryArray;
use arrow::datatypes::*;

#[test]
fn binary_comparison() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(a), Arc::new(b)],
)?;

// expression: "a < b"
let lt = binary(col(0), Operator::Lt, col(1));
let result = lt.evaluate(&batch)?;
assert_eq!(result.len(), 5);

let expected = vec![false, false, true, true, true];
let result = result
.as_any()
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
for i in 0..5 {
assert_eq!(result.value(i), expected[i]);
}

Ok(())
}

#[test]
fn binary_nested() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(a), Arc::new(b)],
)?;

// expression: "a < b OR a == b"
let expr = binary(
binary(col(0), Operator::Lt, col(1)),
Operator::Or,
binary(col(0), Operator::Eq, col(1)),
);
let result = expr.evaluate(&batch)?;
assert_eq!(result.len(), 5);

let expected = vec![true, true, false, true, false];
let result = result
.as_any()
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
for i in 0..5 {
print!("{}", i);
assert_eq!(result.value(i), expected[i]);
}

Ok(())
}

#[test]
fn literal_i32() -> Result<()> {
// create an arbitrary record bacth
Expand Down

0 comments on commit c6faaed

Please sign in to comment.