Skip to content

Commit

Permalink
feat (datafusion integration): convert datafusion expr filters to Ice…
Browse files Browse the repository at this point in the history
…berg Predicate (#588)

* adding main function and tests

* adding tests, removing integration test for now

* fixing typos and lints

* fixing typing issue

* - added support in schmema to convert Date32 to correct arrow type
- refactored scan to use new predicate converter as visitor and seperated it to a new mod
- added support for simple predicates with column cast expressions
- added testing, mostly around date functions

* fixing format and lic

* reducing number of tests (17 -> 7)

* fix formats

* fix naming

* refactoring to use TreeNodeVisitor

* fixing fmt

* small refactor

* adding swapped op and fixing CR comments

---------

Co-authored-by: Alon Agmon <[email protected]>
  • Loading branch information
a-agmon and Alon Agmon authored Sep 23, 2024
1 parent e967deb commit 1533c43
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 7 deletions.
7 changes: 5 additions & 2 deletions crates/iceberg/src/arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ use arrow_array::types::{
validate_decimal_precision_and_scale, Decimal128Type, TimestampMicrosecondType,
};
use arrow_array::{
BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array,
PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
BooleanArray, Date32Array, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array,
Int64Array, PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray,
};
use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit};
use bitvec::macros::internal::funty::Fundamental;
Expand Down Expand Up @@ -646,6 +646,9 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> Result<Box<dyn ArrowDatum + Send
(PrimitiveType::String, PrimitiveLiteral::String(value)) => {
Ok(Box::new(StringArray::new_scalar(value.as_str())))
}
(PrimitiveType::Date, PrimitiveLiteral::Int(value)) => {
Ok(Box::new(Date32Array::new_scalar(*value)))
}
(PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => {
Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value)))
}
Expand Down
335 changes: 335 additions & 0 deletions crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::collections::VecDeque;

use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor};
use datafusion::common::Column;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{Expr, Operator};
use datafusion::scalar::ScalarValue;
use iceberg::expr::{Predicate, Reference};
use iceberg::spec::Datum;

pub struct ExprToPredicateVisitor {
stack: VecDeque<Option<Predicate>>,
}
impl ExprToPredicateVisitor {
/// Create a new predicate conversion visitor.
pub fn new() -> Self {
Self {
stack: VecDeque::new(),
}
}
/// Get the predicate from the stack.
pub fn get_predicate(&self) -> Option<Predicate> {
self.stack
.iter()
.filter_map(|opt| opt.clone())
.reduce(Predicate::and)
}

/// Convert a column expression to an iceberg predicate.
fn convert_column_expr(
&self,
col: &Column,
op: &Operator,
lit: &ScalarValue,
) -> Option<Predicate> {
let reference = Reference::new(col.name.clone());
let datum = scalar_value_to_datum(lit)?;
Some(binary_op_to_predicate(reference, op, datum))
}

/// Convert a compound expression to an iceberg predicate.
///
/// The strategy is to support the following cases:
/// - if its an AND expression then the result will be the valid predicates, whether there are 2 or just 1
/// - if its an OR expression then a predicate will be returned only if there are 2 valid predicates on both sides
fn convert_compound_expr(&self, valid_preds: &[Predicate], op: &Operator) -> Option<Predicate> {
let valid_preds_count = valid_preds.len();
match (op, valid_preds_count) {
(Operator::And, 1) => valid_preds.first().cloned(),
(Operator::And, 2) => Some(Predicate::and(
valid_preds[0].clone(),
valid_preds[1].clone(),
)),
(Operator::Or, 2) => Some(Predicate::or(
valid_preds[0].clone(),
valid_preds[1].clone(),
)),
_ => None,
}
}
}

// Implement TreeNodeVisitor for ExprToPredicateVisitor
impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor {
type Node = Expr;

fn f_down(&mut self, _node: &'n Expr) -> Result<TreeNodeRecursion, DataFusionError> {
Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion, DataFusionError> {
if let Expr::BinaryExpr(binary) = expr {
match (&*binary.left, &binary.op, &*binary.right) {
// process simple binary expressions, e.g. col > 1
(Expr::Column(col), op, Expr::Literal(lit)) => {
let col_pred = self.convert_column_expr(col, op, lit);
self.stack.push_back(col_pred);
}
// // process reversed binary expressions, e.g. 1 < col
(Expr::Literal(lit), op, Expr::Column(col)) => {
let col_pred = op
.swap()
.and_then(|negated_op| self.convert_column_expr(col, &negated_op, lit));
self.stack.push_back(col_pred);
}
// process compound expressions (involving logical operators. e.g., AND or OR and children)
(_left, op, _right) if op.is_logic_operator() => {
let right_pred = self.stack.pop_back().flatten();
let left_pred = self.stack.pop_back().flatten();
let children: Vec<_> = [left_pred, right_pred].into_iter().flatten().collect();
let compound_pred = self.convert_compound_expr(&children, op);
self.stack.push_back(compound_pred);
}
_ => return Ok(TreeNodeRecursion::Continue),
}
}
Ok(TreeNodeRecursion::Continue)
}
}

const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000;
/// Convert a scalar value to an iceberg datum.
fn scalar_value_to_datum(value: &ScalarValue) -> Option<Datum> {
match value {
ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)),
ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)),
ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)),
ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)),
ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)),
ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)),
ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())),
ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())),
ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)),
ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY) as i32)),
_ => None,
}
}

/// convert the data fusion Exp to an iceberg [`Predicate`]
fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate {
match op {
Operator::Eq => reference.equal_to(datum),
Operator::NotEq => reference.not_equal_to(datum),
Operator::Lt => reference.less_than(datum),
Operator::LtEq => reference.less_than_or_equal_to(datum),
Operator::Gt => reference.greater_than(datum),
Operator::GtEq => reference.greater_than_or_equal_to(datum),
_ => Predicate::AlwaysTrue,
}
}

#[cfg(test)]
mod tests {
use std::collections::VecDeque;

use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::tree_node::TreeNode;
use datafusion::common::DFSchema;
use datafusion::prelude::SessionContext;
use iceberg::expr::{Predicate, Reference};
use iceberg::spec::Datum;

use super::ExprToPredicateVisitor;

fn create_test_schema() -> DFSchema {
let arrow_schema = Schema::new(vec![
Field::new("foo", DataType::Int32, false),
Field::new("bar", DataType::Utf8, false),
]);
DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap()
}

#[test]
fn test_predicate_conversion_with_single_condition() {
let sql = "foo > 1";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate().unwrap();
assert_eq!(
predicate,
Reference::new("foo").greater_than(Datum::long(1))
);
}
#[test]
fn test_predicate_conversion_with_single_unsupported_condition() {
let sql = "foo is null";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate();
assert_eq!(predicate, None);
}

#[test]
fn test_predicate_conversion_with_single_condition_rev() {
let sql = "1 < foo";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate().unwrap();
assert_eq!(
predicate,
Reference::new("foo").greater_than(Datum::long(1))
);
}
#[test]
fn test_predicate_conversion_with_and_condition() {
let sql = "foo > 1 and bar = 'test'";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate().unwrap();
let expected_predicate = Predicate::and(
Reference::new("foo").greater_than(Datum::long(1)),
Reference::new("bar").equal_to(Datum::string("test")),
);
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_and_condition_unsupported() {
let sql = "foo > 1 and bar is not null";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate().unwrap();
let expected_predicate = Reference::new("foo").greater_than(Datum::long(1));
assert_eq!(predicate, expected_predicate);
}
#[test]
fn test_predicate_conversion_with_and_condition_both_unsupported() {
let sql = "foo in (1, 2, 3) and bar is not null";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate();
let expected_predicate = None;
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_or_condition_unsupported() {
let sql = "foo > 1 or bar is not null";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate();
let expected_predicate = None;
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_complex_binary_expr() {
let sql = "(foo > 1 and bar = 'test') or foo < 0 ";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate().unwrap();
let inner_predicate = Predicate::and(
Reference::new("foo").greater_than(Datum::long(1)),
Reference::new("bar").equal_to(Datum::string("test")),
);
let expected_predicate = Predicate::or(
inner_predicate,
Reference::new("foo").less_than(Datum::long(0)),
);
assert_eq!(predicate, expected_predicate);
}

#[test]
fn test_predicate_conversion_with_complex_binary_expr_unsupported() {
let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 ";
let df_schema = create_test_schema();
let expr = SessionContext::new()
.parse_sql_expr(sql, &df_schema)
.unwrap();
let mut visitor = ExprToPredicateVisitor::new();
expr.visit(&mut visitor).unwrap();
let predicate = visitor.get_predicate().unwrap();
let expected_predicate = Reference::new("foo").less_than(Datum::long(0));
assert_eq!(predicate, expected_predicate);
}

#[test]
// test the get result method
fn test_get_result_multiple() {
let predicates = vec![
Some(Reference::new("foo").greater_than(Datum::long(1))),
None,
Some(Reference::new("bar").equal_to(Datum::string("test"))),
];
let stack = VecDeque::from(predicates);
let visitor = ExprToPredicateVisitor { stack };
assert_eq!(
visitor.get_predicate(),
Some(Predicate::and(
Reference::new("foo").greater_than(Datum::long(1)),
Reference::new("bar").equal_to(Datum::string("test")),
))
);
}

#[test]
fn test_get_result_single() {
let predicates = vec![Some(Reference::new("foo").greater_than(Datum::long(1)))];
let stack = VecDeque::from(predicates);
let visitor = ExprToPredicateVisitor { stack };
assert_eq!(
visitor.get_predicate(),
Some(Reference::new("foo").greater_than(Datum::long(1)))
);
}
}
1 change: 1 addition & 0 deletions crates/integrations/datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
// specific language governing permissions and limitations
// under the License.

pub(crate) mod expr_to_predicate;
pub(crate) mod scan;
Loading

0 comments on commit 1533c43

Please sign in to comment.