diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index daa296b31a3f..eceedb244043 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -21,15 +21,8 @@ use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; -use crate::error::BallistaError; -use crate::execution_plans::{ - ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, -}; -use crate::serde::protobuf::repartition_exec_node::PartitionMethod; -use crate::serde::protobuf::ShuffleReaderPartition; -use crate::serde::scheduler::PartitionLocation; -use crate::serde::{from_proto_binary_op, proto_error, protobuf}; -use crate::{convert_box_required, convert_required, into_required}; +use log::debug; + use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, @@ -46,7 +39,8 @@ use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunc use datafusion::physical_plan::avro::{AvroExec, AvroReadOptions}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; -use datafusion::physical_plan::hash_join::PartitionMode; +use datafusion::physical_plan::joins::cross_join::CrossJoinExec; +use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::parquet::ParquetPartition; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; @@ -56,7 +50,6 @@ use datafusion::physical_plan::window_functions::{ use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, - cross_join::CrossJoinExec, csv::CsvExec, empty::EmptyExec, expressions::{ @@ -65,7 +58,6 @@ use datafusion::physical_plan::{ }, filter::FilterExec, functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, - hash_join::HashJoinExec, limit::{GlobalLimitExec, LocalLimitExec}, parquet::ParquetExec, projection::ProjectionExec, @@ -78,10 +70,19 @@ use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; use datafusion::prelude::CsvReadOptions; -use log::debug; use protobuf::physical_expr_node::ExprType; use protobuf::physical_plan_node::PhysicalPlanType; +use crate::error::BallistaError; +use crate::execution_plans::{ + ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, +}; +use crate::serde::protobuf::repartition_exec_node::PartitionMethod; +use crate::serde::protobuf::ShuffleReaderPartition; +use crate::serde::scheduler::PartitionLocation; +use crate::serde::{from_proto_binary_op, proto_error, protobuf}; +use crate::{convert_box_required, convert_required, into_required}; + impl TryInto> for &protobuf::PhysicalPlanNode { type Error = BallistaError; diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index 990b92435b79..0e78349ddc33 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -22,6 +22,7 @@ pub mod to_proto; mod roundtrip_tests { use std::{convert::TryInto, sync::Arc}; + use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::{ arrow::{ compute::sort::SortOptions, @@ -34,7 +35,6 @@ mod roundtrip_tests { expressions::{Avg, Column, PhysicalSortExpr}, filter::FilterExec, hash_aggregate::{AggregateMode, HashAggregateExec}, - hash_join::{HashJoinExec, PartitionMode}, limit::{GlobalLimitExec, LocalLimitExec}, sorts::sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, @@ -43,9 +43,10 @@ mod roundtrip_tests { scalar::ScalarValue, }; + use crate::execution_plans::ShuffleWriterExec; + use super::super::super::error::Result; use super::super::protobuf; - use crate::execution_plans::ShuffleWriterExec; fn roundtrip_test(exec_plan: Arc) -> Result<()> { let proto: protobuf::PhysicalPlanNode = exec_plan.clone().try_into()?; diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 175e54550610..420c9a99e773 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -28,7 +28,6 @@ use std::{ use datafusion::logical_plan::JoinType; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::cross_join::CrossJoinExec; use datafusion::physical_plan::csv::CsvExec; use datafusion::physical_plan::expressions::{ CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr, @@ -36,7 +35,8 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; -use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::joins::cross_join::CrossJoinExec; +use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::parquet::{ParquetExec, ParquetPartition}; use datafusion::physical_plan::projection::ProjectionExec; diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 5c97de9cc588..38b0ce34d5f7 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -56,7 +56,7 @@ use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{BinaryExpr, Column, Literal}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::HashAggregateExec; -use datafusion::physical_plan::hash_join::HashJoinExec; +use datafusion::physical_plan::joins::hash_join::HashJoinExec; use datafusion::physical_plan::parquet::ParquetExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index ecbe938e30cf..23142a987214 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -251,7 +251,7 @@ mod test { use ballista_core::serde::protobuf; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; - use datafusion::physical_plan::hash_join::HashJoinExec; + use datafusion::physical_plan::joins::hash_join::HashJoinExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec, diff --git a/datafusion/src/arrow_dyn_list_array.rs b/datafusion/src/arrow_dyn_list_array.rs new file mode 100644 index 000000000000..4f2d9b77e236 --- /dev/null +++ b/datafusion/src/arrow_dyn_list_array.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DynMutableListArray from arrow/io/avro/read/nested.rs + +use arrow::array::{Array, ListArray, MutableArray, Offset}; +use arrow::bitmap::MutableBitmap; +use arrow::buffer::MutableBuffer; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use std::sync::Arc; + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableListArray { + data_type: DataType, + offsets: MutableBuffer, + values: Box, + validity: Option, +} + +impl DynMutableListArray { + pub fn new_from( + values: Box, + data_type: DataType, + capacity: usize, + ) -> Self { + let mut offsets = MutableBuffer::::with_capacity(capacity + 1); + offsets.push(O::default()); + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&data_type); + Self { + data_type, + offsets, + values, + validity: None, + } + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_with_capacity(values: Box, capacity: usize) -> Self { + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self::new_from(values, data_type, capacity) + } + + /// The values + pub fn mut_values(&mut self) -> &mut dyn MutableArray { + self.values.as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> Result<(), ArrowError> { + let size = self.values.len(); + let size = O::from_usize(size).ok_or(ArrowError::KeyOverflowError)?; // todo: make this error + assert!(size >= *self.offsets.last().unwrap()); + + self.offsets.push(size); + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.push(self.last_offset()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + #[inline] + fn last_offset(&self) -> O { + *self.offsets.last().unwrap() + } + + fn init_validity(&mut self) { + let len = self.offsets.len() - 1; + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableListArray { + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + Box::new(ListArray::from_data( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_arc(), + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn as_arc(&mut self) -> Arc { + Arc::new(ListArray::from_data( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_arc(), + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index bb90b1703931..2539715a719a 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -230,6 +230,7 @@ pub mod variable; // re-export dependencies from arrow-rs to minimise version maintenance for crate users pub use arrow; +mod arrow_dyn_list_array; mod arrow_temporal_util; #[cfg(test)] diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 9af8911062df..38624876922c 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -18,15 +18,18 @@ //! CoalesceBatches optimizer that groups batches together rows //! in bigger batches to avoid overhead with small batches -use super::optimizer::PhysicalOptimizerRule; +use std::sync::Arc; + +use crate::physical_plan::joins::hash_join::HashJoinExec; use crate::{ error::Result, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, - hash_join::HashJoinExec, repartition::RepartitionExec, + repartition::RepartitionExec, }, }; -use std::sync::Arc; + +use super::optimizer::PhysicalOptimizerRule; /// Optimizer that introduces CoalesceBatchesExec to avoid overhead with small batches pub struct CoalesceBatches {} diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs index 0b87ceb1a4e2..f82f14ec1148 100644 --- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs +++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs @@ -20,17 +20,17 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use crate::error::Result; use crate::execution::context::ExecutionConfig; use crate::logical_plan::JoinType; -use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::expressions::Column; -use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::joins::cross_join::CrossJoinExec; +use crate::physical_plan::joins::hash_join::HashJoinExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use super::utils::optimize_children; -use crate::error::Result; /// BuildProbeOrder reorders the build and probe phase of /// hash joins. This uses the amount of rows that a datasource has. @@ -153,16 +153,15 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder { #[cfg(test)] mod tests { - use crate::{ - physical_plan::{hash_join::PartitionMode, Statistics}, - test::exec::StatisticsExec, - }; - - use super::*; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; + use crate::physical_plan::joins::hash_join::PartitionMode; + use crate::{physical_plan::Statistics, test::exec::StatisticsExec}; + + use super::*; + fn create_big_and_small() -> (Arc, Arc) { let big = Arc::new(StatisticsExec::new( Statistics { diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index e2e849085484..6607144657d8 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -142,6 +142,18 @@ impl PhysicalSortExpr { } } +/// Convert sort expressions into Vec that can be passed into arrow sort kernel +pub fn exprs_to_sort_columns( + batch: &RecordBatch, + expr: &[PhysicalSortExpr], +) -> Result> { + let columns = expr + .iter() + .map(|e| e.evaluate_to_sort_column(&batch)) + .collect::>>(); + columns +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 494fe3f3dd5b..f5d20202a8c2 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -17,96 +17,17 @@ //! Functionality used both on logical and physical plans -use crate::error::{DataFusionError, Result}; +use std::sync::Arc; + pub use ahash::{CallHasher, RandomState}; use arrow::array::{ Array, ArrayRef, BooleanArray, DictionaryArray, DictionaryKey, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use std::collections::HashSet; -use std::sync::Arc; - -use crate::logical_plan::JoinType; -use crate::physical_plan::expressions::Column; - -/// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = Vec<(Column, Column)>; -/// Reference for JoinOn. -pub type JoinOnRef<'a> = &'a [(Column, Column)]; - -/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. -/// They are valid whenever their columns' intersection equals the set `on` -pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { - let left: HashSet = left - .fields() - .iter() - .enumerate() - .map(|(idx, f)| Column::new(f.name(), idx)) - .collect(); - let right: HashSet = right - .fields() - .iter() - .enumerate() - .map(|(idx, f)| Column::new(f.name(), idx)) - .collect(); - - check_join_set_is_valid(&left, &right, on) -} - -/// Checks whether the sets left, right and on compose a valid join. -/// They are valid whenever their intersection equals the set `on` -fn check_join_set_is_valid( - left: &HashSet, - right: &HashSet, - on: &[(Column, Column)], -) -> Result<()> { - let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); - let left_missing = on_left.difference(left).collect::>(); - - let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); - let right_missing = on_right.difference(right).collect::>(); - - if !left_missing.is_empty() | !right_missing.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}", - left_missing, - right_missing, - ))); - }; - - let remaining = right - .difference(on_right) - .cloned() - .collect::>(); - - let collisions = left.intersection(&remaining).collect::>(); +use arrow::datatypes::{DataType, TimeUnit}; - if !collisions.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.", - collisions, - ))); - }; - - Ok(()) -} - -/// Creates a schema for a join operation. -/// The fields from the left side are first -pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { - let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let left_fields = left.fields().iter(); - let right_fields = right.fields().iter(); - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinType::Semi | JoinType::Anti => left.fields().clone(), - }; - Schema::new(fields) -} +use crate::error::{DataFusionError, Result}; // Combines two hashes into one hash #[inline] @@ -597,66 +518,9 @@ mod tests { use arrow::array::TryExtend; use arrow::array::{MutableDictionaryArray, MutableUtf8Array}; - use super::*; - - fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { - let left = left - .iter() - .map(|x| x.to_owned()) - .collect::>(); - let right = right - .iter() - .map(|x| x.to_owned()) - .collect::>(); - check_join_set_is_valid(&left, &right, on) - } - - #[test] - fn check_valid() -> Result<()> { - let left = vec![Column::new("a", 0), Column::new("b1", 1)]; - let right = vec![Column::new("a", 0), Column::new("b2", 1)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - check(&left, &right, on)?; - Ok(()) - } - - #[test] - fn check_not_in_right() { - let left = vec![Column::new("a", 0), Column::new("b", 1)]; - let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_not_in_left() { - let left = vec![Column::new("b", 0)]; - let right = vec![Column::new("a", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - assert!(check(&left, &right, on).is_err()); - } + use crate::physical_plan::joins::check_join_set_is_valid; - #[test] - fn check_collision() { - // column "a" would appear both in left and right - let left = vec![Column::new("a", 0), Column::new("c", 1)]; - let right = vec![Column::new("a", 0), Column::new("b", 1)]; - let on = &[(Column::new("a", 0), Column::new("b", 1))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_in_right() { - let left = vec![Column::new("a", 0), Column::new("c", 1)]; - let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("b", 0))]; - - assert!(check(&left, &right, on).is_ok()); - } + use super::*; #[test] fn create_hashes_for_float_arrays() -> Result<()> { diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/joins/cross_join.rs similarity index 99% rename from datafusion/src/physical_plan/cross_join.rs rename to datafusion/src/physical_plan/joins/cross_join.rs index 0edc3cc0d1aa..e3f9b1bc36e0 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/joins/cross_join.rs @@ -18,32 +18,30 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use futures::{lock::Mutex, StreamExt}; +use std::time::Instant; use std::{any::Any, sync::Arc, task::Poll}; -use crate::physical_plan::memory::MemoryStream; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; - +use async_trait::async_trait; +use futures::{lock::Mutex, StreamExt}; use futures::{Stream, TryStreamExt}; +use log::debug; -use super::{ - coalesce_partitions::CoalescePartitionsExec, hash_utils::check_join_is_valid, - ColumnStatistics, Statistics, +use crate::physical_plan::joins::check_join_is_valid; +use crate::physical_plan::memory::MemoryStream; +use crate::physical_plan::{ + coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, +}; +use crate::physical_plan::{ + coalesce_partitions::CoalescePartitionsExec, ColumnStatistics, Statistics, }; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, }; -use async_trait::async_trait; -use std::time::Instant; - -use super::{ - coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, -}; -use log::debug; /// Data of the left side type JoinLeftData = RecordBatch; diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/joins/hash_join.rs similarity index 93% rename from datafusion/src/physical_plan/hash_join.rs rename to datafusion/src/physical_plan/joins/hash_join.rs index 259cba65db56..9b2508ab4218 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/joins/hash_join.rs @@ -18,46 +18,42 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. -use ahash::RandomState; - -use smallvec::{smallvec, SmallVec}; +use std::fmt; use std::sync::Arc; use std::{any::Any, usize}; use std::{time::Instant, vec}; -use async_trait::async_trait; -use futures::{Stream, StreamExt, TryStreamExt}; -use tokio::sync::Mutex; - +use ahash::RandomState; +use arrow::compute::take; use arrow::datatypes::*; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::{array::*, buffer::MutableBuffer}; - -use arrow::compute::take; - +use async_trait::async_trait; +use futures::{Stream, StreamExt, TryStreamExt}; use hashbrown::raw::RawTable; +use log::debug; +use smallvec::{smallvec, SmallVec}; +use tokio::sync::Mutex; -use super::{ - coalesce_partitions::CoalescePartitionsExec, - hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; +use crate::physical_plan::coalesce_batches::concat_batches; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::joins::{ + build_join_schema, check_join_is_valid, column_indices_from_schema, equal_rows, + ColumnIndex, JoinOn, }; -use super::{ +use crate::physical_plan::PhysicalExpr; +use crate::physical_plan::{ expressions::Column, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, }; -use super::{hash_utils::create_hashes, Statistics}; -use crate::error::{DataFusionError, Result}; -use crate::logical_plan::JoinType; - -use super::{ +use crate::physical_plan::{hash_utils::create_hashes, Statistics}; +use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; -use crate::physical_plan::coalesce_batches::concat_batches; -use crate::physical_plan::PhysicalExpr; -use log::debug; -use std::fmt; type StringArray = Utf8Array; type LargeStringArray = Utf8Array; @@ -156,14 +152,6 @@ pub enum PartitionMode { CollectLeft, } -/// Information about the index and placement (left or right) of the columns -struct ColumnIndex { - /// Index of the column - index: usize, - /// Whether the column is at the left or right side - is_left: bool, -} - impl HashJoinExec { /// Tries to create a new [HashJoinExec]. /// # Error @@ -220,38 +208,6 @@ impl HashJoinExec { pub fn partition_mode(&self) -> &PartitionMode { &self.mode } - - /// Calculates column indices and left/right placement on input / output schemas and jointype - fn column_indices_from_schema(&self) -> ArrowResult> { - let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Full - | JoinType::Semi - | JoinType::Anti => (true, self.left.schema(), self.right.schema()), - JoinType::Right => (false, self.right.schema(), self.left.schema()), - }; - let mut column_indices = Vec::with_capacity(self.schema.fields().len()); - for field in self.schema.fields() { - let (is_primary, index) = match primary_schema.index_of(field.name()) { - Ok(i) => Ok((true, i)), - Err(_) => { - match secondary_schema.index_of(field.name()) { - Ok(i) => Ok((false, i)), - _ => Err(DataFusionError::Internal( - format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() - )) - } - } - }.map_err(DataFusionError::into_arrow_external_error)?; - - let is_left = - is_primary && primary_is_left || !is_primary && !primary_is_left; - column_indices.push(ColumnIndex { index, is_left }); - } - - Ok(column_indices) - } } #[async_trait] @@ -412,7 +368,12 @@ impl ExecutionPlan for HashJoinExec { let right_stream = self.right.execute(partition).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); - let column_indices = self.column_indices_from_schema()?; + let column_indices = column_indices_from_schema( + &self.join_type, + &self.left.schema(), + &self.right.schema(), + &self.schema, + )?; let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { @@ -780,57 +741,6 @@ fn build_join_indexes( } } -macro_rules! equal_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ - let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); - let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); - - match (left_array.is_null($left), right_array.is_null($right)) { - (false, false) => left_array.value($left) == right_array.value($right), - _ => false, - } - }}; -} - -/// Left and right row have equal values -fn equal_rows( - left: usize, - right: usize, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], -) -> Result { - let mut err = None; - let res = left_arrays - .iter() - .zip(right_arrays) - .all(|(l, r)| match l.data_type() { - DataType::Null => true, - DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), - DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), - DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), - DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), - DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), - DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), - DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), - DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), - DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), - DataType::Timestamp(_, None) => { - equal_rows_elem!(Int64Array, l, r, left, right) - } - DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), - DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), - _ => { - // This is internal because we should have caught this before. - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - }); - - err.unwrap_or(Ok(res)) -} - // Produces a batch for left-side rows that have/have not been matched during the whole join fn produce_from_matched( visited_left_side: &[bool], @@ -964,6 +874,8 @@ impl Stream for HashJoinStream { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{ assert_batches_sorted_eq, physical_plan::{ @@ -973,7 +885,6 @@ mod tests { }; use super::*; - use std::sync::Arc; fn build_table( a: (&str, &Vec), diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs new file mode 100644 index 000000000000..bbf47cb21625 --- /dev/null +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -0,0 +1,359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod cross_join; +pub mod hash_join; +mod smj_utils; +pub mod sort_merge_join; + +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; +use crate::physical_plan::expressions::Column; +use arrow::array::ArrayRef; +use arrow::array::*; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::error::Result as ArrowResult; +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + +type StringArray = Utf8Array; +type LargeStringArray = Utf8Array; + +/// The on clause of the join, as vector of (left, right) columns. +pub type JoinOn = Vec<(Column, Column)>; +/// Reference for JoinOn. +pub type JoinOnRef<'a> = &'a [(Column, Column)]; + +/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. +/// They are valid whenever their columns' intersection equals the set `on` +pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { + let left: HashSet = left + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + let right: HashSet = right + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + + check_join_set_is_valid(&left, &right, on) +} + +/// Checks whether the sets left, right and on compose a valid join. +/// They are valid whenever their intersection equals the set `on` +pub fn check_join_set_is_valid( + left: &HashSet, + right: &HashSet, + on: &[(Column, Column)], +) -> Result<()> { + let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); + let left_missing = on_left.difference(left).collect::>(); + + let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); + let right_missing = on_right.difference(right).collect::>(); + + if !left_missing.is_empty() | !right_missing.is_empty() { + return Err(DataFusionError::Plan(format!( + "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}", + left_missing, + right_missing, + ))); + }; + + let remaining = right + .difference(on_right) + .cloned() + .collect::>(); + + let collisions = left.intersection(&remaining).collect::>(); + + if !collisions.is_empty() { + return Err(DataFusionError::Plan(format!( + "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.", + collisions, + ))); + }; + + Ok(()) +} + +/// Creates a schema for a join operation. +/// The fields from the left side are first +pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { + let fields: Vec = match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let left_fields = left.fields().iter(); + let right_fields = right.fields().iter(); + // left then right + left_fields.chain(right_fields).cloned().collect() + } + JoinType::Semi | JoinType::Anti => left.fields().clone(), + }; + Schema::new(fields) +} + +/// Information about the index and placement (left or right) of the columns +struct ColumnIndex { + /// Index of the column + index: usize, + /// Whether the column is at the left or right side + is_left: bool, +} + +/// Calculates column indices and left/right placement on input / output schemas and jointype +pub fn column_indices_from_schema( + join_type: &JoinType, + left_schema: &Arc, + right_schema: &Arc, + schema: &Arc, +) -> ArrowResult> { + let (primary_is_left, primary_schema, secondary_schema) = match join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => (true, left_schema, right_schema), + JoinType::Right => (false, right_schema, left_schema), + }; + let mut column_indices = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + let (is_primary, index) = match primary_schema.index_of(field.name()) { + Ok(i) => Ok((true, i)), + Err(_) => { + match secondary_schema.index_of(field.name()) { + Ok(i) => Ok((false, i)), + _ => Err(DataFusionError::Internal( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }.map_err(DataFusionError::into_arrow_external_error)?; + + let is_left = is_primary && primary_is_left || !is_primary && !primary_is_left; + column_indices.push(ColumnIndex { index, is_left }); + } + + Ok(column_indices) +} + +macro_rules! equal_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => left_array.value($left) == right_array.value($right), + _ => false, + } + }}; +} + +/// Left and right row have equal values +fn equal_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result { + let mut err = None; + let res = left_arrays + .iter() + .zip(right_arrays) + .all(|(l, r)| match l.data_type() { + DataType::Null => true, + DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), + DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), + DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), + DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), + DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), + DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), + DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), + DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), + DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), + DataType::Timestamp(_, None) => { + equal_rows_elem!(Int64Array, l, r, left, right) + } + DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), + DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), + _ => { + // This is internal because we should have caught this before. + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + }); + + err.unwrap_or(Ok(res)) +} + +macro_rules! cmp_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $res: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => { + let cmp = left_array + .value($left) + .partial_cmp(&right_array.value($right)) + .unwrap(); + if cmp != Ordering::Equal { + $res = cmp; + break; + } + } + _ => unreachable!(), + } + }}; +} + +macro_rules! cmp_rows_elem_str { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $res: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => { + let cmp = left_array + .value($left) + .partial_cmp(right_array.value($right)) + .unwrap(); + if cmp != Ordering::Equal { + $res = cmp; + break; + } + } + _ => unreachable!(), + } + }}; +} + +/// compare left row with right row +fn comp_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result { + let mut res = Ordering::Equal; + for (l, r) in left_arrays.iter().zip(right_arrays) { + match l.data_type() { + DataType::Null => {} + DataType::Boolean => cmp_rows_elem!(BooleanArray, l, r, left, right, res), + DataType::Int8 => cmp_rows_elem!(Int8Array, l, r, left, right, res), + DataType::Int16 => cmp_rows_elem!(Int16Array, l, r, left, right, res), + DataType::Int32 => cmp_rows_elem!(Int32Array, l, r, left, right, res), + DataType::Int64 => cmp_rows_elem!(Int64Array, l, r, left, right, res), + DataType::UInt8 => cmp_rows_elem!(UInt8Array, l, r, left, right, res), + DataType::UInt16 => cmp_rows_elem!(UInt16Array, l, r, left, right, res), + DataType::UInt32 => cmp_rows_elem!(UInt32Array, l, r, left, right, res), + DataType::UInt64 => cmp_rows_elem!(UInt64Array, l, r, left, right, res), + DataType::Timestamp(_, None) => { + cmp_rows_elem!(Int64Array, l, r, left, right, res) + } + DataType::Utf8 => cmp_rows_elem_str!(StringArray, l, r, left, right, res), + DataType::LargeUtf8 => { + cmp_rows_elem_str!(LargeStringArray, l, r, left, right, res) + } + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal( + "Unsupported data type in sort merge join comparator".to_string(), + )); + } + } + } + + Ok(res) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::TryExtend; + use arrow::array::{MutableDictionaryArray, MutableUtf8Array}; + + use crate::physical_plan::joins::check_join_set_is_valid; + + use super::*; + + fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { + let left = left + .iter() + .map(|x| x.to_owned()) + .collect::>(); + let right = right + .iter() + .map(|x| x.to_owned()) + .collect::>(); + check_join_set_is_valid(&left, &right, on) + } + + #[test] + fn check_valid() -> Result<()> { + let left = vec![Column::new("a", 0), Column::new("b1", 1)]; + let right = vec![Column::new("a", 0), Column::new("b2", 1)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + check(&left, &right, on)?; + Ok(()) + } + + #[test] + fn check_not_in_right() { + let left = vec![Column::new("a", 0), Column::new("b", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_not_in_left() { + let left = vec![Column::new("b", 0)]; + let right = vec![Column::new("a", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_collision() { + // column "a" would appear both in left and right + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("a", 0), Column::new("b", 1)]; + let on = &[(Column::new("a", 0), Column::new("b", 1))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_in_right() { + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("b", 0))]; + + assert!(check(&left, &right, on).is_ok()); + } +} diff --git a/datafusion/src/physical_plan/joins/smj_utils.rs b/datafusion/src/physical_plan/joins/smj_utils.rs new file mode 100644 index 000000000000..4614f9442dca --- /dev/null +++ b/datafusion/src/physical_plan/joins/smj_utils.rs @@ -0,0 +1,1190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arrow_dyn_list_array::DynMutableListArray; +use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::logical_plan::JoinType; +use crate::physical_plan::expressions::{ + exprs_to_sort_columns, Column, PhysicalSortExpr, +}; +use crate::physical_plan::joins::{comp_rows, equal_rows, ColumnIndex}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use arrow::array::*; +use arrow::array::{ArrayRef, MutableArray, MutableBooleanArray}; +use arrow::compute::partition::lexicographical_partition_ranges; +use arrow::datatypes::*; +use arrow::error::ArrowError; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use futures::{Stream, StreamExt}; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::iter::repeat; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub(crate) fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { + on_column + .iter() + .map(|c| rb.column(c.index()).clone()) + .collect() +} + +#[derive(Clone)] +struct PartitionedRecordBatch { + batch: RecordBatch, + ranges: Vec>, +} + +impl PartitionedRecordBatch { + fn new( + batch: Option, + expr: &[PhysicalSortExpr], + ) -> ArrowResult> { + match batch { + Some(batch) => { + let columns = exprs_to_sort_columns(&batch, expr) + .map_err(DataFusionError::into_arrow_external_error)?; + let ranges = lexicographical_partition_ranges( + &columns.iter().map(|x| x.into()).collect::>(), + )? + .collect::>(); + Ok(Some(Self { batch, ranges })) + } + None => Ok(None), + } + } + + #[inline] + pub fn is_last_range(&self, range: &Range) -> bool { + range.end == self.batch.num_rows() + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Streaming Side +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct StreamingSideBuffer { + batch: Option, + cur_row: usize, + cur_range: usize, + num_rows: usize, + num_ranges: usize, + is_new_key: bool, + on_column: Vec, +} + +impl StreamingSideBuffer { + fn new(on_column: Vec) -> Self { + Self { + batch: None, + cur_row: 0, + cur_range: 0, + num_rows: 0, + num_ranges: 0, + is_new_key: true, + on_column, + } + } + + fn join_arrays(&self) -> Vec { + join_arrays(&self.batch.unwrap().batch, &self.on_column) + } + + fn reset(&mut self, prb: Option) { + self.batch = prb; + if let Some(prb) = &self.batch { + self.cur_row = 0; + self.cur_range = 0; + self.num_rows = prb.batch.num_rows(); + self.num_ranges = prb.ranges.len(); + self.is_new_key = true; + }; + } + + fn key_any_null(&self) -> bool { + match &self.batch { + None => return true, + Some(batch) => { + for c in self.on_column { + let array = batch.batch.column(c.index()); + if array.is_null(self.cur_row) { + return true; + } + } + false + } + } + } + + #[inline] + fn is_finished(&self) -> bool { + self.batch.is_none() || self.num_rows == self.cur_row + 1 + } + + #[inline] + fn is_last_key_in_batch(&self) -> bool { + self.batch.is_none() || self.num_ranges == self.cur_range + 1 + } + + fn advance(&mut self) { + self.cur_row += 1; + self.is_new_key = false; + if !self.is_last_key_in_batch() { + let ranges = self.batch.unwrap().ranges; + if self.cur_row == ranges[self.cur_range + 1].start { + self.cur_range += 1; + self.is_new_key = true; + } + } + } + + fn advance_key(&mut self) { + let ranges = self.batch.unwrap().ranges; + self.cur_range += 1; + self.cur_row = ranges[self.cur_range].start; + self.is_new_key = true; + } + + fn repeat_cell( + &self, + times: usize, + to: &mut Box, + column_index: &ColumnIndex, + ) { + repeat_cell( + &self.batch.unwrap().batch, + self.cur_row, + times, + to, + column_index, + ); + } + + fn copy_slices( + &self, + slice: &Slice, + array: &mut Box, + column_index: &ColumnIndex, + ) { + let batches = vec![&self.batch.unwrap().batch]; + let slices = vec![slice.clone()]; + copy_slices(&batches, &slices, array, column_index); + } +} + +struct StreamingSideStream { + input: SendableRecordBatchStream, + buffer: StreamingSideBuffer, + input_is_finished: bool, + sort: Vec, +} + +impl StreamingSideStream { + fn new( + input: SendableRecordBatchStream, + on: Vec, + sort: Vec, + ) -> Self { + let buffer = StreamingSideBuffer::new(on); + Self { + input, + buffer, + input_is_finished: false, + sort, + } + } + + fn input_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.input_is_finished { + Poll::Ready(None) + } else { + match self.input.poll_next_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => { + self.input_is_finished = true; + Poll::Ready(None) + } + batch => { + let batch = batch.transpose()?; + let prb = PartitionedRecordBatch::new(batch, &self.sort)?; + self.buffer.reset(prb); + Poll::Ready(Some(Ok(()))) + } + }, + } + } + } + + fn advance( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.buffer.is_finished() && self.input_is_finished { + Poll::Ready(None) + } else { + if self.buffer.is_finished() { + match self.input_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(x) => { + x?; + if self.buffer.is_finished() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(()))) + } + } + }, + } + } else { + self.buffer.advance(); + Poll::Ready(Some(Ok(()))) + } + } + } + + fn advance_key( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.buffer.is_finished() && self.input_is_finished { + Poll::Ready(None) + } else { + if self.buffer.is_finished() || self.buffer.is_last_key_in_batch() { + match self.input_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(x) => { + x?; + if self.buffer.is_finished() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(()))) + } + } + }, + } + } else { + self.buffer.advance_key(); + Poll::Ready(Some(Ok(()))) + } + } + } + + fn advance_key_skip_null( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.buffer.is_finished() && self.input_is_finished { + Poll::Ready(None) + } else { + loop { + match self.advance_key(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(x) => { + x?; + if !self.buffer.key_any_null() { + return Poll::Ready(Some(Ok(()))); + } + } + }, + Poll::Pending => return Poll::Pending, + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Buffering Side +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Holding ranges for same key over several bathes +struct BufferingSideBuffer { + /// batches that contains the current key + /// TODO: make this spillable as well for skew on join key at buffer side + batches: VecDeque, + /// ranges in each PartitionedRecordBatch that contains the current key + ranges: VecDeque>, + /// row index in first batch to the record that starts this batch + key_idx: Option, + /// total number of rows for the current key + row_num: usize, + /// hold found but not currently used batch, to continue iteration + next_key_batch: Option, + /// Join on column + on_column: Vec, +} + +impl BufferingSideBuffer { + fn new(on_column: Vec) -> Self { + Self { + batches: VecDeque::new(), + ranges: VecDeque::new(), + key_idx: None, + row_num: 0, + next_key_batch: None, + on_column, + } + } + + fn join_arrays(&self) -> Vec { + join_arrays(&self.batches[0].batch, &self.on_column) + } + + fn key_any_null(&self) -> bool { + match &self.key_idx { + None => return true, + Some(key_idx) => { + let first_batch = &self.batches[0].batch; + for c in self.on_column { + let array = first_batch.column(c.index()); + if array.is_null(*key_idx) { + return true; + } + } + false + } + } + } + + fn is_finished(&self) -> ArrowResult { + match self.key_idx { + None => Ok(true), + Some(_) => match (self.batches.back(), self.ranges.back()) { + (Some(batch), Some(range)) => Ok(batch.is_last_range(range)), + _ => Err(ArrowError::Other(format!( + "Batches length {} not equal to ranges length {}", + self.batches.len(), + self.ranges.len() + ))), + }, + } + } + + /// Whether the running key ends at the current batch `prb`, true for continues, false for ends. + fn running_key(&mut self, prb: &PartitionedRecordBatch) -> ArrowResult { + let first_range = &prb.ranges[0]; + let range_len = first_range.len(); + let current_batch = &prb.batch; + let single_range = prb.ranges.len() == 1; + + // compare the first record in batch with the current key pointed by key_idx + match self.key_idx { + None => { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.row_num += range_len; + Ok(single_range) + } + Some(key_idx) => { + let key_arrays = join_arrays(&self.batches[0].batch, &self.on_column); + let current_arrays = join_arrays(current_batch, &self.on_column); + let equal = equal_rows(key_idx, 0, &key_arrays, ¤t_arrays) + .map_err(DataFusionError::into_arrow_external_error)?; + if equal { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.row_num += range_len; + Ok(single_range) + } else { + self.next_key_batch = Some(prb.clone()); + Ok(false) // running key ends + } + } + } + } + + fn cleanup(&mut self) { + self.batches.drain(..); + self.ranges.drain(..); + self.next_key_batch = None; + } + + fn reset(&mut self, prb: &PartitionedRecordBatch) { + self.cleanup(); + self.batches.push_back(prb.clone()); + let first_range = &prb.ranges[0]; + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.row_num = first_range.len(); + } + + /// Advance the cursor to the next key seen by this buffer + fn advance_key(&mut self) { + assert_eq!(self.batches.len(), self.ranges.len()); + if self.batches.len() > 1 { + self.batches.drain(0..(self.batches.len() - 1)); + self.ranges.drain(0..(self.batches.len() - 1)); + } + + if let Some(batch) = self.batches.pop_back() { + let tail_range = self.ranges.pop_back().unwrap(); + self.batches.push_back(batch); + let next_range_idx = batch + .ranges + .iter() + .enumerate() + .find(|(idx, range)| range.start == tail_range.start) + .unwrap() + .0; + self.key_idx = Some(tail_range.end); + self.ranges.push_back(batch.ranges[next_range_idx].clone()); + self.row_num = batch.ranges[next_range_idx].len(); + } + } + + /// Locate the starting idx for each of the ranges in the current buffer. + fn range_start_indices(&self) -> Vec { + let mut idx = 0; + let mut start_indices: Vec = vec![]; + self.ranges.iter().for_each(|r| { + start_indices.push(idx); + idx += r.len(); + }); + start_indices.push(usize::MAX); + start_indices + } + + /// Locate buffered records start from `buffered_idx` of `len`gth + fn slices_from_batches( + &self, + start_indices: &Vec, + buffered_idx: usize, + len: usize, + ) -> Vec { + let ranges = &self.ranges; + let mut idx = buffered_idx; + let mut slices: Vec = vec![]; + let mut remaining = len; + let find = start_indices + .iter() + .enumerate() + .find(|(i, start_idx)| **start_idx >= idx) + .unwrap(); + let mut batch_idx = if *find.1 == idx { find.0 } else { find.0 - 1 }; + + while remaining > 0 { + let current_range = &ranges[batch_idx]; + let range_start_idx = start_indices[batch_idx]; + let start_idx = idx - range_start_idx + current_range.start; + let range_available = current_range.len() - (idx - range_start_idx); + + if range_available >= remaining { + slices.push(Slice { + batch_idx, + start_idx, + len: remaining, + }); + remaining = 0; + } else { + slices.push(Slice { + batch_idx, + start_idx, + len: range_available, + }); + remaining -= range_available; + batch_idx += 1; + idx += range_available; + } + } + slices + } + + fn copy_slices( + &self, + slices: &Vec, + array: &mut Box, + column_index: &ColumnIndex, + ) { + let batches = self + .batches + .iter() + .map(|prb| &prb.batch) + .collect::>(); + copy_slices(&batches, slices, array, column_index); + } +} + +/// Slice of batch at `batch_idx` inside BufferingSideBuffer. +#[derive(Copy, Clone)] +struct Slice { + batch_idx: usize, + start_idx: usize, + len: usize, +} + +struct BufferingSideStream { + input: SendableRecordBatchStream, + buffer: BufferingSideBuffer, + input_is_finished: bool, + cumulating: bool, + sort: Vec, +} + +impl BufferingSideStream { + fn new( + input: SendableRecordBatchStream, + on: Vec, + sort: Vec, + ) -> Self { + let buffer = BufferingSideBuffer::new(on); + Self { + input, + buffer, + input_is_finished: false, + cumulating: false, + sort, + } + } + + fn input_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.input_is_finished { + Poll::Ready(None) + } else { + match self.input.poll_next_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => { + self.input_is_finished = true; + Poll::Ready(None) + } + batch => { + let batch = batch.transpose()?; + let prb = + PartitionedRecordBatch::new(batch, &self.sort).transpose(); + Poll::Ready(prb) + } + }, + } + } + } + + fn advance_key( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.buffer.is_finished()? && self.input_is_finished { + return Poll::Ready(None); + } else { + if self.cumulating { + match self.cumulate_same_keys(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => match x { + Some(x) => { + x?; + return Poll::Ready(Some(Ok(()))); + } + None => unreachable!(), + }, + } + } + + if self.buffer.is_finished()? { + return match &self.buffer.next_key_batch { + None => match self.input_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(x) => { + let prb = x?; + self.buffer.reset(&prb); + if prb.ranges.len() == 1 { + self.cumulating = true; + Poll::Pending + } else { + Poll::Ready(Some(Ok(()))) + } + } + }, + }, + Some(batch) => { + self.buffer.reset(batch); + if batch.ranges.len() == 1 { + self.cumulating = true; + Poll::Pending + } else { + Poll::Ready(Some(Ok(()))) + } + } + }; + } else { + self.buffer.advance_key(); + if self.buffer.batches[0].is_last_range(&self.buffer.ranges[0]) { + self.cumulating = true; + } else { + return Poll::Ready(Some(Ok(()))); + } + } + } + + unreachable!() + } + + fn advance_key_skip_null( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.buffer.is_finished()? && self.input_is_finished { + Poll::Ready(None) + } else { + loop { + match self.advance_key(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(x) => { + x?; + if !self.buffer.key_any_null() { + return Poll::Ready(Some(Ok(()))); + } + } + }, + Poll::Pending => return Poll::Pending, + } + } + } + } + + fn cumulate_same_keys( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.input_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => match x { + None => { + self.cumulating = false; + return Poll::Ready(Some(Ok(()))); + } + Some(x) => { + let prb = x?; + let buffer_more = self.buffer.running_key(&prb)?; + if !buffer_more { + self.cumulating = false; + return Poll::Ready(Some(Ok(()))); + } + } + }, + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Output +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct OutputBuffer { + arrays: Vec>, + target_batch_size: usize, + slots_available: usize, + schema: Arc, +} + +impl OutputBuffer { + fn new(target_batch_size: usize, schema: Arc) -> ArrowResult { + let arrays = new_arrays(&schema, target_batch_size)?; + Ok(Self { + arrays, + target_batch_size, + slots_available: target_batch_size, + schema, + }) + } + + fn output_and_reset(mut self) -> ArrowResult { + let result = make_batch(self.schema.clone(), self.arrays); + self.arrays = new_arrays(&self.schema, self.target_batch_size)?; + self.slots_available = self.target_batch_size; + result + } + + fn append(&mut self, size: usize) { + assert!(size <= self.slots_available); + self.slots_available -= size; + } + + #[inline] + fn is_full(&self) -> bool { + self.slots_available == 0 + } +} + +fn new_arrays( + schema: &Arc, + batch_size: usize, +) -> ArrowResult>> { + let arrays: Vec> = schema + .fields() + .iter() + .map(|field| { + let dt = field.data_type.to_logical_type(); + make_mutable(dt, batch_size) + }) + .collect::>()?; + Ok(arrays) +} + +fn make_batch( + schema: Arc, + mut arrays: Vec>, +) -> ArrowResult { + let columns = arrays.iter_mut().map(|array| array.as_arc()).collect(); + RecordBatch::try_new(schema, columns) +} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use arrow::datatypes::PrimitiveType::*; + use arrow::types::{days_ms, months_days_ns}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +fn make_mutable( + data_type: &DataType, + capacity: usize, +) -> ArrowResult> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => Box::new(MutableBooleanArray::with_capacity(capacity)) + as Box, + PhysicalType::Primitive(primitive) => { + with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }) + } + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) + as Box + } + PhysicalType::Utf8 => Box::new(MutableUtf8Array::::with_capacity(capacity)) + as Box, + _ => match data_type { + DataType::List(inner) => { + let values = make_mutable(inner.data_type(), 0)?; + Box::new(DynMutableListArray::::new_with_capacity( + values, capacity, + )) as Box + } + DataType::FixedSizeBinary(size) => Box::new( + MutableFixedSizeBinaryArray::with_capacity(*size as usize, capacity), + ) as Box, + other => { + return Err(ArrowError::NotYetImplemented(format!( + "making mutable of type {} is not implemented yet", + data_type + ))) + } + }, + }) +} + +macro_rules! repeat_n { + ($TO:ty, $FROM:ty, $N:expr, $to: ident, $from: ident, $idx: ident) => {{ + let to = $to.as_mut_any().downcast_mut::<$TO>().unwrap(); + let from = $from.as_any().downcast_ref::<$FROM>().unwrap(); + let repeat_iter = from + .slice($idx, 1) + .iter() + .flat_map(|v| repeat(v).take($N)) + .collect::>(); + to.extend_trusted_len(repeat_iter.into_iter()); + }}; +} + +/// repeat times of cell located by `idx` at streamed side to output +fn repeat_cell( + batch: &RecordBatch, + idx: usize, + times: usize, + to: &mut Box, + column_index: &ColumnIndex, +) { + let from = batch.column(column_index.index); + match to.data_type().to_physical_type() { + PhysicalType::Boolean => { + repeat_n!(MutableBooleanArray, BooleanArray, times, to, from, idx) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, times, to, from, idx), + PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, times, to, from, idx), + PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, times, to, from, idx), + PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, times, to, from, idx), + PrimitiveType::Float32 => { + repeat_n!(Float32Vec, Float32Array, times, to, from, idx) + } + PrimitiveType::Float64 => { + repeat_n!(Float64Vec, Float64Array, times, to, from, idx) + } + _ => todo!(), + }, + PhysicalType::Utf8 => { + repeat_n!(MutableUtf8Array, Utf8Array, times, to, from, idx) + } + _ => todo!(), + } +} + +macro_rules! copy_slices { + ($TO:ty, $FROM:ty, $array: ident, $batches: ident, $slices: ident, $column_index: ident) => {{ + let to = $array.as_mut_any().downcast_mut::<$TO>().unwrap(); + for pos in $slices { + let from = $batches[pos.batch_idx] + .column($column_index.index) + .slice(pos.start_idx, pos.len); + let from = from.as_any().downcast_ref::<$FROM>().unwrap(); + to.extend_trusted_len(from.iter()); + } + }}; +} + +fn copy_slices( + batches: &Vec<&RecordBatch>, + slices: &Vec, + array: &mut Box, + column_index: &ColumnIndex, +) { + // output buffered start `buffered_idx`, len `rows_to_output` + match array.data_type().to_physical_type() { + PhysicalType::Boolean => { + copy_slices!( + MutableBooleanArray, + BooleanArray, + array, + batches, + slices, + column_index + ) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => { + copy_slices!(Int8Vec, Int8Array, array, batches, slices, column_index) + } + PrimitiveType::Int16 => { + copy_slices!(Int16Vec, Int16Array, array, batches, slices, column_index) + } + PrimitiveType::Int32 => { + copy_slices!(Int32Vec, Int32Array, array, batches, slices, column_index) + } + PrimitiveType::Int64 => { + copy_slices!(Int64Vec, Int64Array, array, batches, slices, column_index) + } + PrimitiveType::Float32 => copy_slices!( + Float32Vec, + Float32Array, + array, + batches, + slices, + column_index + ), + PrimitiveType::Float64 => copy_slices!( + Float64Vec, + Float64Array, + array, + batches, + slices, + column_index + ), + _ => todo!(), + }, + PhysicalType::Utf8 => { + copy_slices!( + MutableUtf8Array, + Utf8Array, + array, + batches, + slices, + column_index + ) + } + _ => todo!(), + } +} + +pub struct SortMergeJoinCommon { + streamed: StreamingSideStream, + buffered: BufferingSideStream, + schema: Arc, + /// Information of index and left / right placement of columns + column_indices: Vec, + join_type: JoinType, + runtime: Arc, + output: OutputBuffer, +} + +impl SortMergeJoinCommon { + pub fn new( + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec, + on_buffered: Vec, + streamed_sort: Vec, + buffered_sort: Vec, + column_indices: Vec, + schema: Arc, + join_type: JoinType, + runtime: Arc, + ) -> Result { + let streamed = StreamingSideStream::new(streamed, on_streamed, streamed_sort); + let buffered = BufferingSideStream::new(buffered, on_buffered, buffered_sort); + let output = OutputBuffer::new(runtime.batch_size(), schema.clone()) + .map_err(DataFusionError::ArrowError)?; + Ok(Self { + streamed, + buffered, + schema, + column_indices, + join_type, + runtime, + output, + }) + } + + fn compare_stream_buffer(&self) -> Result { + let stream_arrays = &self.streamed.buffer.join_arrays(); + let buffer_arrays = &self.buffered.buffer.join_arrays(); + comp_rows( + self.streamed.buffer.cur_row, + self.buffered.buffer.key_idx.unwrap(), + stream_arrays, + buffer_arrays, + ) + } +} + +pub struct InnerJoiner { + inner: SortMergeJoinCommon, + matched: bool, + buffered_idx: usize, + start_indices: Vec, + buffer_remaining: usize, + advance_stream: bool, + advance_buffer: bool, + continues_match: bool, +} + +impl InnerJoiner { + pub fn new(inner: SortMergeJoinCommon) -> Self { + Self { + inner, + matched: false, + buffered_idx: 0, + start_indices: vec![], + buffer_remaining: 0, + advance_stream: true, + advance_buffer: true, + continues_match: false, + } + } + + fn find_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.continues_match { + match Pin::new(&mut self.inner.streamed).advance(cx) { + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(y) => { + y?; + if self.inner.streamed.buffer.is_new_key { + self.continues_match = false; + self.advance_stream = false; + self.advance_buffer = true; + self.matched = false; + self.find_next(cx) + } else { + self.continues_match = true; + self.advance_stream = false; + self.advance_buffer = false; + self.matched = false; + Poll::Ready(Some(Ok(()))) + } + } + }, + Poll::Pending => Poll::Pending, + } + } else { + if self.advance_stream { + match Pin::new(&mut self.inner.streamed).advance_key_skip_null(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(y) => { + y?; + self.continues_match = true; + } + }, + Poll::Pending => return Poll::Pending, + } + } + + if self.advance_buffer { + match Pin::new(&mut self.inner.streamed).advance_key_skip_null(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(y) => { + y?; + } + }, + Poll::Pending => return Poll::Pending, + } + } + + let cmp = self + .inner + .compare_stream_buffer() + .map_err(DataFusionError::into_arrow_external_error)?; + + match cmp { + Ordering::Less => { + self.advance_stream = true; + self.advance_buffer = false; + self.find_next(cx) + } + Ordering::Equal => { + self.advance_stream = true; + self.advance_buffer = true; + self.matched = true; + self.buffered_idx = 0; + self.start_indices = self.inner.buffered.buffer.range_start_indices(); + self.buffer_remaining = self.inner.buffered.buffer.row_num; + Poll::Ready(Some(Ok(()))) + } + Ordering::Greater => { + self.advance_stream = false; + self.advance_buffer = true; + self.find_next(cx) + } + } + } + } + + fn fill_output( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let output = &mut self.inner.output; + let streamed = &self.inner.streamed.buffer; + let buffered = &self.inner.buffered.buffer; + + let slots_available = output.slots_available; + let mut rows_to_output = 0; + if slots_available >= self.buffer_remaining { + self.matched = false; + rows_to_output = self.buffer_remaining; + self.buffer_remaining = 0; + } else { + rows_to_output = slots_available; + self.buffer_remaining -= rows_to_output; + } + + let slices = buffered.slices_from_batches( + &self.start_indices, + self.buffered_idx, + rows_to_output, + ); + + output + .arrays + .iter_mut() + .zip(self.inner.schema.fields().iter()) + .zip(self.inner.column_indices.iter()) + .map(|((array, field), column_index)| { + if column_index.is_left { + // repeat streamed `rows_to_output` times + streamed.repeat_cell(rows_to_output, array, column_index); + } else { + // copy buffered start from: `buffered_idx`, len: `rows_to_output` + buffered.copy_slices(&slices, array, column_index); + } + }); + + self.inner.output.append(rows_to_output); + self.buffered_idx += rows_to_output; + + if self.inner.output.is_full() { + let result = output.output_and_reset(); + Poll::Ready(Some(result)) + } else { + self.poll_next(cx) + } + } +} + +impl Stream for InnerJoiner { + type Item = ArrowResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !self.matched { + let find = self.find_next(cx); + match find { + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(y) => match y { + Ok(_) => self.fill_output(cx), + Err(err) => Poll::Ready(Some(Err(ArrowError::External( + "Failed while finding next match for inner join".to_owned(), + err.into(), + )))), + }, + }, + Poll::Pending => Poll::Pending, + } + } else { + self.fill_output(cx) + } + } +} + +impl RecordBatchStream for InnerJoiner { + fn schema(&self) -> SchemaRef { + self.inner.schema.clone() + } +} diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs new file mode 100644 index 000000000000..b1b0292c3afe --- /dev/null +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -0,0 +1,1112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the join plan for executing partitions in parallel and then joining the results +//! into a set of partitions. + +use std::iter::repeat; +use std::sync::Arc; +use std::vec; +use std::{any::Any, usize}; + +use arrow::array::*; +use arrow::datatypes::*; +use arrow::error::{ArrowError, Result as ArrowResult}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use crate::arrow_dyn_list_array::DynMutableListArray; +use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::runtime_env::RUNTIME_ENV; +use crate::logical_plan::JoinType; +use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; +use crate::physical_plan::joins::smj_utils::{ + join_arrays, InnerJoiner, SortMergeJoinCommon, +}; +use crate::physical_plan::joins::{ + build_join_schema, check_join_is_valid, column_indices_from_schema, comp_rows, + equal_rows, ColumnIndex, JoinOn, +}; +use crate::physical_plan::sorts::external_sort::ExternalSortExec; +use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::Statistics; +use crate::physical_plan::{ + expressions::Column, + metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, +}; +use crate::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +}; +use arrow::compute::partition::lexicographical_partition_ranges; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::ops::Range; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::{Receiver, Sender}; + +type StringArray = Utf8Array; +type LargeStringArray = Utf8Array; + +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. +#[derive(Debug)] +pub struct SortMergeJoinExec { + /// left (build) side which gets hashed + left: Arc, + /// right (probe) side which are filtered by the hash table + right: Arc, + /// Set of common columns used to join on + on: Vec<(Column, Column)>, + /// How the join is performed + join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +struct OuterMatchResult { + get_match: bool, + buffered_ended: bool, + more_output: bool, +} + +/// Metrics for SortMergeJoinExec +#[derive(Debug)] +struct SortMergeJoinMetrics { + /// Total time for joining probe-side batches to the build-side batches + join_time: metrics::Time, + /// Number of batches consumed by this operator + input_batches: metrics::Count, + /// Number of rows consumed by this operator + input_rows: metrics::Count, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl SortMergeJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + join_time, + input_batches, + input_rows, + output_batches, + output_rows, + } + } +} + +impl SortMergeJoinExec { + /// Tries to create a new [SortMergeJoinExec]. + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type)); + + Ok(SortMergeJoinExec { + left, + right, + on, + join_type: *join_type, + schema, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// left (build) side which gets hashed + pub fn left(&self) -> &Arc { + &self.left + } + + /// right (probe) side which are filtered by the hash table + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(Column, Column)] { + &self.on + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } +} + +#[async_trait] +impl ExecutionPlan for SortMergeJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 2 => Ok(Arc::new(SortMergeJoinExec::try_new( + children[0].clone(), + children[1].clone(), + self.on.clone(), + &self.join_type, + )?)), + _ => Err(DataFusionError::Internal( + "HashJoinExec wrong number of children".to_string(), + )), + } + } + + fn output_partitioning(&self) -> Partitioning { + self.right.output_partitioning() + } + + async fn execute(&self, partition: usize) -> Result { + let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); + let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + let left = self.left.execute(partition).await?; + let right = self.right.execute(partition).await?; + + let column_indices = column_indices_from_schema( + &self.join_type, + &self.left.schema(), + &self.right.schema(), + &self.schema, + )?; + + let left_sort = self + .left + .as_any() + .downcast_ref::() + .unwrap() + .expr() + .iter() + .map(|s| s.clone()) + .collect::>(); + let right_sort = self + .right + .as_any() + .downcast_ref::() + .unwrap() + .expr() + .iter() + .map(|s| s.clone()) + .collect::>(); + + let join = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => SortMergeJoinCommon::new( + left, + right, + on_left, + on_right, + left_sort, + right_sort, + column_indices, + self.schema.clone(), + self.join_type, + RUNTIME_ENV.clone(), + )?, + JoinType::Right => SortMergeJoinCommon::new( + right, + left, + on_right, + on_left, + right_sort, + left_sort, + column_indices, + self.schema.clone(), + self.join_type, + RUNTIME_ENV.clone(), + )?, + }; + Ok(match self.join_type { + JoinType::Inner => Box::pin(InnerJoiner::new(join)), + JoinType::Left => todo!(), + JoinType::Right => todo!(), + JoinType::Full => todo!(), + JoinType::Semi => todo!(), + JoinType::Anti => todo!(), + }) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + "SortMergeJoinExec: join_type={:?}, on={:?}", + self.join_type, self.on + ) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + Statistics::default() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + assert_batches_sorted_eq, + physical_plan::{ + common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + }, + test::{build_table_i32, columns}, + }; + + use super::*; + use crate::physical_plan::PhysicalExpr; + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema().clone(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn join( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result { + SortMergeJoinExec::try_new(left, right, on, join_type) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + Ok((columns, batches)) + } + + async fn partitioned_join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let partition_count = 4; + + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + + let join = SortMergeJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + join_type, + )?; + + let columns = columns(&join.schema()); + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i).await?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_no_shared_column_names() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); + + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + /// Test where the left has 2 parts, the right with 1 part => 1 part + #[tokio::test] + async fn join_inner_one_two_parts_left() -> Result<()> { + let batch1 = build_table_i32( + ("a1", &vec![1, 2]), + ("b2", &vec![1, 2]), + ("c1", &vec![7, 8]), + ); + let batch2 = + build_table_i32(("a1", &vec![2]), ("b2", &vec![2]), ("c1", &vec![9])); + let schema = batch1.schema().clone(); + let left = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); + + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + /// Test where the left has 1 part, the right has 2 parts => 2 parts + #[tokio::test] + async fn join_inner_one_two_parts_right() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + let batch1 = build_table_i32( + ("a2", &vec![10, 20]), + ("b1", &vec![4, 6]), + ("c2", &vec![70, 80]), + ); + let batch2 = + build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); + let schema = batch1.schema().clone(); + let right = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Inner)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + // first part + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + // second part + let stream = join.execute(1).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 5 | 90 |", + "| 3 | 5 | 9 | 30 | 5 | 90 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + fn build_table_two_batches( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema().clone(); + Arc::new( + MemoryExec::try_new(&[vec![batch.clone(), batch]], schema, None).unwrap(), + ) + } + + #[tokio::test] + async fn join_left_multi_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_two_batches( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_full_multi_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + // create two identical batches for the right side + let right = build_table_two_batches( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Full).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_left_empty_right() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", right.schema()).unwrap(), + )]; + let schema = right.schema().clone(); + let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | 4 | |", + "| 2 | 5 | 8 | | 5 | |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_full_empty_right() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", right.schema()).unwrap(), + )]; + let schema = right.schema().clone(); + let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let join = join(left, right, on, &JoinType::Full).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | | |", + "| 2 | 5 | 8 | | | |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Left, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Semi)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Anti)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + partitioned_join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Full)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 3758d058a34c..f9b039c23445 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -17,16 +17,13 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. -pub use self::metrics::Metric; -use self::metrics::MetricsSet; -use self::{ - coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, -}; -use crate::physical_plan::expressions::{PhysicalSortExpr, SortColumn}; -use crate::{ - error::{DataFusionError, Result}, - scalar::ScalarValue, -}; +use std::fmt; +use std::fmt::{Debug, Display}; +use std::ops::Range; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{any::Any, pin::Pin}; + use arrow::array::ArrayRef; use arrow::compute::merge_sort::SortOptions; use arrow::compute::partition::lexicographical_partition_ranges; @@ -34,14 +31,23 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -pub use display::DisplayFormatType; use futures::stream::Stream; -use std::fmt; -use std::fmt::{Debug, Display}; -use std::ops::Range; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{any::Any, pin::Pin}; + +pub use display::DisplayFormatType; + +use crate::physical_plan::expressions::{PhysicalSortExpr, SortColumn}; +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +pub use self::metrics::Metric; +use self::metrics::MetricsSet; +/// Physical planner interface +pub use self::planner::PhysicalPlanner; +use self::{ + coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, +}; /// Trait for types that stream [arrow::record_batch::RecordBatch] pub trait RecordBatchStream: Stream> { @@ -86,9 +92,6 @@ impl Stream for EmptyRecordBatchStream { } } -/// Physical planner interface -pub use self::planner::PhysicalPlanner; - /// Statistics for a physical plan node /// Fields are optional and can be inexact because the sources /// sometimes provide approximate estimates for performance reasons @@ -612,7 +615,6 @@ pub mod avro; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; -pub mod cross_join; #[cfg(feature = "crypto_expressions")] pub mod crypto_expressions; pub mod csv; @@ -625,8 +627,8 @@ pub mod expressions; pub mod filter; pub mod functions; pub mod hash_aggregate; -pub mod hash_join; pub mod hash_utils; +pub mod joins; pub mod json; pub mod limit; pub mod math_expressions; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index a337b886110c..3088d787d4dd 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -17,11 +17,15 @@ //! Physical query planner -use super::analyze::AnalyzeExec; -use super::{ - aggregates, empty::EmptyExec, expressions::binary, functions, - hash_join::PartitionMode, udaf, union::UnionExec, windows, -}; +use std::sync::Arc; + +use arrow::compute::cast::can_cast_types; +use arrow::compute::sort::SortOptions; +use arrow::datatypes::*; +use log::debug; + +use expressions::col; + use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ unnormalize_cols, DFSchema, Expr, LogicalPlan, Operator, @@ -29,20 +33,21 @@ use crate::logical_plan::{ UserDefinedLogicalNode, }; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; -use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions; use crate::physical_plan::expressions::{CaseExpr, Column, Literal, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; -use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::joins::cross_join::CrossJoinExec; +use crate::physical_plan::joins::hash_join::HashJoinExec; +use crate::physical_plan::joins::hash_join::PartitionMode; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; -use crate::physical_plan::{hash_utils, Partitioning}; +use crate::physical_plan::{joins, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use crate::scalar::ScalarValue; use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys}; @@ -52,12 +57,11 @@ use crate::{ physical_plan::displayable, }; -use arrow::compute::cast::can_cast_types; -use arrow::compute::sort::SortOptions; -use arrow::datatypes::*; -use expressions::col; -use log::debug; -use std::sync::Arc; +use super::analyze::AnalyzeExec; +use super::{ + aggregates, empty::EmptyExec, expressions::binary, functions, udaf, union::UnionExec, + windows, +}; fn create_function_physical_name( fun: &str, @@ -677,7 +681,7 @@ impl DefaultPhysicalPlanner { Column::new(&r.name, right_df_schema.index_of_column(r)?), )) }) - .collect::>()?; + .collect::>()?; if ctx_state.config.target_partitions > 1 && ctx_state.config.repartition_joins @@ -1392,7 +1396,13 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { #[cfg(test)] mod tests { - use super::*; + use fmt::Debug; + use std::convert::TryFrom; + use std::{any::Any, fmt}; + + use arrow::datatypes::{DataType, Field}; + use async_trait::async_trait; + use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{ csv::CsvReadOptions, expressions, DisplayFormatType, Partitioning, Statistics, @@ -1402,11 +1412,8 @@ mod tests { logical_plan::{col, lit, sum, LogicalPlanBuilder}, physical_plan::SendableRecordBatchStream, }; - use arrow::datatypes::{DataType, Field}; - use async_trait::async_trait; - use fmt::Debug; - use std::convert::TryFrom; - use std::{any::Any, fmt}; + + use super::*; fn make_ctx_state() -> ExecutionContextState { ExecutionContextState::new() diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index b11f41d0e163..48c72fb0026d 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -19,7 +19,7 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; @@ -191,14 +191,12 @@ pub fn sort_batch( schema: SchemaRef, expr: &[PhysicalSortExpr], ) -> ArrowResult { - let columns = expr - .iter() - .map(|e| e.evaluate_to_sort_column(&batch)) - .collect::>>() + let columns = exprs_to_sort_columns(&batch, expr) .map_err(DataFusionError::into_arrow_external_error)?; - let columns = columns.iter().map(|x| x.into()).collect::>(); - - let indices = lexsort_to_indices::(&columns, None)?; + let indices = lexsort_to_indices::( + &columns.iter().map(|x| x.into()).collect::>(), + None, + )?; // reorder all rows based on sorted indices RecordBatch::try_new(