diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs index f84ed7fafe503..86aba232e6a19 100644 --- a/datafusion/core/src/physical_plan/hash_join.rs +++ b/datafusion/core/src/physical_plan/hash_join.rs @@ -76,6 +76,8 @@ use crate::arrow::datatypes::TimeUnit; use crate::execution::context::TaskContext; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::PhysicalExpr; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::try_cast; use log::debug; use std::fmt; @@ -295,7 +297,28 @@ impl ExecutionPlan for HashJoinExec { partition: usize, context: Arc, ) -> Result { - let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); + let (on_left, on_right): (Vec<_>, Vec<_>) = self + .on + .iter() + .map(|on| { + let l = Arc::new(on.0.clone()); + let r = Arc::new(on.1.clone()); + + let lt = l.data_type(&self.left.schema()).unwrap(); + let rt = r.data_type(&self.right.schema()).unwrap(); + let res_type = + datafusion_expr::binary_rule::coerce_types(<, &Operator::Eq, &rt) + .unwrap(); + + let left_cast = + try_cast(l, &self.left.schema(), res_type.clone()).unwrap(); + let right_cast = + try_cast(r, &self.right.schema(), res_type).unwrap(); + + (left_cast, right_cast) + }) + .unzip(); + // we only want to compute the build side once for PartitionMode::CollectLeft let left_data = { match self.mode { @@ -414,7 +437,6 @@ impl ExecutionPlan for HashJoinExec { // over the right that uses this information to issue new batches. let right_stream = self.right.execute(partition, context.clone()).await?; - let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { @@ -473,7 +495,7 @@ impl ExecutionPlan for HashJoinExec { /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, /// assuming that the [RecordBatch] corresponds to the `index`th fn update_hash( - on: &[Column], + on: &[Arc], batch: &RecordBatch, hash_map: &mut JoinHashMap, offset: usize, @@ -512,9 +534,9 @@ struct HashJoinStream { /// Input schema schema: Arc, /// columns from the left - on_left: Vec, + on_left: Vec>, /// columns from the right used to compute the hash - on_right: Vec, + on_right: Vec>, /// type of the join join_type: JoinType, /// information from the left @@ -539,8 +561,8 @@ struct HashJoinStream { impl HashJoinStream { fn new( schema: Arc, - on_left: Vec, - on_right: Vec, + on_left: Vec>, + on_right: Vec>, join_type: JoinType, left_data: JoinLeftData, right: SendableRecordBatchStream, @@ -624,8 +646,8 @@ fn build_batch_from_indices( fn build_batch( batch: &RecordBatch, left_data: &JoinLeftData, - on_left: &[Column], - on_right: &[Column], + on_left: &[Arc], + on_right: &[Arc], join_type: JoinType, schema: &Schema, column_indices: &[ColumnIndex], @@ -691,8 +713,8 @@ fn build_join_indexes( left_data: &JoinLeftData, right: &RecordBatch, join_type: JoinType, - left_on: &[Column], - right_on: &[Column], + left_on: &[Arc], + right_on: &[Arc], random_state: &RandomState, null_equals_null: &bool, ) -> Result<(UInt64Array, UInt32Array)> { @@ -2002,8 +2024,8 @@ mod tests { &left_data, &right, JoinType::Inner, - &[Column::new("a", 0)], - &[Column::new("a", 0)], + &[Arc::new(Column::new("a", 0))], + &[Arc::new(Column::new("a", 0))], &random_state, &false, )?;