diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs index 2ff43e0f0..a416d1229 100644 --- a/crates/iceberg/src/arrow/schema.rs +++ b/crates/iceberg/src/arrow/schema.rs @@ -24,8 +24,8 @@ use arrow_array::types::{ validate_decimal_precision_and_scale, Decimal128Type, TimestampMicrosecondType, }; use arrow_array::{ - BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array, - PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray, + BooleanArray, Date32Array, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, + Int64Array, PrimitiveArray, Scalar, StringArray, TimestampMicrosecondArray, }; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; use bitvec::macros::internal::funty::Fundamental; @@ -636,6 +636,9 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> Result { Ok(Box::new(StringArray::new_scalar(value.as_str()))) } + (PrimitiveType::Date, PrimitiveLiteral::Int(value)) => { + Ok(Box::new(Date32Array::new_scalar(*value))) + } (PrimitiveType::Timestamp, PrimitiveLiteral::Long(value)) => { Ok(Box::new(TimestampMicrosecondArray::new_scalar(*value))) } diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs new file mode 100644 index 000000000..110e4f7e4 --- /dev/null +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::VecDeque; + +use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; +use datafusion::common::Column; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{Expr, Operator}; +use datafusion::scalar::ScalarValue; +use iceberg::expr::{Predicate, Reference}; +use iceberg::spec::Datum; + +pub struct ExprToPredicateVisitor { + stack: VecDeque>, +} +impl ExprToPredicateVisitor { + /// Create a new predicate conversion visitor. + pub fn new() -> Self { + Self { + stack: VecDeque::new(), + } + } + /// Get the predicate from the stack. + pub fn get_predicate(&self) -> Option { + self.stack + .iter() + .filter_map(|opt| opt.clone()) + .reduce(Predicate::and) + } + + /// Convert a column expression to an iceberg predicate. + fn convert_column_expr( + &self, + col: &Column, + op: &Operator, + lit: &ScalarValue, + ) -> Option { + let reference = Reference::new(col.name.clone()); + let datum = scalar_value_to_datum(lit)?; + Some(binary_op_to_predicate(reference, op, datum)) + } + + /// Convert a compound expression to an iceberg predicate. + /// + /// The strategy is to support the following cases: + /// - if its an AND expression then the result will be the valid predicates, whether there are 2 or just 1 + /// - if its an OR expression then a predicate will be returned only if there are 2 valid predicates on both sides + fn convert_compound_expr(&self, valid_preds: &[Predicate], op: &Operator) -> Option { + let valid_preds_count = valid_preds.len(); + match (op, valid_preds_count) { + (Operator::And, 1) => valid_preds.first().cloned(), + (Operator::And, 2) => Some(Predicate::and( + valid_preds[0].clone(), + valid_preds[1].clone(), + )), + (Operator::Or, 2) => Some(Predicate::or( + valid_preds[0].clone(), + valid_preds[1].clone(), + )), + _ => None, + } + } +} + +// Implement TreeNodeVisitor for ExprToPredicateVisitor +impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor { + type Node = Expr; + + fn f_down(&mut self, _node: &'n Expr) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, expr: &'n Expr) -> Result { + if let Expr::BinaryExpr(binary) = expr { + match (&*binary.left, &binary.op, &*binary.right) { + // process simple binary expressions, e.g. col > 1 + (Expr::Column(col), op, Expr::Literal(lit)) => { + let col_pred = self.convert_column_expr(col, op, lit); + self.stack.push_back(col_pred); + } + // // process reversed binary expressions, e.g. 1 < col + (Expr::Literal(lit), op, Expr::Column(col)) => { + let col_pred = op + .swap() + .and_then(|negated_op| self.convert_column_expr(col, &negated_op, lit)); + self.stack.push_back(col_pred); + } + // process compound expressions (involving logical operators. e.g., AND or OR and children) + (_left, op, _right) if op.is_logic_operator() => { + let right_pred = self.stack.pop_back().flatten(); + let left_pred = self.stack.pop_back().flatten(); + let children: Vec<_> = [left_pred, right_pred].into_iter().flatten().collect(); + let compound_pred = self.convert_compound_expr(&children, op); + self.stack.push_back(compound_pred); + } + _ => return Ok(TreeNodeRecursion::Continue), + } + } + Ok(TreeNodeRecursion::Continue) + } +} + +const MILLIS_PER_DAY: i64 = 24 * 60 * 60 * 1000; +/// Convert a scalar value to an iceberg datum. +fn scalar_value_to_datum(value: &ScalarValue) -> Option { + match value { + ScalarValue::Int8(Some(v)) => Some(Datum::int(*v as i32)), + ScalarValue::Int16(Some(v)) => Some(Datum::int(*v as i32)), + ScalarValue::Int32(Some(v)) => Some(Datum::int(*v)), + ScalarValue::Int64(Some(v)) => Some(Datum::long(*v)), + ScalarValue::Float32(Some(v)) => Some(Datum::double(*v as f64)), + ScalarValue::Float64(Some(v)) => Some(Datum::double(*v)), + ScalarValue::Utf8(Some(v)) => Some(Datum::string(v.clone())), + ScalarValue::LargeUtf8(Some(v)) => Some(Datum::string(v.clone())), + ScalarValue::Date32(Some(v)) => Some(Datum::date(*v)), + ScalarValue::Date64(Some(v)) => Some(Datum::date((*v / MILLIS_PER_DAY) as i32)), + _ => None, + } +} + +/// convert the data fusion Exp to an iceberg [`Predicate`] +fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate { + match op { + Operator::Eq => reference.equal_to(datum), + Operator::NotEq => reference.not_equal_to(datum), + Operator::Lt => reference.less_than(datum), + Operator::LtEq => reference.less_than_or_equal_to(datum), + Operator::Gt => reference.greater_than(datum), + Operator::GtEq => reference.greater_than_or_equal_to(datum), + _ => Predicate::AlwaysTrue, + } +} + +#[cfg(test)] +mod tests { + use std::collections::VecDeque; + + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::tree_node::TreeNode; + use datafusion::common::DFSchema; + use datafusion::prelude::SessionContext; + use iceberg::expr::{Predicate, Reference}; + use iceberg::spec::Datum; + + use super::ExprToPredicateVisitor; + + fn create_test_schema() -> DFSchema { + let arrow_schema = Schema::new(vec![ + Field::new("foo", DataType::Int32, false), + Field::new("bar", DataType::Utf8, false), + ]); + DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() + } + + #[test] + fn test_predicate_conversion_with_single_condition() { + let sql = "foo > 1"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + assert_eq!( + predicate, + Reference::new("foo").greater_than(Datum::long(1)) + ); + } + #[test] + fn test_predicate_conversion_with_single_unsupported_condition() { + let sql = "foo is null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate(); + assert_eq!(predicate, None); + } + + #[test] + fn test_predicate_conversion_with_single_condition_rev() { + let sql = "1 < foo"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + assert_eq!( + predicate, + Reference::new("foo").greater_than(Datum::long(1)) + ); + } + #[test] + fn test_predicate_conversion_with_and_condition() { + let sql = "foo > 1 and bar = 'test'"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let expected_predicate = Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_and_condition_unsupported() { + let sql = "foo > 1 and bar is not null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let expected_predicate = Reference::new("foo").greater_than(Datum::long(1)); + assert_eq!(predicate, expected_predicate); + } + #[test] + fn test_predicate_conversion_with_and_condition_both_unsupported() { + let sql = "foo in (1, 2, 3) and bar is not null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate(); + let expected_predicate = None; + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_or_condition_unsupported() { + let sql = "foo > 1 or bar is not null"; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate(); + let expected_predicate = None; + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_complex_binary_expr() { + let sql = "(foo > 1 and bar = 'test') or foo < 0 "; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let inner_predicate = Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + ); + let expected_predicate = Predicate::or( + inner_predicate, + Reference::new("foo").less_than(Datum::long(0)), + ); + assert_eq!(predicate, expected_predicate); + } + + #[test] + fn test_predicate_conversion_with_complex_binary_expr_unsupported() { + let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 "; + let df_schema = create_test_schema(); + let expr = SessionContext::new() + .parse_sql_expr(sql, &df_schema) + .unwrap(); + let mut visitor = ExprToPredicateVisitor::new(); + expr.visit(&mut visitor).unwrap(); + let predicate = visitor.get_predicate().unwrap(); + let expected_predicate = Reference::new("foo").less_than(Datum::long(0)); + assert_eq!(predicate, expected_predicate); + } + + #[test] + // test the get result method + fn test_get_result_multiple() { + let predicates = vec![ + Some(Reference::new("foo").greater_than(Datum::long(1))), + None, + Some(Reference::new("bar").equal_to(Datum::string("test"))), + ]; + let stack = VecDeque::from(predicates); + let visitor = ExprToPredicateVisitor { stack }; + assert_eq!( + visitor.get_predicate(), + Some(Predicate::and( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + )) + ); + } + + #[test] + fn test_get_result_single() { + let predicates = vec![Some(Reference::new("foo").greater_than(Datum::long(1)))]; + let stack = VecDeque::from(predicates); + let visitor = ExprToPredicateVisitor { stack }; + assert_eq!( + visitor.get_predicate(), + Some(Reference::new("foo").greater_than(Datum::long(1))) + ); + } +} diff --git a/crates/integrations/datafusion/src/physical_plan/mod.rs b/crates/integrations/datafusion/src/physical_plan/mod.rs index 5ae586a0a..2fab109d7 100644 --- a/crates/integrations/datafusion/src/physical_plan/mod.rs +++ b/crates/integrations/datafusion/src/physical_plan/mod.rs @@ -15,4 +15,5 @@ // specific language governing permissions and limitations // under the License. +pub(crate) mod expr_to_predicate; pub(crate) mod scan; diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index 576acea6b..c53ce76d5 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -22,6 +22,7 @@ use std::vec; use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; +use datafusion::common::tree_node::TreeNode; use datafusion::error::Result as DFResult; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::EquivalenceProperties; @@ -29,9 +30,12 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ DisplayAs, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, }; +use datafusion::prelude::Expr; use futures::{Stream, TryStreamExt}; +use iceberg::expr::Predicate; use iceberg::table::Table; +use crate::physical_plan::expr_to_predicate::ExprToPredicateVisitor; use crate::to_datafusion_error; /// Manages the scanning process of an Iceberg [`Table`], encapsulating the @@ -47,6 +51,8 @@ pub(crate) struct IcebergTableScan { plan_properties: PlanProperties, /// Projection column names, None means all columns projection: Option>, + /// Filters to apply to the table scan + predicates: Option, } impl IcebergTableScan { @@ -55,15 +61,18 @@ impl IcebergTableScan { table: Table, schema: ArrowSchemaRef, projection: Option<&Vec>, + filters: &[Expr], ) -> Self { let plan_properties = Self::compute_properties(schema.clone()); let projection = get_column_names(schema.clone(), projection); + let predicates = convert_filters_to_predicate(filters); Self { table, schema, plan_properties, projection, + predicates, } } @@ -109,7 +118,11 @@ impl ExecutionPlan for IcebergTableScan { _partition: usize, _context: Arc, ) -> DFResult { - let fut = get_batch_stream(self.table.clone(), self.projection.clone()); + let fut = get_batch_stream( + self.table.clone(), + self.projection.clone(), + self.predicates.clone(), + ); let stream = futures::stream::once(fut).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( @@ -143,11 +156,15 @@ impl DisplayAs for IcebergTableScan { async fn get_batch_stream( table: Table, column_names: Option>, + predicates: Option, ) -> DFResult> + Send>>> { - let scan_builder = match column_names { + let mut scan_builder = match column_names { Some(column_names) => table.scan().select(column_names), None => table.scan().select_all(), }; + if let Some(pred) = predicates { + scan_builder = scan_builder.with_filter(pred); + } let table_scan = scan_builder.build().map_err(to_datafusion_error)?; let stream = table_scan @@ -155,10 +172,25 @@ async fn get_batch_stream( .await .map_err(to_datafusion_error)? .map_err(to_datafusion_error); - Ok(Box::pin(stream)) } +/// Converts DataFusion filters ([`Expr`]) to an iceberg [`Predicate`]. +/// If none of the filters could be converted, return `None` which adds no predicates to the scan operation. +/// If the conversion was successful, return the converted predicates combined with an AND operator. +fn convert_filters_to_predicate(filters: &[Expr]) -> Option { + filters + .iter() + .filter_map(|expr| { + let mut visitor = ExprToPredicateVisitor::new(); + if expr.visit(&mut visitor).is_ok() { + visitor.get_predicate() + } else { + None + } + }) + .reduce(Predicate::and) +} fn get_column_names( schema: ArrowSchemaRef, projection: Option<&Vec>, diff --git a/crates/integrations/datafusion/src/table.rs b/crates/integrations/datafusion/src/table.rs index 8d70d9488..016c6c00f 100644 --- a/crates/integrations/datafusion/src/table.rs +++ b/crates/integrations/datafusion/src/table.rs @@ -23,7 +23,7 @@ use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::catalog::Session; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result as DFResult; -use datafusion::logical_expr::Expr; +use datafusion::logical_expr::{BinaryExpr, Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::ExecutionPlan; use iceberg::arrow::schema_to_arrow_schema; use iceberg::table::Table; @@ -76,13 +76,30 @@ impl TableProvider for IcebergTableProvider { &self, _state: &dyn Session, projection: Option<&Vec>, - _filters: &[Expr], + filters: &[Expr], _limit: Option, ) -> DFResult> { Ok(Arc::new(IcebergTableScan::new( self.table.clone(), self.schema.clone(), projection, + filters, ))) } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> std::result::Result, datafusion::error::DataFusionError> + { + let filter_support = filters + .iter() + .map(|e| match e { + Expr::BinaryExpr(BinaryExpr { .. }) => TableProviderFilterPushDown::Inexact, + _ => TableProviderFilterPushDown::Unsupported, + }) + .collect::>(); + + Ok(filter_support) + } }