From 3c1e4c0fd23012afa00dfd05a0571b61023b3d21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Wed, 1 Mar 2023 22:29:36 +0300 Subject: [PATCH] Support for Sliding Windows Joins with Symmetric Hash Join (SHJ) (#5322) * Prunable symmetric hash join implementation * Minor changes after merge * Filter mapping inside SymmetricHashJoin * Commenting on * Minor changes after merge * Simplify interval arithmetic library code * Make the interval arithmetics library more robust * After merge corrections * Simplifications to constraint propagation code * Revamp some API's and enhance comments - Utilize estimate_bounds without propagation for better API. - Remove coupling between node_index & PhysicalExpr pairing and graph. - Better commenting on symmetric hash join while using graph * Resolve a propagation bug and make the propagation returns an opt. status * Refactor and simplify CP code, improve comments * Code deduplication between pipeline fixer and utils, also enhance comments. * Refactor on input stream consumer on SymmetricHashJoin * After merge resolution, before proto update * Revery unnecessary changes in some exprs Also, cargo.lock update. * Remove support indicators to interval library, rename module to use the standard name * Simplify PipelineFixer, remove clones, improve comments * Enhance the symmetric hash join code with reviews * Revamp according to reviews * Use a simple, stateless, one-liner DFS to check for IA support * Move test function to a test_utils module * Simplify DAG creation code * Reducing code change * Comment improvements and simplifications * Revamp SortedFilterExpr usage and enhance comments * Update fifo.rs * Remove unnecessary clones, improve comments and code structure * Remove leaf searches from CP iterations, improve code organization/comments * Bug fix in cp_solver, revamp some comments * Update with correct testing * Test for future support on fuzzy matches between exprs * Compute connected nodes in CP solver via a DFS, improve comments * Revamp OneSideHashJoin constructor and new unit test * Update on concat_batches usage * Revamping according to comments. * Simplifications, refactoring * Minor fix * Fix typo in the new_zero function --------- Co-authored-by: Mehmet Ozan Kabak --- datafusion-cli/Cargo.lock | 17 + datafusion/common/src/scalar.rs | 23 + datafusion/core/src/execution/context.rs | 3 + .../physical_optimizer/pipeline_checker.rs | 4 +- .../src/physical_optimizer/pipeline_fixer.rs | 173 +- .../core/src/physical_optimizer/pruning.rs | 9 +- .../core/src/physical_plan/joins/hash_join.rs | 177 +- .../physical_plan/joins/hash_join_utils.rs | 577 ++++ .../core/src/physical_plan/joins/mod.rs | 16 +- .../physical_plan/joins/nested_loop_join.rs | 5 +- .../joins/symmetric_hash_join.rs | 2473 +++++++++++++++++ .../core/src/physical_plan/joins/utils.rs | 121 +- datafusion/core/src/physical_plan/memory.rs | 14 +- datafusion/core/src/test_util.rs | 43 + datafusion/core/tests/fifo.rs | 122 +- datafusion/expr/src/operator.rs | 29 + datafusion/physical-expr/Cargo.toml | 2 + .../physical-expr/src/expressions/binary.rs | 45 + .../physical-expr/src/expressions/cast.rs | 23 + .../physical-expr/src/intervals/cp_solver.rs | 1038 +++++++ .../src/intervals/interval_aritmetic.rs | 533 ++++ datafusion/physical-expr/src/intervals/mod.rs | 26 + .../physical-expr/src/intervals/test_utils.rs | 67 + datafusion/physical-expr/src/lib.rs | 1 + datafusion/physical-expr/src/physical_expr.rs | 22 + datafusion/physical-expr/src/utils.rs | 291 +- 26 files changed, 5653 insertions(+), 201 deletions(-) create mode 100644 datafusion/core/src/physical_plan/joins/hash_join_utils.rs create mode 100644 datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs create mode 100644 datafusion/physical-expr/src/intervals/cp_solver.rs create mode 100644 datafusion/physical-expr/src/intervals/interval_aritmetic.rs create mode 100644 datafusion/physical-expr/src/intervals/mod.rs create mode 100644 datafusion/physical-expr/src/intervals/test_utils.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index e83c77d7a15d..5e40db04560f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -793,6 +793,7 @@ dependencies = [ "md-5", "num-traits", "paste", + "petgraph", "rand", "regex", "sha2", @@ -964,6 +965,12 @@ dependencies = [ "windows-sys 0.45.0", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flatbuffers" version = "23.1.21" @@ -1801,6 +1808,16 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" +[[package]] +name = "petgraph" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5014253a1331579ce62aa67443b4a658c5e7dd03d4bc6d302b94474888143" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "pin-project-lite" version = "0.2.9" diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 66c1f3f12552..40d1eb25e07a 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1019,6 +1019,29 @@ impl ScalarValue { Self::List(scalars, Box::new(Field::new("item", child_type, true))) } + // Create a zero value in the given type. + pub fn new_zero(datatype: &DataType) -> Result { + assert!(datatype.is_primitive()); + Ok(match datatype { + DataType::Boolean => ScalarValue::Boolean(Some(false)), + DataType::Int8 => ScalarValue::Int8(Some(0)), + DataType::Int16 => ScalarValue::Int16(Some(0)), + DataType::Int32 => ScalarValue::Int32(Some(0)), + DataType::Int64 => ScalarValue::Int64(Some(0)), + DataType::UInt8 => ScalarValue::UInt8(Some(0)), + DataType::UInt16 => ScalarValue::UInt16(Some(0)), + DataType::UInt32 => ScalarValue::UInt32(Some(0)), + DataType::UInt64 => ScalarValue::UInt64(Some(0)), + DataType::Float32 => ScalarValue::Float32(Some(0.0)), + DataType::Float64 => ScalarValue::Float64(Some(0.0)), + _ => { + return Err(DataFusionError::NotImplemented(format!( + "Can't create a zero scalar from data_type \"{datatype:?}\"" + ))); + } + }) + } + /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { match self { diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c7f56733a501..5b39e54dbe86 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -1584,6 +1584,9 @@ impl SessionState { // repartitioning and local sorting steps to meet distribution and ordering requirements. // Therefore, it should run before EnforceDistribution and EnforceSorting. Arc::new(JoinSelection::new()), + // Enforce sort before PipelineFixer + Arc::new(EnforceDistribution::new()), + Arc::new(EnforceSorting::new()), // If the query is processing infinite inputs, the PipelineFixer rule applies the // necessary transformations to make the query runnable (if it is not already runnable). // If the query can not be made runnable, the rule emits an error with a diagnostic message. diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 96f0b0ff6932..8a6b0e003adf 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -305,7 +305,7 @@ mod sql_tests { FROM test LIMIT 5".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "Window Error".to_string() + error_operator: "Sort Error".to_string() }; case.run().await?; @@ -328,7 +328,7 @@ mod sql_tests { SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 FROM test".to_string(), cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "Window Error".to_string() + error_operator: "Sort Error".to_string() }; case.run().await?; Ok(()) diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs index 1ca21bb88b7f..99bcf264fadd 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs @@ -29,11 +29,19 @@ use crate::physical_optimizer::pipeline_checker::{ check_finiteness_requirements, PipelineStatePropagator, }; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::{HashJoinExec, PartitionMode}; +use crate::physical_plan::joins::utils::JoinSide; +use crate::physical_plan::joins::{ + convert_sort_expr_with_filter_schema, HashJoinExec, PartitionMode, + SymmetricHashJoinExec, +}; use crate::physical_plan::rewrite::TreeNodeRewritable; use crate::physical_plan::ExecutionPlan; use datafusion_common::DataFusionError; use datafusion_expr::logical_plan::JoinType; +use datafusion_physical_expr::expressions::{BinaryExpr, CastExpr, Column, Literal}; +use datafusion_physical_expr::intervals::{is_datatype_supported, is_operator_supported}; +use datafusion_physical_expr::PhysicalExpr; + use std::sync::Arc; /// The [PipelineFixer] rule tries to modify a given plan so that it can @@ -48,8 +56,13 @@ impl PipelineFixer { Self {} } } +/// [PipelineFixer] subrules are functions of this type. Such functions take a +/// single [PipelineStatePropagator] argument, which stores state variables +/// indicating the unboundedness status of the current [ExecutionPlan] as +/// the [PipelineFixer] rule traverses the entire plan tree. type PipelineFixerSubrule = - dyn Fn(&PipelineStatePropagator) -> Option>; + dyn Fn(PipelineStatePropagator) -> Option>; + impl PhysicalOptimizerRule for PipelineFixer { fn optimize( &self, @@ -57,8 +70,10 @@ impl PhysicalOptimizerRule for PipelineFixer { _config: &ConfigOptions, ) -> Result> { let pipeline = PipelineStatePropagator::new(plan); - let physical_optimizer_subrules: Vec> = - vec![Box::new(hash_join_swap_subrule)]; + let physical_optimizer_subrules: Vec> = vec![ + Box::new(hash_join_convert_symmetric_subrule), + Box::new(hash_join_swap_subrule), + ]; let state = pipeline.transform_up(&|p| { apply_subrules_and_check_finiteness_requirements( p, @@ -77,6 +92,104 @@ impl PhysicalOptimizerRule for PipelineFixer { } } +/// Indicates whether interval arithmetic is supported for the given expression. +/// Currently, we do not support all [PhysicalExpr]s for interval calculations. +/// We do not support every type of [Operator]s either. Over time, this check +/// will relax as more types of [PhysicalExpr]s and [Operator]s are supported. +/// Currently, [CastExpr], [BinaryExpr], [Column] and [Literal] is supported. +fn check_support(expr: &Arc) -> bool { + let expr_any = expr.as_any(); + let expr_supported = if let Some(binary_expr) = expr_any.downcast_ref::() + { + is_operator_supported(binary_expr.op()) + } else { + expr_any.is::() || expr_any.is::() || expr_any.is::() + }; + expr_supported && expr.children().iter().all(check_support) +} + +/// This function returns whether a given hash join is replaceable by a +/// symmetric hash join. Basically, the requirement is that involved +/// [PhysicalExpr]s, [Operator]s and data types need to be supported, +/// and order information must cover every column in the filter expression. +fn is_suitable_for_symmetric_hash_join(hash_join: &HashJoinExec) -> Result { + if let Some(filter) = hash_join.filter() { + let left = hash_join.left(); + if let Some(left_ordering) = left.output_ordering() { + let right = hash_join.right(); + if let Some(right_ordering) = right.output_ordering() { + let expr_supported = check_support(filter.expression()); + let left_convertible = convert_sort_expr_with_filter_schema( + &JoinSide::Left, + filter, + &left.schema(), + &left_ordering[0], + )? + .is_some(); + let right_convertible = convert_sort_expr_with_filter_schema( + &JoinSide::Right, + filter, + &right.schema(), + &right_ordering[0], + )? + .is_some(); + let fields_supported = filter + .schema() + .fields() + .iter() + .all(|f| is_datatype_supported(f.data_type())); + return Ok(expr_supported + && fields_supported + && left_convertible + && right_convertible); + } + } + } + Ok(false) +} + +/// This subrule checks if one can replace a hash join with a symmetric hash +/// join so that the pipeline does not break due to the join operation in +/// question. If possible, it makes this replacement; otherwise, it has no +/// effect. +fn hash_join_convert_symmetric_subrule( + input: PipelineStatePropagator, +) -> Option> { + let plan = input.plan; + if let Some(hash_join) = plan.as_any().downcast_ref::() { + let ub_flags = input.children_unbounded; + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); + let new_plan = if left_unbounded && right_unbounded { + match is_suitable_for_symmetric_hash_join(hash_join) { + Ok(true) => SymmetricHashJoinExec::try_new( + hash_join.left().clone(), + hash_join.right().clone(), + hash_join + .on() + .iter() + .map(|(l, r)| (l.clone(), r.clone())) + .collect(), + hash_join.filter().unwrap().clone(), + hash_join.join_type(), + hash_join.null_equals_null(), + ) + .map(|e| Arc::new(e) as _), + Ok(false) => Ok(plan), + Err(e) => return Some(Err(e)), + } + } else { + Ok(plan) + }; + Some(new_plan.map(|plan| PipelineStatePropagator { + plan, + unbounded: left_unbounded || right_unbounded, + children_unbounded: ub_flags, + })) + } else { + None + } +} + /// This subrule will swap build/probe sides of a hash join depending on whether its inputs /// may produce an infinite stream of records. The rule ensures that the left (build) side /// of the hash join always operates on an input stream that will produce a finite set of. @@ -119,12 +232,12 @@ impl PhysicalOptimizerRule for PipelineFixer { /// /// ``` fn hash_join_swap_subrule( - input: &PipelineStatePropagator, + input: PipelineStatePropagator, ) -> Option> { - let plan = input.plan.clone(); - let children = &input.children_unbounded; + let plan = input.plan; if let Some(hash_join) = plan.as_any().downcast_ref::() { - let (left_unbounded, right_unbounded) = (children[0], children[1]); + let ub_flags = input.children_unbounded; + let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); let new_plan = if left_unbounded && !right_unbounded { if matches!( *hash_join.join_type(), @@ -140,12 +253,11 @@ fn hash_join_swap_subrule( } else { Ok(plan) }; - let new_state = new_plan.map(|plan| PipelineStatePropagator { + Some(new_plan.map(|plan| PipelineStatePropagator { plan, unbounded: left_unbounded || right_unbounded, - children_unbounded: vec![left_unbounded, right_unbounded], - }); - Some(new_state) + children_unbounded: ub_flags, + })) } else { None } @@ -182,13 +294,46 @@ fn apply_subrules_and_check_finiteness_requirements( physical_optimizer_subrules: &Vec>, ) -> Result> { for sub_rule in physical_optimizer_subrules { - if let Some(value) = sub_rule(&input).transpose()? { + if let Some(value) = sub_rule(input.clone()).transpose()? { input = value; } } check_finiteness_requirements(input) } +#[cfg(test)] +mod util_tests { + use crate::physical_optimizer::pipeline_fixer::check_support; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, NegativeExpr}; + use datafusion_physical_expr::PhysicalExpr; + use std::sync::Arc; + + #[test] + fn check_expr_supported() { + let supported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(check_support(&supported_expr)); + let supported_expr_2 = Arc::new(Column::new("a", 0)) as Arc; + assert!(check_support(&supported_expr_2)); + let unsupported_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(Column::new("a", 0)), + )) as Arc; + assert!(!check_support(&unsupported_expr)); + let unsupported_expr_2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Or, + Arc::new(NegativeExpr::new(Arc::new(Column::new("a", 0)))), + )) as Arc; + assert!(!check_support(&unsupported_expr_2)); + } +} + #[cfg(test)] mod hash_join_tests { use super::*; @@ -574,7 +719,7 @@ mod hash_join_tests { children_unbounded: vec![left_unbounded, right_unbounded], }; let optimized_hash_join = - hash_join_swap_subrule(&initial_hash_join_state).unwrap()?; + hash_join_swap_subrule(initial_hash_join_state).unwrap()?; let optimized_join_plan = optimized_hash_join.plan; // If swap did happen diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index fbf000148dcc..80b72e68f6ea 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -46,10 +46,9 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; - use datafusion_physical_expr::rewrite::{TreeNodeRewritable, TreeNodeRewriter}; -use datafusion_physical_expr::utils::get_phys_expr_columns; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; use log::trace; /// Interface to pass statistics information to [`PruningPredicate`] @@ -447,8 +446,8 @@ impl<'a> PruningExpressionBuilder<'a> { required_columns: &'a mut RequiredStatColumns, ) -> Result { // find column name; input could be a more complicated expression - let left_columns = get_phys_expr_columns(left); - let right_columns = get_phys_expr_columns(right); + let left_columns = collect_columns(left); + let right_columns = collect_columns(right); let (column_expr, scalar_expr, columns, correct_operator) = match (left_columns.len(), right_columns.len()) { (1, 0) => (left, right, left_columns, op), diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index 2c339996480c..f30e18aa78b8 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -84,17 +84,17 @@ use super::{ }; use crate::physical_plan::joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, - get_final_indices_from_bit_map, need_produce_result_in_final, + get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, }; use log::debug; use std::fmt; use std::task::Poll; -// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. +// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. // // Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used // to put the indices in a certain bucket. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the left side, +// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, // we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. // E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 // As the key is a hash value, we need to check possible hash collisions in the probe stage @@ -102,7 +102,7 @@ use std::task::Poll; // but the values don't match. Those are checked in the [equal_rows] macro // TODO: speed up collision check and move away from using a hashbrown HashMap // https://github.com/apache/arrow-datafusion/issues/50 -struct JoinHashMap(RawTable<(u64, SmallVec<[u64; 1]>)>); +pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>); impl fmt::Debug for JoinHashMap { fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { @@ -310,7 +310,7 @@ impl ExecutionPlan for HashJoinExec { /// Specifies whether this plan generates an infinite stream of records. /// If the plan does not support pipelining, but it its input(s) are - /// infinite, returns an error to indicate this. + /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { let (left, right) = (children[0], children[1]); // If left is unbounded, or right is unbounded with JoinType::Right, @@ -609,7 +609,7 @@ async fn partitioned_left_input( /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, /// assuming that the [RecordBatch] corresponds to the `index`th -fn update_hash( +pub fn update_hash( on: &[Column], batch: &RecordBatch, hash_map: &mut JoinHashMap, @@ -680,41 +680,50 @@ impl RecordBatchStream for HashJoinStream { } } -// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join +/// Gets build and probe indices which satisfy the on condition (including +/// the equality condition and the join filter) in the join. #[allow(clippy::too_many_arguments)] -fn build_join_indices( - batch: &RecordBatch, - left_data: &JoinLeftData, - on_left: &[Column], - on_right: &[Column], +pub fn build_join_indices( + probe_batch: &RecordBatch, + build_hashmap: &JoinHashMap, + build_input_buffer: &RecordBatch, + on_build: &[Column], + on_probe: &[Column], filter: Option<&JoinFilter>, random_state: &RandomState, null_equals_null: &bool, + hashes_buffer: &mut Vec, + offset: Option, + build_side: JoinSide, ) -> Result<(UInt64Array, UInt32Array)> { - // Get the indices which is satisfies the equal join condition, like `left.a1 = right.a2` - let (left_indices, right_indices) = build_equal_condition_join_indices( - left_data, - batch, - on_left, - on_right, + // Get the indices that satisfy the equality condition, like `left.a1 = right.a2` + let (build_indices, probe_indices) = build_equal_condition_join_indices( + build_hashmap, + build_input_buffer, + probe_batch, + on_build, + on_probe, random_state, null_equals_null, + hashes_buffer, + offset, )?; if let Some(filter) = filter { - // Filter the indices which is satisfies the non-equal join condition, like `left.b1 = 10` + // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10` apply_join_filter_to_indices( - &left_data.1, - batch, - left_indices, - right_indices, + build_input_buffer, + probe_batch, + build_indices, + probe_indices, filter, + build_side, ) } else { - Ok((left_indices, right_indices)) + Ok((build_indices, probe_indices)) } } -// Returns the index of equal condition join result: left_indices and right_indices +// Returns build/probe indices satisfying the equality condition. // On LEFT.b1 = RIGHT.b2 // LEFT Table: // a1 b1 c1 @@ -742,71 +751,79 @@ fn build_join_indices( // "| 13 | 10 | 130 | 12 | 10 | 120 |", // "| 9 | 8 | 90 | 8 | 8 | 80 |", // "+----+----+-----+----+----+-----+" -// And the result of left and right indices -// left indices: 5, 6, 6, 4 -// right indices: 3, 4, 5, 3 -fn build_equal_condition_join_indices( - left_data: &JoinLeftData, - right: &RecordBatch, - left_on: &[Column], - right_on: &[Column], +// And the result of build and probe indices are: +// Build indices: 5, 6, 6, 4 +// Probe indices: 3, 4, 5, 3 +#[allow(clippy::too_many_arguments)] +pub fn build_equal_condition_join_indices( + build_hashmap: &JoinHashMap, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_on: &[Column], + probe_on: &[Column], random_state: &RandomState, null_equals_null: &bool, + hashes_buffer: &mut Vec, + offset: Option, ) -> Result<(UInt64Array, UInt32Array)> { - let keys_values = right_on + let keys_values = probe_on .iter() - .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows()))) + .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) .collect::>>()?; - let left_join_values = left_on + let build_join_values = build_on .iter() - .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows()))) + .map(|c| { + Ok(c.evaluate(build_input_buffer)? + .into_array(build_input_buffer.num_rows())) + }) .collect::>>()?; - let hashes_buffer = &mut vec![0; keys_values[0].len()]; + hashes_buffer.clear(); + hashes_buffer.resize(probe_batch.num_rows(), 0); let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - let left = &left_data.0; // Using a buffer builder to avoid slower normal builder - let mut left_indices = UInt64BufferBuilder::new(0); - let mut right_indices = UInt32BufferBuilder::new(0); - - // Visit all of the right rows + let mut build_indices = UInt64BufferBuilder::new(0); + let mut probe_indices = UInt32BufferBuilder::new(0); + let offset_value = offset.unwrap_or(0); + // Visit all of the probe rows for (row, hash_value) in hash_values.iter().enumerate() { // Get the hash and find it in the build index - // For every item on the left and right we check if it matches + // For every item on the build and probe we check if it matches // This possibly contains rows with hash collisions, // So we have to check here whether rows are equal or not - if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) + if let Some((_, indices)) = build_hashmap + .0 + .get(*hash_value, |(hash, _)| *hash_value == *hash) { for &i in indices { + // Check hash collisions + let offset_build_index = i as usize - offset_value; // Check hash collisions if equal_rows( - i as usize, + offset_build_index, row, - &left_join_values, + &build_join_values, &keys_values, *null_equals_null, )? { - left_indices.append(i); - right_indices.append(row as u32); + build_indices.append(offset_build_index as u64); + probe_indices.append(row as u32); } } } } - let left = ArrayData::builder(DataType::UInt64) - .len(left_indices.len()) - .add_buffer(left_indices.finish()) - .build() - .unwrap(); - let right = ArrayData::builder(DataType::UInt32) - .len(right_indices.len()) - .add_buffer(right_indices.finish()) - .build() - .unwrap(); + let build = ArrayData::builder(DataType::UInt64) + .len(build_indices.len()) + .add_buffer(build_indices.finish()) + .build()?; + let probe = ArrayData::builder(DataType::UInt32) + .len(probe_indices.len()) + .add_buffer(probe_indices.finish()) + .build()?; Ok(( - PrimitiveArray::::from(left), - PrimitiveArray::::from(right), + PrimitiveArray::::from(build), + PrimitiveArray::::from(probe), )) } @@ -1168,7 +1185,7 @@ impl HashJoinStream { BooleanBufferBuilder::new(0) } }); - + let mut hashes_buffer = vec![]; self.right .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { @@ -1181,12 +1198,16 @@ impl HashJoinStream { // get the matched two indices for the on condition let left_right_indices = build_join_indices( &batch, - left_data, + &left_data.0, + &left_data.1, &self.on_left, &self.on_right, self.filter.as_ref(), &self.random_state, &self.null_equals_null, + &mut hashes_buffer, + None, + JoinSide::Left, ); let result = match left_right_indices { @@ -1214,6 +1235,7 @@ impl HashJoinStream { left_side, right_side, &self.column_indices, + JoinSide::Left, ); self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(batch.num_rows()); @@ -1245,6 +1267,7 @@ impl HashJoinStream { left_side, right_side, &self.column_indices, + JoinSide::Left, ); if let Ok(ref batch) = result { @@ -1280,26 +1303,31 @@ impl Stream for HashJoinStream { #[cfg(test)] mod tests { + use std::sync::Arc; + + use super::*; use crate::physical_expr::expressions::BinaryExpr; + use crate::prelude::SessionContext; use crate::{ assert_batches_sorted_eq, physical_plan::{ - common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + common, + expressions::Column, + hash_utils::create_hashes, + joins::{hash_join::build_equal_condition_join_indices, utils::JoinSide}, + memory::MemoryExec, + repartition::RepartitionExec, }, test::exec::MockExec, test::{build_table_i32, columns}, }; - use arrow::array::UInt32Builder; - use arrow::array::UInt64Builder; - use arrow::datatypes::Field; + use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::Operator; - use super::*; - use crate::physical_plan::joins::utils::JoinSide; - use crate::prelude::SessionContext; use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::Literal; - use std::sync::Arc; + use smallvec::smallvec; fn build_table( a: (&str, &Vec), @@ -2643,12 +2671,15 @@ mod tests { let left_data = (JoinHashMap(hashmap_left), left); let (l, r) = build_equal_condition_join_indices( - &left_data, + &left_data.0, + &left_data.1, &right, &[Column::new("a", 0)], &[Column::new("a", 0)], &random_state, &false, + &mut vec![0; right.num_rows()], + None, )?; let mut left_ids = UInt64Builder::with_capacity(0); diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs new file mode 100644 index 000000000000..6bfc8a1fcf19 --- /dev/null +++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs @@ -0,0 +1,577 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file contains common subroutines for regular and symmetric hash join +//! related functionality, used both in join calculations and optimization rules. + +use std::collections::HashMap; +use std::sync::Arc; +use std::usize; + +use arrow::datatypes::SchemaRef; + +use datafusion_common::DataFusionError; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::intervals::Interval; +use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + +use crate::common::Result; +use crate::physical_plan::joins::utils::{JoinFilter, JoinSide}; + +fn check_filter_expr_contains_sort_information( + expr: &Arc, + reference: &Arc, +) -> bool { + expr.eq(reference) + || expr + .children() + .iter() + .any(|e| check_filter_expr_contains_sort_information(e, reference)) +} + +/// Create a one to one mapping from main columns to filter columns using +/// filter column indices. A column index looks like: +/// ```text +/// ColumnIndex { +/// index: 0, // field index in main schema +/// side: JoinSide::Left, // child side +/// } +/// ``` +pub fn map_origin_col_to_filter_col( + filter: &JoinFilter, + schema: &SchemaRef, + side: &JoinSide, +) -> Result> { + let filter_schema = filter.schema(); + let mut col_to_col_map: HashMap = HashMap::new(); + for (filter_schema_index, index) in filter.column_indices().iter().enumerate() { + if index.side.eq(side) { + // Get the main field from column index: + let main_field = schema.field(index.index); + // Create a column expression: + let main_col = Column::new_with_schema(main_field.name(), schema.as_ref())?; + // Since the order of by filter.column_indices() is the same with + // that of intermediate schema fields, we can get the column directly. + let filter_field = filter_schema.field(filter_schema_index); + let filter_col = Column::new(filter_field.name(), filter_schema_index); + // Insert mapping: + col_to_col_map.insert(main_col, filter_col); + } + } + Ok(col_to_col_map) +} + +/// This function analyzes [PhysicalSortExpr] graphs with respect to monotonicity +/// (sorting) properties. This is necessary since monotonically increasing and/or +/// decreasing expressions are required when using join filter expressions for +/// data pruning purposes. +/// +/// The method works as follows: +/// 1. Maps the original columns to the filter columns using the `map_origin_col_to_filter_col` function. +/// 2. Collects all columns in the sort expression using the `PhysicalExprColumnCollector` visitor. +/// 3. Checks if all columns are included in the `column_mapping_information` map. +/// 4. If all columns are included, the sort expression is converted into a filter expression using the `transform_up` and `convert_filter_columns` functions. +/// 5. Searches the converted filter expression in the filter expression using the `check_filter_expr_contains_sort_information`. +/// 6. If an exact match is encountered, returns the converted filter expression as `Some(Arc)`. +/// 7. If all columns are not included or the exact match is not encountered, returns `None`. +/// +/// Examples: +/// Consider the filter expression "a + b > c + 10 AND a + b < c + 100". +/// 1. If the expression "a@ + d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. +/// 2. If the expression "d@" is sorted, it will not be accepted since the "d@" column is not part of the filter. +/// 3. If the expression "a@ + b@ + c@" is sorted, all columns are represented in the filter expression. However, +/// there is no exact match, so this expression does not indicate pruning. +pub fn convert_sort_expr_with_filter_schema( + side: &JoinSide, + filter: &JoinFilter, + schema: &SchemaRef, + sort_expr: &PhysicalSortExpr, +) -> Result>> { + let column_map = map_origin_col_to_filter_col(filter, schema, side)?; + let expr = sort_expr.expr.clone(); + // Get main schema columns: + let expr_columns = collect_columns(&expr); + // Calculation is possible with `column_map` since sort exprs belong to a child. + let all_columns_are_included = + expr_columns.iter().all(|col| column_map.contains_key(col)); + if all_columns_are_included { + // Since we are sure that one to one column mapping includes all columns, we convert + // the sort expression into a filter expression. + let converted_filter_expr = + expr.transform_up(&|p| convert_filter_columns(p, &column_map))?; + // Search the converted `PhysicalExpr` in filter expression; if an exact + // match is found, use this sorted expression in graph traversals. + if check_filter_expr_contains_sort_information( + filter.expression(), + &converted_filter_expr, + ) { + return Ok(Some(converted_filter_expr)); + } + } + Ok(None) +} + +/// This function is used to build the filter expression based on the sort order of input columns. +/// +/// It first calls the [convert_sort_expr_with_filter_schema] method to determine if the sort +/// order of columns can be used in the filter expression. If it returns a [Some] value, the +/// method wraps the result in a [SortedFilterExpr] instance with the original sort expression and +/// the converted filter expression. Otherwise, this function returns an error. +/// +/// The [SortedFilterExpr] instance contains information about the sort order of columns that can +/// be used in the filter expression, which can be used to optimize the query execution process. +pub fn build_filter_input_order( + side: JoinSide, + filter: &JoinFilter, + schema: &SchemaRef, + order: &PhysicalSortExpr, +) -> Result { + if let Some(expr) = + convert_sort_expr_with_filter_schema(&side, filter, schema, order)? + { + Ok(SortedFilterExpr::new(order.clone(), expr)) + } else { + Err(DataFusionError::Plan(format!( + "The {side} side of the join does not have an expression sorted." + ))) + } +} + +/// Convert a physical expression into a filter expression using the given +/// column mapping information. +fn convert_filter_columns( + input: Arc, + column_map: &HashMap, +) -> Result>> { + // Attempt to downcast the input expression to a Column type. + Ok(if let Some(col) = input.as_any().downcast_ref::() { + // If the downcast is successful, retrieve the corresponding filter column. + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } else { + // If the downcast fails, return the input expression as is. + Some(input) + }) +} + +/// The [SortedFilterExpr] object represents a sorted filter expression. It +/// contains the following information: The origin expression, the filter +/// expression, an interval encapsulating expression bounds, and a stable +/// index identifying the expression in the expression DAG. +/// +/// Physical schema of a [JoinFilter]'s intermediate batch combines two sides +/// and uses new column names. In this process, a column exchange is done so +/// we can utilize sorting information while traversing the filter expression +/// DAG for interval calculations. When evaluating the inner buffer, we use +/// `origin_sorted_expr`. +#[derive(Debug, Clone)] +pub struct SortedFilterExpr { + /// Sorted expression from a join side (i.e. a child of the join) + origin_sorted_expr: PhysicalSortExpr, + /// Expression adjusted for filter schema. + filter_expr: Arc, + /// Interval containing expression bounds + interval: Interval, + /// Node index in the expression DAG + node_index: usize, +} + +impl SortedFilterExpr { + /// Constructor + pub fn new( + origin_sorted_expr: PhysicalSortExpr, + filter_expr: Arc, + ) -> Self { + Self { + origin_sorted_expr, + filter_expr, + interval: Interval::default(), + node_index: 0, + } + } + /// Get origin expr information + pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { + &self.origin_sorted_expr + } + /// Get filter expr information + pub fn filter_expr(&self) -> &Arc { + &self.filter_expr + } + /// Get interval information + pub fn interval(&self) -> &Interval { + &self.interval + } + /// Sets interval + pub fn set_interval(&mut self, interval: Interval) { + self.interval = interval; + } + /// Node index in ExprIntervalGraph + pub fn node_index(&self) -> usize { + self.node_index + } + /// Node index setter in ExprIntervalGraph + pub fn set_node_index(&mut self, node_index: usize) { + self.node_index = node_index; + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::physical_plan::{ + expressions::Column, + expressions::PhysicalSortExpr, + joins::utils::{ColumnIndex, JoinFilter, JoinSide}, + }; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{binary, cast, col, lit}; + use std::sync::Arc; + + /// Filter expr for a + b > c + 10 AND a + b < c + 100 + pub(crate) fn complicated_filter( + filter_schema: &Schema, + ) -> Result> { + let left_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + filter_schema, + )?, + filter_schema, + )?; + + let right_expr = binary( + cast( + binary( + col("0", filter_schema)?, + Operator::Plus, + col("1", filter_schema)?, + filter_schema, + )?, + filter_schema, + DataType::Int64, + )?, + Operator::Lt, + binary( + cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(100))), + filter_schema, + )?, + filter_schema, + )?; + binary(left_expr, Operator::And, right_expr, filter_schema) + } + + #[test] + fn test_column_exchange() -> Result<()> { + let left_child_schema = + Schema::new(vec![Field::new("left_1", DataType::Int32, true)]); + // Sorting information for the left side: + let left_child_sort_expr = PhysicalSortExpr { + expr: col("left_1", &left_child_schema)?, + options: SortOptions::default(), + }; + + let right_child_schema = Schema::new(vec![ + Field::new("right_1", DataType::Int32, true), + Field::new("right_2", DataType::Int32, true), + ]); + // Sorting information for the right side: + let right_child_sort_expr = PhysicalSortExpr { + expr: binary( + col("right_1", &right_child_schema)?, + Operator::Plus, + col("right_2", &right_child_schema)?, + &right_child_schema, + )?, + options: SortOptions::default(), + }; + + let intermediate_schema = Schema::new(vec![ + Field::new("filter_1", DataType::Int32, true), + Field::new("filter_2", DataType::Int32, true), + Field::new("filter_3", DataType::Int32, true), + ]); + // Our filter expression is: left_1 > right_1 + right_2. + let filter_left = col("filter_1", &intermediate_schema)?; + let filter_right = binary( + col("filter_2", &intermediate_schema)?, + Operator::Plus, + col("filter_3", &intermediate_schema)?, + &intermediate_schema, + )?; + let filter_expr = binary( + filter_left.clone(), + Operator::Gt, + filter_right.clone(), + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let left_sort_filter_expr = build_filter_input_order( + JoinSide::Left, + &filter, + &Arc::new(left_child_schema), + &left_child_sort_expr, + )?; + assert!(left_child_sort_expr.eq(left_sort_filter_expr.origin_sorted_expr())); + + let right_sort_filter_expr = build_filter_input_order( + JoinSide::Right, + &filter, + &Arc::new(right_child_schema), + &right_child_sort_expr, + )?; + assert!(right_child_sort_expr.eq(right_sort_filter_expr.origin_sorted_expr())); + + // Assert that adjusted (left) filter expression matches with `left_child_sort_expr`: + assert!(filter_left.eq(left_sort_filter_expr.filter_expr())); + // Assert that adjusted (right) filter expression matches with `right_child_sort_expr`: + assert!(filter_right.eq(right_sort_filter_expr.filter_expr())); + Ok(()) + } + + #[test] + fn test_column_collector() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&schema)?; + let columns = collect_columns(&filter_expr); + assert_eq!(columns.len(), 3); + Ok(()) + } + + #[test] + fn find_expr_inside_expr() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&schema)?; + + let expr_1 = Arc::new(Column::new("gnz", 0)) as _; + assert!(!check_filter_expr_contains_sort_information( + &filter_expr, + &expr_1 + )); + + let expr_2 = col("1", &schema)? as _; + + assert!(check_filter_expr_contains_sort_information( + &filter_expr, + &expr_2 + )); + + let expr_3 = cast( + binary( + col("0", &schema)?, + Operator::Plus, + col("1", &schema)?, + &schema, + )?, + &schema, + DataType::Int64, + )?; + + assert!(check_filter_expr_contains_sort_information( + &filter_expr, + &expr_3 + )); + + let expr_4 = Arc::new(Column::new("1", 42)) as _; + + assert!(!check_filter_expr_contains_sort_information( + &filter_expr, + &expr_4, + )); + Ok(()) + } + + #[test] + fn build_sorted_expr() -> Result<()> { + let left_schema = Schema::new(vec![ + Field::new("la1", DataType::Int32, false), + Field::new("lb1", DataType::Int32, false), + Field::new("lc1", DataType::Int32, false), + Field::new("lt1", DataType::Int32, false), + Field::new("la2", DataType::Int32, false), + Field::new("la1_des", DataType::Int32, false), + ]); + + let right_schema = Schema::new(vec![ + Field::new("ra1", DataType::Int32, false), + Field::new("rb1", DataType::Int32, false), + Field::new("rc1", DataType::Int32, false), + Field::new("rt1", DataType::Int32, false), + Field::new("ra2", DataType::Int32, false), + Field::new("ra1_des", DataType::Int32, false), + ]); + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let left_schema = Arc::new(left_schema); + let right_schema = Arc::new(right_schema); + + assert!(build_filter_input_order( + JoinSide::Left, + &filter, + &left_schema, + &PhysicalSortExpr { + expr: col("la1", left_schema.as_ref())?, + options: SortOptions::default(), + } + ) + .is_ok()); + assert!(build_filter_input_order( + JoinSide::Left, + &filter, + &left_schema, + &PhysicalSortExpr { + expr: col("lt1", left_schema.as_ref())?, + options: SortOptions::default(), + } + ) + .is_err()); + assert!(build_filter_input_order( + JoinSide::Right, + &filter, + &right_schema, + &PhysicalSortExpr { + expr: col("ra1", right_schema.as_ref())?, + options: SortOptions::default(), + } + ) + .is_ok()); + assert!(build_filter_input_order( + JoinSide::Right, + &filter, + &right_schema, + &PhysicalSortExpr { + expr: col("rb1", right_schema.as_ref())?, + options: SortOptions::default(), + } + ) + .is_err()); + + Ok(()) + } + + // Test the case when we have an "ORDER BY a + b", and join filter condition includes "a - b". + #[test] + fn sorted_filter_expr_build() -> Result<()> { + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + ]); + let filter_expr = binary( + col("0", &intermediate_schema)?, + Operator::Minus, + col("1", &intermediate_schema)?, + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let sorted = PhysicalSortExpr { + expr: binary( + col("a", &schema)?, + Operator::Plus, + col("b", &schema)?, + &schema, + )?, + options: SortOptions::default(), + }; + + let res = convert_sort_expr_with_filter_schema( + &JoinSide::Left, + &filter, + &Arc::new(schema), + &sorted, + )?; + assert!(res.is_none()); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs index 63762ab3cf1b..8ad50514f0b0 100644 --- a/datafusion/core/src/physical_plan/joins/mod.rs +++ b/datafusion/core/src/physical_plan/joins/mod.rs @@ -17,10 +17,19 @@ //! DataFusion Join implementations +pub use cross_join::CrossJoinExec; +pub use hash_join::HashJoinExec; +pub use hash_join_utils::convert_sort_expr_with_filter_schema; +pub use nested_loop_join::NestedLoopJoinExec; +// Note: SortMergeJoin is not used in plans yet +pub use sort_merge_join::SortMergeJoinExec; +pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; +mod hash_join_utils; mod nested_loop_join; mod sort_merge_join; +mod symmetric_hash_join; pub mod utils; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -34,10 +43,3 @@ pub enum PartitionMode { /// It will also consider swapping the left and right inputs for the Join Auto, } - -pub use cross_join::CrossJoinExec; -pub use hash_join::HashJoinExec; -pub use nested_loop_join::NestedLoopJoinExec; - -// Note: SortMergeJoin is not used in plans yet -pub use sort_merge_join::SortMergeJoinExec; diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs index e3834573eb2f..c283b11f8f6a 100644 --- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs +++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs @@ -24,7 +24,7 @@ use crate::physical_plan::joins::utils::{ build_batch_from_indices, build_join_schema, check_join_is_valid, combine_join_equivalence_properties, estimate_join_statistics, get_anti_indices, get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices, - get_semi_u64_indices, ColumnIndex, JoinFilter, OnceAsync, OnceFut, + get_semi_u64_indices, ColumnIndex, JoinFilter, JoinSide, OnceAsync, OnceFut, }; use crate::physical_plan::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, @@ -348,6 +348,7 @@ fn build_join_indices( left_indices, right_indices, filter, + JoinSide::Left, ) } else { Ok((left_indices, right_indices)) @@ -412,6 +413,7 @@ impl NestedLoopJoinStream { left_side, right_side, &self.column_indices, + JoinSide::Left, ); self.is_exhausted = true; Some(result) @@ -516,6 +518,7 @@ fn join_left_and_right_batch( left_side, right_side, column_indices, + JoinSide::Left, ) } Err(e) => Err(e), diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs new file mode 100644 index 000000000000..c377be6d1ea0 --- /dev/null +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -0,0 +1,2473 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This file implements the symmetric hash join algorithm with range-based +//! data pruning to join two (potentially infinite) streams. +//! +//! A [SymmetricHashJoinExec] plan takes two children plan (with appropriate +//! output ordering) and produces the join output according to the given join +//! type and other options. +//! +//! This plan uses the [OneSideHashJoiner] object to facilitate join calculations +//! for both its children. + +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; +use std::task::Poll; +use std::vec; +use std::{any::Any, usize}; + +use ahash::RandomState; +use arrow::array::{ + ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, + PrimitiveBuilder, +}; +use arrow::compute::concat_batches; +use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use futures::{Stream, StreamExt}; +use hashbrown::{raw::RawTable, HashSet}; + +use datafusion_common::{utils::bisect, ScalarValue}; +use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval}; + +use crate::error::{DataFusionError, Result}; +use crate::execution::context::TaskContext; +use crate::logical_expr::JoinType; +use crate::physical_plan::{ + expressions::Column, + expressions::PhysicalSortExpr, + joins::{ + hash_join::{build_join_indices, update_hash, JoinHashMap}, + hash_join_utils::{build_filter_input_order, SortedFilterExpr}, + utils::{ + build_batch_from_indices, build_join_schema, check_join_is_valid, + combine_join_equivalence_properties, partitioned_join_output_partitioning, + ColumnIndex, JoinFilter, JoinOn, JoinSide, + }, + }, + metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, + DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; + +/// A symmetric hash join with range conditions is when both streams are hashed on the +/// join key and the resulting hash tables are used to join the streams. +/// The join is considered symmetric because the hash table is built on the join keys from both +/// streams, and the matching of rows is based on the values of the join keys in both streams. +/// This type of join is efficient in streaming context as it allows for fast lookups in the hash +/// table, rather than having to scan through one or both of the streams to find matching rows, also it +/// only considers the elements from the stream that fall within a certain sliding window (w/ range conditions), +/// making it more efficient and less likely to store stale data. This enables operating on unbounded streaming +/// data without any memory issues. +/// +/// For each input stream, create a hash table. +/// - For each new [RecordBatch] in build side, hash and insert into inputs hash table. Update offsets. +/// - Test if input is equal to a predefined set of other inputs. +/// - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch]. +/// - Try to prune other side (probe) with new [RecordBatch]. +/// - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.), +/// output the [RecordBatch] when a pruning happens or at the end of the data. +/// +/// +/// ``` text +/// +-------------------------+ +/// | | +/// left stream ---------| Left OneSideHashJoiner |---+ +/// | | | +/// +-------------------------+ | +/// | +/// |--------- Joined output +/// | +/// +-------------------------+ | +/// | | | +/// right stream ---------| Right OneSideHashJoiner |---+ +/// | | +/// +-------------------------+ +/// +/// Prune build side when the new RecordBatch comes to the probe side. We utilize interval arithmetic +/// on JoinFilter's sorted PhysicalExprs to calculate the joinable range. +/// +/// +/// PROBE SIDE BUILD SIDE +/// BUFFER BUFFER +/// +-------------+ +------------+ +/// | | | | Unjoinable +/// | | | | Range +/// | | | | +/// | | |--------------------------------- +/// | | | | | +/// | | | | | +/// | | / | | +/// | | | | | +/// | | | | | +/// | | | | | +/// | | | | | +/// | | | | | Joinable +/// | |/ | | Range +/// | || | | +/// |+-----------+|| | | +/// || Record || | | +/// || Batch || | | +/// |+-----------+|| | | +/// +-------------+\ +------------+ +/// | +/// \ +/// |--------------------------------- +/// +/// This happens when range conditions are provided on sorted columns. E.g. +/// +/// SELECT * FROM left_table, right_table +/// ON +/// left_key = right_key AND +/// left_time > right_time - INTERVAL 12 MINUTES AND left_time < right_time + INTERVAL 2 HOUR +/// +/// or +/// SELECT * FROM left_table, right_table +/// ON +/// left_key = right_key AND +/// left_sorted > right_sorted - 3 AND left_sorted < right_sorted + 10 +/// +/// For general purpose, in the second scenario, when the new data comes to probe side, the conditions can be used to +/// determine a specific threshold for discarding rows from the inner buffer. For example, if the sort order the +/// two columns ("left_sorted" and "right_sorted") are ascending (it can be different in another scenarios) +/// and the join condition is "left_sorted > right_sorted - 3" and the latest value on the right input is 1234, meaning +/// that the left side buffer must only keep rows where "leftTime > rightTime - 3 > 1234 - 3 > 1231" , +/// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending) +/// than that can be dropped from the inner buffer. +/// ``` +pub struct SymmetricHashJoinExec { + /// Left side stream + pub(crate) left: Arc, + /// Right side stream + pub(crate) right: Arc, + /// Set of common columns used to join on + pub(crate) on: Vec<(Column, Column)>, + /// Filters applied when finding matching rows + pub(crate) filter: JoinFilter, + /// How the join is performed + pub(crate) join_type: JoinType, + /// Order information of filter expressions + sorted_filter_exprs: Vec, + /// Left required sort + left_required_sort_exprs: Vec, + /// Right required sort + right_required_sort_exprs: Vec, + /// Expression graph for interval calculations + physical_expr_graph: ExprIntervalGraph, + /// The schema once the join is applied + schema: SchemaRef, + /// Shares the `RandomState` for the hashing algorithm + random_state: RandomState, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// If null_equals_null is true, null == null else null != null + pub(crate) null_equals_null: bool, +} + +#[derive(Debug)] +struct SymmetricHashJoinSideMetrics { + /// Number of batches consumed by this operator + input_batches: metrics::Count, + /// Number of rows consumed by this operator + input_rows: metrics::Count, +} + +/// Metrics for HashJoinExec +#[derive(Debug)] +struct SymmetricHashJoinMetrics { + /// Number of left batches/rows consumed by this operator + left: SymmetricHashJoinSideMetrics, + /// Number of right batches/rows consumed by this operator + right: SymmetricHashJoinSideMetrics, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl SymmetricHashJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let left = SymmetricHashJoinSideMetrics { + input_batches, + input_rows, + }; + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let right = SymmetricHashJoinSideMetrics { + input_batches, + input_rows, + }; + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + left, + right, + output_batches, + output_rows, + } + } +} + +impl SymmetricHashJoinExec { + /// Tries to create a new [SymmetricHashJoinExec]. + /// # Error + /// This function errors when: + /// - It is not possible to join the left and right sides on keys `on`, or + /// - It fails to construct [SortedFilterExpr]s, or + /// - It fails to create the [ExprIntervalGraph]. + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: &JoinType, + null_equals_null: &bool, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + + // Error out if no "on" contraints are given: + if on.is_empty() { + return Err(DataFusionError::Plan( + "On constraints in SymmetricHashJoinExec should be non-empty".to_string(), + )); + } + + // Check if the join is valid with the given on constraints: + check_join_is_valid(&left_schema, &right_schema, &on)?; + + // Build the join schema from the left and right schemas: + let (schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + + // Set a random state for the join: + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + // Create an expression DAG for the join filter: + let mut physical_expr_graph = + ExprIntervalGraph::try_new(filter.expression().clone())?; + + // Interval calculations require each column to exhibit monotonicity + // independently. However, a `PhysicalSortExpr` object defines a + // lexicographical ordering, so we can only use their first elements. + // when deducing column monotonicities. + // TODO: Extend the `PhysicalSortExpr` mechanism to express independent + // (i.e. simultaneous) ordering properties of columns. + let (left_ordering, right_ordering) = match ( + left.output_ordering(), + right.output_ordering(), + ) { + (Some([left_ordering, ..]), Some([right_ordering, ..])) => { + (left_ordering, right_ordering) + } + _ => { + return Err(DataFusionError::Plan( + "Symmetric hash join requires its children to have an output ordering".to_string(), + )); + } + }; + + // Build the sorted filter expression for the left child: + let left_filter_expression = build_filter_input_order( + JoinSide::Left, + &filter, + &left.schema(), + left_ordering, + )?; + + // Build the sorted filter expression for the right child: + let right_filter_expression = build_filter_input_order( + JoinSide::Right, + &filter, + &right.schema(), + right_ordering, + )?; + + // Store the left and right sorted filter expressions in a vector + let mut sorted_filter_exprs = + vec![left_filter_expression, right_filter_expression]; + + // Gather node indices of converted filter expressions in `SortedFilterExpr` + // using the filter columns vector: + let child_node_indexes = physical_expr_graph.gather_node_indices( + &sorted_filter_exprs + .iter() + .map(|sorted_expr| sorted_expr.filter_expr().clone()) + .collect::>(), + ); + + // Inject calculated node indices into SortedFilterExpr: + for (sorted_expr, (_, index)) in sorted_filter_exprs + .iter_mut() + .zip(child_node_indexes.iter()) + { + sorted_expr.set_node_index(*index); + } + + let left_required_sort_exprs = vec![left_ordering.clone()]; + let right_required_sort_exprs = vec![right_ordering.clone()]; + + Ok(SymmetricHashJoinExec { + left, + right, + on, + filter, + join_type: *join_type, + sorted_filter_exprs, + left_required_sort_exprs, + right_required_sort_exprs, + physical_expr_graph, + schema: Arc::new(schema), + random_state, + metrics: ExecutionPlanMetricsSet::new(), + column_indices, + null_equals_null: *null_equals_null, + }) + } + + /// left stream + pub fn left(&self) -> &Arc { + &self.left + } + + /// right stream + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(Column, Column)] { + &self.on + } + + /// Filters applied before join output + pub fn filter(&self) -> &JoinFilter { + &self.filter + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// Get null_equals_null + pub fn null_equals_null(&self) -> &bool { + &self.null_equals_null + } +} + +impl Debug for SymmetricHashJoinExec { + fn fmt(&self, _f: &mut Formatter<'_>) -> fmt::Result { + todo!() + } +} + +impl ExecutionPlan for SymmetricHashJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn required_input_ordering(&self) -> Vec> { + vec![ + Some(&self.left_required_sort_exprs), + Some(&self.right_required_sort_exprs), + ] + } + + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children.iter().any(|u| *u)) + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + // TODO: This will change when we extend collected executions. + vec![ + if self.left.output_partitioning().partition_count() == 1 { + Distribution::SinglePartition + } else { + Distribution::HashPartitioned(left_expr) + }, + if self.right.output_partitioning().partition_count() == 1 { + Distribution::SinglePartition + } else { + Distribution::HashPartitioned(right_expr) + }, + ] + } + + fn output_partitioning(&self) -> Partitioning { + let left_columns_len = self.left.schema().fields.len(); + partitioned_join_output_partitioning( + self.join_type, + self.left.output_partitioning(), + self.right.output_partitioning(), + left_columns_len, + ) + } + + // TODO: Output ordering might be kept for some cases. + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + let left_columns_len = self.left.schema().fields.len(); + combine_join_equivalence_properties( + self.join_type, + self.left.equivalence_properties(), + self.right.equivalence_properties(), + left_columns_len, + self.on(), + self.schema(), + ) + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(SymmetricHashJoinExec::try_new( + children[0].clone(), + children[1].clone(), + self.on.clone(), + self.filter.clone(), + &self.join_type, + &self.null_equals_null, + )?)) + } + + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default => { + let display_filter = format!(", filter={:?}", self.filter.expression()); + write!( + f, + "SymmetricHashJoinExec: join_type={:?}, on={:?}{}", + self.join_type, self.on, display_filter + ) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + // TODO stats: it is not possible in general to know the output size of joins + Statistics::default() + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); + let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + let left_side_joiner = OneSideHashJoiner::new( + JoinSide::Left, + self.sorted_filter_exprs[0].clone(), + on_left, + self.left.schema(), + ); + let right_side_joiner = OneSideHashJoiner::new( + JoinSide::Right, + self.sorted_filter_exprs[1].clone(), + on_right, + self.right.schema(), + ); + let left_stream = self.left.execute(partition, context.clone())?; + let right_stream = self.right.execute(partition, context)?; + + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics), + physical_expr_graph: self.physical_expr_graph.clone(), + null_equals_null: self.null_equals_null, + final_result: false, + probe_side: JoinSide::Left, + })) + } +} + +/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +struct SymmetricHashJoinStream { + /// Left stream + left_stream: SendableRecordBatchStream, + /// right stream + right_stream: SendableRecordBatchStream, + /// Input schema + schema: Arc, + /// join filter + filter: JoinFilter, + /// type of the join + join_type: JoinType, + // left hash joiner + left: OneSideHashJoiner, + /// right hash joiner + right: OneSideHashJoiner, + /// Information of index and left / right placement of columns + column_indices: Vec, + // Range pruner. + physical_expr_graph: ExprIntervalGraph, + /// Random state used for hashing initialization + random_state: RandomState, + /// If null_equals_null is true, null == null else null != null + null_equals_null: bool, + /// Metrics + metrics: SymmetricHashJoinMetrics, + /// Flag indicating whether there is nothing to process anymore + final_result: bool, + /// The current probe side. We choose build and probe side according to this attribute. + probe_side: JoinSide, +} + +impl RecordBatchStream for SymmetricHashJoinStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for SymmetricHashJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.poll_next_impl(cx) + } +} + +fn prune_hash_values( + prune_length: usize, + hashmap: &mut JoinHashMap, + row_hash_values: &mut VecDeque, + offset: u64, +) -> Result<()> { + // Create a (hash)-(row number set) map + let mut hash_value_map: HashMap> = HashMap::new(); + for index in 0..prune_length { + let hash_value = row_hash_values.pop_front().unwrap(); + if let Some(set) = hash_value_map.get_mut(&hash_value) { + set.insert(offset + index as u64); + } else { + let mut set = HashSet::new(); + set.insert(offset + index as u64); + hash_value_map.insert(hash_value, set); + } + } + for (hash_value, index_set) in hash_value_map.iter() { + if let Some((_, separation_chain)) = hashmap + .0 + .get_mut(*hash_value, |(hash, _)| hash_value == hash) + { + separation_chain.retain(|n| !index_set.contains(n)); + if separation_chain.is_empty() { + hashmap + .0 + .remove_entry(*hash_value, |(hash, _)| hash_value == hash); + } + } + } + Ok(()) +} + +/// Calculate the filter expression intervals. +/// +/// This function updates the `interval` field of each `SortedFilterExpr` based +/// on the first or the last value of the expression in `build_input_buffer` +/// and `probe_batch`. +/// +/// # Arguments +/// +/// * `build_input_buffer` - The [RecordBatch] on the build side of the join. +/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. +/// * `probe_batch` - The `RecordBatch` on the probe side of the join. +/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. +/// +/// ### Note +/// ```text +/// +/// Interval arithmetic is used to calculate viable join ranges for build-side +/// pruning. This is done by first creating an interval for join filter values in +/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the +/// ordering (descending/ascending) of the filter expression. Here, FV denotes the +/// first value on the build side. This range is then compared with the probe side +/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering +/// (ascending/descending) of the probe side. Here, LV denotes the last value on +/// the probe side. +/// +/// As a concrete example, consider the following query: +/// +/// SELECT * FROM left_table, right_table +/// WHERE +/// left_key = right_key AND +/// a > b - 3 AND +/// a < b + 10 +/// +/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// respectively. When a new `RecordBatch` arrives at the right side, the +/// condition a > b - 3 will possibly indicate a prunable range for the left +/// side. Conversely, when a new `RecordBatch` arrives at the left side, the +/// condition a < b + 10 will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// side (i.e. when the left side is the build side): +/// +/// Build Probe +/// +-------+ +-------+ +/// | a | z | | b | y | +/// |+--|--+| |+--|--+| +/// | 1 | 2 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 3 | 1 | | 4 | 3 | +/// |+--|--+| |+--|--+| +/// | 5 | 7 | | 6 | 1 | +/// |+--|--+| |+--|--+| +/// | 7 | 1 | | 6 | 3 | +/// +-------+ +-------+ +/// +/// In this case, the interval representing viable (i.e. joinable) values for +/// column "a" is [1, ∞], and the interval representing possible future values +/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// intervals for the whole filter expression and propagate join constraint by +/// traversing the expression graph. +/// ``` +fn calculate_filter_expr_intervals( + build_input_buffer: &RecordBatch, + build_sorted_filter_expr: &mut SortedFilterExpr, + probe_batch: &RecordBatch, + probe_sorted_filter_expr: &mut SortedFilterExpr, +) -> Result<()> { + // If either build or probe side has no data, return early: + if build_input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { + return Ok(()); + } + // Evaluate build side filter expression and convert the result to an array + let build_array = build_sorted_filter_expr + .origin_sorted_expr() + .expr + .evaluate(&build_input_buffer.slice(0, 1))? + .into_array(1); + // Evaluate probe side filter expression and convert the result to an array + let probe_array = probe_sorted_filter_expr + .origin_sorted_expr() + .expr + .evaluate(&probe_batch.slice(probe_batch.num_rows() - 1, 1))? + .into_array(1); + + // Update intervals for both build and probe side filter expressions + for (array, sorted_expr) in vec![ + (build_array, build_sorted_filter_expr), + (probe_array, probe_sorted_filter_expr), + ] { + // Convert the array to a ScalarValue: + let value = ScalarValue::try_from_array(&array, 0)?; + // Create a ScalarValue representing positive or negative infinity for the same data type: + let infinite = ScalarValue::try_from(value.get_datatype())?; + // Update the interval with lower and upper bounds based on the sort option + sorted_expr.set_interval( + if sorted_expr.origin_sorted_expr().options.descending { + Interval { + lower: infinite, + upper: value, + } + } else { + Interval { + lower: value, + upper: infinite, + } + }, + ); + } + Ok(()) +} + +/// Determine the pruning length for `buffer`. +/// +/// This function evaluates the build side filter expression, converts the +/// result into an array and determines the pruning length by performing a +/// binary search on the array. +/// +/// # Arguments +/// +/// * `buffer`: The record batch to be pruned. +/// * `build_side_filter_expr`: The filter expression on the build side used +/// to determine the pruning length. +/// +/// # Returns +/// +/// A [Result] object that contains the pruning length. The function will return +/// an error if there is an issue evaluating the build side filter expression. +fn determine_prune_length( + buffer: &RecordBatch, + build_side_filter_expr: &SortedFilterExpr, +) -> Result { + let origin_sorted_expr = build_side_filter_expr.origin_sorted_expr(); + let interval = build_side_filter_expr.interval(); + // Evaluate the build side filter expression and convert it into an array + let batch_arr = origin_sorted_expr + .expr + .evaluate(buffer)? + .into_array(buffer.num_rows()); + + // Get the lower or upper interval based on the sort direction + let target = if origin_sorted_expr.options.descending { + interval.upper.clone() + } else { + interval.lower.clone() + }; + + // Perform binary search on the array to determine the length of the record batch to be pruned + bisect::(&[batch_arr], &[target], &[origin_sorted_expr.options]) +} + +/// This method determines if the result of the join should be produced in the final step or not. +/// +/// # Arguments +/// +/// * `build_side` - Enum indicating the side of the join used as the build side. +/// * `join_type` - Enum indicating the type of join to be performed. +/// +/// # Returns +/// +/// A boolean indicating whether the result of the join should be produced in the final step or not. +/// The result will be true if the build side is JoinSide::Left and the join type is one of +/// JoinType::Left, JoinType::LeftAnti, JoinType::Full or JoinType::LeftSemi. +/// If the build side is JoinSide::Right, the result will be true if the join type +/// is one of JoinType::Right, JoinType::RightAnti, JoinType::Full, or JoinType::RightSemi. +fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> bool { + if build_side == JoinSide::Left { + matches!( + join_type, + JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi + ) + } else { + matches!( + join_type, + JoinType::Right | JoinType::RightAnti | JoinType::Full | JoinType::RightSemi + ) + } +} + +/// Get the anti join indices from the visited hash set. +/// +/// This method returns the indices from the original input that were not present in the visited hash set. +/// +/// # Arguments +/// +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. +/// +/// # Returns +/// +/// A `PrimitiveArray` of the anti join indices. +fn get_anti_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set + for v in 0..prune_length { + let row = v + deleted_offset; + bitmap.set_bit(v, visited_rows.contains(&row)); + } + // get the anti index + (0..prune_length) + .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect() +} + +/// This method creates a boolean buffer from the visited rows hash set +/// and the indices of the pruned record batch slice. +/// +/// It gets the indices from the original input that were present in the visited hash set. +/// +/// # Arguments +/// +/// * `prune_length` - The length of the pruned record batch. +/// * `deleted_offset` - The offset to the indices. +/// * `visited_rows` - The hash set of visited indices. +/// +/// # Returns +/// +/// A [PrimitiveArray] of the specified type T, containing the semi indices. +fn get_semi_indices( + prune_length: usize, + deleted_offset: usize, + visited_rows: &HashSet, +) -> PrimitiveArray +where + NativeAdapter: From<::Native>, +{ + let mut bitmap = BooleanBufferBuilder::new(prune_length); + bitmap.append_n(prune_length, false); + // mark the indices as true if they are present in the visited hash set + (0..prune_length).for_each(|v| { + let row = &(v + deleted_offset); + bitmap.set_bit(v, visited_rows.contains(row)); + }); + // get the semi index + (0..prune_length) + .filter_map(|idx| (bitmap.get_bit(idx)).then_some(T::Native::from_usize(idx))) + .collect::>() +} +/// Records the visited indices from the input `PrimitiveArray` of type `T` into the given hash set `visited`. +/// This function will insert the indices (offset by `offset`) into the `visited` hash set. +/// +/// # Arguments +/// +/// * `visited` - A hash set to store the visited indices. +/// * `offset` - An offset to the indices in the `PrimitiveArray`. +/// * `indices` - The input `PrimitiveArray` of type `T` which stores the indices to be recorded. +/// +fn record_visited_indices( + visited: &mut HashSet, + offset: usize, + indices: &PrimitiveArray, +) { + for i in indices.values() { + visited.insert(i.as_usize() + offset); + } +} + +/// Calculate indices by join type. +/// +/// This method returns a tuple of two arrays: build and probe indices. +/// The length of both arrays will be the same. +/// +/// # Arguments +/// +/// * `build_side`: Join side which defines the build side. +/// * `prune_length`: Length of the prune data. +/// * `visited_rows`: Hash set of visited rows of the build side. +/// * `deleted_offset`: Deleted offset of the build side. +/// * `join_type`: The type of join to be performed. +/// +/// # Returns +/// +/// A tuple of two arrays of primitive types representing the build and probe indices. +/// +fn calculate_indices_by_join_type( + build_side: JoinSide, + prune_length: usize, + visited_rows: &HashSet, + deleted_offset: usize, + join_type: JoinType, +) -> Result<(PrimitiveArray, PrimitiveArray)> +where + NativeAdapter: From<::Native>, +{ + // Store the result in a tuple + let result = match (build_side, join_type) { + // In the case of `Left` or `Right` join, or `Full` join, get the anti indices + (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) + | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) + | (_, JoinType::Full) => { + let build_unmatched_indices = + get_anti_indices(prune_length, deleted_offset, visited_rows); + let mut builder = + PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); + builder.append_nulls(build_unmatched_indices.len()); + let probe_indices = builder.finish(); + (build_unmatched_indices, probe_indices) + } + // In the case of `LeftSemi` or `RightSemi` join, get the semi indices + (JoinSide::Left, JoinType::LeftSemi) | (JoinSide::Right, JoinType::RightSemi) => { + let build_unmatched_indices = + get_semi_indices(prune_length, deleted_offset, visited_rows); + let mut builder = + PrimitiveBuilder::::with_capacity(build_unmatched_indices.len()); + builder.append_nulls(build_unmatched_indices.len()); + let probe_indices = builder.finish(); + (build_unmatched_indices, probe_indices) + } + // The case of other join types is not considered + _ => unreachable!(), + }; + Ok(result) +} + +struct OneSideHashJoiner { + /// Build side + build_side: JoinSide, + /// Build side filter sort information + sorted_filter_expr: SortedFilterExpr, + /// Input record batch buffer + input_buffer: RecordBatch, + /// Columns from the side + on: Vec, + /// Hashmap + hashmap: JoinHashMap, + /// To optimize hash deleting in case of pruning, we hold them in memory + row_hash_values: VecDeque, + /// Reuse the hashes buffer + hashes_buffer: Vec, + /// Matched rows + visited_rows: HashSet, + /// Offset + offset: usize, + /// Deleted offset + deleted_offset: usize, + /// Side is exhausted + exhausted: bool, +} + +impl OneSideHashJoiner { + pub fn new( + build_side: JoinSide, + sorted_filter_expr: SortedFilterExpr, + on: Vec, + schema: SchemaRef, + ) -> Self { + Self { + build_side, + input_buffer: RecordBatch::new_empty(schema), + on, + hashmap: JoinHashMap(RawTable::with_capacity(10_000)), + row_hash_values: VecDeque::new(), + hashes_buffer: vec![], + sorted_filter_expr, + visited_rows: HashSet::new(), + offset: 0, + deleted_offset: 0, + exhausted: false, + } + } + + /// Updates the internal state of the [OneSideHashJoiner] with the incoming batch. + /// + /// # Arguments + /// + /// * `batch` - The incoming [RecordBatch] to be merged with the internal input buffer + /// * `random_state` - The random state used to hash values + /// + /// # Returns + /// + /// Returns a [Result] encapsulating any intermediate errors. + fn update_internal_state( + &mut self, + batch: &RecordBatch, + random_state: &RandomState, + ) -> Result<()> { + // Merge the incoming batch with the existing input buffer: + self.input_buffer = concat_batches(&batch.schema(), [&self.input_buffer, batch])?; + // Resize the hashes buffer to the number of rows in the incoming batch: + self.hashes_buffer.resize(batch.num_rows(), 0); + // Update the hashmap with the join key values and hashes of the incoming batch: + update_hash( + &self.on, + batch, + &mut self.hashmap, + self.offset, + random_state, + &mut self.hashes_buffer, + )?; + // Add the hashes buffer to the hash value deque: + self.row_hash_values.extend(self.hashes_buffer.iter()); + Ok(()) + } + + /// This method performs a join between the build side input buffer and the probe side batch. + /// + /// # Arguments + /// + /// * `schema` - A reference to the schema of the output record batch. + /// * `join_type` - The type of join to be performed. + /// * `on_probe` - An array of columns on which the join will be performed. The columns are from the probe side of the join. + /// * `filter` - An optional filter on the join condition. + /// * `probe_batch` - The second record batch to be joined. + /// * `probe_visited` - A hash set to store the visited indices from the probe batch. + /// * `probe_offset` - The offset of the probe side for visited indices calculations. + /// * `column_indices` - An array of columns to be selected for the result of the join. + /// * `random_state` - The random state for the join. + /// * `null_equals_null` - A boolean indicating whether NULL values should be treated as equal when joining. + /// + /// # Returns + /// + /// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`. + /// If the join type is one of the above four, the function will return [None]. + #[allow(clippy::too_many_arguments)] + fn join_with_probe_batch( + &mut self, + schema: &SchemaRef, + join_type: JoinType, + on_probe: &[Column], + filter: &JoinFilter, + probe_batch: &RecordBatch, + probe_visited: &mut HashSet, + probe_offset: usize, + column_indices: &[ColumnIndex], + random_state: &RandomState, + null_equals_null: &bool, + ) -> Result> { + if self.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 { + return Ok(Some(RecordBatch::new_empty(schema.clone()))); + } + let (build_indices, probe_indices) = build_join_indices( + probe_batch, + &self.hashmap, + &self.input_buffer, + &self.on, + on_probe, + Some(filter), + random_state, + null_equals_null, + &mut self.hashes_buffer, + Some(self.deleted_offset), + self.build_side, + )?; + if need_to_produce_result_in_final(self.build_side, join_type) { + record_visited_indices( + &mut self.visited_rows, + self.deleted_offset, + &build_indices, + ); + } + if need_to_produce_result_in_final(self.build_side.negate(), join_type) { + record_visited_indices(probe_visited, probe_offset, &probe_indices); + } + if matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + ) { + Ok(None) + } else { + build_batch_from_indices( + schema, + &self.input_buffer, + probe_batch, + build_indices, + probe_indices, + column_indices, + self.build_side, + ) + .map(Some) + } + } + + /// This function produces unmatched record results based on the build side, + /// join type and other parameters. + /// + /// The method uses first `prune_length` rows from the build side input buffer + /// to produce results. + /// + /// # Arguments + /// + /// * `output_schema` - The schema of the final output record batch. + /// * `prune_length` - The length of the determined prune length. + /// * `probe_schema` - The schema of the probe [RecordBatch]. + /// * `join_type` - The type of join to be performed. + /// * `column_indices` - Indices of columns that are being joined. + /// + /// # Returns + /// + /// * `Option` - The final output record batch if required, otherwise [None]. + fn build_side_determined_results( + &self, + output_schema: &SchemaRef, + prune_length: usize, + probe_schema: SchemaRef, + join_type: JoinType, + column_indices: &[ColumnIndex], + ) -> Result> { + // Check if we need to produce a result in the final output: + if need_to_produce_result_in_final(self.build_side, join_type) { + // Calculate the indices for build and probe sides based on join type and build side: + let (build_indices, probe_indices) = calculate_indices_by_join_type( + self.build_side, + prune_length, + &self.visited_rows, + self.deleted_offset, + join_type, + )?; + + // Create an empty probe record batch: + let empty_probe_batch = RecordBatch::new_empty(probe_schema); + // Build the final result from the indices of build and probe sides: + build_batch_from_indices( + output_schema.as_ref(), + &self.input_buffer, + &empty_probe_batch, + build_indices, + probe_indices, + column_indices, + self.build_side, + ) + .map(Some) + } else { + // If we don't need to produce a result, return None + Ok(None) + } + } + + /// Prunes the internal buffer. + /// + /// Argument `probe_batch` is used to update the intervals of the sorted + /// filter expressions. The updated build interval determines the new length + /// of the build side. If there are rows to prune, they are removed from the + /// internal buffer. + /// + /// # Arguments + /// + /// * `schema` - The schema of the final output record batch + /// * `probe_batch` - Incoming RecordBatch of the probe side. + /// * `probe_side_sorted_filter_expr` - Probe side mutable sorted filter expression. + /// * `join_type` - The type of join (e.g. inner, left, right, etc.). + /// * `column_indices` - A vector of column indices that specifies which columns from the + /// build side should be included in the output. + /// * `physical_expr_graph` - A mutable reference to the physical expression graph. + /// + /// # Returns + /// + /// If there are rows to prune, returns the pruned build side record batch wrapped in an `Ok` variant. + /// Otherwise, returns `Ok(None)`. + fn prune_with_probe_batch( + &mut self, + schema: &SchemaRef, + probe_batch: &RecordBatch, + probe_side_sorted_filter_expr: &mut SortedFilterExpr, + join_type: JoinType, + column_indices: &[ColumnIndex], + physical_expr_graph: &mut ExprIntervalGraph, + ) -> Result> { + // Check if the input buffer is empty: + if self.input_buffer.num_rows() == 0 { + return Ok(None); + } + // Convert the sorted filter expressions into a vector of (node_index, interval) + // tuples for use when updating the interval graph. + let mut filter_intervals = vec![ + ( + self.sorted_filter_expr.node_index(), + self.sorted_filter_expr.interval().clone(), + ), + ( + probe_side_sorted_filter_expr.node_index(), + probe_side_sorted_filter_expr.interval().clone(), + ), + ]; + // Use the join filter intervals to update the physical expression graph: + physical_expr_graph.update_ranges(&mut filter_intervals)?; + // Get the new join filter interval for build side: + let calculated_build_side_interval = filter_intervals.remove(0).1; + // Check if the intervals changed, exit early if not: + if calculated_build_side_interval.eq(self.sorted_filter_expr.interval()) { + return Ok(None); + } + // Determine the pruning length if there was a change in the intervals: + self.sorted_filter_expr + .set_interval(calculated_build_side_interval); + let prune_length = + determine_prune_length(&self.input_buffer, &self.sorted_filter_expr)?; + // If we can not prune, exit early: + if prune_length == 0 { + return Ok(None); + } + // Compute the result, and perform pruning if there are rows to prune: + let result = self.build_side_determined_results( + schema, + prune_length, + probe_batch.schema(), + join_type, + column_indices, + ); + prune_hash_values( + prune_length, + &mut self.hashmap, + &mut self.row_hash_values, + self.deleted_offset as u64, + )?; + for row in self.deleted_offset..(self.deleted_offset + prune_length) { + self.visited_rows.remove(&row); + } + self.input_buffer = self + .input_buffer + .slice(prune_length, self.input_buffer.num_rows() - prune_length); + self.deleted_offset += prune_length; + result + } +} + +fn combine_two_batches( + output_schema: &SchemaRef, + left_batch: Option, + right_batch: Option, +) -> Result> { + match (left_batch, right_batch) { + (Some(batch), None) | (None, Some(batch)) => { + // If only one of the batches are present, return it: + Ok(Some(batch)) + } + (Some(left_batch), Some(right_batch)) => { + // If both batches are present, concatenate them: + concat_batches(output_schema, &[left_batch, right_batch]) + .map_err(DataFusionError::ArrowError) + .map(Some) + } + (None, None) => { + // If neither is present, return an empty batch: + Ok(None) + } + } +} + +impl SymmetricHashJoinStream { + /// Polls the next result of the join operation. + /// + /// If the result of the join is ready, it returns the next record batch. + /// If the join has completed and there are no more results, it returns + /// `Poll::Ready(None)`. If the join operation is not complete, but the + /// current stream is not ready yet, it returns `Poll::Pending`. + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + // If the final result has already been obtained, return `Poll::Ready(None)`: + if self.final_result { + return Poll::Ready(None); + } + // If both streams have been exhausted, return the final result: + if self.right.exhausted && self.left.exhausted { + // Get left side results: + let left_result = self.left.build_side_determined_results( + &self.schema, + self.left.input_buffer.num_rows(), + self.right.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + // Get right side results: + let right_result = self.right.build_side_determined_results( + &self.schema, + self.right.input_buffer.num_rows(), + self.left.input_buffer.schema(), + self.join_type, + &self.column_indices, + )?; + self.final_result = true; + // Combine results: + let result = + combine_two_batches(&self.schema, left_result, right_result)?; + // Update the metrics if we have a batch; otherwise, continue the loop. + if let Some(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Ok(result).transpose()); + } else { + continue; + } + } + + // Determine which stream should be polled next. The side the + // RecordBatch comes from becomes the probe side. + let ( + input_stream, + probe_hash_joiner, + build_hash_joiner, + build_join_side, + probe_side_metrics, + ) = if self.probe_side.eq(&JoinSide::Left) { + ( + &mut self.left_stream, + &mut self.left, + &mut self.right, + JoinSide::Right, + &mut self.metrics.left, + ) + } else { + ( + &mut self.right_stream, + &mut self.right, + &mut self.left, + JoinSide::Left, + &mut self.metrics.right, + ) + }; + // Poll the next batch from `input_stream`: + match input_stream.poll_next_unpin(cx) { + // Batch is available + Poll::Ready(Some(Ok(probe_batch))) => { + // Update the metrics for the stream that was polled: + probe_side_metrics.input_batches.add(1); + probe_side_metrics.input_rows.add(probe_batch.num_rows()); + // Update the internal state of the hash joiner for the build side: + probe_hash_joiner + .update_internal_state(&probe_batch, &self.random_state)?; + // Calculate filter intervals: + calculate_filter_expr_intervals( + &build_hash_joiner.input_buffer, + &mut build_hash_joiner.sorted_filter_expr, + &probe_batch, + &mut probe_hash_joiner.sorted_filter_expr, + )?; + // Join the two sides: + let equal_result = build_hash_joiner.join_with_probe_batch( + &self.schema, + self.join_type, + &probe_hash_joiner.on, + &self.filter, + &probe_batch, + &mut probe_hash_joiner.visited_rows, + probe_hash_joiner.offset, + &self.column_indices, + &self.random_state, + &self.null_equals_null, + )?; + // Increment the offset for the probe hash joiner: + probe_hash_joiner.offset += probe_batch.num_rows(); + // Prune the build side input buffer using the expression + // DAG and filter intervals: + let anti_result = build_hash_joiner.prune_with_probe_batch( + &self.schema, + &probe_batch, + &mut probe_hash_joiner.sorted_filter_expr, + self.join_type, + &self.column_indices, + &mut self.physical_expr_graph, + )?; + // Combine results: + let result = + combine_two_batches(&self.schema, equal_result, anti_result)?; + // Choose next poll side. If the other side is not exhausted, + // switch the probe side before returning the result. + if !build_hash_joiner.exhausted { + self.probe_side = build_join_side; + } + // Update the metrics if we have a batch; otherwise, continue the loop. + if let Some(batch) = &result { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Ok(result).transpose()); + } + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + // Mark the probe side exhausted: + probe_hash_joiner.exhausted = true; + // Change the probe side: + self.probe_side = build_join_side; + } + Poll::Pending => { + if !build_hash_joiner.exhausted { + self.probe_side = build_join_side; + } else { + return Poll::Pending; + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::fs::File; + + use arrow::array::ArrayRef; + use arrow::array::{Int32Array, TimestampNanosecondArray}; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::util::pretty::pretty_format_batches; + use rstest::*; + use tempfile::TempDir; + + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{binary, col, Column}; + use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numeric_expr; + use datafusion_physical_expr::PhysicalExpr; + + use crate::physical_plan::joins::{ + hash_join_utils::tests::complicated_filter, HashJoinExec, PartitionMode, + }; + use crate::physical_plan::{ + collect, common, memory::MemoryExec, repartition::RepartitionExec, + }; + use crate::prelude::{SessionConfig, SessionContext}; + use crate::test_util; + + use super::*; + + const TABLE_SIZE: i32 = 1_000; + + fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { + // compare + let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); + let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); + + let mut first_formatted_sorted: Vec<&str> = + first_formatted.trim().lines().collect(); + first_formatted_sorted.sort_unstable(); + + let mut second_formatted_sorted: Vec<&str> = + second_formatted.trim().lines().collect(); + second_formatted_sorted.sort_unstable(); + + for (i, (first_line, second_line)) in first_formatted_sorted + .iter() + .zip(&second_formatted_sorted) + .enumerate() + { + assert_eq!((i, first_line), (i, second_line)); + } + } + #[allow(clippy::too_many_arguments)] + async fn partitioned_sym_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, + ) -> Result> { + let partition_count = 4; + + let left_expr = on + .iter() + .map(|(l, _)| Arc::new(l.clone()) as _) + .collect::>(); + + let right_expr = on + .iter() + .map(|(_, r)| Arc::new(r.clone()) as _) + .collect::>(); + + let join = SymmetricHashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + filter, + join_type, + &null_equals_null, + )?; + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) + } + #[allow(clippy::too_many_arguments)] + async fn partitioned_hash_join_with_filter( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, + ) -> Result> { + let partition_count = 4; + + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + + let join = HashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + Some(filter), + join_type, + PartitionMode::Partitioned, + &null_equals_null, + )?; + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, context.clone())?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok(batches) + } + + pub fn split_record_batches( + batch: &RecordBatch, + batch_size: usize, + ) -> Result> { + let row_num = batch.num_rows(); + let number_of_batch = row_num / batch_size; + let mut sizes = vec![batch_size; number_of_batch]; + sizes.push(row_num - (batch_size * number_of_batch)); + let mut result = vec![]; + for (i, size) in sizes.iter().enumerate() { + result.push(batch.slice(i * batch_size, *size)); + } + Ok(result) + } + + fn join_expr_tests_fixture( + expr_id: usize, + left_col: Arc, + right_col: Arc, + ) -> Arc { + match expr_id { + // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 0 => gen_conjunctive_numeric_expr( + left_col, + right_col, + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 5, + 3, + 10, + ), + // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + 1 => gen_conjunctive_numeric_expr( + left_col, + right_col, + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 5, + 3, + 10, + ), + // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 + 2 => gen_conjunctive_numeric_expr( + left_col, + right_col, + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + 1, + 5, + 3, + 10, + ), + // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10 + 3 => gen_conjunctive_numeric_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + 10, + 5, + 3, + 10, + ), + // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3 + 4 => gen_conjunctive_numeric_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + 10, + 5, + 30, + 3, + ), + _ => unreachable!(), + } + } + fn build_sides_record_batches( + table_size: i32, + key_cardinality: (i32, i32), + ) -> Result<(RecordBatch, RecordBatch)> { + let null_ratio: f64 = 0.4; + let initial_range = 0..table_size; + let index = (table_size as f64 * null_ratio).round() as i32; + let rest_of = index..table_size; + let ordered: ArrayRef = Arc::new(Int32Array::from_iter( + initial_range.clone().collect::>(), + )); + let ordered_des = Arc::new(Int32Array::from_iter( + initial_range.clone().rev().collect::>(), + )); + let cardinality = Arc::new(Int32Array::from_iter( + initial_range.clone().map(|x| x % 4).collect::>(), + )); + let cardinality_key = Arc::new(Int32Array::from_iter( + initial_range + .clone() + .map(|x| x % key_cardinality.0) + .collect::>(), + )); + let ordered_asc_null_first = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.clone().map(Some)) + .collect::>>() + })); + let ordered_asc_null_last = Arc::new(Int32Array::from_iter({ + rest_of + .clone() + .map(Some) + .chain(std::iter::repeat(None).take(index as usize)) + .collect::>>() + })); + + let ordered_desc_null_first = Arc::new(Int32Array::from_iter({ + std::iter::repeat(None) + .take(index as usize) + .chain(rest_of.rev().map(Some)) + .collect::>>() + })); + + let time = Arc::new(TimestampNanosecondArray::from( + initial_range + .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) + .collect::>(), + )); + + let left = RecordBatch::try_from_iter(vec![ + ("la1", ordered.clone()), + ("lb1", cardinality.clone()), + ("lc1", cardinality_key.clone()), + ("lt1", time.clone()), + ("la2", ordered.clone()), + ("la1_des", ordered_des.clone()), + ("l_asc_null_first", ordered_asc_null_first.clone()), + ("l_asc_null_last", ordered_asc_null_last.clone()), + ("l_desc_null_first", ordered_desc_null_first.clone()), + ])?; + let right = RecordBatch::try_from_iter(vec![ + ("ra1", ordered.clone()), + ("rb1", cardinality), + ("rc1", cardinality_key), + ("rt1", time), + ("ra2", ordered), + ("ra1_des", ordered_des), + ("r_asc_null_first", ordered_asc_null_first), + ("r_asc_null_last", ordered_asc_null_last), + ("r_desc_null_first", ordered_desc_null_first), + ])?; + Ok((left, right)) + } + + fn create_memory_table( + left_batch: RecordBatch, + right_batch: RecordBatch, + left_sorted: Vec, + right_sorted: Vec, + batch_size: usize, + ) -> Result<(Arc, Arc)> { + Ok(( + Arc::new( + MemoryExec::try_new( + &[split_record_batches(&left_batch, batch_size).unwrap()], + left_batch.schema(), + None, + )? + .with_sort_information(left_sorted), + ), + Arc::new( + MemoryExec::try_new( + &[split_record_batches(&right_batch, batch_size).unwrap()], + right_batch.schema(), + None, + )? + .with_sort_information(right_sorted), + ), + )) + } + + async fn experiment( + left: Arc, + right: Arc, + filter: JoinFilter, + join_type: JoinType, + on: JoinOn, + task_ctx: Arc, + ) -> Result<()> { + let first_batches = partitioned_sym_join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + let second_batches = partitioned_hash_join_with_filter( + left, right, on, filter, &join_type, false, task_ctx, + ) + .await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_numeric( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (31, 71), + (99, 12), + )] + cardinality: (i32, i32), + ) -> Result<()> { + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: binary( + col("la1", left_schema)?, + Operator::Plus, + col("la2", left_schema)?, + left_schema, + )?, + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_all_one_ascending_numeric( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (31, 71), + (99, 12), + )] + cardinality: (i32, i32), + #[values(0, 1, 2, 3, 4)] case_expr: usize, + ) -> Result<()> { + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn single_test() -> Result<()> { + let case_expr = 1; + let cardinality = (11, 21); + let join_type = JoinType::Full; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1_des", left_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1_des", right_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 5, + side: JoinSide::Left, + }, + ColumnIndex { + index: 5, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_all_one_descending_numeric_particular( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (31, 71), + (99, 12), + )] + cardinality: (i32, i32), + #[values(0, 1, 2, 3, 4)] case_expr: usize, + ) -> Result<()> { + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1_des", left_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1_des", right_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 5, + side: JoinSide::Left, + }, + ColumnIndex { + index: 5, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn join_change_in_planner() -> Result<()> { + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::with_config(config); + let tmp_dir = TempDir::new().unwrap(); + let left_file_path = tmp_dir.path().join("left.csv"); + File::create(left_file_path.clone()).unwrap(); + test_util::test_create_unbounded_sorted_file( + &ctx, + left_file_path.clone(), + "left", + ) + .await?; + let right_file_path = tmp_dir.path().join("right.csv"); + File::create(right_file_path.clone()).unwrap(); + test_util::test_create_unbounded_sorted_file( + &ctx, + right_file_path.clone(), + "right", + ) + .await?; + let df = ctx.sql("EXPLAIN SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; + let physical_plan = df.create_physical_plan().await?; + let task_ctx = ctx.task_ctx(); + let results = collect(physical_plan.clone(), task_ctx).await.unwrap(); + let formatted = pretty_format_batches(&results).unwrap().to_string(); + let found = formatted + .lines() + .any(|line| line.contains("SymmetricHashJoinExec")); + assert!(found); + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_first() -> Result<()> { + let join_type = JoinType::Full; + let cardinality = (10, 11); + let case_expr = 1; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_asc_null_first", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_asc_null_first", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 6, + side: JoinSide::Left, + }, + ColumnIndex { + index: 6, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_last() -> Result<()> { + let join_type = JoinType::Full; + let cardinality = (10, 11); + let case_expr = 1; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_asc_null_last", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_asc_null_last", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 7, + side: JoinSide::Left, + }, + ColumnIndex { + index: 7, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_first_descending() -> Result<()> { + let join_type = JoinType::Full; + let cardinality = (10, 11); + let case_expr = 1; + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_desc_null_first", left_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_desc_null_first", right_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = join_expr_tests_fixture( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + ); + let column_indices = vec![ + ColumnIndex { + index: 8, + side: JoinSide::Left, + }, + ColumnIndex { + index: 8, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn complex_join_all_one_ascending_numeric_missing_stat() -> Result<()> { + let cardinality = (3, 4); + let join_type = JoinType::Full; + + // a + b > c + 10 AND a + b < c + 100 + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1", left_schema)?, + options: SortOptions::default(), + }]; + + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 13)?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let filter_expr = complicated_filter(&intermediate_schema)?; + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 4, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment(left, right, filter, join_type, on, task_ctx).await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn test_one_side_hash_joiner_visited_rows( + #[values( + (JoinType::Inner, true), + (JoinType::Left,false), + (JoinType::Right, true), + (JoinType::RightSemi, true), + (JoinType::LeftSemi, false), + (JoinType::LeftAnti, false), + (JoinType::RightAnti, true), + (JoinType::Full, false), + )] + case: (JoinType, bool), + ) -> Result<()> { + // Set a random state for the join + let join_type = case.0; + let should_be_empty = case.1; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + // Ensure there will be matching rows + let (left_batch, right_batch) = build_sides_record_batches(20, (1, 1))?; + let left_schema = left_batch.schema(); + let right_schema = right_batch.schema(); + + // Build the join schema from the left and right schemas + let (schema, join_column_indices) = + build_join_schema(&left_schema, &right_schema, &join_type); + let join_schema = Arc::new(schema); + + // Sort information for MemoryExec + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1", &left_schema)?, + options: SortOptions::default(), + }]; + // Sort information for MemoryExec + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1", &right_schema)?, + options: SortOptions::default(), + }]; + // Construct MemoryExec + let (left, right) = + create_memory_table(left_batch, right_batch, left_sorted, right_sorted, 10)?; + + // Filter columns, ensure first batches will have matching rows. + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + ]); + let filter_expr = gen_conjunctive_numeric_expr( + col("0", &intermediate_schema)?, + col("1", &intermediate_schema)?, + Operator::Plus, + Operator::Minus, + Operator::Plus, + Operator::Plus, + 0, + 3, + 0, + 3, + ); + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + let left_sorted_filter_expr = SortedFilterExpr::new( + PhysicalSortExpr { + expr: col("la1", &left_schema)?, + options: SortOptions::default(), + }, + Arc::new(Column::new("0", 0)), + ); + let mut left_side_joiner = OneSideHashJoiner::new( + JoinSide::Left, + left_sorted_filter_expr, + vec![Column::new_with_schema("lc1", &left_schema)?], + left_schema, + ); + + let right_sorted_filter_expr = SortedFilterExpr::new( + PhysicalSortExpr { + expr: col("ra1", &right_schema)?, + options: SortOptions::default(), + }, + Arc::new(Column::new("1", 0)), + ); + let mut right_side_joiner = OneSideHashJoiner::new( + JoinSide::Right, + right_sorted_filter_expr, + vec![Column::new_with_schema("rc1", &right_schema)?], + right_schema, + ); + + let mut left_stream = left.execute(0, task_ctx.clone())?; + let mut right_stream = right.execute(0, task_ctx)?; + + let initial_left_batch = left_stream.next().await.unwrap()?; + left_side_joiner.update_internal_state(&initial_left_batch, &random_state)?; + assert_eq!( + left_side_joiner.input_buffer.num_rows(), + initial_left_batch.num_rows() + ); + + let initial_right_batch = right_stream.next().await.unwrap()?; + right_side_joiner.update_internal_state(&initial_right_batch, &random_state)?; + assert_eq!( + right_side_joiner.input_buffer.num_rows(), + initial_right_batch.num_rows() + ); + + left_side_joiner.join_with_probe_batch( + &join_schema, + join_type, + &right_side_joiner.on, + &filter, + &initial_right_batch, + &mut right_side_joiner.visited_rows, + right_side_joiner.offset, + &join_column_indices, + &random_state, + &false, + )?; + assert_eq!(left_side_joiner.visited_rows.is_empty(), should_be_empty); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index 150c8bc9b317..a756d2ba8938 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -17,11 +17,6 @@ //! Join related functionality used both on logical and physical plans -use crate::error::{DataFusionError, Result, SharedResult}; -use crate::logical_expr::JoinType; -use crate::physical_plan::expressions::Column; -use crate::physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; -use crate::physical_plan::SchemaRef; use arrow::array::{ new_null_array, Array, BooleanBufferBuilder, PrimitiveArray, UInt32Array, UInt32Builder, UInt64Array, @@ -29,22 +24,32 @@ use arrow::array::{ use arrow::compute; use arrow::datatypes::{Field, Schema, UInt32Type, UInt64Type}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; -use datafusion_common::cast::as_boolean_array; -use datafusion_common::ScalarValue; -use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use parking_lot::Mutex; use std::cmp::max; use std::collections::HashSet; +use std::fmt::{Display, Formatter}; use std::future::Future; use std::sync::Arc; use std::task::{Context, Poll}; +use std::usize; + +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{ScalarValue, SharedResult}; + +use datafusion_physical_expr::rewrite::TreeNodeRewritable; +use datafusion_physical_expr::{EquivalentClass, PhysicalExpr}; + +use crate::error::{DataFusionError, Result}; +use crate::logical_expr::JoinType; +use crate::physical_plan::expressions::Column; +use crate::physical_plan::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; +use crate::physical_plan::SchemaRef; use crate::physical_plan::{ ColumnStatistics, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics, }; -use datafusion_physical_expr::rewrite::TreeNodeRewritable; /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; @@ -221,8 +226,17 @@ pub fn cross_join_equivalence_properties( new_properties } +impl Display for JoinSide { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + JoinSide::Left => write!(f, "left"), + JoinSide::Right => write!(f, "right"), + } + } +} + /// Used in ColumnIndex to distinguish which side the index is for -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum JoinSide { /// Left side of the join Left, @@ -230,6 +244,16 @@ pub enum JoinSide { Right, } +impl JoinSide { + /// Inverse the join side + pub fn negate(&self) -> Self { + match self { + JoinSide::Left => JoinSide::Right, + JoinSide::Right => JoinSide::Left, + } + } +} + /// Information about the index and placement (left or right) of the columns #[derive(Debug, Clone)] pub struct ColumnIndex { @@ -743,26 +767,26 @@ pub(crate) fn get_final_indices_from_bit_map( (left_indices, right_indices) } -/// Use the `left_indices` and `right_indices` to restructure tuples, and apply the `filter` to -/// all of them to get the matched left and right indices. pub(crate) fn apply_join_filter_to_indices( - left: &RecordBatch, - right: &RecordBatch, - left_indices: UInt64Array, - right_indices: UInt32Array, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_indices: UInt64Array, + probe_indices: UInt32Array, filter: &JoinFilter, + build_side: JoinSide, ) -> Result<(UInt64Array, UInt32Array)> { - if left_indices.is_empty() && right_indices.is_empty() { - return Ok((left_indices, right_indices)); + if build_indices.is_empty() && probe_indices.is_empty() { + return Ok((build_indices, probe_indices)); }; let intermediate_batch = build_batch_from_indices( filter.schema(), - left, - right, - PrimitiveArray::from(left_indices.data().clone()), - PrimitiveArray::from(right_indices.data().clone()), + build_input_buffer, + probe_batch, + PrimitiveArray::from(build_indices.data().clone()), + PrimitiveArray::from(probe_indices.data().clone()), filter.column_indices(), + build_side, )?; let filter_result = filter .expression() @@ -771,12 +795,11 @@ pub(crate) fn apply_join_filter_to_indices( let mask = as_boolean_array(&filter_result)?; let left_filtered = PrimitiveArray::::from( - compute::filter(&left_indices, mask)?.data().clone(), + compute::filter(&build_indices, mask)?.data().clone(), ); let right_filtered = PrimitiveArray::::from( - compute::filter(&right_indices, mask)?.data().clone(), + compute::filter(&probe_indices, mask)?.data().clone(), ); - Ok((left_filtered, right_filtered)) } @@ -784,16 +807,17 @@ pub(crate) fn apply_join_filter_to_indices( /// The resulting batch has [Schema] `schema`. pub(crate) fn build_batch_from_indices( schema: &Schema, - left: &RecordBatch, - right: &RecordBatch, - left_indices: UInt64Array, - right_indices: UInt32Array, + build_input_buffer: &RecordBatch, + probe_batch: &RecordBatch, + build_indices: UInt64Array, + probe_indices: UInt32Array, column_indices: &[ColumnIndex], + build_side: JoinSide, ) -> Result { if schema.fields().is_empty() { let options = RecordBatchOptions::new() .with_match_field_names(true) - .with_row_count(Some(left_indices.len())); + .with_row_count(Some(build_indices.len())); return Ok(RecordBatch::try_new_with_options( Arc::new(schema.clone()), @@ -808,27 +832,24 @@ pub(crate) fn build_batch_from_indices( let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); for column_index in column_indices { - let array = match column_index.side { - JoinSide::Left => { - let array = left.column(column_index.index); - if array.is_empty() || left_indices.null_count() == left_indices.len() { - // Outer join would generate a null index when finding no match at our side. - // Therefore, it's possible we are empty but need to populate an n-length null array, - // where n is the length of the index array. - assert_eq!(left_indices.null_count(), left_indices.len()); - new_null_array(array.data_type(), left_indices.len()) - } else { - compute::take(array.as_ref(), &left_indices, None)? - } + let array = if column_index.side == build_side { + let array = build_input_buffer.column(column_index.index); + if array.is_empty() || build_indices.null_count() == build_indices.len() { + // Outer join would generate a null index when finding no match at our side. + // Therefore, it's possible we are empty but need to populate an n-length null array, + // where n is the length of the index array. + assert_eq!(build_indices.null_count(), build_indices.len()); + new_null_array(array.data_type(), build_indices.len()) + } else { + compute::take(array.as_ref(), &build_indices, None)? } - JoinSide::Right => { - let array = right.column(column_index.index); - if array.is_empty() || right_indices.null_count() == right_indices.len() { - assert_eq!(right_indices.null_count(), right_indices.len()); - new_null_array(array.data_type(), right_indices.len()) - } else { - compute::take(array.as_ref(), &right_indices, None)? - } + } else { + let array = probe_batch.column(column_index.index); + if array.is_empty() || probe_indices.null_count() == probe_indices.len() { + assert_eq!(probe_indices.null_count(), probe_indices.len()); + new_null_array(array.data_type(), probe_indices.len()) + } else { + compute::take(array.as_ref(), &probe_indices, None)? } }; columns.push(array); diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs index 557daca61db9..f0cd48fa4f9d 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/core/src/physical_plan/memory.rs @@ -45,6 +45,8 @@ pub struct MemoryExec { projected_schema: SchemaRef, /// Optional projection projection: Option>, + // Optional sort information + sort_information: Option>, } impl fmt::Debug for MemoryExec { @@ -77,7 +79,7 @@ impl ExecutionPlan for MemoryExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None + self.sort_information.as_deref() } fn with_new_children( @@ -144,8 +146,18 @@ impl MemoryExec { schema, projected_schema, projection, + sort_information: None, }) } + + /// Set sort information + pub fn with_sort_information( + mut self, + sort_information: Vec, + ) -> Self { + self.sort_information = Some(sort_information); + self + } } /// Iterator over batches diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index 66059d713211..37f3a1667f37 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -27,11 +27,13 @@ use crate::datasource::datasource::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; +use crate::execution::options::ReadOptions; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; +use crate::prelude::{CsvReadOptions, SessionContext}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; @@ -512,3 +514,44 @@ mod tests { assert!(PathBuf::from(res).is_dir()); } } + +/// This function creates an unbounded sorted file for testing purposes. +pub async fn test_create_unbounded_sorted_file( + ctx: &SessionContext, + file_path: PathBuf, + table_name: &str, +) -> datafusion_common::Result<()> { + // Create schema: + let schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + ])); + // Specify the ordering: + let file_sort_order = [datafusion_expr::col("a1")] + .into_iter() + .map(|e| { + let ascending = true; + let nulls_first = false; + e.sort(ascending, nulls_first) + }) + .collect::>(); + // Mark infinite and provide schema: + let fifo_options = CsvReadOptions::new() + .schema(schema.as_ref()) + .has_header(false) + .mark_infinite(true); + // Get listing options: + let options_sort = fifo_options + .to_listing_options(&ctx.copied_config()) + .with_file_sort_order(Some(file_sort_order)); + // Register table: + ctx.register_listing_table( + table_name, + file_path.as_os_str().to_str().unwrap(), + options_sort, + Some(schema), + None, + ) + .await?; + Ok(()) +} diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs index 7b524454d2ba..d8b66b0a29cd 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo.rs @@ -18,17 +18,23 @@ //! This test demonstrates the DataFusion FIFO capabilities. //! #[cfg(not(target_os = "windows"))] +#[cfg(test)] mod unix_test { + use arrow::array::Array; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, - test_util::{aggr_test_schema, arrow_test_data}, + test_util::{ + aggr_test_schema, arrow_test_data, test_create_unbounded_sorted_file, + }, }; use datafusion_common::{DataFusionError, Result}; use futures::StreamExt; use itertools::enumerate; use nix::sys::stat; use nix::unistd; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; use rstest::*; use std::fs::{File, OpenOptions}; use std::io::Write; @@ -38,8 +44,10 @@ mod unix_test { use std::sync::mpsc::{Receiver, Sender}; use std::sync::{Arc, Mutex}; use std::thread; + use std::thread::JoinHandle; use std::time::{Duration, Instant}; use tempfile::TempDir; + // ! For the sake of the test, do not alter the numbers. ! // Session batch size const TEST_BATCH_SIZE: usize = 20; @@ -133,7 +141,7 @@ mod unix_test { } // This test provides a relatively realistic end-to-end scenario where - // we ensure that we swap join sides correctly to accommodate a FIFO source. + // we swap join sides to accommodate a FIFO source. #[rstest] #[timeout(std::time::Duration::from_secs(30))] #[tokio::test(flavor = "multi_thread", worker_threads = 5)] @@ -147,7 +155,7 @@ mod unix_test { let (tx, rx): (Sender, Receiver) = mpsc::channel(); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; - let fifo_path = create_fifo_file(&tmp_dir, "fisrt_fifo.csv")?; + let fifo_path = create_fifo_file(&tmp_dir, "first_fifo.csv")?; // Prevent move let fifo_path_thread = fifo_path.clone(); // Timeout for a long period of BrokenPipe error @@ -217,4 +225,112 @@ mod unix_test { assert_eq!(interleave(&result), unbounded_file); Ok(()) } + + #[derive(Debug, PartialEq)] + enum JoinOperation { + LeftUnmatched, + RightUnmatched, + Equal, + } + + // This test provides a relatively realistic end-to-end scenario where + // we change the join into a [SymmetricHashJoin] to accommodate two + // unbounded (FIFO) sources. + #[rstest] + #[timeout(std::time::Duration::from_secs(30))] + #[tokio::test(flavor = "multi_thread")] + async fn unbounded_file_with_symmetric_join() -> Result<()> { + // To make unbounded deterministic + let waiting = Arc::new(Mutex::new(true)); + let thread_bools = vec![waiting.clone(), waiting.clone()]; + // Create a new temporary FIFO file + let tmp_dir = TempDir::new()?; + let file_names = vec!["first_fifo.csv", "second_fifo.csv"]; + // The sender endpoint can be copied + let (threads, file_paths): (Vec>, Vec) = file_names + .iter() + .zip(thread_bools.iter()) + .map(|(file_name, lock)| { + let waiting_thread = lock.clone(); + let fifo_path = create_fifo_file(&tmp_dir, file_name).unwrap(); + let return_path = fifo_path.clone(); + // Timeout for a long period of BrokenPipe error + let broken_pipe_timeout = Duration::from_secs(5); + // Spawn a new thread to write to the FIFO file + let fifo_writer = thread::spawn(move || { + let mut rng = StdRng::seed_from_u64(42); + let file = OpenOptions::new() + .write(true) + .open(fifo_path.clone()) + .unwrap(); + // Reference time to use when deciding to fail the test + let execution_start = Instant::now(); + // Join filter + let a1_iter = (0..TEST_DATA_SIZE).map(|x| { + if rng.gen_range(0.0..1.0) < 0.3 { + x - 1 + } else { + x + } + }); + // Join key + let a2_iter = (0..TEST_DATA_SIZE).map(|x| x % 10); + for (cnt, (a1, a2)) in a1_iter.zip(a2_iter).enumerate() { + // Wait a reading sign for unbounded execution + // After first batch FIFO reading, we will wait for a batch created. + while *waiting_thread.lock().unwrap() && TEST_BATCH_SIZE + 1 < cnt + { + thread::sleep(Duration::from_millis(200)); + } + let line = format!("{a1},{a2}\n").to_owned(); + write_to_fifo(&file, &line, execution_start, broken_pipe_timeout) + .unwrap(); + } + }); + (fifo_writer, return_path) + }) + .unzip(); + let config = SessionConfig::new() + .with_batch_size(TEST_BATCH_SIZE) + .set_bool("datafusion.execution.coalesce_batches", false) + .with_target_partitions(1); + let ctx = SessionContext::with_config(config); + test_create_unbounded_sorted_file(&ctx, file_paths[0].clone(), "left").await?; + test_create_unbounded_sorted_file(&ctx, file_paths[1].clone(), "right").await?; + // Execute the query + let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; + let mut stream = df.execute_stream().await?; + let mut operations = vec![]; + while let Some(Ok(batch)) = stream.next().await { + *waiting.lock().unwrap() = false; + let op = if batch.column(0).null_count() > 0 { + JoinOperation::LeftUnmatched + } else if batch.column(2).null_count() > 0 { + JoinOperation::RightUnmatched + } else { + JoinOperation::Equal + }; + operations.push(op); + } + + // The SymmetricHashJoin executor produces FULL join results at every + // pruning, which happens before it reaches the end of input and more + // than once. In this test, we feed partially joinable data to both + // sides in order to ensure that both left/right unmatched results are + // generated more than once during the test. + assert!( + operations + .iter() + .filter(|&n| JoinOperation::RightUnmatched.eq(n)) + .count() + > 1 + && operations + .iter() + .filter(|&n| JoinOperation::LeftUnmatched.eq(n)) + .count() + > 1 + ); + threads.into_iter().for_each(|j| j.join().unwrap()); + Ok(()) + } } diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index ca6fb75276ff..d6acffdcf4ad 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -110,6 +110,35 @@ impl Operator { } } + /// Return true if the operator is a comparison operator. + /// + /// For example, 'Binary(a, >, b)' would be a comparison expression. + pub fn is_comparison_operator(&self) -> bool { + matches!( + self, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + ) + } + + /// Return true if the operator is a logic operator. + /// + /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would + /// be a logical expression. + pub fn is_logic_operator(&self) -> bool { + matches!(self, Operator::And | Operator::Or) + } + /// Return the operator where swapping lhs and rhs wouldn't change the result. /// /// For example `Binary(50, >=, a)` could also be represented as `Binary(a, <=, 50)`. diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d8cc89d79925..460a58be51ff 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -60,6 +60,7 @@ lazy_static = { version = "^1.4.0" } md-5 = { version = "^0.10.0", optional = true } num-traits = { version = "0.2", default-features = false } paste = "^1.0" +petgraph = "0.6.2" rand = "0.8" regex = { version = "^1.4.3", optional = true } sha2 = { version = "^0.10.1", optional = true } @@ -69,6 +70,7 @@ uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] criterion = "0.4" rand = "0.8" +rstest = "0.16.0" [[bench]] harness = false diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 0024deb3d4d8..48244a0d07d4 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -74,6 +74,8 @@ use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; use super::column::Column; +use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; +use crate::intervals::{apply_operator, Interval}; use crate::physical_expr::down_cast_any_ref; use crate::{analysis_expect, AnalysisContext, ExprBoundaries, PhysicalExpr}; use datafusion_common::cast::{as_boolean_array, as_decimal128_array}; @@ -813,6 +815,49 @@ impl PhysicalExpr for BinaryExpr { _ => context.with_boundaries(None), } } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + // Get children intervals: + let left_interval = children[0]; + let right_interval = children[1]; + // Calculate current node's interval: + apply_operator(&self.op, left_interval, right_interval) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + // Get children intervals. Graph brings + let left_interval = children[0]; + let right_interval = children[1]; + let (left, right) = if self.op.is_logic_operator() { + // TODO: Currently, this implementation only supports the AND operator + // and does not require any further propagation. In the future, + // upon adding support for additional logical operators, this + // method will require modification to support propagating the + // changes accordingly. + return Ok(vec![]); + } else if self.op.is_comparison_operator() { + if let Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + } = interval + { + // TODO: We will handle strictly false clauses by negating + // the comparison operator (e.g. GT to LE, LT to GE) + // once open/closed intervals are supported. + return Ok(vec![]); + } + // Propagate the comparison operator. + propagate_comparison(&self.op, left_interval, right_interval)? + } else { + // Propagate the arithmetic operator. + propagate_arithmetic(&self.op, interval, left_interval, right_interval)? + }; + Ok(vec![left, right]) + } } impl PartialEq for BinaryExpr { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 4d01131353d8..e3d1b1d67317 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::fmt; use std::sync::Arc; +use crate::intervals::Interval; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; @@ -68,6 +69,10 @@ impl CastExpr { pub fn cast_type(&self) -> &DataType { &self.cast_type } + /// The data type to cast to + pub fn cast_options(&self) -> &CastOptions { + &self.cast_options + } } impl fmt::Display for CastExpr { @@ -111,6 +116,24 @@ impl PhysicalExpr for CastExpr { }, ))) } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + // Cast current node's interval to the right type: + children[0].cast_to(&self.cast_type, &self.cast_options) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + let child_interval = children[0]; + // Get child's datatype: + let cast_type = child_interval.get_datatype(); + Ok(vec![Some( + interval.cast_to(&cast_type, &self.cast_options)?, + )]) + } } impl PartialEq for CastExpr { diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs new file mode 100644 index 000000000000..302a86cdc927 --- /dev/null +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -0,0 +1,1038 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Constraint propagator/solver for custom PhysicalExpr graphs. + +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use arrow_schema::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Operator; +use petgraph::graph::NodeIndex; +use petgraph::stable_graph::{DefaultIx, StableGraph}; +use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; +use petgraph::Outgoing; + +use crate::expressions::Literal; +use crate::intervals::interval_aritmetic::{apply_operator, Interval}; +use crate::utils::{build_dag, ExprTreeNode}; +use crate::PhysicalExpr; + +// Interval arithmetic provides a way to perform mathematical operations on +// intervals, which represent a range of possible values rather than a single +// point value. This allows for the propagation of ranges through mathematical +// operations, and can be used to compute bounds for a complicated expression. +// The key idea is that by breaking down a complicated expression into simpler +// terms, and then combining the bounds for those simpler terms, one can +// obtain bounds for the overall expression. +// +// For example, consider a mathematical expression such as x^2 + y = 4. Since +// it would be a binary tree in [PhysicalExpr] notation, this type of an +// hierarchical computation is well-suited for a graph based implementation. +// In such an implementation, an equation system f(x) = 0 is represented by a +// directed acyclic expression graph (DAEG). +// +// In order to use interval arithmetic to compute bounds for this expression, +// one would first determine intervals that represent the possible values of x +// and y. Let's say that the interval for x is [1, 2] and the interval for y +// is [-3, 1]. In the chart below, you can see how the computation takes place. +// +// This way of using interval arithmetic to compute bounds for a complex +// expression by combining the bounds for the constituent terms within the +// original expression allows us to reason about the range of possible values +// of the expression. This information later can be used in range pruning of +// the provably unnecessary parts of `RecordBatch`es. +// +// References +// 1 - Kabak, Mehmet Ozan. Analog Circuit Start-Up Behavior Analysis: An Interval +// Arithmetic Based Approach, Chapter 4. Stanford University, 2015. +// 2 - Moore, Ramon E. Interval analysis. Vol. 4. Englewood Cliffs: Prentice-Hall, 1966. +// 3 - F. Messine, "Deterministic global optimization using interval constraint +// propagation techniques," RAIRO-Operations Research, vol. 38, no. 04, +// pp. 277{293, 2004. +// +// ``` text +// Computing bounds for an expression using interval arithmetic. Constraint propagation through a top-down evaluation of the expression +// graph using inverse semantics. +// +// [-2, 5] ∩ [4, 4] = [4, 4] [4, 4] +// +-----+ +-----+ +-----+ +-----+ +// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +// | | | | | | | | | | | | | | | | +// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +// | | | | | | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | 2 | | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | | 2 | [1, 4] | y | [0, 1]* +// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | | | [-3, 1] | +// | | | | +// +---+ +---+ +---+ +---+ +// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [1, 2] +// +---+ +---+ +---+ +---+ +// +// (a) Bottom-up evaluation: Step1 (b) Bottom up evaluation: Step2 (a) Top-down propagation: Step1 (b) Top-down propagation: Step2 +// +// [1 - 3, 4 + 1] = [-2, 5] [1 - 3, 4 + 1] = [-2, 5] +// +-----+ +-----+ +-----+ +-----+ +// +----| + |----+ +----| + |----+ +----| + |----+ +----| + |----+ +// | | | | | | | | | | | | | | | | +// | +-----+ | | +-----+ | | +-----+ | | +-----+ | +// | | | | | | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | 2 |[1, 4] | y | | 2 |[1, 4] | y | | 2 |[3, 4]** | y | | 2 |[1, 4] | y | +// |[.] | | | |[.] | | | |[.] | | | |[.] | | | +// +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ +// | [-3, 1] | [-3, 1] | [0, 1] | [-3, 1] +// | | | | +// +---+ +---+ +---+ +---+ +// | x | [1, 2] | x | [1, 2] | x | [1, 2] | x | [sqrt(3), 2]*** +// +---+ +---+ +---+ +---+ +// +// (c) Bottom-up evaluation: Step3 (d) Bottom-up evaluation: Step4 (c) Top-down propagation: Step3 (d) Top-down propagation: Step4 +// +// * [-3, 1] ∩ ([4, 4] - [1, 4]) = [0, 1] +// ** [1, 4] ∩ ([4, 4] - [0, 1]) = [3, 4] +// *** [1, 2] ∩ [sqrt(3), sqrt(4)] = [sqrt(3), 2] +// ``` + +/// This object implements a directed acyclic expression graph (DAEG) that +/// is used to compute ranges for expressions through interval arithmetic. +#[derive(Clone)] +pub struct ExprIntervalGraph { + graph: StableGraph, + root: NodeIndex, +} + +/// This object encapsulates all possible constraint propagation results. +#[derive(PartialEq, Debug)] +pub enum PropagationResult { + CannotPropagate, + Infeasible, + Success, +} + +/// This is a node in the DAEG; it encapsulates a reference to the actual +/// [PhysicalExpr] as well as an interval containing expression bounds. +#[derive(Clone, Debug)] +pub struct ExprIntervalGraphNode { + expr: Arc, + interval: Interval, +} + +impl Display for ExprIntervalGraphNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } +} + +impl ExprIntervalGraphNode { + /// Constructs a new DAEG node with an [-∞, ∞] range. + pub fn new(expr: Arc) -> Self { + ExprIntervalGraphNode { + expr, + interval: Interval::default(), + } + } + + /// Constructs a new DAEG node with the given range. + pub fn new_with_interval(expr: Arc, interval: Interval) -> Self { + ExprIntervalGraphNode { expr, interval } + } + + /// Get the interval object representing the range of the expression. + pub fn interval(&self) -> &Interval { + &self.interval + } + + /// This function creates a DAEG node from Datafusion's [ExprTreeNode] + /// object. Literals are created with definite, singleton intervals while + /// any other expression starts with an indefinite interval ([-∞, ∞]). + pub fn make_node(node: &ExprTreeNode) -> ExprIntervalGraphNode { + let expr = node.expression().clone(); + if let Some(literal) = expr.as_any().downcast_ref::() { + let value = literal.value(); + let interval = Interval { + lower: value.clone(), + upper: value.clone(), + }; + ExprIntervalGraphNode::new_with_interval(expr, interval) + } else { + ExprIntervalGraphNode::new(expr) + } + } +} + +impl PartialEq for ExprIntervalGraphNode { + fn eq(&self, other: &ExprIntervalGraphNode) -> bool { + self.expr.eq(&other.expr) + } +} + +// This function returns the inverse operator of the given operator. +fn get_inverse_op(op: Operator) -> Operator { + match op { + Operator::Plus => Operator::Minus, + Operator::Minus => Operator::Plus, + _ => unreachable!(), + } +} + +/// This function refines intervals `left_child` and `right_child` by applying +/// constraint propagation through `parent` via operation. The main idea is +/// that we can shrink ranges of variables x and y using parent interval p. +/// +/// Assuming that x,y and p has ranges [xL, xU], [yL, yU], and [pL, pU], we +/// apply the following operations: +/// - For plus operation, specifically, we would first do +/// - [xL, xU] <- ([pL, pU] - [yL, yU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([pL, pU] - [xL, xU]) ∩ [yL, yU]. +/// - For minus operation, specifically, we would first do +/// - [xL, xU] <- ([yL, yU] + [pL, pU]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] - [pL, pU]) ∩ [yL, yU]. +pub fn propagate_arithmetic( + op: &Operator, + parent: &Interval, + left_child: &Interval, + right_child: &Interval, +) -> Result<(Option, Option)> { + let inverse_op = get_inverse_op(*op); + // First, propagate to the left: + match apply_operator(&inverse_op, parent, right_child)?.intersect(left_child)? { + // Left is feasible: + Some(value) => { + // Propagate to the right using the new left. + let right = match op { + Operator::Minus => apply_operator(op, &value, parent), + Operator::Plus => apply_operator(&inverse_op, parent, &value), + _ => unreachable!(), + }? + .intersect(right_child)?; + // Return intervals for both children: + Ok((Some(value), right)) + } + // If the left child is infeasible, short-circuit. + None => Ok((None, None)), + } +} + +/// This function provides a target parent interval for comparison operators. +/// If we have expression > 0, expression must have the range [0, ∞]. +/// If we have expression < 0, expression must have the range [-∞, 0]. +/// Currently, we only support strict inequalities since open/closed intervals +/// are not implemented yet. +fn comparison_operator_target(datatype: &DataType, op: &Operator) -> Result { + let unbounded = ScalarValue::try_from(datatype)?; + let zero = ScalarValue::new_zero(datatype)?; + Ok(match *op { + Operator::Gt => Interval { + lower: zero, + upper: unbounded, + }, + Operator::Lt => Interval { + lower: unbounded, + upper: zero, + }, + _ => unreachable!(), + }) +} + +/// This function propagates constraints arising from comparison operators. +/// The main idea is that we can analyze an inequality like x > y through the +/// equivalent inequality x - y > 0. Assuming that x and y has ranges [xL, xU] +/// and [yL, yU], we simply apply constraint propagation across [xL, xU], +/// [yL, yH] and [0, ∞]. Specifically, we would first do +/// - [xL, xU] <- ([yL, yU] + [0, ∞]) ∩ [xL, xU], and then +/// - [yL, yU] <- ([xL, xU] - [0, ∞]) ∩ [yL, yU]. +pub fn propagate_comparison( + op: &Operator, + left_child: &Interval, + right_child: &Interval, +) -> Result<(Option, Option)> { + let parent = comparison_operator_target(&left_child.get_datatype(), op)?; + propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child) +} + +impl ExprIntervalGraph { + pub fn try_new(expr: Arc) -> Result { + // Build the full graph: + let (root, graph) = build_dag(expr, &ExprIntervalGraphNode::make_node)?; + Ok(Self { graph, root }) + } + + pub fn node_count(&self) -> usize { + self.graph.node_count() + } + + // Sometimes, we do not want to calculate and/or propagate intervals all + // way down to leaf expressions. For example, assume that we have a + // `SymmetricHashJoin` which has a child with an output ordering like: + // + // PhysicalSortExpr { + // expr: BinaryExpr('a', +, 'b'), + // sort_option: .. + // } + // + // i.e. its output order comes from a clause like "ORDER BY a + b". In such + // a case, we must calculate the interval for the BinaryExpr('a', +, 'b') + // instead of the columns inside this BinaryExpr, because this interval + // decides whether we prune or not. Therefore, children `PhysicalExpr`s of + // this `BinaryExpr` may be pruned for performance. The figure below + // explains this example visually. + // + // Note that we just remove the nodes from the DAEG, do not make any change + // to the plan itself. + // + // ```text + // + // +-----+ +-----+ + // | GT | | GT | + // +--------| |-------+ +--------| |-------+ + // | +-----+ | | +-----+ | + // | | | | + // +-----+ | +-----+ | + // |Cast | | |Cast | | + // | | | --\ | | | + // +-----+ | ---------- +-----+ | + // | | --/ | | + // | | | | + // +-----+ +-----+ +-----+ +-----+ + // +--|Plus |--+ +--|Plus |--+ |Plus | +--|Plus |--+ + // | | | | | | | | | | | | | | + // Prune from here | +-----+ | | +-----+ | +-----+ | +-----+ | + // ------------------------------------ | | | | + // | | | | | | + // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ + // | a | | b | | c | | 2 | | c | | 2 | + // | | | | | | | | | | | | + // +-----+ +-----+ +-----+ +-----+ +-----+ +-----+ + // + // ``` + + /// This function associates stable node indices with [PhysicalExpr]s so + /// that we can match Arc and NodeIndex objects during + /// membership tests. + pub fn gather_node_indices( + &mut self, + exprs: &[Arc], + ) -> Vec<(Arc, usize)> { + let graph = &self.graph; + let mut bfs = Bfs::new(graph, self.root); + // We collect the node indices (usize) of [PhysicalExpr]s in the order + // given by argument `exprs`. To preserve this order, we initialize each + // expression's node index with usize::MAX, and then find the corresponding + // node indices by traversing the graph. + let mut removals = vec![]; + let mut expr_node_indices = exprs + .iter() + .map(|e| (e.clone(), usize::MAX)) + .collect::>(); + while let Some(node) = bfs.next(graph) { + // Get the plan corresponding to this node: + let expr = &graph[node].expr; + // If the current expression is among `exprs`, slate its children + // for removal: + if let Some(value) = exprs.iter().position(|e| expr.eq(e)) { + // Update the node index of the associated `PhysicalExpr`: + expr_node_indices[value].1 = node.index(); + for edge in graph.edges_directed(node, Outgoing) { + // Slate the child for removal, do not remove immediately. + removals.push(edge.id()); + } + } + } + for edge_idx in removals { + self.graph.remove_edge(edge_idx); + } + // Get the set of node indices reachable from the root node: + let connected_nodes = self.connected_nodes(); + // Remove nodes not connected to the root node: + self.graph + .retain_nodes(|_, index| connected_nodes.contains(&index)); + expr_node_indices + } + + /// Returns the set of node indices reachable from the root node via a + /// simple depth-first search. + fn connected_nodes(&self) -> HashSet { + let mut nodes = HashSet::new(); + let mut dfs = Dfs::new(&self.graph, self.root); + while let Some(node) = dfs.next(&self.graph) { + nodes.insert(node); + } + nodes + } + + /// This function assigns given ranges to expressions in the DAEG. + /// The argument `assignments` associates indices of sought expressions + /// with their corresponding new ranges. + pub fn assign_intervals(&mut self, assignments: &[(usize, Interval)]) { + for (index, interval) in assignments { + let node_index = NodeIndex::from(*index as DefaultIx); + self.graph[node_index].interval = interval.clone(); + } + } + + /// This function fetches ranges of expressions from the DAEG. The argument + /// `assignments` associates indices of sought expressions with their ranges, + /// which this function modifies to reflect the intervals in the DAEG. + pub fn update_intervals(&self, assignments: &mut [(usize, Interval)]) { + for (index, interval) in assignments.iter_mut() { + let node_index = NodeIndex::from(*index as DefaultIx); + *interval = self.graph[node_index].interval.clone(); + } + } + + /// Computes bounds for an expression using interval arithmetic via a + /// bottom-up traversal. + /// + /// # Arguments + /// * `leaf_bounds` - &[(usize, Interval)]. Provide NodeIndex, Interval tuples for leaf variables. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use datafusion_common::ScalarValue; + /// use datafusion_expr::Operator; + /// use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + /// use datafusion_physical_expr::intervals::{Interval, ExprIntervalGraph}; + /// use datafusion_physical_expr::PhysicalExpr; + /// let expr = Arc::new(BinaryExpr::new( + /// Arc::new(Column::new("gnz", 0)), + /// Operator::Plus, + /// Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + /// )); + /// let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + /// // Do it once, while constructing. + /// let node_indices = graph + /// .gather_node_indices(&[Arc::new(Column::new("gnz", 0))]); + /// let left_index = node_indices.get(0).unwrap().1; + /// // Provide intervals for leaf variables (here, there is only one). + /// let intervals = vec![( + /// left_index, + /// Interval { + /// lower: ScalarValue::Int32(Some(10)), + /// upper: ScalarValue::Int32(Some(20)), + /// }, + /// )]; + /// // Evaluate bounds for the composite expression: + /// graph.assign_intervals(&intervals); + /// assert_eq!( + /// graph.evaluate_bounds().unwrap(), + /// &Interval { + /// lower: ScalarValue::Int32(Some(20)), + /// upper: ScalarValue::Int32(Some(30)) + /// } + /// ) + /// + /// ``` + pub fn evaluate_bounds(&mut self) -> Result<&Interval> { + let mut dfs = DfsPostOrder::new(&self.graph, self.root); + while let Some(node) = dfs.next(&self.graph) { + let neighbors = self.graph.neighbors_directed(node, Outgoing); + let mut children_intervals = neighbors + .map(|child| self.graph[child].interval()) + .collect::>(); + // If the current expression is a leaf, its interval should already + // be set externally, just continue with the evaluation procedure: + if !children_intervals.is_empty() { + // Reverse to align with [PhysicalExpr]'s children: + children_intervals.reverse(); + self.graph[node].interval = + self.graph[node].expr.evaluate_bounds(&children_intervals)?; + } + } + Ok(&self.graph[self.root].interval) + } + + /// Updates/shrinks bounds for leaf expressions using interval arithmetic + /// via a top-down traversal. + fn propagate_constraints(&mut self) -> Result { + let mut bfs = Bfs::new(&self.graph, self.root); + while let Some(node) = bfs.next(&self.graph) { + let neighbors = self.graph.neighbors_directed(node, Outgoing); + let mut children = neighbors.collect::>(); + // If the current expression is a leaf, its range is now final. + // So, just continue with the propagation procedure: + if children.is_empty() { + continue; + } + // Reverse to align with [PhysicalExpr]'s children: + children.reverse(); + let children_intervals = children + .iter() + .map(|child| self.graph[*child].interval()) + .collect::>(); + let node_interval = self.graph[node].interval(); + let propagated_intervals = self.graph[node] + .expr + .propagate_constraints(node_interval, &children_intervals)?; + for (child, interval) in children.into_iter().zip(propagated_intervals) { + if let Some(interval) = interval { + self.graph[child].interval = interval; + } else { + // The constraint is infeasible, report: + return Ok(PropagationResult::Infeasible); + } + } + } + Ok(PropagationResult::Success) + } + + /// Updates intervals for all expressions in the DAEG by successive + /// bottom-up and top-down traversals. + pub fn update_ranges( + &mut self, + leaf_bounds: &mut [(usize, Interval)], + ) -> Result { + self.assign_intervals(leaf_bounds); + match self.evaluate_bounds()? { + Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + } => Ok(PropagationResult::Infeasible), + Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(true)), + } => { + let result = self.propagate_constraints(); + self.update_intervals(leaf_bounds); + result + } + _ => Ok(PropagationResult::CannotPropagate), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::intervals::test_utils::gen_conjunctive_numeric_expr; + use itertools::Itertools; + + use crate::expressions::{BinaryExpr, Column}; + use datafusion_common::ScalarValue; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use rstest::*; + + fn experiment( + expr: Arc, + exprs_with_interval: (Arc, Arc), + left_interval: (Option, Option), + right_interval: (Option, Option), + left_waited: (Option, Option), + right_waited: (Option, Option), + result: PropagationResult, + ) -> Result<()> { + let col_stats = vec![ + ( + exprs_with_interval.0.clone(), + Interval { + lower: ScalarValue::Int32(left_interval.0), + upper: ScalarValue::Int32(left_interval.1), + }, + ), + ( + exprs_with_interval.1.clone(), + Interval { + lower: ScalarValue::Int32(right_interval.0), + upper: ScalarValue::Int32(right_interval.1), + }, + ), + ]; + let expected = vec![ + ( + exprs_with_interval.0.clone(), + Interval { + lower: ScalarValue::Int32(left_waited.0), + upper: ScalarValue::Int32(left_waited.1), + }, + ), + ( + exprs_with_interval.1.clone(), + Interval { + lower: ScalarValue::Int32(right_waited.0), + upper: ScalarValue::Int32(right_waited.1), + }, + ), + ]; + let mut graph = ExprIntervalGraph::try_new(expr)?; + let expr_indexes = graph + .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); + + let mut col_stat_nodes = col_stats + .iter() + .zip(expr_indexes.iter()) + .map(|((_, interval), (_, index))| (*index, interval.clone())) + .collect_vec(); + let expected_nodes = expected + .iter() + .zip(expr_indexes.iter()) + .map(|((_, interval), (_, index))| (*index, interval.clone())) + .collect_vec(); + + let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?; + assert_eq!(exp_result, result); + col_stat_nodes + .iter() + .zip(expected_nodes.iter()) + .for_each(|((_, res), (_, expected))| assert_eq!(res, expected)); + Ok(()) + } + + fn generate_case( + expr: Arc, + left_col: Arc, + right_col: Arc, + seed: u64, + expr_left: i32, + expr_right: i32, + ) -> Result<()> { + let mut r = StdRng::seed_from_u64(seed); + + let (left_interval, right_interval, left_waited, right_waited) = if ASC { + let left = (Some(r.gen_range(0..1000)), None); + let right = (Some(r.gen_range(0..1000)), None); + ( + left, + right, + ( + Some(std::cmp::max(left.0.unwrap(), right.0.unwrap() + expr_left)), + None, + ), + ( + Some(std::cmp::max( + right.0.unwrap(), + left.0.unwrap() + expr_right, + )), + None, + ), + ) + } else { + let left = (None, Some(r.gen_range(0..1000))); + let right = (None, Some(r.gen_range(0..1000))); + ( + left, + right, + ( + None, + Some(std::cmp::min(left.1.unwrap(), right.1.unwrap() + expr_left)), + ), + ( + None, + Some(std::cmp::min( + right.1.unwrap(), + left.1.unwrap() + expr_right, + )), + ), + ) + }; + experiment( + expr, + (left_col, right_col), + left_interval, + right_interval, + left_waited, + right_waited, + PropagationResult::Success, + )?; + Ok(()) + } + + #[test] + fn testing_not_possible() -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark > right_watermark + 5 + + let left_and_1 = Arc::new(BinaryExpr::new( + left_col.clone(), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); + experiment( + expr, + (left_col, right_col), + (Some(10), Some(20)), + (Some(100), None), + (Some(10), Some(20)), + (Some(100), None), + PropagationResult::Infeasible, + )?; + Ok(()) + } + + #[rstest] + #[test] + fn case_1( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 < right_watermark + 33 + let expr = gen_conjunctive_numeric_expr( + left_col.clone(), + right_col.clone(), + Operator::Plus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 11, + 3, + 33, + ); + // l > r + 10 AND r > l - 30 + let l_gt_r = 10; + let r_gt_l = -30; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 10 AND l < r + 30 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + + Ok(()) + } + #[rstest] + #[test] + fn case_2( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 < right_watermark + 10 + let expr = gen_conjunctive_numeric_expr( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Plus, + Operator::Plus, + Operator::Plus, + 1, + 5, + 3, + 10, + ); + // l > r + 6 AND r > l - 7 + let l_gt_r = 6; + let r_gt_l = -7; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 6 AND l < r + 7 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + + Ok(()) + } + + #[rstest] + #[test] + fn case_3( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 < right_watermark + 10 + let expr = gen_conjunctive_numeric_expr( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Plus, + Operator::Minus, + Operator::Plus, + 1, + 5, + 3, + 10, + ); + // l > r + 6 AND r > l - 13 + let l_gt_r = 6; + let r_gt_l = -13; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 6 AND l < r + 13 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + + Ok(()) + } + #[rstest] + #[test] + fn case_4( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 < right_watermark + 10 + let expr = gen_conjunctive_numeric_expr( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Plus, + 10, + 5, + 3, + 10, + ); + // l > r + 5 AND r > l - 13 + let l_gt_r = 5; + let r_gt_l = -13; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 5 AND l < r + 13 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + Ok(()) + } + + #[rstest] + #[test] + fn case_5( + #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64, + ) -> Result<()> { + let left_col = Arc::new(Column::new("left_watermark", 0)); + let right_col = Arc::new(Column::new("right_watermark", 0)); + // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 < right_watermark - 3 + + let expr = gen_conjunctive_numeric_expr( + left_col.clone(), + right_col.clone(), + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + 10, + 5, + 30, + 3, + ); + // l > r + 5 AND r > l - 27 + let l_gt_r = 5; + let r_gt_l = -27; + generate_case::( + expr.clone(), + left_col.clone(), + right_col.clone(), + seed, + l_gt_r, + r_gt_l, + )?; + // Descending tests + // r < l - 5 AND l < r + 27 + let r_lt_l = -l_gt_r; + let l_lt_r = -r_gt_l; + generate_case::(expr, left_col, right_col, seed, l_lt_r, r_lt_l)?; + Ok(()) + } + + #[test] + fn test_gather_node_indices_dont_remove() -> Result<()> { + // Expression: a@0 + b@1 + 1 > a@0 - b@1, given a@0 + b@1. + // Do not remove a@0 or b@1, only remove edges since a@0 - b@1 also + // depends on leaf nodes a@0 and b@1. + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Minus, + Arc::new(Column::new("b", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is equal the previous node count. + // This means we did not remove any node. + assert_eq!(prev_node_count, final_node_count); + Ok(()) + } + #[test] + fn test_gather_node_indices_remove() -> Result<()> { + // Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1. + // We expect to remove two nodes since we do not need a@ and b@. + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 0)), + Operator::Minus, + Arc::new(Column::new("z", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is two less than the previous node + // count; i.e. that we did remove two nodes. + assert_eq!(prev_node_count, final_node_count + 2); + Ok(()) + } + + #[test] + fn test_gather_node_indices_remove_one() -> Result<()> { + // Expression: a@0 + b@1 + 1 > a@0 - z@1, given a@0 + b@1. + // We expect to remove one nodesince we still need a@ but not b@. + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Minus, + Arc::new(Column::new("z", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is one less than the previous node + // count; i.e. that we did remove two nodes. + assert_eq!(prev_node_count, final_node_count + 1); + Ok(()) + } + + #[test] + fn test_gather_node_indices_cannot_provide() -> Result<()> { + // Expression: a@0 + 1 + b@1 > y@0 - z@1 -> provide a@0 + b@1 + // TODO: We expect nodes a@0 and b@1 to be pruned, and intervals to be provided from the a@0 + b@1 node. + // However, we do not have an exact node for a@0 + b@1 due to the binary tree structure of the expressions. + // Pruning and interval providing for BinaryExpr expressions are more challenging without exact matches. + // Currently, we only support exact matches for BinaryExprs, but we plan to extend support beyond exact matches in the future. + let left_expr = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + + let right_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 0)), + Operator::Minus, + Arc::new(Column::new("z", 1)), + )); + let expr = Arc::new(BinaryExpr::new(left_expr, Operator::Gt, right_expr)); + let mut graph = ExprIntervalGraph::try_new(expr).unwrap(); + // Define a test leaf node. + let leaf_node = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Column::new("b", 1)), + )); + // Store the current node count. + let prev_node_count = graph.node_count(); + // Gather the index of node in the expression graph that match the test leaf node. + graph.gather_node_indices(&[leaf_node]); + // Store the final node count. + let final_node_count = graph.node_count(); + // Assert that the final node count is equal the previous node count (i.e., no node was pruned). + assert_eq!(prev_node_count, final_node_count); + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs new file mode 100644 index 000000000000..7fc3641b25ef --- /dev/null +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval arithmetic library + +use std::borrow::Borrow; +use std::fmt; +use std::fmt::{Display, Formatter}; + +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::DataType; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::Operator; + +use crate::aggregate::min_max::{max, min}; + +/// This type represents an interval, which is used to calculate reliable +/// bounds for expressions. Currently, we only support addition and +/// subtraction, but more capabilities will be added in the future. +/// Upper/lower bounds having NULL values indicate an unbounded side. For +/// example; [10, 20], [10, ∞], [-∞, 100] and [-∞, ∞] are all valid intervals. +#[derive(Debug, PartialEq, Clone, Eq, Hash)] +pub struct Interval { + pub lower: ScalarValue, + pub upper: ScalarValue, +} + +impl Default for Interval { + fn default() -> Self { + Interval { + lower: ScalarValue::Null, + upper: ScalarValue::Null, + } + } +} + +impl Display for Interval { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "Interval [{}, {}]", self.lower, self.upper) + } +} + +impl Interval { + pub(crate) fn cast_to( + &self, + data_type: &DataType, + cast_options: &CastOptions, + ) -> Result { + Ok(Interval { + lower: cast_scalar_value(&self.lower, data_type, cast_options)?, + upper: cast_scalar_value(&self.upper, data_type, cast_options)?, + }) + } + + pub(crate) fn get_datatype(&self) -> DataType { + self.lower.get_datatype() + } + + /// Decide if this interval is certainly greater than, possibly greater than, + /// or can't be greater than `other` by returning [true, true], + /// [false, true] or [false, false] respectively. + pub(crate) fn gt(&self, other: &Interval) -> Interval { + let flags = if !self.upper.is_null() + && !other.lower.is_null() + && (self.upper <= other.lower) + { + (false, false) + } else if !self.lower.is_null() + && !other.upper.is_null() + && (self.lower > other.upper) + { + (true, true) + } else { + (false, true) + }; + Interval { + lower: ScalarValue::Boolean(Some(flags.0)), + upper: ScalarValue::Boolean(Some(flags.1)), + } + } + + /// Decide if this interval is certainly less than, possibly less than, + /// or can't be less than `other` by returning [true, true], + /// [false, true] or [false, false] respectively. + pub(crate) fn lt(&self, other: &Interval) -> Interval { + other.gt(self) + } + + /// Decide if this interval is certainly equal to, possibly equal to, + /// or can't be equal to `other` by returning [true, true], + /// [false, true] or [false, false] respectively. + pub(crate) fn equal(&self, other: &Interval) -> Interval { + let flags = if !self.lower.is_null() + && (self.lower == self.upper) + && (other.lower == other.upper) + && (self.lower == other.lower) + { + (true, true) + } else if (!self.lower.is_null() + && !other.upper.is_null() + && (self.lower > other.upper)) + || (!self.upper.is_null() + && !other.lower.is_null() + && (self.upper < other.lower)) + { + (false, false) + } else { + (false, true) + }; + Interval { + lower: ScalarValue::Boolean(Some(flags.0)), + upper: ScalarValue::Boolean(Some(flags.1)), + } + } + + /// Compute the logical conjunction of this (boolean) interval with the + /// given boolean interval. + pub(crate) fn and(&self, other: &Interval) -> Result { + let flags = match (self, other) { + ( + Interval { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }, + Interval { + lower: ScalarValue::Boolean(Some(other_lower)), + upper: ScalarValue::Boolean(Some(other_upper)), + }, + ) => { + if *lower && *other_lower { + (true, true) + } else if *upper && *other_upper { + (false, true) + } else { + (false, false) + } + } + _ => { + return Err(DataFusionError::Internal( + "Incompatible types for logical conjunction".to_string(), + )) + } + }; + Ok(Interval { + lower: ScalarValue::Boolean(Some(flags.0)), + upper: ScalarValue::Boolean(Some(flags.1)), + }) + } + + /// Compute the intersection of the interval with the given interval. + /// If the intersection is empty, return None. + pub(crate) fn intersect(&self, other: &Interval) -> Result> { + let lower = if self.lower.is_null() { + other.lower.clone() + } else if other.lower.is_null() { + self.lower.clone() + } else { + max(&self.lower, &other.lower)? + }; + let upper = if self.upper.is_null() { + other.upper.clone() + } else if other.upper.is_null() { + self.upper.clone() + } else { + min(&self.upper, &other.upper)? + }; + Ok(if !lower.is_null() && !upper.is_null() && lower > upper { + // This None value signals an empty interval. + None + } else { + Some(Interval { lower, upper }) + }) + } + + // Compute the negation of the interval. + #[allow(dead_code)] + pub(crate) fn arithmetic_negate(&self) -> Result { + Ok(Interval { + lower: self.upper.arithmetic_negate()?, + upper: self.lower.arithmetic_negate()?, + }) + } + + /// Add the given interval (`other`) to this interval. Say we have + /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2]. + /// Note that this represents all possible values the sum can take if + /// one can choose single values arbitrarily from each of the operands. + pub fn add>(&self, other: T) -> Result { + let rhs = other.borrow(); + let lower = if self.lower.is_null() || rhs.lower.is_null() { + ScalarValue::try_from(self.lower.get_datatype()) + } else { + self.lower.add(&rhs.lower) + }?; + let upper = if self.upper.is_null() || rhs.upper.is_null() { + ScalarValue::try_from(self.upper.get_datatype()) + } else { + self.upper.add(&rhs.upper) + }?; + Ok(Interval { lower, upper }) + } + + /// Subtract the given interval (`other`) from this interval. Say we have + /// intervals [a1, b1] and [a2, b2], then their sum is [a1 - b2, b1 - a2]. + /// Note that this represents all possible values the difference can take + /// if one can choose single values arbitrarily from each of the operands. + pub fn sub>(&self, other: T) -> Result { + let rhs = other.borrow(); + let lower = if self.lower.is_null() || rhs.upper.is_null() { + ScalarValue::try_from(self.lower.get_datatype()) + } else { + self.lower.sub(&rhs.upper) + }?; + let upper = if self.upper.is_null() || rhs.lower.is_null() { + ScalarValue::try_from(self.upper.get_datatype()) + } else { + self.upper.sub(&rhs.lower) + }?; + Ok(Interval { lower, upper }) + } +} + +/// Indicates whether interval arithmetic is supported for the given operator. +pub fn is_operator_supported(op: &Operator) -> bool { + matches!( + op, + &Operator::Plus + | &Operator::Minus + | &Operator::And + | &Operator::Gt + | &Operator::Lt + ) +} + +/// Indicates whether interval arithmetic is supported for the given data type. +pub fn is_datatype_supported(data_type: &DataType) -> bool { + matches!( + data_type, + &DataType::Int64 + | &DataType::Int32 + | &DataType::Int16 + | &DataType::Int8 + | &DataType::UInt64 + | &DataType::UInt32 + | &DataType::UInt16 + | &DataType::UInt8 + ) +} + +pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { + match *op { + Operator::Eq => Ok(lhs.equal(rhs)), + Operator::Gt => Ok(lhs.gt(rhs)), + Operator::Lt => Ok(lhs.lt(rhs)), + Operator::And => lhs.and(rhs), + Operator::Plus => lhs.add(rhs), + Operator::Minus => lhs.sub(rhs), + _ => Ok(Interval { + lower: ScalarValue::Null, + upper: ScalarValue::Null, + }), + } +} + +/// Cast scalar value to the given data type using an arrow kernel. +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_array = cast_with_options(&value.to_array(), data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + +#[cfg(test)] +mod tests { + use crate::intervals::Interval; + use datafusion_common::{Result, ScalarValue}; + + #[test] + fn intersect_test() -> Result<()> { + let possible_cases = vec![ + (Some(1000), None, None, None, Some(1000), None), + (None, Some(1000), None, None, None, Some(1000)), + (None, None, Some(1000), None, Some(1000), None), + (None, None, None, Some(1000), None, Some(1000)), + (Some(1000), None, Some(1000), None, Some(1000), None), + ( + None, + Some(1000), + Some(999), + Some(1002), + Some(999), + Some(1000), + ), + (None, Some(1000), Some(1000), None, Some(1000), Some(1000)), // singleton + (None, None, None, None, None, None), + ]; + + for case in possible_cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .intersect(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })? + .unwrap(), + Interval { + lower: ScalarValue::Int64(case.4), + upper: ScalarValue::Int64(case.5) + } + ) + } + + let empty_cases = vec![ + (None, Some(1000), Some(1001), None), + (Some(1001), None, None, Some(1000)), + (None, Some(1000), Some(1001), Some(1002)), + (Some(1001), Some(1002), None, Some(1000)), + ]; + + for case in empty_cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .intersect(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })?, + None + ) + } + + Ok(()) + } + + #[test] + fn gt_test() { + let cases = vec![ + (Some(1000), None, None, None, false, true), + (None, Some(1000), None, None, false, true), + (None, None, Some(1000), None, false, true), + (None, None, None, Some(1000), false, true), + (None, Some(1000), Some(1000), None, false, false), + (None, Some(1000), Some(1001), None, false, false), + (Some(1000), None, Some(1000), None, false, true), + (None, Some(1000), Some(1001), Some(1002), false, false), + (None, Some(1000), Some(999), Some(1002), false, true), + (Some(1002), None, Some(999), Some(1002), false, true), + (Some(1003), None, Some(999), Some(1002), true, true), + (None, None, None, None, false, true), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .gt(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + }), + Interval { + lower: ScalarValue::Boolean(Some(case.4)), + upper: ScalarValue::Boolean(Some(case.5)) + } + ) + } + } + + #[test] + fn lt_test() { + let cases = vec![ + (Some(1000), None, None, None, false, true), + (None, Some(1000), None, None, false, true), + (None, None, Some(1000), None, false, true), + (None, None, None, Some(1000), false, true), + (None, Some(1000), Some(1000), None, false, true), + (None, Some(1000), Some(1001), None, true, true), + (Some(1000), None, Some(1000), None, false, true), + (None, Some(1000), Some(1001), Some(1002), true, true), + (None, Some(1000), Some(999), Some(1002), false, true), + (None, None, None, None, false, true), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .lt(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + }), + Interval { + lower: ScalarValue::Boolean(Some(case.4)), + upper: ScalarValue::Boolean(Some(case.5)) + } + ) + } + } + + #[test] + fn and_test() -> Result<()> { + let cases = vec![ + (false, true, false, false, false, false), + (false, false, false, true, false, false), + (false, true, false, true, false, true), + (false, true, true, true, false, true), + (false, false, false, false, false, false), + (true, true, true, true, true, true), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Boolean(Some(case.0)), + upper: ScalarValue::Boolean(Some(case.1)) + } + .and(&Interval { + lower: ScalarValue::Boolean(Some(case.2)), + upper: ScalarValue::Boolean(Some(case.3)) + })?, + Interval { + lower: ScalarValue::Boolean(Some(case.4)), + upper: ScalarValue::Boolean(Some(case.5)) + } + ) + } + Ok(()) + } + + #[test] + fn add_test() -> Result<()> { + let cases = vec![ + (Some(1000), None, None, None, None, None), + (None, Some(1000), None, None, None, None), + (None, None, Some(1000), None, None, None), + (None, None, None, Some(1000), None, None), + (Some(1000), None, Some(1000), None, Some(2000), None), + (None, Some(1000), Some(999), Some(1002), None, Some(2002)), + (None, Some(1000), Some(1000), None, None, None), + ( + Some(2001), + Some(1), + Some(1005), + Some(-999), + Some(3006), + Some(-998), + ), + (None, None, None, None, None, None), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .add(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })?, + Interval { + lower: ScalarValue::Int64(case.4), + upper: ScalarValue::Int64(case.5) + } + ) + } + Ok(()) + } + + #[test] + fn sub_test() -> Result<()> { + let cases = vec![ + (Some(1000), None, None, None, None, None), + (None, Some(1000), None, None, None, None), + (None, None, Some(1000), None, None, None), + (None, None, None, Some(1000), None, None), + (Some(1000), None, Some(1000), None, None, None), + (None, Some(1000), Some(999), Some(1002), None, Some(1)), + (None, Some(1000), Some(1000), None, None, Some(0)), + ( + Some(2001), + Some(1000), + Some(1005), + Some(999), + Some(1002), + Some(-5), + ), + (None, None, None, None, None, None), + ]; + + for case in cases { + assert_eq!( + Interval { + lower: ScalarValue::Int64(case.0), + upper: ScalarValue::Int64(case.1) + } + .sub(&Interval { + lower: ScalarValue::Int64(case.2), + upper: ScalarValue::Int64(case.3) + })?, + Interval { + lower: ScalarValue::Int64(case.4), + upper: ScalarValue::Int64(case.5) + } + ) + } + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/intervals/mod.rs b/datafusion/physical-expr/src/intervals/mod.rs new file mode 100644 index 000000000000..45616534cb17 --- /dev/null +++ b/datafusion/physical-expr/src/intervals/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interval calculations +//! + +pub mod cp_solver; +pub mod interval_aritmetic; + +pub mod test_utils; +pub use cp_solver::ExprIntervalGraph; +pub use interval_aritmetic::*; diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs new file mode 100644 index 000000000000..ba02f4ff7aac --- /dev/null +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test utilities for the interval arithmetic library + +use std::sync::Arc; + +use crate::expressions::{BinaryExpr, Literal}; +use crate::PhysicalExpr; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +#[allow(clippy::too_many_arguments)] +/// This test function generates a conjunctive statement with two numeric +/// terms with the following form: +/// left_col (op_1) a > right_col (op_2) b AND left_col (op_3) c < right_col (op_4) d +pub fn gen_conjunctive_numeric_expr( + left_col: Arc, + right_col: Arc, + op_1: Operator, + op_2: Operator, + op_3: Operator, + op_4: Operator, + a: i32, + b: i32, + c: i32, + d: i32, +) -> Arc { + let left_and_1 = Arc::new(BinaryExpr::new( + left_col.clone(), + op_1, + Arc::new(Literal::new(ScalarValue::Int32(Some(a)))), + )); + let left_and_2 = Arc::new(BinaryExpr::new( + right_col.clone(), + op_2, + Arc::new(Literal::new(ScalarValue::Int32(Some(b)))), + )); + + let right_and_1 = Arc::new(BinaryExpr::new( + left_col, + op_3, + Arc::new(Literal::new(ScalarValue::Int32(Some(c)))), + )); + let right_and_2 = Arc::new(BinaryExpr::new( + right_col, + op_4, + Arc::new(Literal::new(ScalarValue::Int32(Some(d)))), + )); + let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); + let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); + Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)) +} diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 7a2ea6872fa7..c93c8f3c86e0 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -26,6 +26,7 @@ pub mod execution_props; pub mod expressions; pub mod functions; pub mod hash_utils; +pub mod intervals; pub mod math_expressions; mod physical_expr; pub mod planner; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index f4e9593c8264..e2588cc2aa81 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -30,6 +30,7 @@ use std::fmt::{Debug, Display}; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, filter_record_batch, is_not_null, SlicesIterator}; +use crate::intervals::Interval; use std::any::Any; use std::sync::Arc; @@ -81,6 +82,27 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { fn analyze(&self, context: AnalysisContext) -> AnalysisContext { context } + + /// Computes bounds for the expression using interval arithmetic. + fn evaluate_bounds(&self, _children: &[&Interval]) -> Result { + Err(DataFusionError::NotImplemented(format!( + "Not implemented for {self}" + ))) + } + + /// Updates/shrinks bounds for the expression using interval arithmetic. + /// If constraint propagation reveals an infeasibility, returns [None] for + /// the child causing infeasibility. If none of the children intervals + /// change, may return an empty vector instead of cloning `children`. + fn propagate_constraints( + &self, + _interval: &Interval, + _children: &[&Interval], + ) -> Result>> { + Err(DataFusionError::NotImplemented(format!( + "Not implemented for {self}" + ))) + } } /// Shared [`PhysicalExpr`]. diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index 612d0e0b8ea0..a80a92bc5a5a 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -16,19 +16,15 @@ // under the License. use crate::equivalence::EquivalentClass; -use crate::expressions::BinaryExpr; -use crate::expressions::Column; -use crate::expressions::UnKnownColumn; -use crate::rewrite::RewriteRecursion; -use crate::rewrite::TreeNodeRewritable; -use crate::rewrite::TreeNodeRewriter; -use crate::PhysicalSortExpr; -use crate::{EquivalenceProperties, PhysicalExpr}; -use datafusion_common::DataFusionError; -use datafusion_expr::Operator; - +use crate::expressions::{BinaryExpr, Column, UnKnownColumn}; +use crate::rewrite::{TreeNodeRewritable, TreeNodeRewriter}; +use crate::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::SchemaRef; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::Operator; +use petgraph::graph::NodeIndex; +use petgraph::stable_graph::StableGraph; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -239,43 +235,142 @@ pub fn ordering_satisfy_concrete EquivalenceProperties>( } } -/// Extract referenced [`Column`]s within a [`PhysicalExpr`]. -/// -/// This works recursively. -pub fn get_phys_expr_columns(pred: &Arc) -> HashSet { - let mut rewriter = ColumnCollector::default(); - pred.clone() - .transform_using(&mut rewriter) - .expect("never fail"); - rewriter.cols +#[derive(Clone, Debug)] +pub struct ExprTreeNode { + expr: Arc, + data: Option, + child_nodes: Vec>, } -#[derive(Debug, Default)] -struct ColumnCollector { - cols: HashSet, +impl ExprTreeNode { + pub fn new(expr: Arc) -> Self { + ExprTreeNode { + expr, + data: None, + child_nodes: vec![], + } + } + + pub fn expression(&self) -> &Arc { + &self.expr + } + + pub fn children(&self) -> Vec> { + self.expr + .children() + .into_iter() + .map(ExprTreeNode::new) + .collect() + } } -impl TreeNodeRewriter> for ColumnCollector { - fn pre_visit( - &mut self, - node: &Arc, - ) -> Result { - if let Some(column) = node.as_any().downcast_ref::() { - self.cols.insert(column.clone()); - } - Ok(RewriteRecursion::Continue) +impl TreeNodeRewritable for ExprTreeNode { + fn map_children(mut self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + self.child_nodes = self + .children() + .into_iter() + .map(transform) + .collect::>>()?; + Ok(self) } +} + +/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a +/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting +/// identical expressions in one node. Caller specifies the node type in the +/// DAEG via the `constructor` argument, which constructs nodes in the DAEG +/// from the [ExprTreeNode] ancillary object. +struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> T> { + // The resulting DAEG (expression DAG). + graph: StableGraph, + // A vector of visited expression nodes and their corresponding node indices. + visited_plans: Vec<(Arc, NodeIndex)>, + // A function to convert an input expression node to T. + constructor: &'a F, +} +impl<'a, T, F: Fn(&ExprTreeNode) -> T> + TreeNodeRewriter> for PhysicalExprDAEGBuilder<'a, T, F> +{ + // This method mutates an expression node by transforming it to a physical expression + // and adding it to the graph. The method returns the mutated expression node. fn mutate( &mut self, - expr: Arc, - ) -> Result, DataFusionError> { - Ok(expr) + mut node: ExprTreeNode, + ) -> Result> { + // Get the expression associated with the input expression node. + let expr = &node.expr; + + // Check if the expression has already been visited. + let node_idx = match self.visited_plans.iter().find(|(e, _)| expr.eq(e)) { + // If the expression has been visited, return the corresponding node index. + Some((_, idx)) => *idx, + // If the expression has not been visited, add a new node to the graph and + // add edges to its child nodes. Add the visited expression to the vector + // of visited expressions and return the newly created node index. + None => { + let node_idx = self.graph.add_node((self.constructor)(&node)); + for expr_node in node.child_nodes.iter() { + self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); + } + self.visited_plans.push((expr.clone(), node_idx)); + node_idx + } + }; + // Set the data field of the input expression node to the corresponding node index. + node.data = Some(node_idx); + // Return the mutated expression node. + Ok(node) } } +// A function that builds a directed acyclic graph of physical expression trees. +pub fn build_dag( + expr: Arc, + constructor: &F, +) -> Result<(NodeIndex, StableGraph)> +where + F: Fn(&ExprTreeNode) -> T, +{ + // Create a new expression tree node from the input expression. + let init = ExprTreeNode::new(expr); + // Create a new `PhysicalExprDAEGBuilder` instance. + let mut builder = PhysicalExprDAEGBuilder { + graph: StableGraph::::new(), + visited_plans: Vec::<(Arc, NodeIndex)>::new(), + constructor, + }; + // Use the builder to transform the expression tree node into a DAG. + let root = init.transform_using(&mut builder)?; + // Return a tuple containing the root node index and the DAG. + Ok((root.data.unwrap(), builder.graph)) +} + +fn collect_columns_recursive( + expr: &Arc, + columns: &mut HashSet, +) { + if let Some(column) = expr.as_any().downcast_ref::() { + if !columns.iter().any(|c| c.eq(column)) { + columns.insert(column.clone()); + } + } + expr.children() + .iter() + .for_each(|e| collect_columns_recursive(e, columns)) +} + +/// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. +pub fn collect_columns(expr: &Arc) -> HashSet { + let mut columns = HashSet::::new(); + collect_columns_recursive(expr, &mut columns); + columns +} + /// Re-assign column indices referenced in predicate according to given schema. -/// /// This may be helpful when dealing with projections. pub fn reassign_predicate_columns( pred: Arc, @@ -283,19 +378,19 @@ pub fn reassign_predicate_columns( ignore_not_found: bool, ) -> Result, DataFusionError> { let mut rewriter = ColumnAssigner { - schema: schema.clone(), + schema, ignore_not_found, }; - pred.clone().transform_using(&mut rewriter) + pred.transform_using(&mut rewriter) } #[derive(Debug)] -struct ColumnAssigner { - schema: SchemaRef, +struct ColumnAssigner<'a> { + schema: &'a SchemaRef, ignore_not_found: bool, } -impl TreeNodeRewriter> for ColumnAssigner { +impl<'a> TreeNodeRewriter> for ColumnAssigner<'a> { fn mutate( &mut self, expr: Arc, @@ -315,16 +410,122 @@ impl TreeNodeRewriter> for ColumnAssigner { #[cfg(test)] mod tests { - use super::*; - use crate::expressions::Column; + use crate::expressions::{binary, cast, col, lit, Column, Literal}; use crate::PhysicalSortExpr; use arrow::compute::SortOptions; - use datafusion_common::Result; + use datafusion_common::{Result, ScalarValue}; + use std::fmt::{Display, Formatter}; - use arrow_schema::Schema; + use arrow_schema::{DataType, Field, Schema}; + use petgraph::visit::Bfs; use std::sync::Arc; + #[derive(Clone)] + struct DummyProperty { + expr_type: String, + } + + /// This is a dummy node in the DAEG; it stores a reference to the actual + /// [PhysicalExpr] as well as a dummy property. + #[derive(Clone)] + struct PhysicalExprDummyNode { + pub expr: Arc, + pub property: DummyProperty, + } + + impl Display for PhysicalExprDummyNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.expr) + } + } + + fn make_dummy_node(node: &ExprTreeNode) -> PhysicalExprDummyNode { + let expr = node.expression().clone(); + let dummy_property = if expr.as_any().is::() { + "Binary" + } else if expr.as_any().is::() { + "Column" + } else if expr.as_any().is::() { + "Literal" + } else { + "Other" + } + .to_owned(); + PhysicalExprDummyNode { + expr, + property: DummyProperty { + expr_type: dummy_property, + }, + } + } + + #[test] + fn test_build_dag() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + Field::new("2", DataType::Int32, true), + ]); + let expr = binary( + cast( + binary( + col("0", &schema)?, + Operator::Plus, + col("1", &schema)?, + &schema, + )?, + &schema, + DataType::Int64, + )?, + Operator::Gt, + binary( + cast(col("2", &schema)?, &schema, DataType::Int64)?, + Operator::Plus, + lit(ScalarValue::Int64(Some(10))), + &schema, + )?, + &schema, + )?; + let mut vector_dummy_props = vec![]; + let (root, graph) = build_dag(expr, &make_dummy_node)?; + let mut bfs = Bfs::new(&graph, root); + while let Some(node_index) = bfs.next(&graph) { + let node = &graph[node_index]; + vector_dummy_props.push(node.property.clone()); + } + + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Binary") + .count(), + 3 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Column") + .count(), + 3 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Literal") + .count(), + 1 + ); + assert_eq!( + vector_dummy_props + .iter() + .filter(|property| property.expr_type == "Other") + .count(), + 2 + ); + Ok(()) + } + #[test] fn expr_list_eq_test() -> Result<()> { let list1: Vec> = vec![