Skip to content

Commit

Permalink
Implement binary expression with support for comparison and boolean o…
Browse files Browse the repository at this point in the history
…perators
  • Loading branch information
andygrove committed Sep 24, 2019
1 parent cde09f7 commit 9b94cc8
Showing 1 changed file with 170 additions and 3 deletions.
173 changes: 170 additions & 3 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ use std::sync::Arc;

use crate::error::{ExecutionError, Result};
use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::logicalplan::ScalarValue;
use crate::logicalplan::{Operator, ScalarValue};
use arrow::array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::compute::kernels::boolean::{and, or};
use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow::array::{
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder,
Int8Builder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder,
Expand Down Expand Up @@ -197,6 +199,106 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(expr))
}

/// Invoke a compute kernel on a pair of arrays
macro_rules! compute_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap();
let rr = $RIGHT.as_any().downcast_ref::<$DT>().unwrap();
Ok(Arc::new($OP(&ll, &rr)?))
}};
}

/// Invoke a compute kernel on a pair of arrays
macro_rules! comparison_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
other => Err(ExecutionError::General(format!(
"Unsupported data type {:?}",
other
))),
}
}};
}

/// Invoke a boolean kernel on a pair of arrays
macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let ll = $LEFT.as_any().downcast_ref::<BooleanArray>().unwrap();
let rr = $RIGHT.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(Arc::new($OP(&ll, &rr)?))
}};
}
/// Binary expression
pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
}

impl BinaryExpr {
/// Create new binary expression
pub fn new(
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
) -> Self {
Self { left, op, right }
}
}

impl PhysicalExpr for BinaryExpr {
fn name(&self) -> String {
format!("{:?}", self.op)
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
self.left.data_type(input_schema)
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let left = self.left.evaluate(batch)?;
let right = self.right.evaluate(batch)?;
if left.data_type() != right.data_type() {
return Err(ExecutionError::General(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op,
left.data_type(),
right.data_type()
)));
}
match &self.op {
Operator::Lt => comparison_op!(left, right, lt),
Operator::LtEq => comparison_op!(left, right, lt_eq),
Operator::Gt => comparison_op!(left, right, gt),
Operator::GtEq => comparison_op!(left, right, gt_eq),
Operator::Eq => comparison_op!(left, right, eq),
Operator::NotEq => comparison_op!(left, right, neq),
Operator::And => boolean_op!(left, right, and),
Operator::Or => boolean_op!(left, right, or),
_ => Err(ExecutionError::General("Unsupported operator".to_string())),
}
}
}

/// Create a binary expression
pub fn binary(
l: Arc<dyn PhysicalExpr>,
op: Operator,
r: Arc<dyn PhysicalExpr>,
) -> Arc<dyn PhysicalExpr> {
Arc::new(BinaryExpr::new(l, op, r))
}

/// CAST expression casts an expression to a specific data type
pub struct CastExpr {
/// The expression to cast
Expand Down Expand Up @@ -335,6 +437,71 @@ mod tests {
use arrow::array::BinaryArray;
use arrow::datatypes::*;

#[test]
fn binary_comparison() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(a), Arc::new(b)],
)?;

// expression: "a < b"
let lt = binary(col(0), Operator::Lt, col(1));
let result = lt.evaluate(&batch)?;
assert_eq!(result.len(), 5);

let expected = vec![false, false, true, true, true];
let result = result
.as_any()
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
for i in 0..5 {
assert_eq!(result.value(i), expected[i]);
}

Ok(())
}

#[test]
fn binary_nested() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(a), Arc::new(b)],
)?;

// expression: "a < b OR a == b"
let expr = binary(
binary(col(0), Operator::Lt, col(1)),
Operator::Or,
binary(col(0), Operator::Eq, col(1)),
);
let result = expr.evaluate(&batch)?;
assert_eq!(result.len(), 5);

let expected = vec![true, true, false, true, false];
let result = result
.as_any()
.downcast_ref::<BooleanArray>()
.expect("failed to downcast to BooleanArray");
for i in 0..5 {
print!("{}", i);
assert_eq!(result.value(i), expected[i]);
}

Ok(())
}

#[test]
fn literal_i32() -> Result<()> {
// create an arbitrary record bacth
Expand Down

0 comments on commit 9b94cc8

Please sign in to comment.