diff --git a/Cargo.toml b/Cargo.toml index 61ce0276328d..317f0b648b5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,5 +31,5 @@ members = [ exclude = ["python"] [patch.crates-io] -arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", rev = "b7e991366104d1647b955a828e0551256ef2e7c9" } -arrow-flight = { git = "https://github.com/jorgecarleitao/arrow2.git", rev = "b7e991366104d1647b955a828e0551256ef2e7c9" } +arrow2 = { path = "/Users/shenyijie/oss/arrow2" } +arrow-flight = { path = "/Users/shenyijie/oss/arrow2/arrow-flight" } diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index 56ee938c1930..63b60f7ac8ec 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -42,6 +42,7 @@ use datafusion::arrow::io::ipc::read::FileReader; use datafusion::arrow::io::ipc::write::FileWriter; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; +use datafusion::physical_plan::common::IPCWriterWrapper; use datafusion::physical_plan::hash_utils::create_hashes; use datafusion::physical_plan::metrics::{ self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -197,7 +198,7 @@ impl ShuffleWriterExec { // we won't necessary produce output for every possible partition, so we // create writers on demand - let mut writers: Vec> = vec![]; + let mut writers: Vec> = vec![]; for _ in 0..num_output_partitions { writers.push(None); } @@ -267,8 +268,10 @@ impl ShuffleWriterExec { let path = path.to_str().unwrap(); info!("Writing results to {}", path); - let mut writer = - ShuffleWriter::new(path, stream.schema().as_ref())?; + let mut writer = IPCWriterWrapper::new( + path, + stream.schema().as_ref(), + )?; writer.write(&output_batch)?; writers[output_partition] = Some(writer); @@ -433,56 +436,6 @@ fn result_schema() -> SchemaRef { ])) } -struct ShuffleWriter { - path: String, - writer: FileWriter>, - num_batches: u64, - num_rows: u64, - num_bytes: u64, -} - -impl ShuffleWriter { - fn new(path: &str, schema: &Schema) -> Result { - let file = File::create(path) - .map_err(|e| { - BallistaError::General(format!( - "Failed to create partition file at {}: {:?}", - path, e - )) - }) - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - let buffer_writer = std::io::BufWriter::new(file); - Ok(Self { - num_batches: 0, - num_rows: 0, - num_bytes: 0, - path: path.to_owned(), - writer: FileWriter::try_new(buffer_writer, schema)?, - }) - } - - fn write(&mut self, batch: &RecordBatch) -> Result<()> { - self.writer.write(batch)?; - self.num_batches += 1; - self.num_rows += batch.num_rows() as u64; - let num_bytes: usize = batch - .columns() - .iter() - .map(|array| estimated_bytes_size(array.as_ref())) - .sum(); - self.num_bytes += num_bytes as u64; - Ok(()) - } - - fn finish(&mut self) -> Result<()> { - self.writer.finish().map_err(DataFusionError::ArrowError) - } - - fn path(&self) -> &str { - &self.path - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index bd76c8a847c7..eceedb244043 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -21,15 +21,8 @@ use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; -use crate::error::BallistaError; -use crate::execution_plans::{ - ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, -}; -use crate::serde::protobuf::repartition_exec_node::PartitionMethod; -use crate::serde::protobuf::ShuffleReaderPartition; -use crate::serde::scheduler::PartitionLocation; -use crate::serde::{from_proto_binary_op, proto_error, protobuf}; -use crate::{convert_box_required, convert_required, into_required}; +use log::debug; + use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, @@ -46,7 +39,8 @@ use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunc use datafusion::physical_plan::avro::{AvroExec, AvroReadOptions}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; -use datafusion::physical_plan::hash_join::PartitionMode; +use datafusion::physical_plan::joins::cross_join::CrossJoinExec; +use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::parquet::ParquetPartition; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; @@ -56,7 +50,6 @@ use datafusion::physical_plan::window_functions::{ use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, - cross_join::CrossJoinExec, csv::CsvExec, empty::EmptyExec, expressions::{ @@ -65,22 +58,31 @@ use datafusion::physical_plan::{ }, filter::FilterExec, functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, - hash_join::HashJoinExec, limit::{GlobalLimitExec, LocalLimitExec}, parquet::ParquetExec, projection::ProjectionExec, repartition::RepartitionExec, - sort::{SortExec, SortOptions}, + sorts::sort::SortExec, + sorts::SortOptions, Partitioning, }; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; use datafusion::prelude::CsvReadOptions; -use log::debug; use protobuf::physical_expr_node::ExprType; use protobuf::physical_plan_node::PhysicalPlanType; +use crate::error::BallistaError; +use crate::execution_plans::{ + ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, +}; +use crate::serde::protobuf::repartition_exec_node::PartitionMethod; +use crate::serde::protobuf::ShuffleReaderPartition; +use crate::serde::scheduler::PartitionLocation; +use crate::serde::{from_proto_binary_op, proto_error, protobuf}; +use crate::{convert_box_required, convert_required, into_required}; + impl TryInto> for &protobuf::PhysicalPlanNode { type Error = BallistaError; diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index 6d4f95b0a342..0e78349ddc33 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -22,6 +22,7 @@ pub mod to_proto; mod roundtrip_tests { use std::{convert::TryInto, sync::Arc}; + use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::{ arrow::{ compute::sort::SortOptions, @@ -34,18 +35,18 @@ mod roundtrip_tests { expressions::{Avg, Column, PhysicalSortExpr}, filter::FilterExec, hash_aggregate::{AggregateMode, HashAggregateExec}, - hash_join::{HashJoinExec, PartitionMode}, limit::{GlobalLimitExec, LocalLimitExec}, - sort::SortExec, + sorts::sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, }, scalar::ScalarValue, }; + use crate::execution_plans::ShuffleWriterExec; + use super::super::super::error::Result; use super::super::protobuf; - use crate::execution_plans::ShuffleWriterExec; fn roundtrip_test(exec_plan: Arc) -> Result<()> { let proto: protobuf::PhysicalPlanNode = exec_plan.clone().try_into()?; diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 22a49cb881ba..420c9a99e773 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -28,7 +28,6 @@ use std::{ use datafusion::logical_plan::JoinType; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::cross_join::CrossJoinExec; use datafusion::physical_plan::csv::CsvExec; use datafusion::physical_plan::expressions::{ CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr, @@ -36,11 +35,12 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; -use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::joins::cross_join::CrossJoinExec; +use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::parquet::{ParquetExec, ParquetPartition}; use datafusion::physical_plan::projection::ProjectionExec; -use datafusion::physical_plan::sort::SortExec; +use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::{ physical_plan::expressions::{Count, Literal}, scalar::ScalarValue, diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index e4307b6ae1c4..38b0ce34d5f7 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -56,10 +56,10 @@ use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{BinaryExpr, Column, Literal}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::HashAggregateExec; -use datafusion::physical_plan::hash_join::HashJoinExec; +use datafusion::physical_plan::joins::hash_join::HashJoinExec; use datafusion::physical_plan::parquet::ParquetExec; use datafusion::physical_plan::projection::ProjectionExec; -use datafusion::physical_plan::sort::SortExec; +use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ metrics, AggregateExpr, ExecutionPlan, Metric, PhysicalExpr, RecordBatchStream, }; diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 6ed6fb6c7ebd..23142a987214 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -251,8 +251,8 @@ mod test { use ballista_core::serde::protobuf; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; - use datafusion::physical_plan::hash_join::HashJoinExec; - use datafusion::physical_plan::sort::SortExec; + use datafusion::physical_plan::joins::hash_join::HashJoinExec; + use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec, }; diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 93ec642628b3..4a5826fdf5b8 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -40,7 +40,7 @@ path = "src/lib.rs" default = ["crypto_expressions", "regex_expressions", "unicode_expressions"] simd = ["arrow/simd"] crypto_expressions = ["md-5", "sha2"] -regex_expressions = ["regex", "lazy_static"] +regex_expressions = ["regex"] unicode_expressions = ["unicode-segmentation"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] @@ -67,15 +67,16 @@ sha2 = { version = "^0.9.1", optional = true } ordered-float = "2.0" unicode-segmentation = { version = "^1.7.1", optional = true } regex = { version = "^1.4.3", optional = true } -lazy_static = { version = "^1.4.0", optional = true } +lazy_static = { version = "^1.4.0"} smallvec = { version = "1.6", features = ["union"] } rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } +uuid = { version = "0.8", features = ["v4"] } +tempfile = "3" [dev-dependencies] criterion = "0.3" -tempfile = "3" doc-comment = "0.3" [[bench]] diff --git a/datafusion/benches/aggregate_query_sql.rs b/datafusion/benches/aggregate_query_sql.rs index d2df31416558..56560273ba60 100644 --- a/datafusion/benches/aggregate_query_sql.rs +++ b/datafusion/benches/aggregate_query_sql.rs @@ -132,5 +132,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); } -criterion_group!(benches, criterion_benchmark); +criterion_group!(name = benches; + config = Criterion::default().measurement_time(std::time::Duration::from_secs(30)); + targets = criterion_benchmark); criterion_main!(benches); diff --git a/datafusion/src/arrow_dyn_list_array.rs b/datafusion/src/arrow_dyn_list_array.rs new file mode 100644 index 000000000000..7d9dc0d5d258 --- /dev/null +++ b/datafusion/src/arrow_dyn_list_array.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DynMutableListArray from arrow/io/avro/read/nested.rs + +use arrow::array::{Array, ListArray, MutableArray, Offset}; +use arrow::bitmap::MutableBitmap; +use arrow::buffer::MutableBuffer; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use std::sync::Arc; + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableListArray { + data_type: DataType, + offsets: MutableBuffer, + values: Box, + validity: Option, +} + +impl DynMutableListArray { + pub fn new_from( + values: Box, + data_type: DataType, + capacity: usize, + ) -> Self { + let mut offsets = MutableBuffer::::with_capacity(capacity + 1); + offsets.push(O::default()); + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&data_type); + Self { + data_type, + offsets, + values, + validity: None, + } + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_with_capacity(values: Box, capacity: usize) -> Self { + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self::new_from(values, data_type, capacity) + } + + /// The values + #[allow(dead_code)] + pub fn mut_values(&mut self) -> &mut dyn MutableArray { + self.values.as_mut() + } + + #[inline] + #[allow(dead_code)] + pub fn try_push_valid(&mut self) -> Result<(), ArrowError> { + let size = self.values.len(); + let size = O::from_usize(size).ok_or(ArrowError::KeyOverflowError)?; // todo: make this error + assert!(size >= *self.offsets.last().unwrap()); + + self.offsets.push(size); + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.push(self.last_offset()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + #[inline] + fn last_offset(&self) -> O { + *self.offsets.last().unwrap() + } + + fn init_validity(&mut self) { + let len = self.offsets.len() - 1; + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableListArray { + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + Box::new(ListArray::from_data( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_arc(), + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn as_arc(&mut self) -> Arc { + Arc::new(ListArray::from_data( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_arc(), + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs index a47bfac8b622..89726385329e 100644 --- a/datafusion/src/error.rs +++ b/datafusion/src/error.rs @@ -61,6 +61,9 @@ pub enum DataFusionError { /// Error returned during execution of the query. /// Examples include files not found, errors in parsing certain types. Execution(String), + /// This error is thrown when a consumer cannot acquire memory from the Memory Manager + /// we can just cancel the execution of the partition. + OutOfMemory(String), } impl DataFusionError { @@ -129,6 +132,9 @@ impl Display for DataFusionError { DataFusionError::Execution(ref desc) => { write!(f, "Execution error: {}", desc) } + DataFusionError::OutOfMemory(ref desc) => { + write!(f, "Out of memory error: {}", desc) + } } } } diff --git a/datafusion/src/execution/disk_manager.rs b/datafusion/src/execution/disk_manager.rs new file mode 100644 index 000000000000..9632374687fe --- /dev/null +++ b/datafusion/src/execution/disk_manager.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages files generated during query execution, files are +//! hashed among the directories listed in RuntimeConfig::local_dirs. + +use crate::error::{DataFusionError, Result}; +use std::collections::hash_map::DefaultHasher; +use std::fs; +use std::fs::File; +use std::hash::{Hash, Hasher}; +use std::path::{Path, PathBuf}; +use uuid::Uuid; + +/// Manages files generated during query execution, e.g. spill files generated +/// while processing dataset larger than available memory. +pub struct DiskManager { + local_dirs: Vec, +} + +impl DiskManager { + /// Create local dirs inside user provided dirs through conf + pub fn new(conf_dirs: &Vec) -> Result { + Ok(Self { + local_dirs: create_local_dirs(conf_dirs)?, + }) + } + + /// Create a file in conf dirs in randomized manner and return the file path + pub fn create_tmp_file(&self) -> Result { + create_tmp_file(&self.local_dirs) + } + + #[allow(dead_code)] + fn cleanup_resource(&mut self) -> Result<()> { + for dir in self.local_dirs.drain(..) { + fs::remove_dir(dir)?; + } + Ok(()) + } +} + +/// Setup local dirs by creating one new dir in each of the given dirs +fn create_local_dirs(local_dir: &Vec) -> Result> { + local_dir + .into_iter() + .map(|root| create_directory(root, "datafusion")) + .collect() +} + +const MAX_DIR_CREATION_ATTEMPTS: i32 = 10; + +fn create_directory(root: &str, prefix: &str) -> Result { + let mut attempt = 0; + while attempt < MAX_DIR_CREATION_ATTEMPTS { + let mut path = PathBuf::from(root); + path.push(format!("{}-{}", prefix, Uuid::new_v4().to_string())); + let path = path.as_path(); + if !path.exists() { + fs::create_dir(path)?; + return Ok(path.canonicalize().unwrap().to_str().unwrap().to_string()); + } + attempt += 1; + } + Err(DataFusionError::Execution(format!( + "Failed to create a temp dir under {} after {} attempts", + root, MAX_DIR_CREATION_ATTEMPTS + ))) +} + +fn get_file(file_name: &str, local_dirs: &Vec) -> String { + let mut hasher = DefaultHasher::new(); + file_name.hash(&mut hasher); + let hash = hasher.finish(); + let dir = &local_dirs[hash.rem_euclid(local_dirs.len() as u64) as usize]; + let mut path = PathBuf::new(); + path.push(dir); + path.push(file_name); + path.to_str().unwrap().to_string() +} + +fn create_tmp_file(local_dirs: &Vec) -> Result { + let name = Uuid::new_v4().to_string(); + let mut path = get_file(&*name, local_dirs); + while Path::new(path.as_str()).exists() { + path = get_file(&*Uuid::new_v4().to_string(), local_dirs); + } + File::create(&path)?; + Ok(path) +} diff --git a/datafusion/src/execution/memory_management/memory_pool.rs b/datafusion/src/execution/memory_management/memory_pool.rs new file mode 100644 index 000000000000..a94630a3d3bf --- /dev/null +++ b/datafusion/src/execution/memory_management/memory_pool.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution Memory Pool that guarantees a memory allocation strategy + +use crate::execution::memory_management::MemoryConsumerId; +use async_trait::async_trait; +use hashbrown::HashMap; +use log::{info, warn}; +use std::cmp::min; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use tokio::runtime::Handle; +use tokio::sync::{Notify, RwLock}; + +#[async_trait] +pub(crate) trait ExecutionMemoryPool: Sync + Send + Debug { + fn memory_available(&self) -> usize; + fn memory_used(&self) -> usize; + fn memory_used_partition(&self, partition_id: usize) -> usize; + async fn acquire_memory(&self, required: usize, consumer: &MemoryConsumerId) + -> usize; + async fn update_usage( + &self, + granted_size: usize, + real_size: usize, + consumer: &MemoryConsumerId, + ); + async fn release_memory(&self, release_size: usize, partition_id: usize); + async fn release_all(&self, partition_id: usize) -> usize; +} + +pub(crate) struct DummyExecutionMemoryPool { + pool_size: usize, +} + +impl DummyExecutionMemoryPool { + pub fn new() -> Self { + Self { + pool_size: usize::MAX, + } + } +} + +impl Debug for DummyExecutionMemoryPool { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("DummyExecutionMemoryPool") + .field("total", &self.pool_size) + .finish() + } +} + +#[async_trait] +impl ExecutionMemoryPool for DummyExecutionMemoryPool { + fn memory_available(&self) -> usize { + usize::MAX + } + + fn memory_used(&self) -> usize { + 0 + } + + fn memory_used_partition(&self, _partition_id: usize) -> usize { + 0 + } + + async fn acquire_memory( + &self, + required: usize, + _consumer: &MemoryConsumerId, + ) -> usize { + required + } + + async fn update_usage( + &self, + _granted_size: usize, + _real_size: usize, + _consumer: &MemoryConsumerId, + ) { + } + + async fn release_memory(&self, _release_size: usize, _partition_id: usize) {} + + async fn release_all(&self, _partition_id: usize) -> usize { + usize::MAX + } +} + +pub(crate) struct ConstraintExecutionMemoryPool { + pool_size: usize, + /// memory usage per partition + memory_usage: RwLock>, + notify: Notify, +} + +impl ConstraintExecutionMemoryPool { + pub fn new(size: usize) -> Self { + Self { + pool_size: size, + memory_usage: RwLock::new(HashMap::new()), + notify: Notify::new(), + } + } +} + +impl Debug for ConstraintExecutionMemoryPool { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("ConstraintExecutionMemoryPool") + .field("total", &self.pool_size) + .field("used", &self.memory_used()) + .finish() + } +} + +#[async_trait] +impl ExecutionMemoryPool for ConstraintExecutionMemoryPool { + fn memory_available(&self) -> usize { + self.pool_size - self.memory_used() + } + + fn memory_used(&self) -> usize { + Handle::current() + .block_on(async { self.memory_usage.read().await.values().sum() }) + } + + fn memory_used_partition(&self, partition_id: usize) -> usize { + Handle::current().block_on(async { + let partition_usage = self.memory_usage.read().await; + match partition_usage.get(&partition_id) { + None => 0, + Some(v) => *v, + } + }) + } + + async fn acquire_memory( + &self, + required: usize, + consumer: &MemoryConsumerId, + ) -> usize { + assert!(required > 0); + let partition_id = consumer.partition_id; + { + let mut partition_usage = self.memory_usage.write().await; + if !partition_usage.contains_key(&partition_id) { + partition_usage.entry(partition_id).or_insert(0); + // This will later cause waiting tasks to wake up and check numTasks again + self.notify.notify_waiters(); + } + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // partition would have more than 1 / num_active_partition of the memory) or we have enough free + // memory to give it (we always let each partition get at least 1 / (2 * num_active_partition)). + loop { + let partition_usage = self.memory_usage.read().await; + let num_active_partition = partition_usage.len(); + let current_mem = *partition_usage.get(&partition_id).unwrap(); + + let max_memory_per_partition = self.pool_size / num_active_partition; + let min_memory_per_partition = self.pool_size / (2 * num_active_partition); + + // How much we can grant this partition; keep its share within 0 <= X <= 1 / num_active_partition + let max_grant = match max_memory_per_partition.checked_sub(current_mem) { + None => 0, + Some(max_available) => min(required, max_available), + }; + + let total_used: usize = partition_usage.values().sum(); + let total_available = self.pool_size - total_used; + // Only give it as much memory as is free, which might be none if it reached 1 / num_active_partition + let to_grant = min(max_grant, total_available); + + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if to_grant < required && current_mem + to_grant < min_memory_per_partition { + info!( + "{:?} waiting for at least 1/2N of pool to be free", + consumer + ); + let _ = self.notify.notified().await; + } else { + drop(partition_usage); + let mut partition_usage = self.memory_usage.write().await; + *partition_usage.entry(partition_id).or_insert(0) += to_grant; + return to_grant; + } + } + } + + async fn update_usage( + &self, + granted_size: usize, + real_size: usize, + consumer: &MemoryConsumerId, + ) { + assert!(granted_size > 0); + assert!(real_size > 0); + if granted_size == real_size { + return; + } else { + let mut partition_usage = self.memory_usage.write().await; + if granted_size > real_size { + *partition_usage.entry(consumer.partition_id).or_insert(0) -= + granted_size - real_size; + } else { + // TODO: this would have caused OOM already if size estimation ahead is much smaller than + // that of actual allocation + *partition_usage.entry(consumer.partition_id).or_insert(0) += + real_size - granted_size; + } + } + } + + async fn release_memory(&self, release_size: usize, partition_id: usize) { + let partition_usage = self.memory_usage.read().await; + let current_mem = match partition_usage.get(&partition_id) { + None => 0, + Some(v) => *v, + }; + + let to_free = if current_mem < release_size { + warn!( + "Release called to free {} but partition only holds {} from the pool", + release_size, current_mem + ); + current_mem + } else { + release_size + }; + if partition_usage.contains_key(&partition_id) { + drop(partition_usage); + let mut partition_usage = self.memory_usage.write().await; + let entry = partition_usage.entry(partition_id).or_insert(0); + *entry -= to_free; + if *entry == 0 { + partition_usage.remove(&partition_id); + } + } + self.notify.notify_waiters(); + } + + async fn release_all(&self, partition_id: usize) -> usize { + let partition_usage = self.memory_usage.read().await; + let mut current_mem = 0; + match partition_usage.get(&partition_id) { + None => return current_mem, + Some(v) => current_mem = *v, + } + + drop(partition_usage); + let mut partition_usage = self.memory_usage.write().await; + partition_usage.remove(&partition_id); + self.notify.notify_waiters(); + return current_mem; + } +} diff --git a/datafusion/src/execution/memory_management/mod.rs b/datafusion/src/execution/memory_management/mod.rs new file mode 100644 index 000000000000..7b1c067f70ae --- /dev/null +++ b/datafusion/src/execution/memory_management/mod.rs @@ -0,0 +1,403 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages all available memory during query execution + +pub mod memory_pool; + +use crate::error::DataFusionError::OutOfMemory; +use crate::error::{DataFusionError, Result}; +use crate::execution::memory_management::memory_pool::{ + ConstraintExecutionMemoryPool, DummyExecutionMemoryPool, ExecutionMemoryPool, +}; +use async_trait::async_trait; +use futures::lock::Mutex; +use hashbrown::HashMap; +use log::{debug, info}; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Weak}; + +static mut CONSUMER_ID: AtomicUsize = AtomicUsize::new(0); + +#[derive(Clone)] +/// Memory manager that enforces how execution memory is shared between all kinds of memory consumers. +/// Execution memory refers to that used for computation in sorts, aggregations, joins and shuffles. +pub struct MemoryManager { + execution_pool: Arc, + partition_memory_manager: Arc>>, +} + +impl MemoryManager { + /// Create memory manager based on configured execution pool size. + pub fn new(exec_pool_size: usize) -> Self { + let execution_pool: Arc = if exec_pool_size == usize::MAX + { + Arc::new(DummyExecutionMemoryPool::new()) + } else { + Arc::new(ConstraintExecutionMemoryPool::new(exec_pool_size)) + }; + Self { + execution_pool, + partition_memory_manager: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Acquire size of `required` memory from manager + pub async fn acquire_exec_memory( + self: &Arc, + required: usize, + consumer_id: &MemoryConsumerId, + ) -> Result { + let partition_id = consumer_id.partition_id; + let mut all_managers = self.partition_memory_manager.lock().await; + let partition_manager = all_managers + .entry(partition_id) + .or_insert_with(|| PartitionMemoryManager::new(partition_id, self.clone())); + partition_manager + .acquire_exec_memory(required, consumer_id) + .await + } + + /// Register consumer to manager, for memory tracking and enables spilling by + /// memory used. + pub async fn register_consumer(self: &Arc, consumer: Arc) { + let partition_id = consumer.partition_id(); + let mut all_managers = self.partition_memory_manager.lock().await; + let partition_manager = all_managers + .entry(partition_id) + .or_insert_with(|| PartitionMemoryManager::new(partition_id, self.clone())); + partition_manager.register_consumer(consumer).await; + } + + pub(crate) async fn acquire_exec_pool_memory( + &self, + required: usize, + consumer: &MemoryConsumerId, + ) -> usize { + self.execution_pool.acquire_memory(required, consumer).await + } + + pub(crate) async fn release_exec_pool_memory( + &self, + release_size: usize, + partition_id: usize, + ) { + self.execution_pool + .release_memory(release_size, partition_id) + .await + } + + /// Revise pool usage while handling variable length data structure. + /// In this case, we may estimate and allocate in advance, and revise the usage + /// after the construction of the data structure. + #[allow(dead_code)] + pub(crate) async fn update_exec_pool_usage( + &self, + granted_size: usize, + real_size: usize, + consumer: &MemoryConsumerId, + ) { + self.execution_pool + .update_usage(granted_size, real_size, consumer) + .await + } + + /// Called during the shutdown procedure of a partition, for memory reclamation. + #[allow(dead_code)] + pub(crate) async fn release_all_exec_pool_for_partition( + &self, + partition_id: usize, + ) -> usize { + self.execution_pool.release_all(partition_id).await + } + + #[allow(dead_code)] + pub(crate) fn exec_memory_used(&self) -> usize { + self.execution_pool.memory_used() + } + + pub(crate) fn exec_memory_used_for_partition(&self, partition_id: usize) -> usize { + self.execution_pool.memory_used_partition(partition_id) + } +} + +fn next_id() -> usize { + unsafe { CONSUMER_ID.fetch_add(1, Ordering::SeqCst) } +} + +/// Memory manager that tracks all consumers for a specific partition +/// Trigger the spill for consumer(s) when memory is insufficient +pub struct PartitionMemoryManager { + memory_manager: Weak, + partition_id: usize, + consumers: Mutex>>, +} + +impl PartitionMemoryManager { + /// Create manager for a partition + pub fn new(partition_id: usize, memory_manager: Arc) -> Self { + Self { + memory_manager: Arc::downgrade(&memory_manager), + partition_id, + consumers: Mutex::new(HashMap::new()), + } + } + + /// Register a memory consumer at its first appearance + pub async fn register_consumer(&self, consumer: Arc) { + let mut consumers = self.consumers.lock().await; + let id = consumer.id().clone(); + consumers.insert(id, consumer); + } + + /// Try to acquire `required` of execution memory for the consumer and return the number of bytes + /// obtained, or return OutOfMemoryError if no enough memory avaiable even after possible spills. + pub async fn acquire_exec_memory( + &self, + required: usize, + consumer_id: &MemoryConsumerId, + ) -> Result { + let mut consumers = self.consumers.lock().await; + let memory_manager = self.memory_manager.upgrade().ok_or_else(|| { + DataFusionError::Execution("Failed to get MemoryManager".to_string()) + })?; + let mut got = memory_manager + .acquire_exec_pool_memory(required, consumer_id) + .await; + if got < required { + // Try to release memory from other consumers first + // Sort the consumers according to their memory usage and spill from + // consumer that holds the maximum memory, to reduce the total frequency of + // spilling + + let mut all_consumers: Vec> = vec![]; + for c in consumers.iter() { + all_consumers.push(c.1.clone()); + } + all_consumers.sort_by(|a, b| b.get_used().cmp(&a.get_used())); + + for c in all_consumers.iter_mut() { + if c.id() == consumer_id { + continue; + } + + let released = c.spill(required - got, consumer_id).await?; + if released > 0 { + debug!( + "Partition {} released {} from consumer {}", + self.partition_id, + released, + c.str_repr() + ); + got += memory_manager + .acquire_exec_pool_memory(required - got, consumer_id) + .await; + if got > required { + break; + } + } + } + } + + if got < required { + // spill itself + let consumer = consumers.get_mut(consumer_id).unwrap(); + let released = consumer.spill(required - got, consumer_id).await?; + if released > 0 { + debug!( + "Partition {} released {} from consumer itself {}", + self.partition_id, + released, + consumer.str_repr() + ); + got += memory_manager + .acquire_exec_pool_memory(required - got, consumer_id) + .await; + } + } + + if got < required { + return Err(OutOfMemory(format!( + "Unable to acquire {} bytes of memory, got {}", + required, got + ))); + } + + debug!("{} acquired {}", consumer_id, got); + Ok(got) + } + + /// log current memory usage for all consumers in this partition + pub async fn show_memory_usage(&self) -> Result<()> { + info!("Memory usage for partition {}", self.partition_id); + let consumers = self.consumers.lock().await; + let mut used = 0; + for (_, c) in consumers.iter() { + let cur_used = c.get_used(); + used += cur_used; + if cur_used > 0 { + info!( + "Consumer {} acquired {}", + c.str_repr(), + human_readable_size(cur_used as usize) + ) + } + } + let no_consumer_size = self + .memory_manager + .upgrade() + .ok_or_else(|| { + DataFusionError::Execution("Failed to get MemoryManager".to_string()) + })? + .exec_memory_used_for_partition(self.partition_id) + - (used as usize); + info!( + "{} bytes of memory were used for partition {} without specific consumer", + human_readable_size(no_consumer_size), + self.partition_id + ); + Ok(()) + } +} + +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +/// Id that uniquely identifies a Memory Consumer +pub struct MemoryConsumerId { + /// partition the consumer belongs to + pub partition_id: usize, + /// unique id + pub id: usize, +} + +impl MemoryConsumerId { + /// Auto incremented new Id + pub fn new(partition_id: usize) -> Self { + let id = next_id(); + Self { partition_id, id } + } +} + +impl Display for MemoryConsumerId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}:{}", self.partition_id, self.id) + } +} + +#[async_trait] +/// A memory consumer that supports spilling. +pub trait MemoryConsumer: Send + Sync + Debug { + /// Display name of the consumer + fn name(&self) -> String; + + /// Unique id of the consumer + fn id(&self) -> &MemoryConsumerId; + + /// Ptr to MemoryManager + fn memory_manager(&self) -> Arc; + + /// partition that the consumer belongs to + fn partition_id(&self) -> usize { + self.id().partition_id + } + + /// Try allocate `required` bytes as needed + async fn allocate(&self, required: usize) -> Result<()> { + let got = self + .memory_manager() + .acquire_exec_memory(required, self.id()) + .await?; + self.update_used(got as isize); + Ok(()) + } + + /// Spill at least `size` bytes to disk and update related counters + async fn spill(&self, size: usize, trigger: &MemoryConsumerId) -> Result { + let released = self.spill_inner(size, trigger).await?; + if released > 0 { + self.memory_manager() + .release_exec_pool_memory(released, self.id().partition_id) + .await; + self.update_used(-(released as isize)); + self.spilled_bytes_add(released); + self.spilled_count_increment(); + } + Ok(released) + } + + /// Spill at least `size` bytes to disk and frees memory + async fn spill_inner(&self, size: usize, trigger: &MemoryConsumerId) + -> Result; + + /// Get current memory usage for the consumer itself + fn get_used(&self) -> isize; + + /// Update memory usage + fn update_used(&self, delta: isize); + + /// Get total number of spilled bytes so far + fn spilled_bytes(&self) -> usize; + + /// Update spilled bytes counter + fn spilled_bytes_add(&self, add: usize); + + /// Get total number of triggered spills so far + fn spilled_count(&self) -> usize; + + /// Update spilled count + fn spilled_count_increment(&self); + + /// String representation for the consumer + fn str_repr(&self) -> String { + format!("{}({})", self.name(), self.id()) + } + + #[inline] + /// log during spilling + fn log_spill(&self, size: usize) { + info!( + "{} spilling of {} bytes to disk ({} times so far)", + self.str_repr(), + size, + self.spilled_count() + ); + } +} + +const TB: u64 = 1 << 40; +const GB: u64 = 1 << 30; +const MB: u64 = 1 << 20; +const KB: u64 = 1 << 10; + +fn human_readable_size(size: usize) -> String { + let size = size as u64; + let (value, unit) = { + if size >= 2 * TB { + (size as f64 / TB as f64, "TB") + } else if size >= 2 * GB { + (size as f64 / GB as f64, "GB") + } else if size >= 2 * MB { + (size as f64 / MB as f64, "MB") + } else if size >= 2 * KB { + (size as f64 / KB as f64, "KB") + } else { + (size as f64, "B") + } + }; + format!("{:.1} {}", value, unit) +} diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index ff44dd43f834..b4d5aa0b9fc4 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -19,3 +19,6 @@ pub mod context; pub mod dataframe_impl; +pub mod disk_manager; +pub mod memory_management; +pub mod runtime_env; diff --git a/datafusion/src/execution/runtime_env.rs b/datafusion/src/execution/runtime_env.rs new file mode 100644 index 000000000000..d0cd4718ffa9 --- /dev/null +++ b/datafusion/src/execution/runtime_env.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution runtime environment that tracks memory, disk and various configurations +//! that are used during physical plan execution. + +use crate::error::Result; +use crate::execution::disk_manager::DiskManager; +use crate::execution::memory_management::{MemoryConsumer, MemoryManager}; +use std::sync::Arc; + +lazy_static! { + /// Employ lazy static temporarily for RuntimeEnv, to avoid plumbing it through + /// all `async fn execute(&self, partition: usize, runtime: Arc)` + pub static ref RUNTIME_ENV: Arc = { + let config = RuntimeConfig::new(); + Arc::new(RuntimeEnv::new(config).unwrap()) + }; +} + +#[derive(Clone)] +/// Execution runtime environment +pub struct RuntimeEnv { + /// Runtime configuration + pub config: RuntimeConfig, + /// Runtime memory management + pub memory_manager: Arc, + /// Manage temporary files during query execution + pub disk_manager: Arc, +} + +impl RuntimeEnv { + /// Create env based on configuration + pub fn new(config: RuntimeConfig) -> Result { + let memory_manager = Arc::new(MemoryManager::new(config.max_memory)); + let disk_manager = Arc::new(DiskManager::new(&config.local_dirs)?); + Ok(Self { + config, + memory_manager, + disk_manager, + }) + } + + /// Get execution batch size based on config + pub fn batch_size(&self) -> usize { + self.config.batch_size + } + + /// Register the consumer to get it tracked + pub async fn register_consumer(&self, memory_consumer: Arc) { + self.memory_manager.register_consumer(memory_consumer).await; + } +} + +#[derive(Clone)] +/// Execution runtime configuration +pub struct RuntimeConfig { + /// Default batch size when creating new batches + pub batch_size: usize, + /// Max execution memory allowed for DataFusion + pub max_memory: usize, + /// Local dirs to store temporary files during execution + pub local_dirs: Vec, +} + +impl RuntimeConfig { + /// New with default values + pub fn new() -> Self { + Default::default() + } + + /// Customize batch size + pub fn with_batch_size(mut self, n: usize) -> Self { + // batch size must be greater than zero + assert!(n > 0); + self.batch_size = n; + self + } + + /// Customize exec size + pub fn with_max_execution_memory(mut self, max_memory: usize) -> Self { + assert!(max_memory > 0); + self.max_memory = max_memory; + self + } + + /// Customize exec size + pub fn with_local_dirs(mut self, local_dirs: Vec) -> Self { + assert!(local_dirs.len() > 0); + self.local_dirs = local_dirs; + self + } +} + +impl Default for RuntimeConfig { + fn default() -> Self { + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().to_str().unwrap().to_string(); + std::mem::forget(tmp_dir); + + Self { + batch_size: 8192, + max_memory: usize::MAX, + local_dirs: vec![path], + } + } +} diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index bb90b1703931..2539715a719a 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -230,6 +230,7 @@ pub mod variable; // re-export dependencies from arrow-rs to minimise version maintenance for crate users pub use arrow; +mod arrow_dyn_list_array; mod arrow_temporal_util; #[cfg(test)] diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 9af8911062df..38624876922c 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -18,15 +18,18 @@ //! CoalesceBatches optimizer that groups batches together rows //! in bigger batches to avoid overhead with small batches -use super::optimizer::PhysicalOptimizerRule; +use std::sync::Arc; + +use crate::physical_plan::joins::hash_join::HashJoinExec; use crate::{ error::Result, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, - hash_join::HashJoinExec, repartition::RepartitionExec, + repartition::RepartitionExec, }, }; -use std::sync::Arc; + +use super::optimizer::PhysicalOptimizerRule; /// Optimizer that introduces CoalesceBatchesExec to avoid overhead with small batches pub struct CoalesceBatches {} diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs index 0b87ceb1a4e2..f82f14ec1148 100644 --- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs +++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs @@ -20,17 +20,17 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use crate::error::Result; use crate::execution::context::ExecutionConfig; use crate::logical_plan::JoinType; -use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::expressions::Column; -use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::joins::cross_join::CrossJoinExec; +use crate::physical_plan::joins::hash_join::HashJoinExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use super::utils::optimize_children; -use crate::error::Result; /// BuildProbeOrder reorders the build and probe phase of /// hash joins. This uses the amount of rows that a datasource has. @@ -153,16 +153,15 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder { #[cfg(test)] mod tests { - use crate::{ - physical_plan::{hash_join::PartitionMode, Statistics}, - test::exec::StatisticsExec, - }; - - use super::*; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; + use crate::physical_plan::joins::hash_join::PartitionMode; + use crate::{physical_plan::Statistics, test::exec::StatisticsExec}; + + use super::*; + fn create_big_and_small() -> (Arc, Arc) { let big = Arc::new(StatisticsExec::new( Statistics { diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index ae320bb55733..5584a64ed2e1 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -25,11 +25,13 @@ use arrow::compute::concat; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; +use arrow::io::ipc::write::FileWriter; use arrow::record_batch::RecordBatch; use futures::channel::mpsc; use futures::{SinkExt, Stream, StreamExt, TryStreamExt}; use std::fs; -use std::fs::metadata; +use std::fs::{metadata, File}; +use std::io::BufWriter; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::task::JoinHandle; @@ -227,6 +229,63 @@ pub fn compute_record_batch_statistics( } } +/// Write in Arrow IPC format. +pub struct IPCWriterWrapper { + /// path + pub path: String, + /// Inner writer + pub writer: FileWriter>, + /// bathes written + pub num_batches: u64, + /// rows written + pub num_rows: u64, + /// bytes written + pub num_bytes: u64, +} + +impl IPCWriterWrapper { + /// Create new writer + pub fn new(path: &str, schema: &Schema) -> Result { + let file = File::create(path).map_err(|e| DataFusionError::IoError(e))?; + let buffer_writer = std::io::BufWriter::new(file); + Ok(Self { + num_batches: 0, + num_rows: 0, + num_bytes: 0, + path: path.to_owned(), + writer: FileWriter::try_new(buffer_writer, schema)?, + }) + } + + /// Write one single batch + pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { + self.writer.write(batch)?; + self.num_batches += 1; + self.num_rows += batch.num_rows() as u64; + let num_bytes: usize = batch_memory_size(batch); + self.num_bytes += num_bytes as u64; + Ok(()) + } + + /// Finish the writer + pub fn finish(&mut self) -> Result<()> { + self.writer.finish().map_err(DataFusionError::ArrowError) + } + + /// Path write to + pub fn path(&self) -> &str { + &self.path + } +} + +/// Estimate batch memory footprint +pub fn batch_memory_size(rb: &RecordBatch) -> usize { + rb.columns() + .iter() + .map(|c| estimated_bytes_size(c.as_ref())) + .sum() +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index e2e849085484..6607144657d8 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -142,6 +142,18 @@ impl PhysicalSortExpr { } } +/// Convert sort expressions into Vec that can be passed into arrow sort kernel +pub fn exprs_to_sort_columns( + batch: &RecordBatch, + expr: &[PhysicalSortExpr], +) -> Result> { + let columns = expr + .iter() + .map(|e| e.evaluate_to_sort_column(&batch)) + .collect::>>(); + columns +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 72c1a54ff611..e57b0e5e5c0b 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -34,6 +34,7 @@ use crate::physical_plan::{ }; use crate::{ error::{DataFusionError, Result}, + execution::memory_management::MemoryConsumerId, scalar::ScalarValue, }; @@ -212,8 +213,11 @@ impl ExecutionPlan for HashAggregateExec { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let streamer_id = MemoryConsumerId::new(partition); + if self.group_expr.is_empty() { Ok(Box::pin(HashAggregateStream::new( + streamer_id, self.mode, self.schema.clone(), self.aggr_expr.clone(), @@ -740,6 +744,7 @@ pin_project! { /// Special case aggregate with no groups async fn compute_hash_aggregate( + _id: MemoryConsumerId, mode: AggregateMode, schema: SchemaRef, aggr_expr: Vec>, @@ -776,6 +781,7 @@ async fn compute_hash_aggregate( impl HashAggregateStream { /// Create a new HashAggregateStream pub fn new( + id: MemoryConsumerId, mode: AggregateMode, schema: SchemaRef, aggr_expr: Vec>, @@ -788,6 +794,7 @@ impl HashAggregateStream { let elapsed_compute = baseline_metrics.elapsed_compute().clone(); tokio::spawn(async move { let result = compute_hash_aggregate( + id, mode, schema_clone, aggr_expr, diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 494fe3f3dd5b..346b69db26f0 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -17,96 +17,17 @@ //! Functionality used both on logical and physical plans -use crate::error::{DataFusionError, Result}; +use std::sync::Arc; + pub use ahash::{CallHasher, RandomState}; use arrow::array::{ Array, ArrayRef, BooleanArray, DictionaryArray, DictionaryKey, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use std::collections::HashSet; -use std::sync::Arc; - -use crate::logical_plan::JoinType; -use crate::physical_plan::expressions::Column; - -/// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = Vec<(Column, Column)>; -/// Reference for JoinOn. -pub type JoinOnRef<'a> = &'a [(Column, Column)]; - -/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. -/// They are valid whenever their columns' intersection equals the set `on` -pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { - let left: HashSet = left - .fields() - .iter() - .enumerate() - .map(|(idx, f)| Column::new(f.name(), idx)) - .collect(); - let right: HashSet = right - .fields() - .iter() - .enumerate() - .map(|(idx, f)| Column::new(f.name(), idx)) - .collect(); - - check_join_set_is_valid(&left, &right, on) -} - -/// Checks whether the sets left, right and on compose a valid join. -/// They are valid whenever their intersection equals the set `on` -fn check_join_set_is_valid( - left: &HashSet, - right: &HashSet, - on: &[(Column, Column)], -) -> Result<()> { - let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); - let left_missing = on_left.difference(left).collect::>(); - - let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); - let right_missing = on_right.difference(right).collect::>(); - - if !left_missing.is_empty() | !right_missing.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}", - left_missing, - right_missing, - ))); - }; - - let remaining = right - .difference(on_right) - .cloned() - .collect::>(); - - let collisions = left.intersection(&remaining).collect::>(); +use arrow::datatypes::{DataType, TimeUnit}; - if !collisions.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.", - collisions, - ))); - }; - - Ok(()) -} - -/// Creates a schema for a join operation. -/// The fields from the left side are first -pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { - let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let left_fields = left.fields().iter(); - let right_fields = right.fields().iter(); - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinType::Semi | JoinType::Anti => left.fields().clone(), - }; - Schema::new(fields) -} +use crate::error::{DataFusionError, Result}; // Combines two hashes into one hash #[inline] @@ -599,65 +520,6 @@ mod tests { use super::*; - fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { - let left = left - .iter() - .map(|x| x.to_owned()) - .collect::>(); - let right = right - .iter() - .map(|x| x.to_owned()) - .collect::>(); - check_join_set_is_valid(&left, &right, on) - } - - #[test] - fn check_valid() -> Result<()> { - let left = vec![Column::new("a", 0), Column::new("b1", 1)]; - let right = vec![Column::new("a", 0), Column::new("b2", 1)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - check(&left, &right, on)?; - Ok(()) - } - - #[test] - fn check_not_in_right() { - let left = vec![Column::new("a", 0), Column::new("b", 1)]; - let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_not_in_left() { - let left = vec![Column::new("b", 0)]; - let right = vec![Column::new("a", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_collision() { - // column "a" would appear both in left and right - let left = vec![Column::new("a", 0), Column::new("c", 1)]; - let right = vec![Column::new("a", 0), Column::new("b", 1)]; - let on = &[(Column::new("a", 0), Column::new("b", 1))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_in_right() { - let left = vec![Column::new("a", 0), Column::new("c", 1)]; - let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("b", 0))]; - - assert!(check(&left, &right, on).is_ok()); - } - #[test] fn create_hashes_for_float_arrays() -> Result<()> { let f32_arr = Arc::new(Float32Array::from_slice(&[0.12, 0.5, 1f32, 444.7])); diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/joins/cross_join.rs similarity index 99% rename from datafusion/src/physical_plan/cross_join.rs rename to datafusion/src/physical_plan/joins/cross_join.rs index 0edc3cc0d1aa..e3f9b1bc36e0 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/joins/cross_join.rs @@ -18,32 +18,30 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use futures::{lock::Mutex, StreamExt}; +use std::time::Instant; use std::{any::Any, sync::Arc, task::Poll}; -use crate::physical_plan::memory::MemoryStream; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; - +use async_trait::async_trait; +use futures::{lock::Mutex, StreamExt}; use futures::{Stream, TryStreamExt}; +use log::debug; -use super::{ - coalesce_partitions::CoalescePartitionsExec, hash_utils::check_join_is_valid, - ColumnStatistics, Statistics, +use crate::physical_plan::joins::check_join_is_valid; +use crate::physical_plan::memory::MemoryStream; +use crate::physical_plan::{ + coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, +}; +use crate::physical_plan::{ + coalesce_partitions::CoalescePartitionsExec, ColumnStatistics, Statistics, }; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, }; -use async_trait::async_trait; -use std::time::Instant; - -use super::{ - coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, -}; -use log::debug; /// Data of the left side type JoinLeftData = RecordBatch; diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/joins/hash_join.rs similarity index 93% rename from datafusion/src/physical_plan/hash_join.rs rename to datafusion/src/physical_plan/joins/hash_join.rs index 259cba65db56..fdb19cf80fbc 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/joins/hash_join.rs @@ -18,49 +18,42 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. -use ahash::RandomState; - -use smallvec::{smallvec, SmallVec}; +use std::fmt; use std::sync::Arc; use std::{any::Any, usize}; use std::{time::Instant, vec}; -use async_trait::async_trait; -use futures::{Stream, StreamExt, TryStreamExt}; -use tokio::sync::Mutex; - +use ahash::RandomState; +use arrow::compute::take; use arrow::datatypes::*; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::{array::*, buffer::MutableBuffer}; - -use arrow::compute::take; - +use async_trait::async_trait; +use futures::{Stream, StreamExt, TryStreamExt}; use hashbrown::raw::RawTable; +use log::debug; +use smallvec::{smallvec, SmallVec}; +use tokio::sync::Mutex; -use super::{ - coalesce_partitions::CoalescePartitionsExec, - hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; +use crate::physical_plan::coalesce_batches::concat_batches; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::joins::{ + build_join_schema, check_join_is_valid, column_indices_from_schema, equal_rows, + ColumnIndex, JoinOn, }; -use super::{ +use crate::physical_plan::PhysicalExpr; +use crate::physical_plan::{ expressions::Column, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, }; -use super::{hash_utils::create_hashes, Statistics}; -use crate::error::{DataFusionError, Result}; -use crate::logical_plan::JoinType; - -use super::{ +use crate::physical_plan::{hash_utils::create_hashes, Statistics}; +use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; -use crate::physical_plan::coalesce_batches::concat_batches; -use crate::physical_plan::PhysicalExpr; -use log::debug; -use std::fmt; - -type StringArray = Utf8Array; -type LargeStringArray = Utf8Array; // Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. // @@ -156,14 +149,6 @@ pub enum PartitionMode { CollectLeft, } -/// Information about the index and placement (left or right) of the columns -struct ColumnIndex { - /// Index of the column - index: usize, - /// Whether the column is at the left or right side - is_left: bool, -} - impl HashJoinExec { /// Tries to create a new [HashJoinExec]. /// # Error @@ -220,38 +205,6 @@ impl HashJoinExec { pub fn partition_mode(&self) -> &PartitionMode { &self.mode } - - /// Calculates column indices and left/right placement on input / output schemas and jointype - fn column_indices_from_schema(&self) -> ArrowResult> { - let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Full - | JoinType::Semi - | JoinType::Anti => (true, self.left.schema(), self.right.schema()), - JoinType::Right => (false, self.right.schema(), self.left.schema()), - }; - let mut column_indices = Vec::with_capacity(self.schema.fields().len()); - for field in self.schema.fields() { - let (is_primary, index) = match primary_schema.index_of(field.name()) { - Ok(i) => Ok((true, i)), - Err(_) => { - match secondary_schema.index_of(field.name()) { - Ok(i) => Ok((false, i)), - _ => Err(DataFusionError::Internal( - format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() - )) - } - } - }.map_err(DataFusionError::into_arrow_external_error)?; - - let is_left = - is_primary && primary_is_left || !is_primary && !primary_is_left; - column_indices.push(ColumnIndex { index, is_left }); - } - - Ok(column_indices) - } } #[async_trait] @@ -412,7 +365,12 @@ impl ExecutionPlan for HashJoinExec { let right_stream = self.right.execute(partition).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); - let column_indices = self.column_indices_from_schema()?; + let column_indices = column_indices_from_schema( + &self.join_type, + &self.left.schema(), + &self.right.schema(), + &self.schema, + )?; let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { @@ -780,57 +738,6 @@ fn build_join_indexes( } } -macro_rules! equal_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ - let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); - let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); - - match (left_array.is_null($left), right_array.is_null($right)) { - (false, false) => left_array.value($left) == right_array.value($right), - _ => false, - } - }}; -} - -/// Left and right row have equal values -fn equal_rows( - left: usize, - right: usize, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], -) -> Result { - let mut err = None; - let res = left_arrays - .iter() - .zip(right_arrays) - .all(|(l, r)| match l.data_type() { - DataType::Null => true, - DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), - DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), - DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), - DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), - DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), - DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), - DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), - DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), - DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), - DataType::Timestamp(_, None) => { - equal_rows_elem!(Int64Array, l, r, left, right) - } - DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), - DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), - _ => { - // This is internal because we should have caught this before. - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - }); - - err.unwrap_or(Ok(res)) -} - // Produces a batch for left-side rows that have/have not been matched during the whole join fn produce_from_matched( visited_left_side: &[bool], @@ -964,6 +871,8 @@ impl Stream for HashJoinStream { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{ assert_batches_sorted_eq, physical_plan::{ @@ -973,7 +882,6 @@ mod tests { }; use super::*; - use std::sync::Arc; fn build_table( a: (&str, &Vec), diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs new file mode 100644 index 000000000000..3117e27e944b --- /dev/null +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Various join implementations. + +pub mod cross_join; +pub mod hash_join; +pub mod sort_merge_join; + +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; +use crate::physical_plan::expressions::Column; +use arrow::array::ArrayRef; +use arrow::array::*; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::error::Result as ArrowResult; +use std::cmp::Ordering; +use std::collections::HashSet; +use std::sync::Arc; + +type StringArray = Utf8Array; +type LargeStringArray = Utf8Array; + +/// The on clause of the join, as vector of (left, right) columns. +pub type JoinOn = Vec<(Column, Column)>; +/// Reference for JoinOn. +pub type JoinOnRef<'a> = &'a [(Column, Column)]; + +/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. +/// They are valid whenever their columns' intersection equals the set `on` +pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { + let left: HashSet = left + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + let right: HashSet = right + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + + check_join_set_is_valid(&left, &right, on) +} + +/// Checks whether the sets left, right and on compose a valid join. +/// They are valid whenever their intersection equals the set `on` +pub fn check_join_set_is_valid( + left: &HashSet, + right: &HashSet, + on: &[(Column, Column)], +) -> Result<()> { + let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); + let left_missing = on_left.difference(left).collect::>(); + + let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); + let right_missing = on_right.difference(right).collect::>(); + + if !left_missing.is_empty() | !right_missing.is_empty() { + return Err(DataFusionError::Plan(format!( + "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}", + left_missing, + right_missing, + ))); + }; + + let remaining = right + .difference(on_right) + .cloned() + .collect::>(); + + let collisions = left.intersection(&remaining).collect::>(); + + if !collisions.is_empty() { + return Err(DataFusionError::Plan(format!( + "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.", + collisions, + ))); + }; + + Ok(()) +} + +/// Creates a schema for a join operation. +/// The fields from the left side are first +pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { + let fields: Vec = match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let left_fields = left.fields().iter(); + let right_fields = right.fields().iter(); + // left then right + left_fields.chain(right_fields).cloned().collect() + } + JoinType::Semi | JoinType::Anti => left.fields().clone(), + }; + Schema::new(fields) +} + +/// Information about the index and placement (left or right) of the columns +pub(crate) struct ColumnIndex { + /// Index of the column + index: usize, + /// Whether the column is at the left or right side + is_left: bool, +} + +/// Calculates column indices and left/right placement on input / output schemas and jointype +pub(crate) fn column_indices_from_schema( + join_type: &JoinType, + left_schema: &Arc, + right_schema: &Arc, + schema: &Arc, +) -> ArrowResult> { + let (primary_is_left, primary_schema, secondary_schema) = match join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => (true, left_schema, right_schema), + JoinType::Right => (false, right_schema, left_schema), + }; + let mut column_indices = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + let (is_primary, index) = match primary_schema.index_of(field.name()) { + Ok(i) => Ok((true, i)), + Err(_) => { + match secondary_schema.index_of(field.name()) { + Ok(i) => Ok((false, i)), + _ => Err(DataFusionError::Internal( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }.map_err(DataFusionError::into_arrow_external_error)?; + + let is_left = is_primary && primary_is_left || !is_primary && !primary_is_left; + column_indices.push(ColumnIndex { index, is_left }); + } + + Ok(column_indices) +} + +macro_rules! equal_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => left_array.value($left) == right_array.value($right), + _ => false, + } + }}; +} + +/// Left and right row have equal values +fn equal_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result { + let mut err = None; + let res = left_arrays + .iter() + .zip(right_arrays) + .all(|(l, r)| match l.data_type() { + DataType::Null => true, + DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), + DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), + DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), + DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), + DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), + DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), + DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), + DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), + DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), + DataType::Timestamp(_, None) => { + equal_rows_elem!(Int64Array, l, r, left, right) + } + DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), + DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), + _ => { + // This is internal because we should have caught this before. + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + }); + + err.unwrap_or(Ok(res)) +} + +macro_rules! cmp_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $res: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => { + let cmp = left_array + .value($left) + .partial_cmp(&right_array.value($right)) + .unwrap(); + if cmp != Ordering::Equal { + $res = cmp; + break; + } + } + _ => unreachable!(), + } + }}; +} + +macro_rules! cmp_rows_elem_str { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $res: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => { + let cmp = left_array + .value($left) + .partial_cmp(right_array.value($right)) + .unwrap(); + if cmp != Ordering::Equal { + $res = cmp; + break; + } + } + _ => unreachable!(), + } + }}; +} + +/// compare left row with right row +fn comp_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result { + let mut res = Ordering::Equal; + for (l, r) in left_arrays.iter().zip(right_arrays) { + match l.data_type() { + DataType::Null => {} + DataType::Boolean => cmp_rows_elem!(BooleanArray, l, r, left, right, res), + DataType::Int8 => cmp_rows_elem!(Int8Array, l, r, left, right, res), + DataType::Int16 => cmp_rows_elem!(Int16Array, l, r, left, right, res), + DataType::Int32 => cmp_rows_elem!(Int32Array, l, r, left, right, res), + DataType::Int64 => cmp_rows_elem!(Int64Array, l, r, left, right, res), + DataType::UInt8 => cmp_rows_elem!(UInt8Array, l, r, left, right, res), + DataType::UInt16 => cmp_rows_elem!(UInt16Array, l, r, left, right, res), + DataType::UInt32 => cmp_rows_elem!(UInt32Array, l, r, left, right, res), + DataType::UInt64 => cmp_rows_elem!(UInt64Array, l, r, left, right, res), + DataType::Timestamp(_, None) => { + cmp_rows_elem!(Int64Array, l, r, left, right, res) + } + DataType::Utf8 => cmp_rows_elem_str!(StringArray, l, r, left, right, res), + DataType::LargeUtf8 => { + cmp_rows_elem_str!(LargeStringArray, l, r, left, right, res) + } + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal( + "Unsupported data type in sort merge join comparator".to_string(), + )); + } + } + } + + Ok(res) +} + +#[cfg(test)] +mod tests { + + use crate::physical_plan::joins::check_join_set_is_valid; + + use super::*; + + fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { + let left = left + .iter() + .map(|x| x.to_owned()) + .collect::>(); + let right = right + .iter() + .map(|x| x.to_owned()) + .collect::>(); + check_join_set_is_valid(&left, &right, on) + } + + #[test] + fn check_valid() -> Result<()> { + let left = vec![Column::new("a", 0), Column::new("b1", 1)]; + let right = vec![Column::new("a", 0), Column::new("b2", 1)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + check(&left, &right, on)?; + Ok(()) + } + + #[test] + fn check_not_in_right() { + let left = vec![Column::new("a", 0), Column::new("b", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_not_in_left() { + let left = vec![Column::new("b", 0)]; + let right = vec![Column::new("a", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_collision() { + // column "a" would appear both in left and right + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("a", 0), Column::new("b", 1)]; + let on = &[(Column::new("a", 0), Column::new("b", 1))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_in_right() { + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("b", 0))]; + + assert!(check(&left, &right, on).is_ok()); + } +} diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs new file mode 100644 index 000000000000..79af2ae026b7 --- /dev/null +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -0,0 +1,2327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the join plan for executing partitions in parallel and then joining the results +//! into a set of partitions. + +use std::iter::repeat; +use std::sync::Arc; +use std::vec; +use std::{any::Any, usize}; + +use arrow::array::*; +use arrow::datatypes::*; +use arrow::error::{ArrowError, Result as ArrowResult}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use futures::StreamExt; + +use crate::arrow_dyn_list_array::DynMutableListArray; +use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::runtime_env::RUNTIME_ENV; +use crate::logical_plan::JoinType; +use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; +use crate::physical_plan::joins::{ + build_join_schema, check_join_is_valid, column_indices_from_schema, comp_rows, + equal_rows, ColumnIndex, JoinOn, +}; +use crate::physical_plan::sorts::external_sort::ExternalSortExec; +use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::Statistics; +use crate::physical_plan::{ + expressions::Column, + metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, +}; +use crate::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, +}; +use arrow::compute::partition::lexicographical_partition_ranges; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::ops::Range; +use tokio::sync::mpsc::{Receiver, Sender}; + +fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { + on_column + .iter() + .map(|c| rb.column(c.index()).clone()) + .collect() +} + +fn range_start_indices(buffered_ranges: &VecDeque>) -> Vec { + let mut idx = 0; + let mut start_indices: Vec = vec![]; + buffered_ranges.iter().for_each(|r| { + start_indices.push(idx); + idx += r.len(); + }); + start_indices.push(usize::MAX); + start_indices +} + +/// Locate buffered records start from `buffered_idx` of `len`gth +/// inside buffered batches. +fn slices_from_batches( + buffered_ranges: &VecDeque>, + start_indices: &Vec, + buffered_idx: usize, + len: usize, +) -> Vec { + let mut idx = buffered_idx; + let mut slices: Vec = vec![]; + let mut remaining = len; + let find = start_indices + .iter() + .enumerate() + .find(|(_, start_idx)| **start_idx >= idx) + .unwrap(); + let mut batch_idx = if *find.1 == idx { find.0 } else { find.0 - 1 }; + + while remaining > 0 { + let current_range = &buffered_ranges[batch_idx]; + let range_start_idx = start_indices[batch_idx]; + let start_idx = idx - range_start_idx + current_range.start; + let range_available = current_range.len() - (idx - range_start_idx); + + if range_available >= remaining { + slices.push(Slice { + batch_idx, + start_idx, + len: remaining, + }); + remaining = 0; + } else { + slices.push(Slice { + batch_idx, + start_idx, + len: range_available, + }); + remaining -= range_available; + batch_idx += 1; + idx += range_available; + } + } + slices +} + +/// Slice of batch at `batch_idx` inside BufferedBatches. +struct Slice { + batch_idx: usize, + start_idx: usize, + len: usize, +} + +#[derive(Clone)] +struct PartitionedRecordBatch { + batch: RecordBatch, + ranges: Vec>, +} + +impl PartitionedRecordBatch { + fn new( + batch: Option, + expr: &[PhysicalSortExpr], + ) -> Result> { + match batch { + Some(batch) => { + let columns = exprs_to_sort_columns(&batch, expr)?; + let ranges = lexicographical_partition_ranges( + &columns.iter().map(|x| x.into()).collect::>(), + )? + .collect::>(); + Ok(Some(Self { batch, ranges })) + } + None => Ok(None), + } + } + + #[inline] + fn is_last_range(&self, range: &Range) -> bool { + range.end == self.batch.num_rows() + } +} + +struct StreamingBatch { + batch: Option, + cur_row: usize, + cur_range: usize, + num_rows: usize, + num_ranges: usize, + is_new_key: bool, + on_column: Vec, + sort: Vec, +} + +impl StreamingBatch { + fn new(on_column: Vec, sort: Vec) -> Self { + Self { + batch: None, + cur_row: 0, + cur_range: 0, + num_rows: 0, + num_ranges: 0, + is_new_key: true, + on_column, + sort, + } + } + + fn rest_batch(&mut self, prb: Option) { + self.batch = prb; + if let Some(prb) = &self.batch { + self.cur_row = 0; + self.cur_range = 0; + self.num_rows = prb.batch.num_rows(); + self.num_ranges = prb.ranges.len(); + self.is_new_key = true; + }; + } + + fn key_any_null(&self) -> bool { + match &self.batch { + None => return true, + Some(batch) => { + for c in &self.on_column { + let array = batch.batch.column(c.index()); + if array.is_null(self.cur_row) { + return true; + } + } + false + } + } + } + + #[inline] + fn is_finished(&self) -> bool { + self.batch.is_none() || self.num_rows == self.cur_row + 1 + } + + #[inline] + fn is_last_key_in_batch(&self) -> bool { + self.batch.is_none() || self.num_ranges == self.cur_range + 1 + } + + fn advance(&mut self) { + self.cur_row += 1; + self.is_new_key = false; + if !self.is_last_key_in_batch() { + let ranges = &self.batch.as_ref().unwrap().ranges; + if self.cur_row == ranges[self.cur_range + 1].start { + self.cur_range += 1; + self.is_new_key = true; + } + } else { + self.batch = None; + } + } + + fn advance_key(&mut self) { + let ranges = &self.batch.as_ref().unwrap().ranges; + self.cur_range += 1; + self.cur_row = ranges[self.cur_range].start; + self.is_new_key = true; + } +} + +/// Holding ranges for same key over several bathes +struct BufferedBatches { + /// batches that contains the current key + /// TODO: make this spillable as well for skew on join key at buffer side + batches: VecDeque, + /// ranges in each PartitionedRecordBatch that contains the current key + ranges: VecDeque>, + /// row index in first batch to the record that starts this batch + key_idx: Option, + /// total number of rows for the current key + row_num: usize, + /// hold found but not currently used batch, to continue iteration + next_key_batch: Vec, + /// Join on column + on_column: Vec, + sort: Vec, + /// last range's index in the last batch + range_idx: usize, +} + +impl BufferedBatches { + fn new(on_column: Vec, sort: Vec) -> Self { + Self { + batches: VecDeque::new(), + ranges: VecDeque::new(), + key_idx: None, + row_num: 0, + next_key_batch: vec![], + on_column, + sort, + range_idx: 0, + } + } + + fn key_any_null(&self) -> bool { + match &self.key_idx { + None => return true, + Some(key_idx) => { + let first_batch = &self.batches[0].batch; + for c in &self.on_column { + let array = first_batch.column(c.index()); + if array.is_null(*key_idx) { + return true; + } + } + false + } + } + } + + fn is_finished(&self) -> Result { + match self.key_idx { + None => Ok(true), + Some(_) => match (self.batches.back(), self.ranges.back()) { + (Some(batch), Some(range)) => Ok(batch.is_last_range(range)), + _ => Err(DataFusionError::Execution(format!( + "Batches length {} not equal to ranges length {}", + self.batches.len(), + self.ranges.len() + ))), + }, + } + } + + /// Whether the running key ends at the current batch `prb`, true for continues, false for ends. + fn running_key(&mut self, prb: &PartitionedRecordBatch) -> Result { + let first_range = &prb.ranges[0]; + let range_len = first_range.len(); + let current_batch = &prb.batch; + let single_range = prb.ranges.len() == 1; + + // compare the first record in batch with the current key pointed by key_idx + match self.key_idx { + None => { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.range_idx = 0; + self.row_num += range_len; + Ok(single_range) + } + Some(key_idx) => { + let key_arrays = join_arrays(&self.batches[0].batch, &self.on_column); + let current_arrays = join_arrays(current_batch, &self.on_column); + let equal = equal_rows(key_idx, 0, &key_arrays, ¤t_arrays)?; + if equal { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.range_idx = 0; + self.row_num += range_len; + Ok(single_range) + } else { + self.next_key_batch.push(prb.clone()); + Ok(false) // running key ends + } + } + } + } + + fn cleanup(&mut self) { + self.batches.drain(..); + self.ranges.drain(..); + self.next_key_batch.drain(..); + } + + fn reset_batch(&mut self, prb: &PartitionedRecordBatch) { + self.cleanup(); + self.batches.push_back(prb.clone()); + let first_range = &prb.ranges[0]; + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.row_num = first_range.len(); + } + + /// Advance the cursor to the next key seen by this buffer + fn advance_in_current_batch(&mut self) { + assert_eq!(self.batches.len(), self.ranges.len()); + if self.batches.len() > 1 { + self.batches.drain(0..(self.batches.len() - 1)); + self.ranges.drain(0..(self.batches.len() - 1)); + } + + self.range_idx += 1; + if let Some(batch) = self.batches.pop_back() { + let tail_range = self.ranges.pop_back().unwrap(); + self.key_idx = Some(tail_range.end); + self.ranges.push_back(batch.ranges[self.range_idx].clone()); + self.row_num = batch.ranges[self.range_idx].len(); + self.batches.push_back(batch); + } + } +} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use arrow::datatypes::PrimitiveType::*; + use arrow::types::{days_ms, months_days_ns}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +fn make_mutable( + data_type: &DataType, + capacity: usize, +) -> ArrowResult> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => Box::new(MutableBooleanArray::with_capacity(capacity)) + as Box, + PhysicalType::Primitive(primitive) => { + with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }) + } + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) + as Box + } + PhysicalType::Utf8 => Box::new(MutableUtf8Array::::with_capacity(capacity)) + as Box, + _ => match data_type { + DataType::List(inner) => { + let values = make_mutable(inner.data_type(), 0)?; + Box::new(DynMutableListArray::::new_with_capacity( + values, capacity, + )) as Box + } + DataType::FixedSizeBinary(size) => Box::new( + MutableFixedSizeBinaryArray::with_capacity(*size as usize, capacity), + ) as Box, + _ => { + return Err(ArrowError::NotYetImplemented(format!( + "making mutable of type {} is not implemented yet", + data_type + ))); + } + }, + }) +} + +fn new_arrays( + schema: &Arc, + batch_size: usize, +) -> ArrowResult>> { + let arrays: Vec> = schema + .fields() + .iter() + .map(|field| { + let dt = field.data_type.to_logical_type(); + make_mutable(dt, batch_size) + }) + .collect::>()?; + Ok(arrays) +} + +fn make_batch( + schema: Arc, + mut arrays: Vec>, +) -> ArrowResult { + let columns = arrays.iter_mut().map(|array| array.as_arc()).collect(); + RecordBatch::try_new(schema, columns) +} + +struct OutputBuffer { + arrays: Vec>, + target_batch_size: usize, + slots_available: usize, + schema: Arc, +} + +impl OutputBuffer { + fn new(target_batch_size: usize, schema: Arc) -> Result { + let arrays = new_arrays(&schema, target_batch_size) + .map_err(DataFusionError::ArrowError)?; + Ok(Self { + arrays, + target_batch_size, + slots_available: target_batch_size, + schema, + }) + } + + fn output_and_reset(&mut self) -> ArrowResult { + let result = make_batch(self.schema.clone(), self.arrays.drain(..).collect()); + let mut new = new_arrays(&self.schema, self.target_batch_size)?; + self.arrays.append(&mut new); + self.slots_available = self.target_batch_size; + result + } + + fn append(&mut self, size: usize) { + assert!(size <= self.slots_available); + self.slots_available -= size; + } + + #[inline] + fn is_full(&self) -> bool { + self.slots_available == 0 + } +} + +macro_rules! repeat_n { + ($TO:ty, $FROM:ty, $N:expr, $to: ident, $from: ident, $idx: ident) => {{ + let to = $to.as_mut_any().downcast_mut::<$TO>().unwrap(); + let from = $from + .as_any() + .downcast_ref::<$FROM>() + .unwrap() + .slice($idx, 1); + let repeat_iter = from + .iter() + .flat_map(|v| repeat(v).take($N)) + .collect::>(); + to.extend_trusted_len(repeat_iter.into_iter()); + }}; +} + +macro_rules! copy_slices { + ($TO:ty, $FROM:ty, $array: ident, $batches: ident, $slices: ident, $column_index: ident) => {{ + let to = $array.as_mut_any().downcast_mut::<$TO>().unwrap(); + for pos in $slices { + let from = $batches[pos.batch_idx] + .column($column_index.index) + .slice(pos.start_idx, pos.len); + let from = from.as_any().downcast_ref::<$FROM>().unwrap(); + to.extend_trusted_len(from.iter()); + } + }}; +} + +/// repeat times of cell located by `idx` at streamed side to output +fn repeat_streamed_cell( + stream_batch: &RecordBatch, + idx: usize, + times: usize, + to: &mut Box, + column_index: &ColumnIndex, +) { + let from = stream_batch.column(column_index.index); + match to.data_type().to_physical_type() { + PhysicalType::Boolean => { + repeat_n!(MutableBooleanArray, BooleanArray, times, to, from, idx) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, times, to, from, idx), + PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, times, to, from, idx), + PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, times, to, from, idx), + PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, times, to, from, idx), + PrimitiveType::Float32 => { + repeat_n!(Float32Vec, Float32Array, times, to, from, idx) + } + PrimitiveType::Float64 => { + repeat_n!(Float64Vec, Float64Array, times, to, from, idx) + } + _ => todo!(), + }, + PhysicalType::Utf8 => { + repeat_n!(MutableUtf8Array, Utf8Array, times, to, from, idx) + } + _ => todo!(), + } +} + +fn copy_slices( + batches: &Vec<&RecordBatch>, + slices: &Vec, + array: &mut Box, + column_index: &ColumnIndex, +) { + // output buffered start `buffered_idx`, len `rows_to_output` + match array.data_type().to_physical_type() { + PhysicalType::Boolean => { + copy_slices!( + MutableBooleanArray, + BooleanArray, + array, + batches, + slices, + column_index + ) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => { + copy_slices!(Int8Vec, Int8Array, array, batches, slices, column_index) + } + PrimitiveType::Int16 => { + copy_slices!(Int16Vec, Int16Array, array, batches, slices, column_index) + } + PrimitiveType::Int32 => { + copy_slices!(Int32Vec, Int32Array, array, batches, slices, column_index) + } + PrimitiveType::Int64 => { + copy_slices!(Int64Vec, Int64Array, array, batches, slices, column_index) + } + PrimitiveType::Float32 => copy_slices!( + Float32Vec, + Float32Array, + array, + batches, + slices, + column_index + ), + PrimitiveType::Float64 => copy_slices!( + Float64Vec, + Float64Array, + array, + batches, + slices, + column_index + ), + _ => todo!(), + }, + PhysicalType::Utf8 => { + copy_slices!( + MutableUtf8Array, + Utf8Array, + array, + batches, + slices, + column_index + ) + } + _ => todo!(), + } +} + +struct SortMergeJoinDriver { + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + /// Information of index and left / right placement of columns + column_indices: Vec, + stream_batch: StreamingBatch, + buffered_batches: BufferedBatches, + output: OutputBuffer, +} + +impl SortMergeJoinDriver { + fn new( + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec, + on_buffered: Vec, + streamed_sort: Vec, + buffered_sort: Vec, + column_indices: Vec, + schema: Arc, + runtime: Arc, + ) -> Result { + let batch_size = runtime.batch_size(); + Ok(Self { + streamed, + buffered, + column_indices, + stream_batch: StreamingBatch::new(on_streamed, streamed_sort), + buffered_batches: BufferedBatches::new(on_buffered, buffered_sort), + output: OutputBuffer::new(batch_size, schema)?, + }) + } + + async fn inner_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + while self.find_inner_next().await? { + loop { + self.join_eq_records(sender).await?; + self.stream_batch.advance(); + if self.stream_batch.is_new_key { + break; + } + } + } + + Ok(()) + } + + async fn outer_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let mut buffer_ends = false; + + loop { + let OuterMatchResult { + get_match, + buffered_ended, + more_output, + } = self.find_outer_next(buffer_ends).await?; + if !more_output { + break; + } + buffer_ends = buffered_ended; + if get_match { + loop { + self.join_eq_records(sender).await?; + self.stream_batch.advance(); + if self.stream_batch.is_new_key { + break; + } + } + } else { + self.stream_copy_buffer_null(sender).await?; + } + } + + Ok(()) + } + + async fn full_outer_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let mut stream_ends = false; + let mut buffer_ends = false; + let mut advance_stream = true; + let mut advance_buffer = true; + + loop { + if advance_buffer { + buffer_ends = !self.advance_buffered_key().await?; + } + if advance_stream { + stream_ends = !self.advance_streamed_key().await?; + } + + if stream_ends && buffer_ends { + break; + } else if stream_ends { + self.stream_null_buffer_copy(sender).await?; + advance_buffer = true; + advance_stream = false; + } else if buffer_ends { + self.stream_copy_buffer_null(sender).await?; + advance_stream = true; + advance_buffer = false; + } else { + if self.stream_batch.key_any_null() { + self.stream_copy_buffer_null(sender).await?; + advance_stream = true; + advance_buffer = false; + continue; + } + if self.buffered_batches.key_any_null() { + self.stream_null_buffer_copy(sender).await?; + advance_buffer = true; + advance_stream = false; + continue; + } + + let current_cmp = self.compare_stream_buffer()?; + match current_cmp { + Ordering::Less => { + self.stream_copy_buffer_null(sender).await?; + advance_stream = true; + advance_buffer = false; + } + Ordering::Equal => { + loop { + self.join_eq_records(sender).await?; + self.stream_batch.advance(); + if self.stream_batch.is_new_key { + break; + } + } + advance_stream = false; // we already reach the next key of stream + advance_buffer = true; + } + Ordering::Greater => { + self.stream_null_buffer_copy(sender).await?; + advance_buffer = true; + advance_stream = false; + } + } + } + } + Ok(()) + } + + async fn semi_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + while self.find_inner_next().await? { + self.stream_copy_buffer_omit(sender).await?; + } + Ok(()) + } + + async fn anti_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let mut buffer_ends = false; + + loop { + let OuterMatchResult { + get_match, + buffered_ended, + more_output, + } = self.find_outer_next(buffer_ends).await?; + if !more_output { + break; + } + buffer_ends = buffered_ended; + if get_match { + // do nothing + } else { + self.stream_copy_buffer_omit(sender).await?; + } + } + + Ok(()) + } + + async fn join_eq_records( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let mut remaining = self.buffered_batches.row_num; + let stream_batch = &self.stream_batch.batch.as_ref().unwrap().batch; + let stream_row = self.stream_batch.cur_row; + + let batches = self + .buffered_batches + .batches + .iter() + .map(|prb| &prb.batch) + .collect::>(); + let buffered_ranges = &self.buffered_batches.ranges; + + let mut unfinished = true; + let mut buffered_idx = 0; + let start_indices = range_start_indices(buffered_ranges); + + // output each buffered matching record once + while unfinished { + let output_slots_available = self.output.slots_available; + let rows_to_output = if output_slots_available >= remaining { + unfinished = false; + remaining + } else { + remaining -= output_slots_available; + output_slots_available + }; + + // get slices for buffered side for the current output + let slices = slices_from_batches( + buffered_ranges, + &start_indices, + buffered_idx, + rows_to_output, + ); + + self.output + .arrays + .iter_mut() + .zip(self.column_indices.iter()) + .for_each(|(array, column_index)| { + if column_index.is_left { + // repeat streamed `rows_to_output` times + repeat_streamed_cell( + stream_batch, + stream_row, + rows_to_output, + array, + column_index, + ); + } else { + // copy buffered start from: `buffered_idx`, len: `rows_to_output` + copy_slices(&batches, &slices, array, column_index); + } + }); + + self.output.append(rows_to_output); + buffered_idx += rows_to_output; + + if self.output.is_full() { + let result = self.output.output_and_reset(); + if let Err(e) = sender.send(result).await { + println!("ERROR batch via inner join stream: {}", e); + }; + } + } + Ok(()) + } + + async fn stream_copy_buffer_null( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let stream_batch = &self.stream_batch.batch.as_ref().unwrap().batch; + let batch = vec![stream_batch]; + let stream_range = &self.stream_batch.batch.as_ref().unwrap().ranges + [self.stream_batch.cur_range]; + let mut remaining = stream_range.len(); + + let mut unfinished = true; + let mut streamed_idx = self.stream_batch.cur_row; + + // output each buffered matching record once + while unfinished { + let output_slots_available = self.output.slots_available; + let rows_to_output = if output_slots_available >= remaining { + unfinished = false; + remaining + } else { + remaining -= output_slots_available; + output_slots_available + }; + + let slice = vec![Slice { + batch_idx: 0, + start_idx: streamed_idx, + len: rows_to_output, + }]; + + self.output + .arrays + .iter_mut() + .zip(self.column_indices.iter()) + .for_each(|(array, column_index)| { + if column_index.is_left { + copy_slices(&batch, &slice, array, column_index); + } else { + (0..rows_to_output).for_each(|_| array.push_null()); + } + }); + + self.output.append(rows_to_output); + streamed_idx += rows_to_output; + + if self.output.is_full() { + let result = self.output.output_and_reset(); + if let Err(e) = sender.send(result).await { + println!("ERROR batch via outer join stream: {}", e); + }; + } + } + Ok(()) + } + + async fn stream_null_buffer_copy( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let mut remaining = self.buffered_batches.row_num; + + let batches = self + .buffered_batches + .batches + .iter() + .map(|prb| &prb.batch) + .collect::>(); + let buffered_ranges = &self.buffered_batches.ranges; + + let mut unfinished = true; + let mut buffered_idx = 0; + let start_indices = range_start_indices(buffered_ranges); + + // output each buffered matching record once + while unfinished { + let output_slots_available = self.output.slots_available; + let rows_to_output = if output_slots_available >= remaining { + unfinished = false; + remaining + } else { + remaining -= output_slots_available; + output_slots_available + }; + + // get slices for buffered side for the current output + let slices = slices_from_batches( + buffered_ranges, + &start_indices, + buffered_idx, + rows_to_output, + ); + + self.output + .arrays + .iter_mut() + .zip(self.column_indices.iter()) + .for_each(|(array, column_index)| { + if column_index.is_left { + (0..rows_to_output).for_each(|_| array.push_null()); + } else { + // copy buffered start from: `buffered_idx`, len: `rows_to_output` + copy_slices(&batches, &slices, array, column_index); + } + }); + + self.output.append(rows_to_output); + buffered_idx += rows_to_output; + + if self.output.is_full() { + let result = self.output.output_and_reset(); + if let Err(e) = sender.send(result).await { + println!("ERROR batch via outer join stream: {}", e); + }; + } + } + Ok(()) + } + + async fn stream_copy_buffer_omit( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let stream_batch = &self.stream_batch.batch.as_ref().unwrap().batch; + let batch = vec![stream_batch]; + let stream_range = &self.stream_batch.batch.as_ref().unwrap().ranges + [self.stream_batch.cur_range]; + let mut remaining = stream_range.len(); + + let mut unfinished = true; + let mut streamed_idx = self.stream_batch.cur_row; + + // output each buffered matching record once + while unfinished { + let output_slots_available = self.output.slots_available; + let rows_to_output = if output_slots_available >= remaining { + unfinished = false; + remaining + } else { + remaining -= output_slots_available; + output_slots_available + }; + + let slice = vec![Slice { + batch_idx: 0, + start_idx: streamed_idx, + len: rows_to_output, + }]; + + self.output + .arrays + .iter_mut() + .zip(self.column_indices.iter()) + .for_each(|(array, column_index)| { + copy_slices(&batch, &slice, array, column_index); + }); + + self.output.append(rows_to_output); + streamed_idx += rows_to_output; + + if self.output.is_full() { + let result = self.output.output_and_reset(); + if let Err(e) = sender.send(result).await { + println!("ERROR batch via semi/anti join stream: {}", e); + }; + } + } + Ok(()) + } + + async fn find_inner_next(&mut self) -> Result { + if self.stream_batch.key_any_null() { + let more_stream = self.advance_streamed_key_null_free().await?; + if !more_stream { + return Ok(false); + } + } + + if self.buffered_batches.key_any_null() { + let more_buffer = self.advance_buffered_key_null_free().await?; + if !more_buffer { + return Ok(false); + } + } + + loop { + let current_cmp = self.compare_stream_buffer()?; + match current_cmp { + Ordering::Less => { + let more_stream = self.advance_streamed_key_null_free().await?; + if !more_stream { + return Ok(false); + } + } + Ordering::Equal => return Ok(true), + Ordering::Greater => { + let more_buffer = self.advance_buffered_key_null_free().await?; + if !more_buffer { + return Ok(false); + } + } + } + } + } + + async fn find_outer_next(&mut self, buffer_ends: bool) -> Result { + let more_stream = self.advance_streamed_key().await?; + if buffer_ends { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: true, + more_output: more_stream, + }); + } else { + if !more_stream { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: false, + more_output: false, + }); + } + + if self.buffered_batches.key_any_null() { + let more_buffer = self.advance_buffered_key_null_free().await?; + if !more_buffer { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: true, + more_output: true, + }); + } + } + + loop { + if self.stream_batch.key_any_null() { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: false, + more_output: true, + }); + } + + let current_cmp = self.compare_stream_buffer()?; + match current_cmp { + Ordering::Less => { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: false, + more_output: true, + }) + } + Ordering::Equal => { + return Ok(OuterMatchResult { + get_match: true, + buffered_ended: false, + more_output: true, + }) + } + Ordering::Greater => { + let more_buffer = self.advance_buffered_key_null_free().await?; + if !more_buffer { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: true, + more_output: true, + }); + } + } + } + } + } + } + + async fn get_stream_next(&mut self) -> Result<()> { + let batch = self.streamed.next().await.transpose()?; + let prb = PartitionedRecordBatch::new(batch, &self.stream_batch.sort)?; + self.stream_batch.rest_batch(prb); + Ok(()) + } + + /// true for has next, false for ended + async fn advance_streamed_key(&mut self) -> Result { + if self.stream_batch.is_finished() || self.stream_batch.is_last_key_in_batch() { + self.get_stream_next().await?; + Ok(!self.stream_batch.is_finished()) + } else { + self.stream_batch.advance_key(); + Ok(true) + } + } + + /// true for has next, false for ended + async fn advance_streamed_key_null_free(&mut self) -> Result { + let mut more_stream_keys = self.advance_streamed_key().await?; + loop { + if more_stream_keys && self.stream_batch.key_any_null() { + more_stream_keys = self.advance_streamed_key().await?; + } else { + break; + } + } + Ok(more_stream_keys) + } + + async fn get_buffered_next(&mut self) -> Result> { + let batch = self.buffered.next().await.transpose()?; + PartitionedRecordBatch::new(batch, &self.buffered_batches.sort) + } + + /// true for has next, false for ended + async fn advance_buffered_key(&mut self) -> Result { + if self.buffered_batches.is_finished()? { + if self.buffered_batches.next_key_batch.is_empty() { + let batch = self.get_buffered_next().await?; + match batch { + None => return Ok(false), + Some(batch) => { + self.buffered_batches.reset_batch(&batch); + if batch.ranges.len() == 1 { + self.cumulate_same_keys().await?; + } + self.buffered_batches.range_idx = 0; + } + } + } else { + assert_eq!(self.buffered_batches.next_key_batch.len(), 1); + let batch = self.buffered_batches.next_key_batch.pop().unwrap(); + self.buffered_batches.reset_batch(&batch); + if batch.ranges.len() == 1 { + self.cumulate_same_keys().await?; + } + self.buffered_batches.range_idx = 0; + } + } else { + self.buffered_batches.advance_in_current_batch(); + if self.buffered_batches.batches[0] + .is_last_range(&self.buffered_batches.ranges[0]) + { + self.cumulate_same_keys().await?; + self.buffered_batches.range_idx = 0; + } + } + Ok(false) + } + + /// true for has next, false for buffer side ended + async fn cumulate_same_keys(&mut self) -> Result<()> { + loop { + let batch = self.get_buffered_next().await?; + match batch { + None => return Ok(()), + Some(batch) => { + let more_batches = self.buffered_batches.running_key(&batch)?; + if !more_batches { + return Ok(()); + } + } + } + } + } + + async fn advance_buffered_key_null_free(&mut self) -> Result { + let mut more_buffered_keys = self.advance_buffered_key().await?; + loop { + if more_buffered_keys && self.buffered_batches.key_any_null() { + more_buffered_keys = self.advance_buffered_key().await?; + } else { + break; + } + } + Ok(more_buffered_keys) + } + + fn compare_stream_buffer(&self) -> Result { + let stream_arrays = join_arrays( + &self.stream_batch.batch.as_ref().unwrap().batch, + &self.stream_batch.on_column, + ); + let buffer_arrays = join_arrays( + &self.buffered_batches.batches[0].batch, + &self.buffered_batches.on_column, + ); + comp_rows( + self.stream_batch.cur_row, + self.buffered_batches.key_idx.unwrap(), + &stream_arrays, + &buffer_arrays, + ) + } +} + +struct OuterMatchResult { + get_match: bool, + buffered_ended: bool, + more_output: bool, +} + +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. +#[derive(Debug)] +pub struct SortMergeJoinExec { + /// left (build) side which gets hashed + left: Arc, + /// right (probe) side which are filtered by the hash table + right: Arc, + /// Set of common columns used to join on + on: Vec<(Column, Column)>, + /// How the join is performed + join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +/// Metrics for SortMergeJoinExec +#[derive(Debug)] +struct SortMergeJoinMetrics { + /// Total time for joining probe-side batches to the build-side batches + join_time: metrics::Time, + /// Number of batches consumed by this operator + input_batches: metrics::Count, + /// Number of rows consumed by this operator + input_rows: metrics::Count, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl SortMergeJoinMetrics { + #[allow(dead_code)] + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + join_time, + input_batches, + input_rows, + output_batches, + output_rows, + } + } +} + +impl SortMergeJoinExec { + /// Tries to create a new [SortMergeJoinExec]. + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type)); + + Ok(SortMergeJoinExec { + left, + right, + on, + join_type: *join_type, + schema, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// left (build) side which gets hashed + pub fn left(&self) -> &Arc { + &self.left + } + + /// right (probe) side which are filtered by the hash table + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(Column, Column)] { + &self.on + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } +} + +#[async_trait] +impl ExecutionPlan for SortMergeJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + self.right.output_partitioning() + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 2 => Ok(Arc::new(SortMergeJoinExec::try_new( + children[0].clone(), + children[1].clone(), + self.on.clone(), + &self.join_type, + )?)), + _ => Err(DataFusionError::Internal( + "HashJoinExec wrong number of children".to_string(), + )), + } + } + + async fn execute(&self, partition: usize) -> Result { + let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); + let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + let left = self.left.execute(partition).await?; + let right = self.right.execute(partition).await?; + + let column_indices = column_indices_from_schema( + &self.join_type, + &self.left.schema(), + &self.right.schema(), + &self.schema, + )?; + + let (tx, rx): ( + Sender>, + Receiver>, + ) = tokio::sync::mpsc::channel(2); + + let left_sort = self + .left + .as_any() + .downcast_ref::() + .unwrap() + .expr() + .iter() + .map(|s| s.clone()) + .collect::>(); + let right_sort = self + .right + .as_any() + .downcast_ref::() + .unwrap() + .expr() + .iter() + .map(|s| s.clone()) + .collect::>(); + + let mut driver = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => SortMergeJoinDriver::new( + left, + right, + on_left, + on_right, + left_sort, + right_sort, + column_indices, + self.schema.clone(), + RUNTIME_ENV.clone(), + )?, + JoinType::Right => SortMergeJoinDriver::new( + right, + left, + on_right, + on_left, + right_sort, + left_sort, + column_indices, + self.schema.clone(), + RUNTIME_ENV.clone(), + )?, + }; + + match self.join_type { + JoinType::Inner => driver.inner_join_driver(&tx).await?, + JoinType::Left => driver.outer_join_driver(&tx).await?, + JoinType::Right => driver.outer_join_driver(&tx).await?, + JoinType::Full => driver.full_outer_driver(&tx).await?, + JoinType::Semi => driver.semi_join_driver(&tx).await?, + JoinType::Anti => driver.anti_join_driver(&tx).await?, + } + + let result = RecordBatchReceiverStream::create(&self.schema, rx); + + Ok(result) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + "SortMergeJoinExec: join_type={:?}, on={:?}", + self.join_type, self.on + ) + } + } + } + + fn statistics(&self) -> Statistics { + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + Statistics::default() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + assert_batches_sorted_eq, + physical_plan::{ + common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + }, + test::{build_table_i32, columns}, + }; + + use super::*; + use crate::physical_plan::PhysicalExpr; + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema().clone(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn join( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result { + SortMergeJoinExec::try_new(left, right, on, join_type) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + Ok((columns, batches)) + } + + async fn partitioned_join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let partition_count = 4; + + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + + let join = SortMergeJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + join_type, + )?; + + let columns = columns(&join.schema()); + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i).await?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_no_shared_column_names() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); + + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + /// Test where the left has 2 parts, the right with 1 part => 1 part + #[tokio::test] + async fn join_inner_one_two_parts_left() -> Result<()> { + let batch1 = build_table_i32( + ("a1", &vec![1, 2]), + ("b2", &vec![1, 2]), + ("c1", &vec![7, 8]), + ); + let batch2 = + build_table_i32(("a1", &vec![2]), ("b2", &vec![2]), ("c1", &vec![9])); + let schema = batch1.schema().clone(); + let left = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); + + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + /// Test where the left has 1 part, the right has 2 parts => 2 parts + #[tokio::test] + async fn join_inner_one_two_parts_right() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + let batch1 = build_table_i32( + ("a2", &vec![10, 20]), + ("b1", &vec![4, 6]), + ("c2", &vec![70, 80]), + ); + let batch2 = + build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); + let schema = batch1.schema().clone(); + let right = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Inner)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + // first part + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + // second part + let stream = join.execute(1).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 5 | 90 |", + "| 3 | 5 | 9 | 30 | 5 | 90 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + fn build_table_two_batches( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema().clone(); + Arc::new( + MemoryExec::try_new(&[vec![batch.clone(), batch]], schema, None).unwrap(), + ) + } + + #[tokio::test] + async fn join_left_multi_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_two_batches( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_full_multi_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + // create two identical batches for the right side + let right = build_table_two_batches( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Full).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_left_empty_right() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", right.schema()).unwrap(), + )]; + let schema = right.schema().clone(); + let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | 4 | |", + "| 2 | 5 | 8 | | 5 | |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_full_empty_right() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", right.schema()).unwrap(), + )]; + let schema = right.schema().clone(); + let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let join = join(left, right, on, &JoinType::Full).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | | |", + "| 2 | 5 | 8 | | | |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Left, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Semi)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Anti)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + partitioned_join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Full)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 52ce6d3ad311..f9b039c23445 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -17,16 +17,13 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. -pub use self::metrics::Metric; -use self::metrics::MetricsSet; -use self::{ - coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, -}; -use crate::physical_plan::expressions::{PhysicalSortExpr, SortColumn}; -use crate::{ - error::{DataFusionError, Result}, - scalar::ScalarValue, -}; +use std::fmt; +use std::fmt::{Debug, Display}; +use std::ops::Range; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{any::Any, pin::Pin}; + use arrow::array::ArrayRef; use arrow::compute::merge_sort::SortOptions; use arrow::compute::partition::lexicographical_partition_ranges; @@ -34,14 +31,23 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -pub use display::DisplayFormatType; use futures::stream::Stream; -use std::fmt; -use std::fmt::{Debug, Display}; -use std::ops::Range; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{any::Any, pin::Pin}; + +pub use display::DisplayFormatType; + +use crate::physical_plan::expressions::{PhysicalSortExpr, SortColumn}; +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +pub use self::metrics::Metric; +use self::metrics::MetricsSet; +/// Physical planner interface +pub use self::planner::PhysicalPlanner; +use self::{ + coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, +}; /// Trait for types that stream [arrow::record_batch::RecordBatch] pub trait RecordBatchStream: Stream> { @@ -86,9 +92,6 @@ impl Stream for EmptyRecordBatchStream { } } -/// Physical planner interface -pub use self::planner::PhysicalPlanner; - /// Statistics for a physical plan node /// Fields are optional and can be inexact because the sources /// sometimes provide approximate estimates for performance reasons @@ -612,7 +615,6 @@ pub mod avro; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; -pub mod cross_join; #[cfg(feature = "crypto_expressions")] pub mod crypto_expressions; pub mod csv; @@ -625,8 +627,8 @@ pub mod expressions; pub mod filter; pub mod functions; pub mod hash_aggregate; -pub mod hash_join; pub mod hash_utils; +pub mod joins; pub mod json; pub mod limit; pub mod math_expressions; @@ -638,8 +640,7 @@ pub mod projection; #[cfg(feature = "regex_expressions")] pub mod regex_expressions; pub mod repartition; -pub mod sort; -pub mod sort_preserving_merge; +pub mod sorts; pub mod source; pub mod stream; pub mod string_expressions; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 718ebaab0f4f..3088d787d4dd 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -17,11 +17,15 @@ //! Physical query planner -use super::analyze::AnalyzeExec; -use super::{ - aggregates, empty::EmptyExec, expressions::binary, functions, - hash_join::PartitionMode, udaf, union::UnionExec, windows, -}; +use std::sync::Arc; + +use arrow::compute::cast::can_cast_types; +use arrow::compute::sort::SortOptions; +use arrow::datatypes::*; +use log::debug; + +use expressions::col; + use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ unnormalize_cols, DFSchema, Expr, LogicalPlan, Operator, @@ -29,20 +33,21 @@ use crate::logical_plan::{ UserDefinedLogicalNode, }; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; -use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions; use crate::physical_plan::expressions::{CaseExpr, Column, Literal, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; -use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::joins::cross_join::CrossJoinExec; +use crate::physical_plan::joins::hash_join::HashJoinExec; +use crate::physical_plan::joins::hash_join::PartitionMode; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sort::SortExec; +use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; -use crate::physical_plan::{hash_utils, Partitioning}; +use crate::physical_plan::{joins, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use crate::scalar::ScalarValue; use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys}; @@ -52,12 +57,11 @@ use crate::{ physical_plan::displayable, }; -use arrow::compute::cast::can_cast_types; -use arrow::compute::sort::SortOptions; -use arrow::datatypes::*; -use expressions::col; -use log::debug; -use std::sync::Arc; +use super::analyze::AnalyzeExec; +use super::{ + aggregates, empty::EmptyExec, expressions::binary, functions, udaf, union::UnionExec, + windows, +}; fn create_function_physical_name( fun: &str, @@ -677,7 +681,7 @@ impl DefaultPhysicalPlanner { Column::new(&r.name, right_df_schema.index_of_column(r)?), )) }) - .collect::>()?; + .collect::>()?; if ctx_state.config.target_partitions > 1 && ctx_state.config.repartition_joins @@ -1392,7 +1396,13 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { #[cfg(test)] mod tests { - use super::*; + use fmt::Debug; + use std::convert::TryFrom; + use std::{any::Any, fmt}; + + use arrow::datatypes::{DataType, Field}; + use async_trait::async_trait; + use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{ csv::CsvReadOptions, expressions, DisplayFormatType, Partitioning, Statistics, @@ -1402,11 +1412,8 @@ mod tests { logical_plan::{col, lit, sum, LogicalPlanBuilder}, physical_plan::SendableRecordBatchStream, }; - use arrow::datatypes::{DataType, Field}; - use async_trait::async_trait; - use fmt::Debug; - use std::convert::TryFrom; - use std::{any::Any, fmt}; + + use super::*; fn make_ctx_state() -> ExecutionContextState { ExecutionContextState::new() diff --git a/datafusion/src/physical_plan/sorts/external_sort.rs b/datafusion/src/physical_plan/sorts/external_sort.rs new file mode 100644 index 000000000000..d66b6f718338 --- /dev/null +++ b/datafusion/src/physical_plan/sorts/external_sort.rs @@ -0,0 +1,696 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the External-Sort plan + +use crate::error::{DataFusionError, Result}; +use crate::execution::memory_management::{ + MemoryConsumer, MemoryConsumerId, MemoryManager, +}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::runtime_env::RUNTIME_ENV; +use crate::physical_plan::common::{ + batch_memory_size, IPCWriterWrapper, SizedRecordBatchStream, +}; +use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use crate::physical_plan::sorts::in_mem_sort::InMemSortStream; +use crate::physical_plan::sorts::sort::sort_batch; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; +use crate::physical_plan::sorts::SpillableStream; +use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::{ + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + SendableRecordBatchStream, Statistics, +}; +use arrow::datatypes::SchemaRef; +use arrow::error::Result as ArrowResult; +use arrow::io::ipc::read::{read_file_metadata, FileReader}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use futures::lock::Mutex; +use futures::StreamExt; +use log::{error, info}; +use std::any::Any; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::fs::File; +use std::io::BufReader; +use std::sync::atomic::{AtomicBool, AtomicIsize, AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::mpsc::{Receiver as TKReceiver, Sender as TKSender}; +use tokio::task; + +struct ExternalSorter { + id: MemoryConsumerId, + schema: SchemaRef, + in_mem_batches: Mutex>, + spills: Mutex>, + /// Sort expressions + expr: Vec, + runtime: Arc, + metrics: ExecutionPlanMetricsSet, + used: AtomicIsize, + spilled_bytes: AtomicUsize, + spilled_count: AtomicUsize, + insert_finished: AtomicBool, +} + +impl ExternalSorter { + pub fn new( + partition_id: usize, + schema: SchemaRef, + expr: Vec, + runtime: Arc, + ) -> Self { + Self { + id: MemoryConsumerId::new(partition_id), + schema, + in_mem_batches: Mutex::new(vec![]), + spills: Mutex::new(vec![]), + expr, + runtime, + metrics: ExecutionPlanMetricsSet::new(), + used: AtomicIsize::new(0), + spilled_bytes: AtomicUsize::new(0), + spilled_count: AtomicUsize::new(0), + insert_finished: AtomicBool::new(false), + } + } + + pub(crate) fn finish_insert(&self) { + self.insert_finished.store(true, Ordering::SeqCst); + } + + async fn spill_while_inserting(&self) -> Result { + info!( + "{} spilling sort data of {} to disk while inserting ({} time(s) so far)", + self.str_repr(), + self.get_used(), + self.spilled_count() + ); + + let partition = self.partition_id(); + let mut in_mem_batches = self.in_mem_batches.lock().await; + // we could always get a chance to free some memory as long as we are holding some + if in_mem_batches.len() == 0 { + return Ok(0); + } + + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + let path = self.runtime.disk_manager.create_tmp_file()?; + let stream = in_mem_merge_sort( + &mut *in_mem_batches, + self.schema.clone(), + &*self.expr, + self.runtime.batch_size(), + baseline_metrics, + ) + .await; + + let total_size = spill(&mut stream?, path.clone(), self.schema.clone()).await?; + + let mut spills = self.spills.lock().await; + self.spilled_count.fetch_add(1, Ordering::SeqCst); + self.spilled_bytes.fetch_add(total_size, Ordering::SeqCst); + spills.push(path); + Ok(total_size) + } + + async fn insert_batch(&self, input: RecordBatch) -> Result<()> { + let size = batch_memory_size(&input); + self.allocate(size).await?; + // sort each batch as it's inserted, more probably to be cache-resident + let sorted_batch = sort_batch(input, self.schema.clone(), &*self.expr)?; + let mut in_mem_batches = self.in_mem_batches.lock().await; + in_mem_batches.push(sorted_batch); + Ok(()) + } + + /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`(SPMS). + /// Always put in mem batch based stream to idx 0 in SPMS so that we could spill + /// the stream when `spill()` is called on us. + async fn sort(&self) -> Result { + let partition = self.partition_id(); + let mut in_mem_batches = self.in_mem_batches.lock().await; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let mut streams: Vec = vec![]; + let in_mem_stream = in_mem_merge_sort( + &mut *in_mem_batches, + self.schema.clone(), + &self.expr, + self.runtime.batch_size(), + baseline_metrics, + ) + .await?; + streams.push(SpillableStream::new_spillable(in_mem_stream)); + + let mut spills = self.spills.lock().await; + + for spill in spills.drain(..) { + let stream = read_spill_as_stream(spill, self.schema.clone()).await?; + streams.push(SpillableStream::new_unspillable(stream)); + } + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + Ok(Box::pin( + SortPreservingMergeStream::new_from_stream( + streams, + self.schema.clone(), + &self.expr, + self.runtime.batch_size(), + baseline_metrics, + partition, + self.runtime.clone(), + ) + .await, + )) + } +} + +impl Debug for ExternalSorter { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("ExternalSorter") + .field("id", &self.id()) + .field("memory_used", &self.get_used()) + .field("spilled_bytes", &self.spilled_bytes()) + .field("spilled_count", &self.spilled_count()) + .finish() + } +} + +#[async_trait] +impl MemoryConsumer for ExternalSorter { + fn name(&self) -> String { + "ExternalSorter".to_owned() + } + + fn id(&self) -> &MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } + + async fn spill_inner( + &self, + _size: usize, + _trigger: &MemoryConsumerId, + ) -> Result { + if !self.insert_finished.load(Ordering::SeqCst) { + let total_size = self.spill_while_inserting().await; + total_size + } else { + Ok(0) + } + } + + fn get_used(&self) -> isize { + self.used.load(Ordering::SeqCst) + } + + fn update_used(&self, delta: isize) { + self.used.fetch_add(delta, Ordering::SeqCst); + } + + fn spilled_bytes(&self) -> usize { + self.spilled_bytes.load(Ordering::SeqCst) + } + + fn spilled_bytes_add(&self, add: usize) { + self.spilled_bytes.fetch_add(add, Ordering::SeqCst); + } + + fn spilled_count(&self) -> usize { + self.spilled_count.load(Ordering::SeqCst) + } + + fn spilled_count_increment(&self) { + self.spilled_count.fetch_add(1, Ordering::SeqCst); + } +} + +/// consume the `sorted_bathes` and do in_mem_sort +async fn in_mem_merge_sort( + sorted_bathes: &mut Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + target_batch_size: usize, + baseline_metrics: BaselineMetrics, +) -> Result { + if sorted_bathes.len() == 1 { + Ok(Box::pin(SizedRecordBatchStream::new( + schema, + vec![Arc::new(sorted_bathes.pop().unwrap())], + ))) + } else { + let new = sorted_bathes.drain(..).collect(); + assert_eq!(sorted_bathes.len(), 0); + Ok(Box::pin(InMemSortStream::new( + new, + schema, + expressions, + target_batch_size, + baseline_metrics, + )?)) + } +} + +async fn spill( + in_mem_stream: &mut SendableRecordBatchStream, + path: String, + schema: SchemaRef, +) -> Result { + let (sender, receiver): ( + TKSender>, + TKReceiver>, + ) = tokio::sync::mpsc::channel(2); + while let Some(item) = in_mem_stream.next().await { + sender.send(item).await.ok(); + } + let path_clone = path.clone(); + let res = + task::spawn_blocking(move || write_sorted(receiver, path_clone, schema)).await; + match res { + Ok(r) => r, + Err(e) => Err(DataFusionError::Execution(format!( + "Error occurred while spilling {}", + e + ))), + } +} + +async fn read_spill_as_stream( + path: String, + schema: SchemaRef, +) -> Result { + let (sender, receiver): ( + TKSender>, + TKReceiver>, + ) = tokio::sync::mpsc::channel(2); + let path_clone = path.clone(); + task::spawn_blocking(move || { + if let Err(e) = read_spill(sender, path_clone) { + error!("Failure while reading spill file: {}. Error: {}", path, e); + } + }); + Ok(RecordBatchReceiverStream::create(&schema, receiver)) +} + +pub(crate) async fn convert_stream_disk_based( + in_mem_stream: &mut SendableRecordBatchStream, + path: String, + schema: SchemaRef, +) -> Result<(SendableRecordBatchStream, usize)> { + let size = spill(in_mem_stream, path.clone(), schema.clone()).await?; + read_spill_as_stream(path.clone(), schema.clone()) + .await + .map(|s| (s, size)) +} + +fn write_sorted( + mut receiver: TKReceiver>, + path: String, + schema: SchemaRef, +) -> Result { + let mut writer = IPCWriterWrapper::new(path.as_ref(), schema.as_ref())?; + while let Some(batch) = receiver.blocking_recv() { + writer.write(&batch?)?; + } + writer.finish()?; + info!( + "Spilled {} batches of total {} rows to disk, memory released {}", + writer.num_batches, writer.num_rows, writer.num_bytes + ); + Ok(writer.num_bytes as usize) +} + +fn read_spill(sender: TKSender>, path: String) -> Result<()> { + let mut file = BufReader::new(File::open(&path)?); + let file_meta = read_file_metadata(&mut file)?; + let reader = FileReader::new(&mut file, file_meta, None); + for batch in reader { + sender + .blocking_send(batch) + .map_err(|e| DataFusionError::Execution(format!("{}", e)))?; + } + Ok(()) +} + +/// Sort execution plan +#[derive(Debug)] +pub struct ExternalSortExec { + /// Input schema + input: Arc, + /// Sort expressions + expr: Vec, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Preserve partitions of input plan + preserve_partitioning: bool, +} + +impl ExternalSortExec { + /// Create a new sort execution plan + pub fn try_new( + expr: Vec, + input: Arc, + ) -> Result { + Ok(Self::new_with_partitioning(expr, input, false)) + } + + /// Create a new sort execution plan with the option to preserve + /// the partitioning of the input plan + pub fn new_with_partitioning( + expr: Vec, + input: Arc, + preserve_partitioning: bool, + ) -> Self { + Self { + expr, + input, + metrics: ExecutionPlanMetricsSet::new(), + preserve_partitioning, + } + } + + /// Input schema + pub fn input(&self) -> &Arc { + &self.input + } + + /// Sort expressions + pub fn expr(&self) -> &[PhysicalSortExpr] { + &self.expr + } +} + +#[async_trait] +impl ExecutionPlan for ExternalSortExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + /// Get the output partitioning of this plan + fn output_partitioning(&self) -> Partitioning { + if self.preserve_partitioning { + self.input.output_partitioning() + } else { + Partitioning::UnknownPartitioning(1) + } + } + + fn required_child_distribution(&self) -> Distribution { + if self.preserve_partitioning { + Distribution::UnspecifiedDistribution + } else { + Distribution::SinglePartition + } + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(ExternalSortExec::try_new( + self.expr.clone(), + children[0].clone(), + )?)), + _ => Err(DataFusionError::Internal( + "ExternalSortExec wrong number of children".to_string(), + )), + } + } + + async fn execute(&self, partition: usize) -> Result { + if !self.preserve_partitioning { + if 0 != partition { + return Err(DataFusionError::Internal(format!( + "ExternalSortExec invalid partition {}", + partition + ))); + } + + // sort needs to operate on a single partition currently + if 1 != self.input.output_partitioning().partition_count() { + return Err(DataFusionError::Internal( + "SortExec requires a single input partition".to_owned(), + )); + } + } + + let input = self.input.execute(partition).await?; + external_sort(input, partition, self.expr.clone(), RUNTIME_ENV.clone()).await + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + let expr: Vec = self.expr.iter().map(|e| e.to_string()).collect(); + write!(f, "SortExec: [{}]", expr.join(",")) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + self.input.statistics() + } +} + +/// Sort based on `ExternalSorter` +pub async fn external_sort( + mut input: SendableRecordBatchStream, + partition_id: usize, + expr: Vec, + runtime: Arc, +) -> Result { + let schema = input.schema(); + let sorter = Arc::new(ExternalSorter::new( + partition_id, + schema.clone(), + expr, + runtime.clone(), + )); + runtime.register_consumer(sorter.clone()).await; + + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch).await?; + } + + sorter.finish_insert(); + sorter.sort().await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::physical_plan::expressions::col; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::sorts::SortOptions; + use crate::physical_plan::{ + collect, + csv::{CsvExec, CsvReadOptions}, + }; + use crate::test; + use arrow::array::*; + use arrow::datatypes::*; + + #[tokio::test] + async fn test_sort() -> Result<()> { + let schema = test::aggr_test_schema(); + let partitions = 4; + let path = test::create_partitioned_csv("aggregate_test_100.csv", partitions)?; + let csv = CsvExec::try_new( + &path, + CsvReadOptions::new().schema(&schema), + None, + 1024, + None, + )?; + + let sort_exec = Arc::new(ExternalSortExec::try_new( + vec![ + // c1 string column + PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }, + // c2 uin32 column + PhysicalSortExpr { + expr: col("c2", &schema)?, + options: SortOptions::default(), + }, + // c7 uin8 column + PhysicalSortExpr { + expr: col("c7", &schema)?, + options: SortOptions::default(), + }, + ], + Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), + )?); + + let result: Vec = collect(sort_exec).await?; + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + let c1 = columns[0] + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(c1.value(0), "a"); + assert_eq!(c1.value(c1.len() - 1), "e"); + + let c2 = columns[1].as_any().downcast_ref::().unwrap(); + assert_eq!(c2.value(0), 1); + assert_eq!(c2.value(c2.len() - 1), 5,); + + let c7 = columns[6].as_any().downcast_ref::().unwrap(); + assert_eq!(c7.value(0), 15); + assert_eq!(c7.value(c7.len() - 1), 254,); + + Ok(()) + } + + #[tokio::test] + async fn test_lex_sort_by_float() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float64, true), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![ + Some(f32::NAN), + None, + None, + Some(f32::NAN), + Some(1.0_f32), + Some(1.0_f32), + Some(2.0_f32), + Some(3.0_f32), + ])), + Arc::new(Float64Array::from(vec![ + Some(200.0_f64), + Some(20.0_f64), + Some(10.0_f64), + Some(100.0_f64), + Some(f64::NAN), + None, + None, + Some(f64::NAN), + ])), + ], + )?; + + let sort_exec = Arc::new(ExternalSortExec::try_new( + vec![ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ], + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?), + )?); + + assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); + assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); + + let result: Vec = collect(sort_exec.clone()).await?; + // let metrics = sort_exec.metrics().unwrap(); + // assert!(metrics.elapsed_compute().unwrap() > 0); + // assert_eq!(metrics.output_rows().unwrap(), 8); + assert_eq!(result.len(), 1); + + let columns = result[0].columns(); + + assert_eq!(DataType::Float32, *columns[0].data_type()); + assert_eq!(DataType::Float64, *columns[1].data_type()); + + let a = columns[0].as_any().downcast_ref::().unwrap(); + let b = columns[1].as_any().downcast_ref::().unwrap(); + + // convert result to strings to allow comparing to expected result containing NaN + let result: Vec<(Option, Option)> = (0..result[0].num_rows()) + .map(|i| { + let aval = if a.is_valid(i) { + Some(a.value(i).to_string()) + } else { + None + }; + let bval = if b.is_valid(i) { + Some(b.value(i).to_string()) + } else { + None + }; + (aval, bval) + }) + .collect(); + + let expected: Vec<(Option, Option)> = vec![ + (None, Some("10".to_owned())), + (None, Some("20".to_owned())), + (Some("NaN".to_owned()), Some("100".to_owned())), + (Some("NaN".to_owned()), Some("200".to_owned())), + (Some("3".to_owned()), Some("NaN".to_owned())), + (Some("2".to_owned()), None), + (Some("1".to_owned()), Some("NaN".to_owned())), + (Some("1".to_owned()), None), + ]; + + assert_eq!(expected, result); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/sorts/in_mem_sort.rs b/datafusion/src/physical_plan/sorts/in_mem_sort.rs new file mode 100644 index 000000000000..4491db2a80f1 --- /dev/null +++ b/datafusion/src/physical_plan/sorts/in_mem_sort.rs @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::BinaryHeap; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::growable::make_growable; +use arrow::compute::sort::SortOptions; +use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use futures::Stream; + +use crate::error::Result; +use crate::physical_plan::metrics::BaselineMetrics; +use crate::physical_plan::sorts::{RowIndex, SortKeyCursor}; +use crate::physical_plan::{ + expressions::PhysicalSortExpr, PhysicalExpr, RecordBatchStream, +}; + +pub(crate) struct InMemSortStream { + /// The schema of the RecordBatches yielded by this stream + schema: SchemaRef, + /// For each input stream maintain a dequeue of SortKeyCursor + /// + /// Exhausted cursors will be popped off the front once all + /// their rows have been yielded to the output + bathes: Vec>, + /// The accumulated row indexes for the next record batch + in_progress: Vec, + /// The desired RecordBatch size to yield + target_batch_size: usize, + /// used to record execution metrics + baseline_metrics: BaselineMetrics, + /// If the stream has encountered an error + aborted: bool, + /// min heap for record comparison + min_heap: BinaryHeap, +} + +impl InMemSortStream { + pub(crate) fn new( + sorted_batches: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + target_batch_size: usize, + baseline_metrics: BaselineMetrics, + ) -> Result { + let len = sorted_batches.len(); + let mut cursors = Vec::with_capacity(len); + let mut min_heap = BinaryHeap::with_capacity(len); + + let column_expressions: Vec> = + expressions.iter().map(|x| x.expr.clone()).collect(); + + // The sort options for each expression + let sort_options: Arc> = + Arc::new(expressions.iter().map(|x| x.options).collect()); + + sorted_batches + .into_iter() + .enumerate() + .try_for_each(|(idx, batch)| { + let batch = Arc::new(batch); + let cursor = match SortKeyCursor::new( + idx, + batch.clone(), + &column_expressions, + sort_options.clone(), + ) { + Ok(cursor) => cursor, + Err(e) => return Err(e), + }; + min_heap.push(cursor); + cursors.insert(idx, batch); + Ok(()) + })?; + + Ok(Self { + schema, + bathes: cursors, + target_batch_size, + baseline_metrics, + aborted: false, + in_progress: vec![], + min_heap, + }) + } + + /// Returns the index of the next batch to pull a row from, or None + /// if all cursors for all batch are exhausted + fn next_cursor(&mut self) -> Result> { + match self.min_heap.pop() { + None => Ok(None), + Some(cursor) => Ok(Some(cursor)), + } + } + + /// Drains the in_progress row indexes, and builds a new RecordBatch from them + /// + /// Will then drop any cursors for which all rows have been yielded to the output + fn build_record_batch(&mut self) -> ArrowResult { + let columns = self + .schema + .fields() + .iter() + .enumerate() + .map(|(column_idx, _)| { + let arrays = self + .bathes + .iter() + .map(|batch| batch.column(column_idx).as_ref()) + .collect::>(); + + let mut array_data = + make_growable(&arrays, false, self.in_progress.len()); + + if self.in_progress.is_empty() { + return array_data.as_arc(); + } + + let first = &self.in_progress[0]; + let mut buffer_idx = first.stream_idx; + let mut start_row_idx = first.row_idx; + let mut end_row_idx = start_row_idx + 1; + + for row_index in self.in_progress.iter().skip(1) { + let next_buffer_idx = row_index.stream_idx; + + if next_buffer_idx == buffer_idx && row_index.row_idx == end_row_idx { + // subsequent row in same batch + end_row_idx += 1; + continue; + } + + // emit current batch of rows for current buffer + array_data.extend( + buffer_idx, + start_row_idx, + end_row_idx - start_row_idx, + ); + + // start new batch of rows + buffer_idx = next_buffer_idx; + start_row_idx = row_index.row_idx; + end_row_idx = start_row_idx + 1; + } + + // emit final batch of rows + array_data.extend(buffer_idx, start_row_idx, end_row_idx - start_row_idx); + array_data.as_arc() + }) + .collect(); + + self.in_progress.clear(); + RecordBatch::try_new(self.schema.clone(), columns) + } + + #[inline] + fn poll_next_inner( + self: &mut Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll>> { + if self.aborted { + return Poll::Ready(None); + } + + loop { + // NB timer records time taken on drop, so there are no + // calls to `timer.done()` below. + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + + match self.next_cursor() { + Ok(Some(mut cursor)) => { + let batch_idx = cursor.batch_idx; + let row_idx = cursor.advance(); + + // insert the cursor back to min_heap if the record batch is not exhausted + if !cursor.is_finished() { + self.min_heap.push(cursor); + } + + self.in_progress.push(RowIndex { + stream_idx: batch_idx, + cursor_idx: 0, + row_idx, + }); + } + Ok(None) if self.in_progress.is_empty() => return Poll::Ready(None), + Ok(None) => return Poll::Ready(Some(self.build_record_batch())), + Err(e) => { + self.aborted = true; + return Poll::Ready(Some(Err(ArrowError::External( + "".to_string(), + Box::new(e), + )))); + } + }; + + if self.in_progress.len() == self.target_batch_size { + return Poll::Ready(Some(self.build_record_batch())); + } + } + } +} + +impl Stream for InMemSortStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.poll_next_inner(cx); + self.baseline_metrics.record_poll(poll) + } +} + +impl RecordBatchStream for InMemSortStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} diff --git a/datafusion/src/physical_plan/sorts/mod.rs b/datafusion/src/physical_plan/sorts/mod.rs new file mode 100644 index 000000000000..691ffb836e68 --- /dev/null +++ b/datafusion/src/physical_plan/sorts/mod.rs @@ -0,0 +1,294 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort functionalities + +pub mod external_sort; +mod in_mem_sort; +pub mod sort; +pub mod sort_preserving_merge; + +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; +use arrow::array::ord::DynComparator; +pub use arrow::compute::sort::SortOptions; +use arrow::record_batch::RecordBatch; +use arrow::{array::ArrayRef, error::Result as ArrowResult}; +use futures::channel::mpsc; +use futures::stream::FusedStream; +use futures::Stream; +use hashbrown::HashMap; +use std::borrow::BorrowMut; +use std::cmp::Ordering; +use std::fmt::{Debug, Formatter}; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; + +/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of +/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. +/// +/// Additionally it maintains a row cursor that can be advanced through the rows +/// of the provided `RecordBatch` +/// +/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to +/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores +/// a row comparator for each other cursor that it is compared to. +struct SortKeyCursor { + columns: Vec, + cur_row: usize, + num_rows: usize, + + // An index uniquely identifying the record batch scanned by this cursor. + batch_idx: usize, + batch: Arc, + + // A collection of comparators that compare rows in this cursor's batch to + // the cursors in other batches. Other batches are uniquely identified by + // their batch_idx. + batch_comparators: RwLock>>, + sort_options: Arc>, +} + +impl std::fmt::Debug for SortKeyCursor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SortKeyCursor") + .field("columns", &self.columns) + .field("cur_row", &self.cur_row) + .field("num_rows", &self.num_rows) + .field("batch_idx", &self.batch_idx) + .field("batch", &self.batch) + .field("batch_comparators", &"") + .finish() + } +} + +impl SortKeyCursor { + fn new( + batch_idx: usize, + batch: Arc, + sort_key: &[Arc], + sort_options: Arc>, + ) -> Result { + let columns: Vec = sort_key + .iter() + .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) + .collect::>()?; + Ok(Self { + cur_row: 0, + num_rows: batch.num_rows(), + columns, + batch, + batch_idx, + batch_comparators: RwLock::new(HashMap::new()), + sort_options, + }) + } + + fn is_finished(&self) -> bool { + self.num_rows == self.cur_row + } + + fn advance(&mut self) -> usize { + assert!(!self.is_finished()); + let t = self.cur_row; + self.cur_row += 1; + t + } + + /// Compares the sort key pointed to by this instance's row cursor with that of another + fn compare(&self, other: &SortKeyCursor) -> Result { + if self.columns.len() != other.columns.len() { + return Err(DataFusionError::Internal(format!( + "SortKeyCursors had inconsistent column counts: {} vs {}", + self.columns.len(), + other.columns.len() + ))); + } + + if self.columns.len() != self.sort_options.len() { + return Err(DataFusionError::Internal(format!( + "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", + self.columns.len(), + self.sort_options.len() + ))); + } + + let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self + .columns + .iter() + .zip(other.columns.iter()) + .zip(self.sort_options.iter()) + .collect::>(); + + self.init_cmp_if_needed(other, &zipped)?; + + let map = self.batch_comparators.read().unwrap(); + let cmp = map.get(&other.batch_idx).ok_or_else(|| { + DataFusionError::Execution(format!( + "Failed to find comparator for {} cmp {}", + self.batch_idx, other.batch_idx + )) + })?; + + for (i, ((l, r), sort_options)) in zipped.iter().enumerate() { + match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { + (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), + (false, true) => return Ok(Ordering::Greater), + (true, false) if sort_options.nulls_first => { + return Ok(Ordering::Greater) + } + (true, false) => return Ok(Ordering::Less), + (false, false) => {} + (true, true) => match cmp[i](self.cur_row, other.cur_row) { + Ordering::Equal => {} + o if sort_options.descending => return Ok(o.reverse()), + o => return Ok(o), + }, + } + } + + Ok(Ordering::Equal) + } + + /// Initialize a collection of comparators for comparing + /// columnar arrays of this cursor and "other" if needed. + fn init_cmp_if_needed( + &self, + other: &SortKeyCursor, + zipped: &Vec<((&ArrayRef, &ArrayRef), &SortOptions)>, + ) -> Result<()> { + let hm = self.batch_comparators.read().unwrap(); + if !hm.contains_key(&other.batch_idx) { + drop(hm); + let mut map = self.batch_comparators.write().unwrap(); + let cmp = map + .borrow_mut() + .entry(other.batch_idx) + .or_insert_with(|| Vec::with_capacity(other.columns.len())); + + for (i, ((l, r), _)) in zipped.iter().enumerate() { + if i >= cmp.len() { + // initialise comparators + cmp.push(arrow::array::ord::build_compare(l.as_ref(), r.as_ref())?); + } + } + } + Ok(()) + } +} + +/// A `RowIndex` identifies a specific row from those buffered +/// by a `SortPreservingMergeStream` +#[derive(Debug, Clone)] +struct RowIndex { + /// The index of the stream + stream_idx: usize, + /// For sort_preserving_merge, it's the index of the cursor within the stream's VecDequeue. + /// For in_mem_sort which have only one batch for each stream, cursor_idx always 0 + cursor_idx: usize, + /// The row index + row_idx: usize, +} + +impl Ord for SortKeyCursor { + fn cmp(&self, other: &Self) -> Ordering { + other.compare(self).unwrap() + } +} + +impl PartialEq for SortKeyCursor { + fn eq(&self, other: &Self) -> bool { + other.compare(self).unwrap() == Ordering::Equal + } +} + +impl Eq for SortKeyCursor {} + +impl PartialOrd for SortKeyCursor { + fn partial_cmp(&self, other: &Self) -> Option { + other.compare(self).ok() + } +} + +pub(crate) struct SpillableStream { + pub stream: SendableRecordBatchStream, + pub spillable: bool, +} + +impl Debug for SpillableStream { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "SpillableStream {}", self.spillable) + } +} + +impl SpillableStream { + pub(crate) fn new_spillable(stream: SendableRecordBatchStream) -> Self { + Self { + stream, + spillable: true, + } + } + + pub(crate) fn new_unspillable(stream: SendableRecordBatchStream) -> Self { + Self { + stream, + spillable: false, + } + } +} + +#[derive(Debug)] +enum StreamWrapper { + Receiver(mpsc::Receiver>), + Stream(Option), +} + +impl Stream for StreamWrapper { + type Item = ArrowResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + StreamWrapper::Receiver(ref mut receiver) => Pin::new(receiver).poll_next(cx), + StreamWrapper::Stream(ref mut stream) => { + let inner = match stream { + None => return Poll::Ready(None), + Some(inner) => inner, + }; + + match Pin::new(&mut inner.stream).poll_next(cx) { + Poll::Ready(msg) => { + if msg.is_none() { + *stream = None + } + Poll::Ready(msg) + } + Poll::Pending => Poll::Pending, + } + } + } + } +} + +impl FusedStream for StreamWrapper { + fn is_terminated(&self) -> bool { + match self { + StreamWrapper::Receiver(receiver) => receiver.is_terminated(), + StreamWrapper::Stream(stream) => stream.is_none(), + } + } +} diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs similarity index 96% rename from datafusion/src/physical_plan/sort.rs rename to datafusion/src/physical_plan/sorts/sort.rs index 260d1bb2d6d3..48c72fb0026d 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -17,14 +17,14 @@ //! Defines the SORT plan -use super::metrics::{ +use super::{RecordBatchStream, SendableRecordBatchStream}; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; +use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; -use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; -use crate::error::{DataFusionError, Result}; -use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ - common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, Statistics, }; pub use arrow::compute::sort::SortOptions; use arrow::compute::{sort::lexsort_to_indices, take}; @@ -185,21 +185,18 @@ impl ExecutionPlan for SortExec { } } -fn sort_batch( +/// Sort the record batch based on `expr` and reorder based on sort result. +pub fn sort_batch( batch: RecordBatch, schema: SchemaRef, expr: &[PhysicalSortExpr], ) -> ArrowResult { - let columns = expr - .iter() - .map(|e| e.evaluate_to_sort_column(&batch)) - .collect::>>() + let columns = exprs_to_sort_columns(&batch, expr) .map_err(DataFusionError::into_arrow_external_error)?; - let columns = columns.iter().map(|x| x.into()).collect::>(); - - // sort combined record batch - // TODO: pushup the limit expression to sort - let indices = lexsort_to_indices::(&columns, None)?; + let indices = lexsort_to_indices::( + &columns.iter().map(|x| x.into()).collect::>(), + None, + )?; // reorder all rows based on sorted indices RecordBatch::try_new( @@ -240,6 +237,7 @@ impl SortStream { // combine all record batches into one for each column let combined = common::combine_batches(&batches, schema.clone())?; // sort combined record batch + // TODO: pushup the limit expression to sort let result = combined .map(|batch| sort_batch(batch, schema, &expr)) .transpose()? diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs similarity index 83% rename from datafusion/src/physical_plan/sort_preserving_merge.rs rename to datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index 311a4c9de893..955121f9c4b1 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -17,7 +17,6 @@ //! Defines the sort preserving merge plan -use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; @@ -25,8 +24,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::array::ord::DynComparator; -use arrow::array::{growable::make_growable, ord::build_compare, ArrayRef}; +use arrow::array::growable::make_growable; use arrow::compute::sort::SortOptions; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; @@ -35,15 +33,28 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; use futures::channel::mpsc; use futures::stream::FusedStream; -use futures::{Stream, StreamExt}; -use hashbrown::HashMap; +use futures::{Future, Stream, StreamExt}; use crate::error::{DataFusionError, Result}; +use crate::execution::memory_management::{ + MemoryConsumer, MemoryConsumerId, MemoryManager, +}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::runtime_env::RUNTIME_ENV; +use crate::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use crate::physical_plan::sorts::external_sort::convert_stream_disk_based; +use crate::physical_plan::sorts::{ + RowIndex, SortKeyCursor, SpillableStream, StreamWrapper, +}; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use futures::lock::Mutex; +use std::fmt::{Debug, Formatter}; /// Sort preserving merge execution plan /// @@ -159,13 +170,18 @@ impl ExecutionPlan for SortPreservingMergeExec { }) .collect(); - Ok(Box::pin(SortPreservingMergeStream::new( - streams, - self.schema(), - &self.expr, - self.target_batch_size, - baseline_metrics, - ))) + Ok(Box::pin( + SortPreservingMergeStream::new_from_receiver( + streams, + self.schema(), + &self.expr, + self.target_batch_size, + baseline_metrics, + partition, + RUNTIME_ENV.clone(), + ) + .await, + )) } } } @@ -192,153 +208,131 @@ impl ExecutionPlan for SortPreservingMergeExec { } } -/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of -/// `PhysicalExpr` that when evaluated on the `RecordBatch` yield the sort keys. -/// -/// Additionally it maintains a row cursor that can be advanced through the rows -/// of the provided `RecordBatch` -/// -/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to -/// by this row cursor, with that of another `SortKeyCursor`. A cursor stores -/// a row comparator for each other cursor that it is compared to. -struct SortKeyCursor { - columns: Vec, - cur_row: usize, - num_rows: usize, - - // An index uniquely identifying the record batch scanned by this cursor. - batch_idx: usize, - batch: RecordBatch, - - // A collection of comparators that compare rows in this cursor's batch to - // the cursors in other batches. Other batches are uniquely identified by - // their batch_idx. - batch_comparators: HashMap>, +struct MergingStreams { + /// ConsumerId + id: MemoryConsumerId, + /// The sorted input streams to merge together + pub(crate) streams: Mutex>, + /// The schema of the RecordBatches yielded by this stream + schema: SchemaRef, + /// Runtime + runtime: Arc, } -impl<'a> std::fmt::Debug for SortKeyCursor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SortKeyCursor") - .field("columns", &self.columns) - .field("cur_row", &self.cur_row) - .field("num_rows", &self.num_rows) - .field("batch_idx", &self.batch_idx) - .field("batch", &self.batch) - .field("batch_comparators", &"") +impl Debug for MergingStreams { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MergingStreams") + .field("id", &self.id()) .finish() } } -impl SortKeyCursor { - fn new( - batch_idx: usize, - batch: RecordBatch, - sort_key: &[Arc], - ) -> Result { - let columns = sort_key - .iter() - .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))) - .collect::>()?; - Ok(Self { - cur_row: 0, - num_rows: batch.num_rows(), - columns, - batch, - batch_idx, - batch_comparators: HashMap::new(), - }) +impl MergingStreams { + pub fn new( + partition: usize, + input_streams: Vec, + schema: SchemaRef, + runtime: Arc, + ) -> Self { + Self { + id: MemoryConsumerId::new(partition), + streams: Mutex::new(input_streams), + schema, + runtime, + } } - fn is_finished(&self) -> bool { - self.num_rows == self.cur_row + async fn spill_underlying_stream( + &self, + stream_idx: usize, + path: String, + ) -> Result { + let mut streams = self.streams.lock().await; + let origin_stream = &mut streams[stream_idx]; + match origin_stream { + StreamWrapper::Receiver(_) => { + return Err(DataFusionError::Execution( + "Unexpected spilling a receiver stream in SortPreservingMerge" + .to_string(), + )) + } + StreamWrapper::Stream(stream) => match stream { + None => Ok(0), + Some(ref mut stream) => { + return if stream.spillable { + let (disk_stream, spill_size) = convert_stream_disk_based( + &mut stream.stream, + path, + self.schema.clone(), + ) + .await?; + streams[stream_idx] = StreamWrapper::Stream(Some( + SpillableStream::new_unspillable(disk_stream), + )); + Ok(spill_size) + } else { + Ok(0) + } + } + }, + } } +} - fn advance(&mut self) -> usize { - assert!(!self.is_finished()); - let t = self.cur_row; - self.cur_row += 1; - t +#[async_trait] +impl MemoryConsumer for MergingStreams { + fn name(&self) -> String { + "MergingStreams".to_owned() } - /// Compares the sort key pointed to by this instance's row cursor with that of another - fn compare( - &mut self, - other: &SortKeyCursor, - options: &[SortOptions], - ) -> Result { - if self.columns.len() != other.columns.len() { - return Err(DataFusionError::Internal(format!( - "SortKeyCursors had inconsistent column counts: {} vs {}", - self.columns.len(), - other.columns.len() - ))); - } + fn id(&self) -> &MemoryConsumerId { + &self.id + } - if self.columns.len() != options.len() { - return Err(DataFusionError::Internal(format!( - "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}", - self.columns.len(), - options.len() - ))); - } + fn memory_manager(&self) -> Arc { + self.runtime.memory_manager.clone() + } - let zipped = self - .columns - .iter() - .zip(other.columns.iter()) - .zip(options.iter()); - - // Recall or initialise a collection of comparators for comparing - // columnar arrays of this cursor and "other". - let cmp = self - .batch_comparators - .entry(other.batch_idx) - .or_insert_with(|| Vec::with_capacity(other.columns.len())); - - for (i, ((l, r), sort_options)) in zipped.enumerate() { - if i >= cmp.len() { - // initialise comparators as potentially needed - cmp.push(build_compare(l.as_ref(), r.as_ref())?); - } + async fn spill_inner( + &self, + _size: usize, + _trigger: &MemoryConsumerId, + ) -> Result { + let path = self.runtime.disk_manager.create_tmp_file()?; + self.spill_underlying_stream(0, path).await + } - match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) { - (false, true) if sort_options.nulls_first => return Ok(Ordering::Less), - (false, true) => return Ok(Ordering::Greater), - (true, false) if sort_options.nulls_first => { - return Ok(Ordering::Greater) - } - (true, false) => return Ok(Ordering::Less), - (false, false) => {} - (true, true) => match cmp[i](self.cur_row, other.cur_row) { - Ordering::Equal => {} - o if sort_options.descending => return Ok(o.reverse()), - o => return Ok(o), - }, - } - } + fn get_used(&self) -> isize { + todo!() + } - Ok(Ordering::Equal) + fn update_used(&self, _delta: isize) { + todo!() } -} -/// A `RowIndex` identifies a specific row from those buffered -/// by a `SortPreservingMergeStream` -#[derive(Debug, Clone)] -struct RowIndex { - /// The index of the stream - stream_idx: usize, - /// The index of the cursor within the stream's VecDequeue - cursor_idx: usize, - /// The row index - row_idx: usize, + fn spilled_bytes(&self) -> usize { + todo!() + } + + fn spilled_bytes_add(&self, _add: usize) { + todo!() + } + + fn spilled_count(&self) -> usize { + todo!() + } + + fn spilled_count_increment(&self) { + todo!() + } } #[derive(Debug)] -struct SortPreservingMergeStream { +pub(crate) struct SortPreservingMergeStream { /// The schema of the RecordBatches yielded by this stream schema: SchemaRef, /// The sorted input streams to merge together - streams: Vec>>, + streams: Arc, /// For each input stream maintain a dequeue of SortKeyCursor /// /// Exhausted cursors will be popped off the front once all @@ -349,7 +343,7 @@ struct SortPreservingMergeStream { /// The physical expressions to sort by column_expressions: Vec>, /// The sort options for each expression - sort_options: Vec, + sort_options: Arc>, /// The desired RecordBatch size to yield target_batch_size: usize, /// used to record execution metrics @@ -362,24 +356,78 @@ struct SortPreservingMergeStream { } impl SortPreservingMergeStream { - fn new( - streams: Vec>>, + pub(crate) async fn new_from_receiver( + receivers: Vec>>, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + target_batch_size: usize, + baseline_metrics: BaselineMetrics, + partition: usize, + runtime: Arc, + ) -> Self { + let cursors = (0..receivers.len()) + .into_iter() + .map(|_| VecDeque::new()) + .collect(); + + let receivers = receivers + .into_iter() + .map(|s| StreamWrapper::Receiver(s)) + .collect(); + let streams = Arc::new(MergingStreams::new( + partition, + receivers, + schema.clone(), + runtime.clone(), + )); + runtime.register_consumer(streams.clone()).await; + + Self { + schema, + cursors, + streams, + column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), + sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), + target_batch_size, + baseline_metrics, + aborted: false, + in_progress: vec![], + next_batch_index: 0, + } + } + + pub(crate) async fn new_from_stream( + streams: Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], target_batch_size: usize, baseline_metrics: BaselineMetrics, + partition: usize, + runtime: Arc, ) -> Self { let cursors = (0..streams.len()) .into_iter() .map(|_| VecDeque::new()) .collect(); + let streams = streams + .into_iter() + .map(|s| StreamWrapper::Stream(Some(s))) + .collect::>(); + let streams = Arc::new(MergingStreams::new( + partition, + streams, + schema.clone(), + runtime.clone(), + )); + runtime.register_consumer(streams.clone()).await; + Self { schema, cursors, streams, column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), - sort_options: expressions.iter().map(|x| x.options).collect(), + sort_options: Arc::new(expressions.iter().map(|x| x.options).collect()), target_batch_size, baseline_metrics, aborted: false, @@ -403,37 +451,45 @@ impl SortPreservingMergeStream { } } - let stream = &mut self.streams[idx]; - if stream.is_terminated() { - return Poll::Ready(Ok(())); - } + let mut streams_future = self.streams.streams.lock(); - // Fetch a new input record and create a cursor from it - match futures::ready!(stream.poll_next_unpin(cx)) { - None => return Poll::Ready(Ok(())), - Some(Err(e)) => { - return Poll::Ready(Err(e)); - } - Some(Ok(batch)) => { - let cursor = match SortKeyCursor::new( - self.next_batch_index, // assign this batch an ID - batch, - &self.column_expressions, - ) { - Ok(cursor) => cursor, - Err(e) => { - return Poll::Ready(Err(ArrowError::External( - "".to_string(), - Box::new(e), - ))); + match Pin::new(&mut streams_future).poll(cx) { + Poll::Ready(mut streams) => { + let stream = &mut streams[idx]; + if stream.is_terminated() { + return Poll::Ready(Ok(())); + } + + // Fetch a new input record and create a cursor from it + match futures::ready!(stream.poll_next_unpin(cx)) { + None => return Poll::Ready(Ok(())), + Some(Err(e)) => { + return Poll::Ready(Err(e)); + } + Some(Ok(batch)) => { + let cursor = match SortKeyCursor::new( + self.next_batch_index, // assign this batch an ID + Arc::new(batch), + &self.column_expressions, + self.sort_options.clone(), + ) { + Ok(cursor) => cursor, + Err(e) => { + return Poll::Ready(Err(ArrowError::External( + "".to_string(), + Box::new(e), + ))); + } + }; + self.next_batch_index += 1; + self.cursors[idx].push_back(cursor) } - }; - self.next_batch_index += 1; - self.cursors[idx].push_back(cursor) + } + + Poll::Ready(Ok(())) } + Poll::Pending => Poll::Pending, } - - Poll::Ready(Ok(())) } /// Returns the index of the next stream to pull a row from, or None @@ -449,9 +505,7 @@ impl SortPreservingMergeStream { match min_cursor { None => min_cursor = Some((idx, candidate)), Some((_, ref mut min)) => { - if min.compare(candidate, &self.sort_options)? - == Ordering::Greater - { + if min.compare(candidate)? == Ordering::Greater { min_cursor = Some((idx, candidate)) } } @@ -658,7 +712,7 @@ mod tests { use crate::physical_plan::csv::CsvExec; use crate::physical_plan::expressions::col; use crate::physical_plan::memory::MemoryExec; - use crate::physical_plan::sort::SortExec; + use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{collect, common}; use crate::test; @@ -1215,13 +1269,16 @@ mod tests { let metrics = ExecutionPlanMetricsSet::new(); let baseline_metrics = BaselineMetrics::new(&metrics, 0); - let merge_stream = SortPreservingMergeStream::new( + let merge_stream = SortPreservingMergeStream::new_from_receiver( streams, batches.schema(), sort.as_slice(), 1024, baseline_metrics, - ); + 0, + RUNTIME_ENV.clone(), + ) + .await; let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 7b1e00196608..795a8990d8d7 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -2403,7 +2403,7 @@ async fn explain_analyze_baseline_metrics() { fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { use datafusion::physical_plan; - plan.as_any().downcast_ref::().is_some() + plan.as_any().downcast_ref::().is_some() || plan.as_any().downcast_ref::().is_some() // CoalescePartitionsExec doesn't do any work so is not included || plan.as_any().downcast_ref::().is_some()