diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index e9c368d3a41ba..7e9616483c09b 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -37,10 +37,13 @@ use crate::execution::filter::FilterRelation; use crate::execution::limit::LimitRelation; use crate::execution::physical_plan::common; use crate::execution::physical_plan::datasource::DatasourceExec; -use crate::execution::physical_plan::expressions::{Column, Sum}; +use crate::execution::physical_plan::expressions::{ + BinaryExpr, CastExpr, Column, Literal, Sum, +}; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::projection::ProjectionExec; +use crate::execution::physical_plan::selection::SelectionExec; use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr}; use crate::execution::projection::ProjectRelation; use crate::execution::relation::{DataSourceRelation, Relation}; @@ -280,6 +283,12 @@ impl ExecutionContext { schema.clone(), )?)) } + LogicalPlan::Selection { input, expr, .. } => { + let input = self.create_physical_plan(input, batch_size)?; + let input_schema = input.as_ref().schema().clone(); + let runtime_expr = self.create_physical_expr(expr, &input_schema)?; + Ok(Arc::new(SelectionExec::try_new(runtime_expr, input)?)) + } _ => Err(ExecutionError::General( "Unsupported logical plan variant".to_string(), )), @@ -290,13 +299,25 @@ impl ExecutionContext { pub fn create_physical_expr( &self, e: &Expr, - _input_schema: &Schema, + input_schema: &Schema, ) -> Result> { match e { Expr::Column(i) => Ok(Arc::new(Column::new(*i))), - _ => Err(ExecutionError::NotImplemented( - "Unsupported expression".to_string(), - )), + Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), + Expr::BinaryExpr { left, op, right } => Ok(Arc::new(BinaryExpr::new( + self.create_physical_expr(left, input_schema)?, + op.clone(), + self.create_physical_expr(right, input_schema)?, + ))), + Expr::Cast { expr, data_type } => Ok(Arc::new(CastExpr::try_new( + self.create_physical_expr(expr, input_schema)?, + input_schema, + data_type.clone(), + )?)), + other => Err(ExecutionError::NotImplemented(format!( + "Physical plan does not support logical expression {:?}", + other + ))), } } @@ -569,6 +590,29 @@ mod tests { Ok(()) } + #[test] + fn parallel_selection() -> Result<()> { + let tmp_dir = TempDir::new("parallel_selection")?; + let partition_count = 4; + let mut ctx = create_ctx(&tmp_dir, partition_count)?; + + let logical_plan = + ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?; + let logical_plan = ctx.optimize(&logical_plan)?; + + let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?; + + let results = ctx.collect(physical_plan.as_ref())?; + + // there should be one batch per partition + assert_eq!(results.len(), partition_count); + + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(row_count, 20); + + Ok(()) + } + #[test] fn aggregate() -> Result<()> { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4)?; diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index eb53392ad4d05..f0c34c228db02 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -88,3 +88,4 @@ pub mod expressions; pub mod hash_aggregate; pub mod merge; pub mod projection; +pub mod selection; diff --git a/rust/datafusion/src/execution/physical_plan/selection.rs b/rust/datafusion/src/execution/physical_plan/selection.rs new file mode 100644 index 0000000000000..ca5e58e3e218e --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/selection.rs @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the selection execution plan. A selection filters rows based on a predicate + +use std::sync::{Arc, Mutex}; + +use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::{ + BatchIterator, ExecutionPlan, Partition, PhysicalExpr, +}; +use arrow::array::BooleanArray; +use arrow::compute::filter; +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatch; + +/// Execution plan for a Selection +pub struct SelectionExec { + /// The selection predicate expression + expr: Arc, + /// The input plan + input: Arc, +} + +impl SelectionExec { + /// Create a selection on an input + pub fn try_new( + expr: Arc, + input: Arc, + ) -> Result { + Ok(Self { + expr: expr.clone(), + input: input.clone(), + }) + } +} + +impl ExecutionPlan for SelectionExec { + /// Get the schema for this execution plan + fn schema(&self) -> Arc { + // The selection operator does not make any changes to the schema of its input + self.input.schema() + } + + /// Get the partitions for this execution plan + fn partitions(&self) -> Result>> { + let partitions: Vec> = self + .input + .partitions()? + .iter() + .map(|p| { + let expr = self.expr.clone(); + let partition: Arc = Arc::new(SelectionPartition { + schema: self.input.schema(), + expr, + input: p.clone() as Arc, + }); + + partition + }) + .collect(); + + Ok(partitions) + } +} + +/// Represents a single partition of a Selection execution plan +struct SelectionPartition { + schema: Arc, + expr: Arc, + input: Arc, +} + +impl Partition for SelectionPartition { + /// Execute the Selection + fn execute(&self) -> Result>> { + Ok(Arc::new(Mutex::new(SelectionIterator { + schema: self.schema.clone(), + expr: self.expr.clone(), + input: self.input.execute()?, + }))) + } +} + +/// Selection iterator +struct SelectionIterator { + schema: Arc, + expr: Arc, + input: Arc>, +} + +impl BatchIterator for SelectionIterator { + /// Get the schema + fn schema(&self) -> Arc { + self.schema.clone() + } + + /// Get the next batch + fn next(&mut self) -> Result> { + let mut input = self.input.lock().unwrap(); + match input.next()? { + Some(batch) => { + // evaluate the selection predicate to get a boolean array + let predicate_result = self.expr.evaluate(&batch)?; + + if let Some(f) = predicate_result.as_any().downcast_ref::() + { + // filter each array + let mut filtered_arrays = vec![]; + for i in 0..batch.num_columns() { + let array = batch.column(i); + let filtered_array = filter(array.as_ref(), f)?; + filtered_arrays.push(filtered_array); + } + Ok(Some(RecordBatch::try_new( + batch.schema().clone(), + filtered_arrays, + )?)) + } else { + Err(ExecutionError::InternalError( + "Predicate evaluated to non-boolean value".to_string(), + )) + } + } + None => Ok(None), + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::execution::physical_plan::csv::CsvExec; + use crate::execution::physical_plan::expressions::*; + use crate::execution::physical_plan::ExecutionPlan; + use crate::logicalplan::{Operator, ScalarValue}; + use crate::test; + use std::iter::Iterator; + + #[test] + fn simple_predicate() -> Result<()> { + let schema = test::aggr_test_schema(); + + let partitions = 4; + let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + + let csv = CsvExec::try_new(&path, schema, true, None, 1024)?; + + let predicate: Arc = binary( + binary(col(1), Operator::Gt, lit(ScalarValue::UInt32(1))), + Operator::And, + binary(col(1), Operator::Lt, lit(ScalarValue::UInt32(4))), + ); + + let selection: Arc = + Arc::new(SelectionExec::try_new(predicate, Arc::new(csv))?); + + let results = test::execute(selection.as_ref())?; + + results + .iter() + .for_each(|batch| assert_eq!(13, batch.num_columns())); + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(41, row_count); + + Ok(()) + } + +}