From 6654a7df96d18ed79b6861d826c927384cdfcf93 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 4 Nov 2021 12:13:41 +0800 Subject: [PATCH 01/15] wip: sort-merge-join --- .../src/serde/physical_plan/from_proto.rs | 27 +- .../rust/core/src/serde/physical_plan/mod.rs | 5 +- .../core/src/serde/physical_plan/to_proto.rs | 4 +- ballista/rust/core/src/utils.rs | 2 +- ballista/rust/scheduler/src/planner.rs | 2 +- .../physical_optimizer/coalesce_batches.rs | 9 +- .../hash_build_probe_order.rs | 17 +- datafusion/src/physical_plan/hash_utils.rs | 145 +- .../physical_plan/{ => joins}/cross_join.rs | 26 +- .../physical_plan/{ => joins}/hash_join.rs | 44 +- datafusion/src/physical_plan/joins/mod.rs | 174 ++ .../physical_plan/joins/sort_merge_join.rs | 1802 +++++++++++++++++ datafusion/src/physical_plan/mod.rs | 46 +- datafusion/src/physical_plan/planner.rs | 49 +- 14 files changed, 2100 insertions(+), 252 deletions(-) rename datafusion/src/physical_plan/{ => joins}/cross_join.rs (99%) rename datafusion/src/physical_plan/{ => joins}/hash_join.rs (99%) create mode 100644 datafusion/src/physical_plan/joins/mod.rs create mode 100644 datafusion/src/physical_plan/joins/sort_merge_join.rs diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index daa296b31a3f..eceedb244043 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -21,15 +21,8 @@ use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; -use crate::error::BallistaError; -use crate::execution_plans::{ - ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, -}; -use crate::serde::protobuf::repartition_exec_node::PartitionMethod; -use crate::serde::protobuf::ShuffleReaderPartition; -use crate::serde::scheduler::PartitionLocation; -use crate::serde::{from_proto_binary_op, proto_error, protobuf}; -use crate::{convert_box_required, convert_required, into_required}; +use log::debug; + use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, @@ -46,7 +39,8 @@ use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunc use datafusion::physical_plan::avro::{AvroExec, AvroReadOptions}; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; -use datafusion::physical_plan::hash_join::PartitionMode; +use datafusion::physical_plan::joins::cross_join::CrossJoinExec; +use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::parquet::ParquetPartition; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; @@ -56,7 +50,6 @@ use datafusion::physical_plan::window_functions::{ use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, - cross_join::CrossJoinExec, csv::CsvExec, empty::EmptyExec, expressions::{ @@ -65,7 +58,6 @@ use datafusion::physical_plan::{ }, filter::FilterExec, functions::{self, BuiltinScalarFunction, ScalarFunctionExpr}, - hash_join::HashJoinExec, limit::{GlobalLimitExec, LocalLimitExec}, parquet::ParquetExec, projection::ProjectionExec, @@ -78,10 +70,19 @@ use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; use datafusion::prelude::CsvReadOptions; -use log::debug; use protobuf::physical_expr_node::ExprType; use protobuf::physical_plan_node::PhysicalPlanType; +use crate::error::BallistaError; +use crate::execution_plans::{ + ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec, +}; +use crate::serde::protobuf::repartition_exec_node::PartitionMethod; +use crate::serde::protobuf::ShuffleReaderPartition; +use crate::serde::scheduler::PartitionLocation; +use crate::serde::{from_proto_binary_op, proto_error, protobuf}; +use crate::{convert_box_required, convert_required, into_required}; + impl TryInto> for &protobuf::PhysicalPlanNode { type Error = BallistaError; diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index 990b92435b79..0e78349ddc33 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -22,6 +22,7 @@ pub mod to_proto; mod roundtrip_tests { use std::{convert::TryInto, sync::Arc}; + use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::{ arrow::{ compute::sort::SortOptions, @@ -34,7 +35,6 @@ mod roundtrip_tests { expressions::{Avg, Column, PhysicalSortExpr}, filter::FilterExec, hash_aggregate::{AggregateMode, HashAggregateExec}, - hash_join::{HashJoinExec, PartitionMode}, limit::{GlobalLimitExec, LocalLimitExec}, sorts::sort::SortExec, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning, @@ -43,9 +43,10 @@ mod roundtrip_tests { scalar::ScalarValue, }; + use crate::execution_plans::ShuffleWriterExec; + use super::super::super::error::Result; use super::super::protobuf; - use crate::execution_plans::ShuffleWriterExec; fn roundtrip_test(exec_plan: Arc) -> Result<()> { let proto: protobuf::PhysicalPlanNode = exec_plan.clone().try_into()?; diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 175e54550610..420c9a99e773 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -28,7 +28,6 @@ use std::{ use datafusion::logical_plan::JoinType; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::cross_join::CrossJoinExec; use datafusion::physical_plan::csv::CsvExec; use datafusion::physical_plan::expressions::{ CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr, @@ -36,7 +35,8 @@ use datafusion::physical_plan::expressions::{ use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; -use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode}; +use datafusion::physical_plan::joins::cross_join::CrossJoinExec; +use datafusion::physical_plan::joins::hash_join::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::parquet::{ParquetExec, ParquetPartition}; use datafusion::physical_plan::projection::ProjectionExec; diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 5c97de9cc588..38b0ce34d5f7 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -56,7 +56,7 @@ use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{BinaryExpr, Column, Literal}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::HashAggregateExec; -use datafusion::physical_plan::hash_join::HashJoinExec; +use datafusion::physical_plan::joins::hash_join::HashJoinExec; use datafusion::physical_plan::parquet::ParquetExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index ecbe938e30cf..23142a987214 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -251,7 +251,7 @@ mod test { use ballista_core::serde::protobuf; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; - use datafusion::physical_plan::hash_join::HashJoinExec; + use datafusion::physical_plan::joins::hash_join::HashJoinExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{ coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec, diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 9af8911062df..38624876922c 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -18,15 +18,18 @@ //! CoalesceBatches optimizer that groups batches together rows //! in bigger batches to avoid overhead with small batches -use super::optimizer::PhysicalOptimizerRule; +use std::sync::Arc; + +use crate::physical_plan::joins::hash_join::HashJoinExec; use crate::{ error::Result, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, - hash_join::HashJoinExec, repartition::RepartitionExec, + repartition::RepartitionExec, }, }; -use std::sync::Arc; + +use super::optimizer::PhysicalOptimizerRule; /// Optimizer that introduces CoalesceBatchesExec to avoid overhead with small batches pub struct CoalesceBatches {} diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs index 0b87ceb1a4e2..f82f14ec1148 100644 --- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs +++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs @@ -20,17 +20,17 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use crate::error::Result; use crate::execution::context::ExecutionConfig; use crate::logical_plan::JoinType; -use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::expressions::Column; -use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::joins::cross_join::CrossJoinExec; +use crate::physical_plan::joins::hash_join::HashJoinExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{ExecutionPlan, PhysicalExpr}; use super::optimizer::PhysicalOptimizerRule; use super::utils::optimize_children; -use crate::error::Result; /// BuildProbeOrder reorders the build and probe phase of /// hash joins. This uses the amount of rows that a datasource has. @@ -153,16 +153,15 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder { #[cfg(test)] mod tests { - use crate::{ - physical_plan::{hash_join::PartitionMode, Statistics}, - test::exec::StatisticsExec, - }; - - use super::*; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; + use crate::physical_plan::joins::hash_join::PartitionMode; + use crate::{physical_plan::Statistics, test::exec::StatisticsExec}; + + use super::*; + fn create_big_and_small() -> (Arc, Arc) { let big = Arc::new(StatisticsExec::new( Statistics { diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 494fe3f3dd5b..48cd5bcada7a 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -17,7 +17,9 @@ //! Functionality used both on logical and physical plans -use crate::error::{DataFusionError, Result}; +use std::collections::HashSet; +use std::sync::Arc; + pub use ahash::{CallHasher, RandomState}; use arrow::array::{ Array, ArrayRef, BooleanArray, DictionaryArray, DictionaryKey, Float32Array, @@ -25,89 +27,11 @@ use arrow::array::{ UInt32Array, UInt64Array, UInt8Array, Utf8Array, }; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use std::collections::HashSet; -use std::sync::Arc; +use crate::error::{DataFusionError, Result}; use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; -/// The on clause of the join, as vector of (left, right) columns. -pub type JoinOn = Vec<(Column, Column)>; -/// Reference for JoinOn. -pub type JoinOnRef<'a> = &'a [(Column, Column)]; - -/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. -/// They are valid whenever their columns' intersection equals the set `on` -pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { - let left: HashSet = left - .fields() - .iter() - .enumerate() - .map(|(idx, f)| Column::new(f.name(), idx)) - .collect(); - let right: HashSet = right - .fields() - .iter() - .enumerate() - .map(|(idx, f)| Column::new(f.name(), idx)) - .collect(); - - check_join_set_is_valid(&left, &right, on) -} - -/// Checks whether the sets left, right and on compose a valid join. -/// They are valid whenever their intersection equals the set `on` -fn check_join_set_is_valid( - left: &HashSet, - right: &HashSet, - on: &[(Column, Column)], -) -> Result<()> { - let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); - let left_missing = on_left.difference(left).collect::>(); - - let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); - let right_missing = on_right.difference(right).collect::>(); - - if !left_missing.is_empty() | !right_missing.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}", - left_missing, - right_missing, - ))); - }; - - let remaining = right - .difference(on_right) - .cloned() - .collect::>(); - - let collisions = left.intersection(&remaining).collect::>(); - - if !collisions.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.", - collisions, - ))); - }; - - Ok(()) -} - -/// Creates a schema for a join operation. -/// The fields from the left side are first -pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { - let fields: Vec = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let left_fields = left.fields().iter(); - let right_fields = right.fields().iter(); - // left then right - left_fields.chain(right_fields).cloned().collect() - } - JoinType::Semi | JoinType::Anti => left.fields().clone(), - }; - Schema::new(fields) -} - // Combines two hashes into one hash #[inline] fn combine_hashes(l: u64, r: u64) -> u64 { @@ -597,66 +521,9 @@ mod tests { use arrow::array::TryExtend; use arrow::array::{MutableDictionaryArray, MutableUtf8Array}; - use super::*; + use crate::physical_plan::joins::check_join_set_is_valid; - fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { - let left = left - .iter() - .map(|x| x.to_owned()) - .collect::>(); - let right = right - .iter() - .map(|x| x.to_owned()) - .collect::>(); - check_join_set_is_valid(&left, &right, on) - } - - #[test] - fn check_valid() -> Result<()> { - let left = vec![Column::new("a", 0), Column::new("b1", 1)]; - let right = vec![Column::new("a", 0), Column::new("b2", 1)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - check(&left, &right, on)?; - Ok(()) - } - - #[test] - fn check_not_in_right() { - let left = vec![Column::new("a", 0), Column::new("b", 1)]; - let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_not_in_left() { - let left = vec![Column::new("b", 0)]; - let right = vec![Column::new("a", 0)]; - let on = &[(Column::new("a", 0), Column::new("a", 0))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_collision() { - // column "a" would appear both in left and right - let left = vec![Column::new("a", 0), Column::new("c", 1)]; - let right = vec![Column::new("a", 0), Column::new("b", 1)]; - let on = &[(Column::new("a", 0), Column::new("b", 1))]; - - assert!(check(&left, &right, on).is_err()); - } - - #[test] - fn check_in_right() { - let left = vec![Column::new("a", 0), Column::new("c", 1)]; - let right = vec![Column::new("b", 0)]; - let on = &[(Column::new("a", 0), Column::new("b", 0))]; - - assert!(check(&left, &right, on).is_ok()); - } + use super::*; #[test] fn create_hashes_for_float_arrays() -> Result<()> { diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/joins/cross_join.rs similarity index 99% rename from datafusion/src/physical_plan/cross_join.rs rename to datafusion/src/physical_plan/joins/cross_join.rs index 0edc3cc0d1aa..e3f9b1bc36e0 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/joins/cross_join.rs @@ -18,32 +18,30 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use futures::{lock::Mutex, StreamExt}; +use std::time::Instant; use std::{any::Any, sync::Arc, task::Poll}; -use crate::physical_plan::memory::MemoryStream; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; - +use async_trait::async_trait; +use futures::{lock::Mutex, StreamExt}; use futures::{Stream, TryStreamExt}; +use log::debug; -use super::{ - coalesce_partitions::CoalescePartitionsExec, hash_utils::check_join_is_valid, - ColumnStatistics, Statistics, +use crate::physical_plan::joins::check_join_is_valid; +use crate::physical_plan::memory::MemoryStream; +use crate::physical_plan::{ + coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, +}; +use crate::physical_plan::{ + coalesce_partitions::CoalescePartitionsExec, ColumnStatistics, Statistics, }; use crate::{ error::{DataFusionError, Result}, scalar::ScalarValue, }; -use async_trait::async_trait; -use std::time::Instant; - -use super::{ - coalesce_batches::concat_batches, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, -}; -use log::debug; /// Data of the left side type JoinLeftData = RecordBatch; diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/joins/hash_join.rs similarity index 99% rename from datafusion/src/physical_plan/hash_join.rs rename to datafusion/src/physical_plan/joins/hash_join.rs index 259cba65db56..5aeb4db92068 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/joins/hash_join.rs @@ -18,46 +18,39 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. -use ahash::RandomState; - -use smallvec::{smallvec, SmallVec}; +use std::fmt; use std::sync::Arc; use std::{any::Any, usize}; use std::{time::Instant, vec}; -use async_trait::async_trait; -use futures::{Stream, StreamExt, TryStreamExt}; -use tokio::sync::Mutex; - +use ahash::RandomState; +use arrow::compute::take; use arrow::datatypes::*; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::{array::*, buffer::MutableBuffer}; - -use arrow::compute::take; - +use async_trait::async_trait; +use futures::{Stream, StreamExt, TryStreamExt}; use hashbrown::raw::RawTable; +use log::debug; +use smallvec::{smallvec, SmallVec}; +use tokio::sync::Mutex; -use super::{ - coalesce_partitions::CoalescePartitionsExec, - hash_utils::{build_join_schema, check_join_is_valid, JoinOn}, -}; -use super::{ +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; +use crate::physical_plan::coalesce_batches::concat_batches; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn}; +use crate::physical_plan::PhysicalExpr; +use crate::physical_plan::{ expressions::Column, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, }; -use super::{hash_utils::create_hashes, Statistics}; -use crate::error::{DataFusionError, Result}; -use crate::logical_plan::JoinType; - -use super::{ +use crate::physical_plan::{hash_utils::create_hashes, Statistics}; +use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; -use crate::physical_plan::coalesce_batches::concat_batches; -use crate::physical_plan::PhysicalExpr; -use log::debug; -use std::fmt; type StringArray = Utf8Array; type LargeStringArray = Utf8Array; @@ -964,6 +957,8 @@ impl Stream for HashJoinStream { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::{ assert_batches_sorted_eq, physical_plan::{ @@ -973,7 +968,6 @@ mod tests { }; use super::*; - use std::sync::Arc; fn build_table( a: (&str, &Vec), diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs new file mode 100644 index 000000000000..cfd0a640e5a5 --- /dev/null +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod cross_join; +pub mod hash_join; +pub mod sort_merge_join; + +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; +use crate::physical_plan::expressions::Column; +use arrow::datatypes::{Field, Schema}; +use std::collections::HashSet; + +/// The on clause of the join, as vector of (left, right) columns. +pub type JoinOn = Vec<(Column, Column)>; +/// Reference for JoinOn. +pub type JoinOnRef<'a> = &'a [(Column, Column)]; + +/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. +/// They are valid whenever their columns' intersection equals the set `on` +pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { + let left: HashSet = left + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + let right: HashSet = right + .fields() + .iter() + .enumerate() + .map(|(idx, f)| Column::new(f.name(), idx)) + .collect(); + + check_join_set_is_valid(&left, &right, on) +} + +/// Checks whether the sets left, right and on compose a valid join. +/// They are valid whenever their intersection equals the set `on` +pub fn check_join_set_is_valid( + left: &HashSet, + right: &HashSet, + on: &[(Column, Column)], +) -> Result<()> { + let on_left = &on.iter().map(|on| on.0.clone()).collect::>(); + let left_missing = on_left.difference(left).collect::>(); + + let on_right = &on.iter().map(|on| on.1.clone()).collect::>(); + let right_missing = on_right.difference(right).collect::>(); + + if !left_missing.is_empty() | !right_missing.is_empty() { + return Err(DataFusionError::Plan(format!( + "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}", + left_missing, + right_missing, + ))); + }; + + let remaining = right + .difference(on_right) + .cloned() + .collect::>(); + + let collisions = left.intersection(&remaining).collect::>(); + + if !collisions.is_empty() { + return Err(DataFusionError::Plan(format!( + "The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.", + collisions, + ))); + }; + + Ok(()) +} + +/// Creates a schema for a join operation. +/// The fields from the left side are first +pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema { + let fields: Vec = match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + let left_fields = left.fields().iter(); + let right_fields = right.fields().iter(); + // left then right + left_fields.chain(right_fields).cloned().collect() + } + JoinType::Semi | JoinType::Anti => left.fields().clone(), + }; + Schema::new(fields) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::TryExtend; + use arrow::array::{MutableDictionaryArray, MutableUtf8Array}; + + use crate::physical_plan::joins::check_join_set_is_valid; + + use super::*; + + fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { + let left = left + .iter() + .map(|x| x.to_owned()) + .collect::>(); + let right = right + .iter() + .map(|x| x.to_owned()) + .collect::>(); + check_join_set_is_valid(&left, &right, on) + } + + #[test] + fn check_valid() -> Result<()> { + let left = vec![Column::new("a", 0), Column::new("b1", 1)]; + let right = vec![Column::new("a", 0), Column::new("b2", 1)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + check(&left, &right, on)?; + Ok(()) + } + + #[test] + fn check_not_in_right() { + let left = vec![Column::new("a", 0), Column::new("b", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_not_in_left() { + let left = vec![Column::new("b", 0)]; + let right = vec![Column::new("a", 0)]; + let on = &[(Column::new("a", 0), Column::new("a", 0))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_collision() { + // column "a" would appear both in left and right + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("a", 0), Column::new("b", 1)]; + let on = &[(Column::new("a", 0), Column::new("b", 1))]; + + assert!(check(&left, &right, on).is_err()); + } + + #[test] + fn check_in_right() { + let left = vec![Column::new("a", 0), Column::new("c", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[(Column::new("a", 0), Column::new("b", 0))]; + + assert!(check(&left, &right, on).is_ok()); + } +} diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs new file mode 100644 index 000000000000..c0b9ad103059 --- /dev/null +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -0,0 +1,1802 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the join plan for executing partitions in parallel and then joining the results +//! into a set of partitions. + +use std::fmt; +use std::sync::Arc; +use std::{any::Any, usize}; +use std::{time::Instant, vec}; + +use ahash::RandomState; +use arrow::compute::take; +use arrow::datatypes::*; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use arrow::{array::*, buffer::MutableBuffer}; +use async_trait::async_trait; +use futures::{Stream, StreamExt, TryStreamExt}; +use hashbrown::raw::RawTable; +use log::debug; +use smallvec::{smallvec, SmallVec}; +use tokio::sync::Mutex; + +use crate::error::{DataFusionError, Result}; +use crate::logical_plan::JoinType; +use crate::physical_plan::coalesce_batches::concat_batches; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn}; +use crate::physical_plan::PhysicalExpr; +use crate::physical_plan::{ + expressions::Column, + metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, +}; +use crate::physical_plan::{hash_utils::create_hashes, Statistics}; +use crate::physical_plan::{ + DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, +}; + +type StringArray = Utf8Array; +type LargeStringArray = Utf8Array; + +// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. +// +// Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used +// to put the indices in a certain bucket. +// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the left side, +// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +// As the key is a hash value, we need to check possible hash collisions in the probe stage +// During this stage it might be the case that a row is contained the same hashmap value, +// but the values don't match. Those are checked in the [equal_rows] macro +// TODO: speed up collission check and move away from using a hashbrown HashMap +// https://github.com/apache/arrow-datafusion/issues/50 +struct JoinHashMap(RawTable<(u64, SmallVec<[u64; 1]>)>); + +impl fmt::Debug for JoinHashMap { + fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { + Ok(()) + } +} + +type JoinLeftData = Arc<(JoinHashMap, RecordBatch)>; + +/// join execution plan executes partitions in parallel and combines them into a set of +/// partitions. +#[derive(Debug)] +pub struct SortMergeJoinExec { + /// left (build) side which gets hashed + left: Arc, + /// right (probe) side which are filtered by the hash table + right: Arc, + /// Set of common columns used to join on + on: Vec<(Column, Column)>, + /// How the join is performed + join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Build-side + build_side: Arc>>, + /// Shares the `RandomState` for the hashing algorithm + random_state: RandomState, + /// Partitioning mode to use + mode: PartitionMode, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, +} + +/// Metrics for SortMergeJoinExec +#[derive(Debug)] +struct SortMergeJoinMetrics { + /// Total time for joining probe-side batches to the build-side batches + join_time: metrics::Time, + /// Number of batches consumed by this operator + input_batches: metrics::Count, + /// Number of rows consumed by this operator + input_rows: metrics::Count, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl SortMergeJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + join_time, + input_batches, + input_rows, + output_batches, + output_rows, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq)] +/// Partitioning mode to use for hash join +pub enum PartitionMode { + /// Left/right children are partitioned using the left and right keys + Partitioned, + /// Left side will collected into one partition + CollectLeft, +} + +/// Information about the index and placement (left or right) of the columns +struct ColumnIndex { + /// Index of the column + index: usize, + /// Whether the column is at the left or right side + is_left: bool, +} + +impl SortMergeJoinExec { + /// Tries to create a new [HashJoinExec]. + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + partition_mode: PartitionMode, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type)); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + Ok(SortMergeJoinExec { + left, + right, + on, + join_type: *join_type, + schema, + build_side: Arc::new(Mutex::new(None)), + random_state, + mode: partition_mode, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// left (build) side which gets hashed + pub fn left(&self) -> &Arc { + &self.left + } + + /// right (probe) side which are filtered by the hash table + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(Column, Column)] { + &self.on + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// The partitioning mode of this hash join + pub fn partition_mode(&self) -> &PartitionMode { + &self.mode + } + + /// Calculates column indices and left/right placement on input / output schemas and jointype + fn column_indices_from_schema(&self) -> ArrowResult> { + let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => (true, self.left.schema(), self.right.schema()), + JoinType::Right => (false, self.right.schema(), self.left.schema()), + }; + let mut column_indices = Vec::with_capacity(self.schema.fields().len()); + for field in self.schema.fields() { + let (is_primary, index) = match primary_schema.index_of(field.name()) { + Ok(i) => Ok((true, i)), + Err(_) => { + match secondary_schema.index_of(field.name()) { + Ok(i) => Ok((false, i)), + _ => Err(DataFusionError::Internal( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }.map_err(DataFusionError::into_arrow_external_error)?; + + let is_left = + is_primary && primary_is_left || !is_primary && !primary_is_left; + column_indices.push(ColumnIndex { index, is_left }); + } + + Ok(column_indices) + } +} + +#[async_trait] +impl ExecutionPlan for SortMergeJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + &self, + children: Vec>, + ) -> Result> { + match children.len() { + 2 => Ok(Arc::new(SortMergeJoinExec::try_new( + children[0].clone(), + children[1].clone(), + self.on.clone(), + &self.join_type, + self.mode, + )?)), + _ => Err(DataFusionError::Internal( + "HashJoinExec wrong number of children".to_string(), + )), + } + } + + fn output_partitioning(&self) -> Partitioning { + self.right.output_partitioning() + } + + async fn execute(&self, partition: usize) -> Result { + let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); + // we only want to compute the build side once for PartitionMode::CollectLeft + let left_data = { + match self.mode { + PartitionMode::CollectLeft => { + let mut build_side = self.build_side.lock().await; + + match build_side.as_ref() { + Some(stream) => stream.clone(), + None => { + let start = Instant::now(); + + // merge all left parts into a single stream + let merge = CoalescePartitionsExec::new(self.left.clone()); + let stream = merge.execute(0).await?; + + // This operation performs 2 steps at once: + // 1. creates a [JoinHashMap] of all batches from the stream + // 2. stores the batches in a vector. + let initial = (0, Vec::new()); + let (num_rows, batches) = stream + .try_fold(initial, |mut acc, batch| async { + acc.0 += batch.num_rows(); + acc.1.push(batch); + Ok(acc) + }) + .await?; + let mut hashmap = + JoinHashMap(RawTable::with_capacity(num_rows)); + let mut hashes_buffer = Vec::new(); + let mut offset = 0; + for batch in batches.iter() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + &on_left, + batch, + &mut hashmap, + offset, + &self.random_state, + &mut hashes_buffer, + )?; + offset += batch.num_rows(); + } + // Merge all batches into a single batch, so we + // can directly index into the arrays + let single_batch = + concat_batches(&self.left.schema(), &batches, num_rows)?; + + let left_side = Arc::new((hashmap, single_batch)); + + *build_side = Some(left_side.clone()); + + debug!( + "Built build-side of hash join containing {} rows in {} ms", + num_rows, + start.elapsed().as_millis() + ); + + left_side + } + } + } + PartitionMode::Partitioned => { + let start = Instant::now(); + + // Load 1 partition of left side in memory + let stream = self.left.execute(partition).await?; + + // This operation performs 2 steps at once: + // 1. creates a [JoinHashMap] of all batches from the stream + // 2. stores the batches in a vector. + let initial = (0, Vec::new()); + let (num_rows, batches) = stream + .try_fold(initial, |mut acc, batch| async { + acc.0 += batch.num_rows(); + acc.1.push(batch); + Ok(acc) + }) + .await?; + let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows)); + let mut hashes_buffer = Vec::new(); + let mut offset = 0; + for batch in batches.iter() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + &on_left, + batch, + &mut hashmap, + offset, + &self.random_state, + &mut hashes_buffer, + )?; + offset += batch.num_rows(); + } + // Merge all batches into a single batch, so we + // can directly index into the arrays + let single_batch = + concat_batches(&self.left.schema(), &batches, num_rows)?; + + let left_side = Arc::new((hashmap, single_batch)); + + debug!( + "Built build-side {} of hash join containing {} rows in {} ms", + partition, + num_rows, + start.elapsed().as_millis() + ); + + left_side + } + } + }; + + // we have the batches and the hash map with their keys. We can how create a stream + // over the right that uses this information to issue new batches. + + let right_stream = self.right.execute(partition).await?; + let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + + let column_indices = self.column_indices_from_schema()?; + let num_rows = left_data.1.num_rows(); + let visited_left_side = match self.join_type { + JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { + vec![false; num_rows] + } + JoinType::Inner | JoinType::Right => vec![], + }; + Ok(Box::pin(SortMergeJoinStream::new( + self.schema.clone(), + on_left, + on_right, + self.join_type, + left_data, + right_stream, + column_indices, + self.random_state.clone(), + visited_left_side, + SortMergeJoinMetrics::new(partition, &self.metrics), + ))) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!( + f, + "HashJoinExec: mode={:?}, join_type={:?}, on={:?}", + self.mode, self.join_type, self.on + ) + } + } + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + Statistics::default() + } +} + +/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, +/// assuming that the [RecordBatch] corresponds to the `index`th +fn update_hash( + on: &[Column], + batch: &RecordBatch, + hash_map: &mut JoinHashMap, + offset: usize, + random_state: &RandomState, + hashes_buffer: &mut Vec, +) -> Result<()> { + // evaluate the keys + let keys_values = on + .iter() + .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) + .collect::>>()?; + + // calculate the hash values + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + + // insert hashes to key of the hashmap + for (row, hash_value) in hash_values.iter().enumerate() { + let item = hash_map + .0 + .get_mut(*hash_value, |(hash, _)| *hash_value == *hash); + if let Some((_, indices)) = item { + indices.push((row + offset) as u64); + } else { + hash_map.0.insert( + *hash_value, + (*hash_value, smallvec![(row + offset) as u64]), + |(hash, _)| *hash, + ); + } + } + Ok(()) +} + +/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +struct SortMergeJoinStream { + /// Input schema + schema: Arc, + /// columns from the left + on_left: Vec, + /// columns from the right used to compute the hash + on_right: Vec, + /// type of the join + join_type: JoinType, + /// information from the left + left_data: JoinLeftData, + /// right + right: SendableRecordBatchStream, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// Random state used for hashing initialization + random_state: RandomState, + /// Keeps track of the left side rows whether they are visited + visited_left_side: Vec, // TODO: use a more memory efficient data structure, https://github.com/apache/arrow-datafusion/issues/240 + /// There is nothing to process anymore and left side is processed in case of left join + is_exhausted: bool, + /// Metrics + join_metrics: SortMergeJoinMetrics, +} + +#[allow(clippy::too_many_arguments)] +impl SortMergeJoinStream { + fn new( + schema: Arc, + on_left: Vec, + on_right: Vec, + join_type: JoinType, + left_data: JoinLeftData, + right: SendableRecordBatchStream, + column_indices: Vec, + random_state: RandomState, + visited_left_side: Vec, + join_metrics: SortMergeJoinMetrics, + ) -> Self { + SortMergeJoinStream { + schema, + on_left, + on_right, + join_type, + left_data, + right, + column_indices, + random_state, + visited_left_side, + is_exhausted: false, + join_metrics, + } + } +} + +impl RecordBatchStream for SortMergeJoinStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. +/// The resulting batch has [Schema] `schema`. +/// # Error +/// This function errors when: +/// * +fn build_batch_from_indices( + schema: &Schema, + left: &RecordBatch, + right: &RecordBatch, + left_indices: UInt64Array, + right_indices: UInt32Array, + column_indices: &[ColumnIndex], +) -> ArrowResult<(RecordBatch, UInt64Array)> { + // build the columns of the new [RecordBatch]: + // 1. pick whether the column is from the left or right + // 2. based on the pick, `take` items from the different RecordBatches + let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); + + for column_index in column_indices { + let array = if column_index.is_left { + let array = left.column(column_index.index); + take::take(array.as_ref(), &left_indices)?.into() + } else { + let array = right.column(column_index.index); + take::take(array.as_ref(), &right_indices)?.into() + }; + columns.push(array); + } + RecordBatch::try_new(Arc::new(schema.clone()), columns).map(|x| (x, left_indices)) +} + +#[allow(clippy::too_many_arguments)] +fn build_batch( + batch: &RecordBatch, + left_data: &JoinLeftData, + on_left: &[Column], + on_right: &[Column], + join_type: JoinType, + schema: &Schema, + column_indices: &[ColumnIndex], + random_state: &RandomState, +) -> ArrowResult<(RecordBatch, UInt64Array)> { + let (left_indices, right_indices) = + build_join_indexes(left_data, batch, join_type, on_left, on_right, random_state) + .unwrap(); + + if matches!(join_type, JoinType::Semi | JoinType::Anti) { + return Ok(( + RecordBatch::new_empty(Arc::new(schema.clone())), + left_indices, + )); + } + + build_batch_from_indices( + schema, + &left_data.1, + batch, + left_indices, + right_indices, + column_indices, + ) +} + +/// returns a vector with (index from left, index from right). +/// The size of this vector corresponds to the total size of a joined batch +// For a join on column A: +// left right +// batch 1 +// A B A D +// --------------- +// 1 a 3 6 +// 2 b 1 2 +// 3 c 2 4 +// batch 2 +// A B A D +// --------------- +// 1 a 5 10 +// 2 b 2 2 +// 4 d 1 1 +// indices (batch, batch_row) +// left right +// (0, 2) (0, 0) +// (0, 0) (0, 1) +// (0, 1) (0, 2) +// (1, 0) (0, 1) +// (1, 1) (0, 2) +// (0, 1) (1, 1) +// (0, 0) (1, 2) +// (1, 1) (1, 1) +// (1, 0) (1, 2) +fn build_join_indexes( + left_data: &JoinLeftData, + right: &RecordBatch, + join_type: JoinType, + left_on: &[Column], + right_on: &[Column], + random_state: &RandomState, +) -> Result<(UInt64Array, UInt32Array)> { + let keys_values = right_on + .iter() + .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows()))) + .collect::>>()?; + let left_join_values = left_on + .iter() + .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows()))) + .collect::>>()?; + let hashes_buffer = &mut vec![0; keys_values[0].len()]; + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + let left = &left_data.0; + + match join_type { + JoinType::Inner | JoinType::Semi | JoinType::Anti => { + // Using a buffer builder to avoid slower normal builder + let mut left_indices = MutableBuffer::::new(); + let mut right_indices = MutableBuffer::::new(); + + // Visit all of the right rows + for (row, hash_value) in hash_values.iter().enumerate() { + // Get the hash and find it in the build index + + // For every item on the left and right we check if it matches + // This possibly contains rows with hash collisions, + // So we have to check here whether rows are equal or not + if let Some((_, indices)) = + left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + for &i in indices { + // Check hash collisions + if equal_rows(i as usize, row, &left_join_values, &keys_values)? { + left_indices.push(i as u64); + right_indices.push(row as u32); + } + } + } + } + + Ok(( + PrimitiveArray::::from_data( + DataType::UInt64, + left_indices.into(), + None, + ), + PrimitiveArray::::from_data( + DataType::UInt32, + right_indices.into(), + None, + ), + )) + } + JoinType::Left => { + let mut left_indices = MutableBuffer::::new(); + let mut right_indices = MutableBuffer::::new(); + + // First visit all of the rows + for (row, hash_value) in hash_values.iter().enumerate() { + if let Some((_, indices)) = + left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + for &i in indices { + // Collision check + if equal_rows(i as usize, row, &left_join_values, &keys_values)? { + left_indices.push(i as u64); + right_indices.push(row as u32); + } + } + }; + } + Ok(( + PrimitiveArray::::from_data( + DataType::UInt64, + left_indices.into(), + None, + ), + PrimitiveArray::::from_data( + DataType::UInt32, + right_indices.into(), + None, + ), + )) + } + JoinType::Right | JoinType::Full => { + let mut left_indices = MutablePrimitiveArray::::new(); + let mut right_indices = MutablePrimitiveArray::::new(); + + for (row, hash_value) in hash_values.iter().enumerate() { + match left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) { + Some((_, indices)) => { + let mut no_match = true; + for &i in indices { + if equal_rows( + i as usize, + row, + &left_join_values, + &keys_values, + )? { + left_indices.push(Some(i as u64)); + right_indices.push(Some(row as u32)); + no_match = false; + } + } + // If no rows matched left, still must keep the right + // with all nulls for left + if no_match { + left_indices.push(None); + right_indices.push(Some(row as u32)); + } + } + None => { + // when no match, add the row with None for the left side + left_indices.push(None); + right_indices.push(Some(row as u32)); + } + } + } + Ok((left_indices.into(), right_indices.into())) + } + } +} + +macro_rules! equal_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => left_array.value($left) == right_array.value($right), + _ => false, + } + }}; +} + +/// Left and right row have equal values +fn equal_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result { + let mut err = None; + let res = left_arrays + .iter() + .zip(right_arrays) + .all(|(l, r)| match l.data_type() { + DataType::Null => true, + DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), + DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), + DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), + DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), + DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), + DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), + DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), + DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), + DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), + DataType::Timestamp(_, None) => { + equal_rows_elem!(Int64Array, l, r, left, right) + } + DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), + DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), + _ => { + // This is internal because we should have caught this before. + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + }); + + err.unwrap_or(Ok(res)) +} + +// Produces a batch for left-side rows that have/have not been matched during the whole join +fn produce_from_matched( + visited_left_side: &[bool], + schema: &SchemaRef, + column_indices: &[ColumnIndex], + left_data: &JoinLeftData, + unmatched: bool, +) -> ArrowResult { + // Find indices which didn't match any right row (are false) + let indices = if unmatched { + visited_left_side + .iter() + .enumerate() + .filter(|&(_, &value)| !value) + .map(|(index, _)| index as u64) + .collect::>() + } else { + // produce those that did match + visited_left_side + .iter() + .enumerate() + .filter(|&(_, &value)| value) + .map(|(index, _)| index as u64) + .collect::>() + }; + + // generate batches by taking values from the left side and generating columns filled with null on the right side + let indices = UInt64Array::from_data(DataType::UInt64, indices.into(), None); + + let num_rows = indices.len(); + let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); + for (idx, column_index) in column_indices.iter().enumerate() { + let array = if column_index.is_left { + let array = left_data.1.column(column_index.index); + take::take(array.as_ref(), &indices)?.into() + } else { + let datatype = schema.field(idx).data_type().clone(); + new_null_array(datatype, num_rows).into() + }; + + columns.push(array); + } + RecordBatch::try_new(schema.clone(), columns) +} + +impl Stream for SortMergeJoinStream { + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.right + .poll_next_unpin(cx) + .map(|maybe_batch| match maybe_batch { + Some(Ok(batch)) => { + let timer = self.join_metrics.join_time.timer(); + let result = build_batch( + &batch, + &self.left_data, + &self.on_left, + &self.on_right, + self.join_type, + &self.schema, + &self.column_indices, + &self.random_state, + ); + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + if let Ok((ref batch, ref left_side)) = result { + timer.done(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + + match self.join_type { + JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => { + left_side.iter().flatten().for_each(|x| { + self.visited_left_side[*x as usize] = true; + }); + } + JoinType::Inner | JoinType::Right => {} + } + } + Some(result.map(|x| x.0)) + } + other => { + let timer = self.join_metrics.join_time.timer(); + // For the left join, produce rows for unmatched rows + match self.join_type { + JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti + if !self.is_exhausted => + { + let result = produce_from_matched( + &self.visited_left_side, + &self.schema, + &self.column_indices, + &self.left_data, + self.join_type != JoinType::Semi, + ); + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + if let Ok(ref batch) = result { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + } + timer.done(); + self.is_exhausted = true; + return Some(result); + } + JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti + | JoinType::Inner + | JoinType::Right => {} + } + + other + } + }) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::{ + assert_batches_sorted_eq, + physical_plan::{ + common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, + }, + test::{build_table_i32, columns}, + }; + + use super::*; + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema().clone(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn join( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result { + SortMergeJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let join = join(left, right, on, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + Ok((columns, batches)) + } + + async fn partitioned_join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + ) -> Result<(Vec, Vec)> { + let partition_count = 4; + + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| { + ( + Arc::new(l.clone()) as Arc, + Arc::new(r.clone()) as Arc, + ) + }) + .unzip(); + + let join = SortMergeJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + join_type, + PartitionMode::Partitioned, + )?; + + let columns = columns(&join.schema()); + + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i).await?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_inner_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Inner, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_one_no_shared_column_names() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); + + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + /// Test where the left has 2 parts, the right with 1 part => 1 part + #[tokio::test] + async fn join_inner_one_two_parts_left() -> Result<()> { + let batch1 = build_table_i32( + ("a1", &vec![1, 2]), + ("b2", &vec![1, 2]), + ("c1", &vec![7, 8]), + ); + let batch2 = + build_table_i32(("a1", &vec![2]), ("b2", &vec![2]), ("c1", &vec![9])); + let schema = batch1.schema().clone(); + let left = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![ + ( + Column::new_with_schema("a1", &left.schema())?, + Column::new_with_schema("a1", &right.schema())?, + ), + ( + Column::new_with_schema("b2", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + ), + ]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?; + + assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); + + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + /// Test where the left has 1 part, the right has 2 parts => 2 parts + #[tokio::test] + async fn join_inner_one_two_parts_right() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + + let batch1 = build_table_i32( + ("a2", &vec![10, 20]), + ("b1", &vec![4, 6]), + ("c2", &vec![70, 80]), + ); + let batch2 = + build_table_i32(("a2", &vec![30]), ("b1", &vec![5]), ("c2", &vec![90])); + let schema = batch1.schema().clone(); + let right = Arc::new( + MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(), + ); + + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Inner)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + // first part + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + // second part + let stream = join.execute(1).await?; + let batches = common::collect(stream).await?; + assert_eq!(batches.len(), 1); + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 2 | 5 | 8 | 30 | 5 | 90 |", + "| 3 | 5 | 9 | 30 | 5 | 90 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + fn build_table_two_batches( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema().clone(); + Arc::new( + MemoryExec::try_new(&[vec![batch.clone(), batch]], schema, None).unwrap(), + ) + } + + #[tokio::test] + async fn join_left_multi_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_two_batches( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_full_multi_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + // create two identical batches for the right side + let right = build_table_two_batches( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Full).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_left_empty_right() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b1", right.schema()).unwrap(), + )]; + let schema = right.schema().clone(); + let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | 4 | |", + "| 2 | 5 | 8 | | 5 | |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_full_empty_right() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![])); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", right.schema()).unwrap(), + )]; + let schema = right.schema().clone(); + let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let join = join(left, right, on, &JoinType::Full).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | | | |", + "| 2 | 5 | 8 | | | |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_left_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = partitioned_join_collect( + left.clone(), + right.clone(), + on.clone(), + &JoinType::Left, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | 7 | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Semi)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let join = join(left, right, on, &JoinType::Anti)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn partitioned_join_right_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b1", &right.schema())?, + )]; + + let (columns, batches) = + partitioned_join_collect(left, right, on, &JoinType::Right).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| | 6 | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema()).unwrap(), + Column::new_with_schema("b2", &right.schema()).unwrap(), + )]; + + let join = join(left, right, on, &JoinType::Full)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[test] + fn join_with_hash_collision() -> Result<()> { + let mut hashmap_left = RawTable::with_capacity(2); + let left = build_table_i32( + ("a", &vec![10, 20]), + ("x", &vec![100, 200]), + ("y", &vec![200, 300]), + ); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let hashes_buff = &mut vec![0; left.num_rows()]; + let hashes = + create_hashes(&[left.columns()[0].clone()], &random_state, hashes_buff)?; + + // Create hash collisions (same hashes) + hashmap_left.insert(hashes[0], (hashes[0], smallvec![0, 1]), |(h, _)| *h); + hashmap_left.insert(hashes[1], (hashes[1], smallvec![0, 1]), |(h, _)| *h); + + let right = build_table_i32( + ("a", &vec![10, 20]), + ("b", &vec![0, 0]), + ("c", &vec![30, 40]), + ); + + let left_data = JoinLeftData::new((JoinHashMap(hashmap_left), left)); + let (l, r) = build_join_indexes( + &left_data, + &right, + JoinType::Inner, + &[Column::new("a", 0)], + &[Column::new("a", 0)], + &random_state, + )?; + + let left_ids = UInt64Array::from_slice(&[0, 1]); + let right_ids = UInt32Array::from_slice(&[0, 1]); + + assert_eq!(left_ids, l); + assert_eq!(right_ids, r); + + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 3758d058a34c..f9b039c23445 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -17,16 +17,13 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. -pub use self::metrics::Metric; -use self::metrics::MetricsSet; -use self::{ - coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, -}; -use crate::physical_plan::expressions::{PhysicalSortExpr, SortColumn}; -use crate::{ - error::{DataFusionError, Result}, - scalar::ScalarValue, -}; +use std::fmt; +use std::fmt::{Debug, Display}; +use std::ops::Range; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{any::Any, pin::Pin}; + use arrow::array::ArrayRef; use arrow::compute::merge_sort::SortOptions; use arrow::compute::partition::lexicographical_partition_ranges; @@ -34,14 +31,23 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -pub use display::DisplayFormatType; use futures::stream::Stream; -use std::fmt; -use std::fmt::{Debug, Display}; -use std::ops::Range; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{any::Any, pin::Pin}; + +pub use display::DisplayFormatType; + +use crate::physical_plan::expressions::{PhysicalSortExpr, SortColumn}; +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +pub use self::metrics::Metric; +use self::metrics::MetricsSet; +/// Physical planner interface +pub use self::planner::PhysicalPlanner; +use self::{ + coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, +}; /// Trait for types that stream [arrow::record_batch::RecordBatch] pub trait RecordBatchStream: Stream> { @@ -86,9 +92,6 @@ impl Stream for EmptyRecordBatchStream { } } -/// Physical planner interface -pub use self::planner::PhysicalPlanner; - /// Statistics for a physical plan node /// Fields are optional and can be inexact because the sources /// sometimes provide approximate estimates for performance reasons @@ -612,7 +615,6 @@ pub mod avro; pub mod coalesce_batches; pub mod coalesce_partitions; pub mod common; -pub mod cross_join; #[cfg(feature = "crypto_expressions")] pub mod crypto_expressions; pub mod csv; @@ -625,8 +627,8 @@ pub mod expressions; pub mod filter; pub mod functions; pub mod hash_aggregate; -pub mod hash_join; pub mod hash_utils; +pub mod joins; pub mod json; pub mod limit; pub mod math_expressions; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index a337b886110c..d5579c1ebb24 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -17,11 +17,15 @@ //! Physical query planner -use super::analyze::AnalyzeExec; -use super::{ - aggregates, empty::EmptyExec, expressions::binary, functions, - hash_join::PartitionMode, udaf, union::UnionExec, windows, -}; +use std::sync::Arc; + +use arrow::compute::cast::can_cast_types; +use arrow::compute::sort::SortOptions; +use arrow::datatypes::*; +use log::debug; + +use expressions::col; + use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ unnormalize_cols, DFSchema, Expr, LogicalPlan, Operator, @@ -29,20 +33,21 @@ use crate::logical_plan::{ UserDefinedLogicalNode, }; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; -use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions; use crate::physical_plan::expressions::{CaseExpr, Column, Literal, PhysicalSortExpr}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; -use crate::physical_plan::hash_join::HashJoinExec; +use crate::physical_plan::joins::cross_join::CrossJoinExec; +use crate::physical_plan::joins::hash_join::HashJoinExec; +use crate::physical_plan::joins::hash_join::PartitionMode; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; -use crate::physical_plan::{hash_utils, Partitioning}; +use crate::physical_plan::{hash_utils, joins, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use crate::scalar::ScalarValue; use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys}; @@ -52,12 +57,11 @@ use crate::{ physical_plan::displayable, }; -use arrow::compute::cast::can_cast_types; -use arrow::compute::sort::SortOptions; -use arrow::datatypes::*; -use expressions::col; -use log::debug; -use std::sync::Arc; +use super::analyze::AnalyzeExec; +use super::{ + aggregates, empty::EmptyExec, expressions::binary, functions, udaf, union::UnionExec, + windows, +}; fn create_function_physical_name( fun: &str, @@ -677,7 +681,7 @@ impl DefaultPhysicalPlanner { Column::new(&r.name, right_df_schema.index_of_column(r)?), )) }) - .collect::>()?; + .collect::>()?; if ctx_state.config.target_partitions > 1 && ctx_state.config.repartition_joins @@ -1392,7 +1396,13 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { #[cfg(test)] mod tests { - use super::*; + use fmt::Debug; + use std::convert::TryFrom; + use std::{any::Any, fmt}; + + use arrow::datatypes::{DataType, Field}; + use async_trait::async_trait; + use crate::logical_plan::{DFField, DFSchema, DFSchemaRef}; use crate::physical_plan::{ csv::CsvReadOptions, expressions, DisplayFormatType, Partitioning, Statistics, @@ -1402,11 +1412,8 @@ mod tests { logical_plan::{col, lit, sum, LogicalPlanBuilder}, physical_plan::SendableRecordBatchStream, }; - use arrow::datatypes::{DataType, Field}; - use async_trait::async_trait; - use fmt::Debug; - use std::convert::TryFrom; - use std::{any::Any, fmt}; + + use super::*; fn make_ctx_state() -> ExecutionContextState { ExecutionContextState::new() From 9b05c942fcc19590b0f42f29254c48b1ea64e485 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 8 Nov 2021 01:25:26 +0800 Subject: [PATCH 02/15] wip inner --- .../src/physical_plan/expressions/mod.rs | 14 + .../physical_plan/joins/sort_merge_join.rs | 667 +++++++++++++++--- datafusion/src/physical_plan/sorts/sort.rs | 11 +- 3 files changed, 572 insertions(+), 120 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index e2e849085484..98950b9cd0fa 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -142,6 +142,20 @@ impl PhysicalSortExpr { } } +/// Convert sort expressions into Vec that can be passed into arrow sort kernel +pub fn exprs_to_sort_columns( + batch: &RecordBatch, + expr: &[PhysicalSortExpr], +) -> Result> { + let columns = expr + .iter() + .map(|e| e.evaluate_to_sort_column(&batch)) + .collect::>>() + .map_err(DataFusionError::into_arrow_external_error)?; + let columns = columns.iter().map(|x| x.into()).collect::>(); + Ok(columns) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index c0b9ad103059..34bc014d472c 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -1,4 +1,3 @@ - // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -24,24 +23,25 @@ use std::sync::Arc; use std::{any::Any, usize}; use std::{time::Instant, vec}; -use ahash::RandomState; use arrow::compute::take; use arrow::datatypes::*; -use arrow::error::Result as ArrowResult; +use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use arrow::{array::*, buffer::MutableBuffer}; use async_trait::async_trait; use futures::{Stream, StreamExt, TryStreamExt}; -use hashbrown::raw::RawTable; use log::debug; use smallvec::{smallvec, SmallVec}; use tokio::sync::Mutex; use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::JoinType; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn}; +use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::PhysicalExpr; use crate::physical_plan::{ expressions::Column, @@ -52,10 +52,532 @@ use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; +use arrow::array::growable::GrowablePrimitive; +use arrow::compute::partition::lexicographical_partition_ranges; +use arrow::compute::sort::SortOptions; +use std::cmp::Ordering; +use std::collections::VecDeque; +use std::ops::Range; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::{Receiver, Sender}; type StringArray = Utf8Array; type LargeStringArray = Utf8Array; +#[derive(Clone)] +struct PartitionedRecordBatch { + batch: RecordBatch, + ranges: Vec>, +} + +impl PartitionedRecordBatch { + fn new( + batch: Option, + expr: &[PhysicalSortExpr], + ) -> Result> { + match batch { + Some(batch) => { + let columns = exprs_to_sort_columns(&batch, expr)?; + let ranges = + lexicographical_partition_ranges(&columns)?.collect::>(); + Ok(Some(Self { batch, ranges })) + } + None => Ok(None), + } + } + + #[inline] + fn is_last_range(&self, range: &Range) -> bool { + range.end == self.batch.num_rows() + } +} + +struct StreamingBatch { + batch: Option, + cur_row: usize, + cur_range: usize, + num_rows: usize, + num_ranges: uszie, + is_new_key: bool, + on_column: Vec, + sort: Vec, +} + +impl StreamingBatch { + fn new(on_column: Vec, sort: Vec) -> Self { + Self { + batch: None, + cur_row: 0, + cur_range: 0, + num_rows: 0, + num_ranges: 0, + is_new_key: true, + on_column, + sort, + } + } + + fn rest_batch(&mut self, prb: Option) { + self.batch = prb; + if let Some(prb) = &self.batch { + self.cur_row = 0; + self.cur_range = 0; + self.num_rows = prb.batch.num_rows(); + self.num_ranges = prb.ranges.len(); + self.is_new_key = true; + }; + } + + fn key_any_null(&self) -> bool { + match &self.batch { + None => return true, + Some(batch) => { + for c in self.on_column { + let array = batch.batch.column(c.index()); + if array.is_null(self.cur_row) { + return true; + } + } + false + } + } + } + + #[inline] + fn is_finished(&self) -> bool { + self.batch.is_none() || self.num_rows == self.cur_row + 1 + } + + #[inline] + fn is_last_key_in_batch(&self) -> bool { + self.batch.is_none() || self.num_ranges == self.cur_range + 1 + } + + fn advance(&mut self) { + self.cur_row += 1; + self.is_new_key = false; + if !self.is_last_key_in_batch() { + let ranges = self.batch.unwrap().ranges; + if self.cur_row == ranges[self.cur_range + 1].start { + self.cur_range += 1; + self.is_new_key = true; + } + } + } + + fn advance_key(&mut self) { + let ranges = self.batch.unwrap().ranges; + self.cur_range += 1; + self.cur_row = ranges[self.cur_range].start; + self.is_new_key = true; + } +} + +/// Holding ranges for same key over several bathes +struct BufferedBatches { + /// batches that contains the current key + /// TODO: make this spillable as well for skew on join key at buffer side + batches: VecDeque, + /// ranges in each PartitionedRecordBatch that contains the current key + ranges: VecDeque>, + /// row index in first batch to the record that starts this batch + key_idx: Option, + /// total number of rows for the current key + row_num: usize, + /// hold found but not currently used batch, to continue iteration + next_key_batch: Option, + /// Join on column + on_column: Vec, + sort: Vec, +} + +#[inline] +fn range_len(range: &Range) -> usize { + range.end - range.start +} + +impl BufferedBatches { + fn new(on_column: Vec, sort: Vec) -> Self { + Self { + batches: VecDeque::new(), + ranges: VecDeque::new(), + key_idx: None, + row_num: 0, + next_key_batch: None, + on_column, + sort, + } + } + + fn key_any_null(&self) -> bool { + match &self.key_idx { + None => return true, + Some(key_idx) => { + let first_batch = &self.batches[0].batch; + for c in self.on_column { + let array = first_batch.column(c.index()); + if array.is_null(*key_idx) { + return true; + } + } + false + } + } + } + + fn is_finished(&self) -> Result { + match self.key_idx { + None => Ok(true), + Some(_) => match (self.batches.back(), self.ranges.back()) { + (Some(batch), Some(range)) => Ok(batch.is_last_range(range)), + _ => Err(DataFusionError::Execution(format!( + "Batches length {} not equal to ranges length {}", + self.batches.len(), + self.ranges.len() + ))), + }, + } + } + + /// Whether the running key ends at the current batch `prb`, true for continues, false for ends. + fn running_key(&mut self, prb: &PartitionedRecordBatch) -> Result { + let first_range = &prb.ranges[0]; + let range_len = range_len(first_range); + let current_batch = &prb.batch; + let single_range = prb.ranges.len() == 1; + + // compare the first record in batch with the current key pointed by key_idx + match self.key_idx { + None => { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.row_num += range_len; + Ok(single_range) + } + Some(key_idx) => { + let key_arrays = join_arrays(&self.batches[0].batch, &self.on_column); + let current_arrays = join_arrays(current_batch, &self.on_column); + let equal = equal_rows(key_idx, 0, &key_arrays, ¤t_arrays)?; + if equal { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.row_num += range_len; + Ok(single_range) + } else { + self.next_key_batch = Some(prb.clone()); + Ok(false) // running key ends + } + } + } + } + + fn cleanup(&mut self) { + self.batches.drain(..); + self.ranges.drain(..); + self.next_key_batch = None; + } + + fn reset_batch(&mut self, prb: &PartitionedRecordBatch) { + self.cleanup(); + self.batches.push_back(prb.clone()); + let first_range = &prb.ranges[0]; + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.row_num = range_len(first_range); + } + + /// Advance the cursor to the next key seen by this buffer + fn advance_in_current_batch(&mut self) { + assert_eq!(self.batches.len(), self.ranges.len()); + if self.batches.len() > 1 { + self.batches.drain(0..(self.batches.len() - 1)); + self.ranges.drain(0..(self.batches.len() - 1)); + } + + if let Some(batch) = self.batches.pop_back() { + let tail_range = self.ranges.pop_back().unwrap(); + self.batches.push_back(batch); + let next_range_idx = batch + .ranges + .iter() + .find_position(|x| x.start == tail_range.start) + .unwrap() + .0; + self.key_idx = Some(tail_range.end); + self.ranges.push_back(batch.ranges[next_range_idx].clone()); + self.row_num = range_len(&batch.ranges[next_range_idx]); + } + } +} + +fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { + on_column.iter().map(|c| rb.column(c.index())).collect() +} + +struct SMJStream { + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec, + on_buffered: Vec, + schema: Arc, + /// Information of index and left / right placement of columns + column_indices: Vec, + stream_batch: StreamingBatch, + buffered_batches: BufferedBatches, + result_sender: Sender>, + result: SendableRecordBatchStream, + runtime: Arc, +} + +impl RecordBatchStream for SMJStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for SMJStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + } +} + +fn make_mutable(data_type: &DataType, capacity: usize) -> Result> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => Box::new(MutableBooleanArray::with_capacity(capacity)) + as Box, + PhysicalType::Primitive(primitive) => { + with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }) + } + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) + as Box + } + PhysicalType::Utf8 => Box::new(MutableUtf8Array::::with_capacity(capacity)) + as Box, + _ => match data_type { + DataType::List(inner) => { + let values = make_mutable(inner.data_type(), 0)?; + Box::new(DynMutableListArray::::new_with_capacity( + values, capacity, + )) as Box + } + DataType::FixedSizeBinary(size) => Box::new( + MutableFixedSizeBinaryArray::with_capacity(*size as usize, capacity), + ) as Box, + other => { + return Err(DataFusionError::Execution(format!( + "making mutable of type {} is not implemented yet", + data_type + ))) + } + }, + }) +} + +impl SMJStream { + fn new( + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec, + on_buffered: Vec, + streamed_sort: Vec, + buffered_sort: Vec, + schema: Arc, + runtime: Arc, + ) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(2); + let column_indices = vec![]; + Self { + streamed, + buffered, + on_streamed, + on_buffered, + schema, + column_indices, + stream_batch: StreamingBatch::new(on_streamed.clone(), streamed_sort), + buffered_batches: BufferedBatches::new(on_buffered.clone(), buffered_sort), + result_sender: tx, + result: RecordBatchReceiverStream::create(&schema, rx), + runtime, + } + } + + async fn inner_join_driver(&mut self) -> Result<()> { + let targe_batch_size = self.runtime.batch_size(); + + if let Err(e) = self.result_sender.send().await { + println!("ERROR batch via inner join stream: {}", e); + }; + + let mut arrays: Vec = + Vec::with_capacity(self.schema.fields().len()); + for (idx, field) in self.schema.fields().iter().enumerate() { + match field.data_type { + DataType::Int32 => { + arrays.push(MutablePrimitiveArray::::with_capacity( + targe_batch_size, + )); + } + DataType::Utf8 => { + arrays.push(MutableUtf8Array::with_capacity(targe_batch_size)); + } + _ => {} + } + } + + while self.find_next_inner_match()? {} + + Ok(()) + } + + fn find_next_inner_match(&mut self) -> Result { + if self.stream_batch.key_any_null() { + let more_stream = self.advance_streamed_key_null_free()?; + if !more_stream { + return Ok(false); + } + } + + if self.buffered_batches.key_any_null() { + let more_buffer = self.advance_buffered_key_null_free()?; + if !more_buffer { + return Ok(false); + } + } + + loop { + let current_cmp = self.compare_stream_buffer(); + match current_cmp { + Ordering::Less => { + let more_stream = self.advance_streamed_key_null_free()?; + if !more_stream { + return Ok(false); + } + } + Ordering::Equal => return Ok(true), + Ordering::Greater => { + let more_buffer = self.advance_buffered_key_null_free()?; + if !more_buffer { + return Ok(false); + } + } + } + } + } + + /// true for has next, false for ended + fn advance_streamed(&mut self) -> Result { + if self.stream_batch.is_finished() { + self.get_stream_next()?; + Ok(!self.stream_batch.is_finished()) + } else { + self.stream_batch.advance(); + Ok(true) + } + } + + /// true for has next, false for ended + fn advance_streamed_key(&mut self) -> Result { + if self.stream_batch.is_finished() || self.stream_batch.is_last_key_in_batch() { + self.get_stream_next()?; + Ok(!self.stream_batch.is_finished()) + } else { + self.stream_batch.advance_key(); + Ok(true) + } + } + + /// true for has next, false for ended + fn advance_streamed_key_null_free(&mut self) -> Result { + let mut more_stream_keys = self.advance_streamed_key()?; + loop { + if more_stream_keys && self.stream_batch.key_any_null() { + more_stream_keys = self.advance_streamed_key()?; + } else { + break; + } + } + Ok(more_stream_keys) + } + + fn advance_buffered_key_null_free(&mut self) -> Result { + let mut more_buffered_keys = self.advance_buffered_key()?; + loop { + if more_buffered_keys && self.buffered_batches.key_any_null() { + more_buffered_keys = self.advance_buffered_key()?; + } else { + break; + } + } + Ok(more_buffered_keys) + } + + /// true for has next, false for ended + fn advance_buffered_key(&mut self) -> Result { + if self.buffered_batches.is_finished() { + match &self.buffered_batches.next_key_batch { + None => { + let batch = self.get_buffered_next()?; + match batch { + None => return Ok(false), + Some(batch) => { + self.buffered_batches.reset_batch(&batch); + } + } + } + Some(batch) => { + self.buffered_batches.reset_batch(batch); + } + } + } else { + self.buffered_batches.advance_in_current_batch(); + } + Ok(false) + } + + /// true for has next, false for buffer side ended + fn cumulate_same_keys(&mut self) -> Result { + let batch = self.get_buffered_next()?; + match batch { + None => Ok(false), + Some(batch) => { + let more_batches = self.buffered_batches.running_key(&batch)?; + if more_batches { + self.cumulate_same_keys() + } else { + // reach end of current key, but the stream continues + Ok(true) + } + } + } + } + + fn compare_stream_buffer(&self) -> Ordering { + todo!() + } + + fn get_stream_next(&mut self) -> Result<()> { + let batch = self.streamed.next().await.transpose()?; + let prb = PartitionedRecordBatch::new(batch, &self.stream_batch.sort)?; + self.stream_batch.rest_batch(prb); + Ok(()) + } + + fn get_buffered_next(&mut self) -> Result> { + let batch = self.buffered.next().await.transpose()?; + PartitionedRecordBatch::new(batch, &self.buffered_batches.sort) + } +} + // Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. // // Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used @@ -66,7 +588,7 @@ type LargeStringArray = Utf8Array; // As the key is a hash value, we need to check possible hash collisions in the probe stage // During this stage it might be the case that a row is contained the same hashmap value, // but the values don't match. Those are checked in the [equal_rows] macro -// TODO: speed up collission check and move away from using a hashbrown HashMap +// TODO: speed up collision check and move away from using a hashbrown HashMap // https://github.com/apache/arrow-datafusion/issues/50 struct JoinHashMap(RawTable<(u64, SmallVec<[u64; 1]>)>); @@ -92,12 +614,6 @@ pub struct SortMergeJoinExec { join_type: JoinType, /// The schema once the join is applied schema: SchemaRef, - /// Build-side - build_side: Arc>>, - /// Shares the `RandomState` for the hashing algorithm - random_state: RandomState, - /// Partitioning mode to use - mode: PartitionMode, /// Execution metrics metrics: ExecutionPlanMetricsSet, } @@ -141,15 +657,6 @@ impl SortMergeJoinMetrics { } } -#[derive(Clone, Copy, Debug, PartialEq)] -/// Partitioning mode to use for hash join -pub enum PartitionMode { - /// Left/right children are partitioned using the left and right keys - Partitioned, - /// Left side will collected into one partition - CollectLeft, -} - /// Information about the index and placement (left or right) of the columns struct ColumnIndex { /// Index of the column @@ -159,7 +666,7 @@ struct ColumnIndex { } impl SortMergeJoinExec { - /// Tries to create a new [HashJoinExec]. + /// Tries to create a new [SortMergeJoinExec]. /// # Error /// This function errors when it is not possible to join the left and right sides on keys `on`. pub fn try_new( @@ -167,7 +674,6 @@ impl SortMergeJoinExec { right: Arc, on: JoinOn, join_type: &JoinType, - partition_mode: PartitionMode, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); @@ -175,17 +681,12 @@ impl SortMergeJoinExec { let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type)); - let random_state = RandomState::with_seeds(0, 0, 0, 0); - Ok(SortMergeJoinExec { left, right, on, join_type: *join_type, schema, - build_side: Arc::new(Mutex::new(None)), - random_state, - mode: partition_mode, metrics: ExecutionPlanMetricsSet::new(), }) } @@ -210,11 +711,6 @@ impl SortMergeJoinExec { &self.join_type } - /// The partitioning mode of this hash join - pub fn partition_mode(&self) -> &PartitionMode { - &self.mode - } - /// Calculates column indices and left/right placement on input / output schemas and jointype fn column_indices_from_schema(&self) -> ArrowResult> { let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { @@ -272,7 +768,6 @@ impl ExecutionPlan for SortMergeJoinExec { children[1].clone(), self.on.clone(), &self.join_type, - self.mode, )?)), _ => Err(DataFusionError::Internal( "HashJoinExec wrong number of children".to_string(), @@ -437,7 +932,7 @@ impl ExecutionPlan for SortMergeJoinExec { DisplayFormatType::Default => { write!( f, - "HashJoinExec: mode={:?}, join_type={:?}, on={:?}", + "SortMergeJoinExec: mode={:?}, join_type={:?}, on={:?}", self.mode, self.join_type, self.on ) } @@ -509,8 +1004,6 @@ struct SortMergeJoinStream { right: SendableRecordBatchStream, /// Information of index and left / right placement of columns column_indices: Vec, - /// Random state used for hashing initialization - random_state: RandomState, /// Keeps track of the left side rows whether they are visited visited_left_side: Vec, // TODO: use a more memory efficient data structure, https://github.com/apache/arrow-datafusion/issues/240 /// There is nothing to process anymore and left side is processed in case of left join @@ -529,7 +1022,6 @@ impl SortMergeJoinStream { left_data: JoinLeftData, right: SendableRecordBatchStream, column_indices: Vec, - random_state: RandomState, visited_left_side: Vec, join_metrics: SortMergeJoinMetrics, ) -> Self { @@ -541,7 +1033,6 @@ impl SortMergeJoinStream { left_data, right, column_indices, - random_state, visited_left_side, is_exhausted: false, join_metrics, @@ -595,11 +1086,9 @@ fn build_batch( join_type: JoinType, schema: &Schema, column_indices: &[ColumnIndex], - random_state: &RandomState, ) -> ArrowResult<(RecordBatch, UInt64Array)> { let (left_indices, right_indices) = - build_join_indexes(left_data, batch, join_type, on_left, on_right, random_state) - .unwrap(); + build_join_indexes(left_data, batch, join_type, on_left, on_right).unwrap(); if matches!(join_type, JoinType::Semi | JoinType::Anti) { return Ok(( @@ -651,13 +1140,12 @@ fn build_join_indexes( join_type: JoinType, left_on: &[Column], right_on: &[Column], - random_state: &RandomState, ) -> Result<(UInt64Array, UInt32Array)> { - let keys_values = right_on + let keys_values: Vec = right_on .iter() .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows()))) .collect::>>()?; - let left_join_values = left_on + let left_join_values: Vec = left_on .iter() .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows()))) .collect::>>()?; @@ -679,7 +1167,7 @@ fn build_join_indexes( // This possibly contains rows with hash collisions, // So we have to check here whether rows are equal or not if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) + left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) { for &i in indices { // Check hash collisions @@ -711,7 +1199,7 @@ fn build_join_indexes( // First visit all of the rows for (row, hash_value) in hash_values.iter().enumerate() { if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) + left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) { for &i in indices { // Collision check @@ -890,7 +1378,6 @@ impl Stream for SortMergeJoinStream { self.join_type, &self.schema, &self.column_indices, - &self.random_state, ); self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); @@ -921,27 +1408,27 @@ impl Stream for SortMergeJoinStream { | JoinType::Full | JoinType::Semi | JoinType::Anti - if !self.is_exhausted => - { - let result = produce_from_matched( - &self.visited_left_side, - &self.schema, - &self.column_indices, - &self.left_data, - self.join_type != JoinType::Semi, - ); + if !self.is_exhausted => + { + let result = produce_from_matched( + &self.visited_left_side, + &self.schema, + &self.column_indices, + &self.left_data, + self.join_type != JoinType::Semi, + ); + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); } - timer.done(); - self.is_exhausted = true; - return Some(result); } + timer.done(); + self.is_exhausted = true; + return Some(result); + } JoinType::Left | JoinType::Full | JoinType::Semi @@ -986,7 +1473,7 @@ mod tests { on: JoinOn, join_type: &JoinType, ) -> Result { - SortMergeJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft) + SortMergeJoinExec::try_new(left, right, on, join_type) } async fn join_collect( @@ -1033,7 +1520,6 @@ mod tests { )?), on, join_type, - PartitionMode::Partitioned, )?; let columns = columns(&join.schema()); @@ -1114,7 +1600,7 @@ mod tests { on.clone(), &JoinType::Inner, ) - .await?; + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1551,7 +2037,7 @@ mod tests { on.clone(), &JoinType::Left, ) - .await?; + .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); let expected = vec![ @@ -1756,47 +2242,4 @@ mod tests { Ok(()) } - - #[test] - fn join_with_hash_collision() -> Result<()> { - let mut hashmap_left = RawTable::with_capacity(2); - let left = build_table_i32( - ("a", &vec![10, 20]), - ("x", &vec![100, 200]), - ("y", &vec![200, 300]), - ); - - let random_state = RandomState::with_seeds(0, 0, 0, 0); - let hashes_buff = &mut vec![0; left.num_rows()]; - let hashes = - create_hashes(&[left.columns()[0].clone()], &random_state, hashes_buff)?; - - // Create hash collisions (same hashes) - hashmap_left.insert(hashes[0], (hashes[0], smallvec![0, 1]), |(h, _)| *h); - hashmap_left.insert(hashes[1], (hashes[1], smallvec![0, 1]), |(h, _)| *h); - - let right = build_table_i32( - ("a", &vec![10, 20]), - ("b", &vec![0, 0]), - ("c", &vec![30, 40]), - ); - - let left_data = JoinLeftData::new((JoinHashMap(hashmap_left), left)); - let (l, r) = build_join_indexes( - &left_data, - &right, - JoinType::Inner, - &[Column::new("a", 0)], - &[Column::new("a", 0)], - &random_state, - )?; - - let left_ids = UInt64Array::from_slice(&[0, 1]); - let right_ids = UInt32Array::from_slice(&[0, 1]); - - assert_eq!(left_ids, l); - assert_eq!(right_ids, r); - - Ok(()) - } } diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index b11f41d0e163..5e4c33c39472 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -19,13 +19,14 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; -use crate::physical_plan::expressions::PhysicalSortExpr; +use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; use crate::physical_plan::{ common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, Statistics, }; +use arrow::compute::sort::SortColumn; pub use arrow::compute::sort::SortOptions; use arrow::compute::{sort::lexsort_to_indices, take}; use arrow::datatypes::SchemaRef; @@ -191,13 +192,7 @@ pub fn sort_batch( schema: SchemaRef, expr: &[PhysicalSortExpr], ) -> ArrowResult { - let columns = expr - .iter() - .map(|e| e.evaluate_to_sort_column(&batch)) - .collect::>>() - .map_err(DataFusionError::into_arrow_external_error)?; - let columns = columns.iter().map(|x| x.into()).collect::>(); - + let columns = exprs_to_sort_columns(&batch, expr)?; let indices = lexsort_to_indices::(&columns, None)?; // reorder all rows based on sorted indices From 3365684b909002ff0ce3352f70d633e59fb8df97 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 8 Nov 2021 19:09:29 +0800 Subject: [PATCH 03/15] wip --- .../physical_plan/joins/sort_merge_join.rs | 223 ++++++++++++++++-- 1 file changed, 206 insertions(+), 17 deletions(-) diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 34bc014d472c..db8a136aeb79 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -19,6 +19,7 @@ //! into a set of partitions. use std::fmt; +use std::iter::repeat; use std::sync::Arc; use std::{any::Any, usize}; use std::{time::Instant, vec}; @@ -347,6 +348,29 @@ impl Stream for SMJStream { } } +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use datafusion::arrow::datatypes::PrimitiveType::*; + use datafusion::arrow::types::{days_ms, months_days_ns}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + fn make_mutable(data_type: &DataType, capacity: usize) -> Result> { Ok(match data_type.to_physical_type() { PhysicalType::Boolean => Box::new(MutableBooleanArray::with_capacity(capacity)) @@ -383,6 +407,79 @@ fn make_mutable(data_type: &DataType, capacity: usize) -> Result, + batch_size: usize, +) -> Result>> { + let arrays: Vec> = schema + .fields() + .iter() + .map(|field| { + let dt = field.data_type.to_logical_type(); + make_mutable(dt, batch_size) + }) + .collect::>()?; + Ok(arrays) +} + +fn make_batch( + schema: Arc, + mut arrays: Vec>, +) -> ArrowResult { + let columns = arrays.iter_mut().map(|array| array.as_arc()).collect(); + RecordBatch::try_new(schema, columns) +} + +macro_rules! repeat_n { + ($TO:ty, $FROM:ty, $N:expr) => {{ + let to = to.as_mut_any().downcast_mut::<$TO>().unwrap(); + let from = from.as_any().downcast_ref::<$FROM>().unwrap(); + let repeat_iter = from.slice(idx, 1).iter().flat_map(|v| repeat(v).take($N)); + to.extend_trusted_len(repeat_iter); + }}; +} + +/// extend mutable with cell in `from` at `idx` of `rows_to_output` times. +fn repeat_n_times( + rows_to_output: usize, + to: &mut Box, + from: &Arc, + idx: usize, +) { + match to.data_type().to_physical_type() { + PhysicalType::Boolean => { + repeat_n!(MutableBooleanArray, BooleanArray, rows_to_output) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, rows_to_output), + PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, rows_to_output), + PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, rows_to_output), + PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, rows_to_output), + PrimitiveType::Float32 => repeat_n!(Float32Vec, Float32Array, rows_to_output), + PrimitiveType::Float64 => repeat_n!(Float64Vec, Float64Array, rows_to_output), + _ => todo!(), + }, + PhysicalType::Utf8 => { + repeat_n!(MutableUtf8Array, Utf8Array, rows_to_output) + } + PhysicalType::Binary => { + repeat_n!(MutableBinaryArray, BinaryArray, rows_to_output) + } + PhysicalType::FixedSizeBinary => repeat_n!( + MutableFixedSizeBinaryArray, + FixedSizeBinaryArray, + rows_to_output + ), + _ => todo!(), + } +} + +struct Pos { + batch_idx: usize, + start_idx: usize, + len: usize, +} + impl SMJStream { fn new( streamed: SendableRecordBatchStream, @@ -412,30 +509,122 @@ impl SMJStream { } async fn inner_join_driver(&mut self) -> Result<()> { - let targe_batch_size = self.runtime.batch_size(); + let target_batch_size = self.runtime.batch_size(); - if let Err(e) = self.result_sender.send().await { - println!("ERROR batch via inner join stream: {}", e); - }; + let mut batch_available = target_batch_size; + let mut arrays = new_arrays(&self.schema, target_batch_size)?; + + while self.find_next_inner_match()? { + let stream_repeat = self.buffered_batches.row_num; + let stream_batch = self.stream_batch.batch.unwrap().batch; + + let buffered_batches = &self.buffered_batches.batches; + let buffered_ranges = &self.buffered_batches.ranges; + let mut idx = 0; + let mut start_indices: Vec = vec![]; + let ranges: Vec = buffered_ranges + .iter() + .map(|r| { + start_indices.push(idx); + let len = range_len(r); + idx += len; + }) + .collect(); + start_indices.push(usize::MAX); + + let mut buffer_unfinished = true; + let mut buffered_idx = 0; + let mut rows_to_output = 0; + + // until all buffered row output + while buffer_unfinished { + if batch_available >= stream_repeat { + buffer_unfinished = false; + rows_to_output = stream_repeat; + batch_available -= stream_repeat; + } else { + buffered_idx += stream_repeat - batch_available; + rows_to_output = batch_available; + batch_available = 0; + } + + // output buffered start `buffered_idx`, len `rows_to_output` + // (start_pos, len) + let mut slices: Vec = vec![]; + let mut rows_remaining = rows_to_output; + let find = start_indices + .iter() + .find_position(|&&start_idx| start_idx >= buffered_idx) + .unwrap(); + let mut batch_idx = if find.1 == buffered_idx { + find.0 + } else { + find.0 - 1 + }; + + while rows_remaining > 0 { + // TODO here - let mut arrays: Vec = - Vec::with_capacity(self.schema.fields().len()); - for (idx, field) in self.schema.fields().iter().enumerate() { - match field.data_type { - DataType::Int32 => { - arrays.push(MutablePrimitiveArray::::with_capacity( - targe_batch_size, - )); + slices.push(Pos { + batch_idx, + start_idx: x, + len: l, + }) } - DataType::Utf8 => { - arrays.push(MutableUtf8Array::with_capacity(targe_batch_size)); + + arrays + .iter_mut() + .zip(self.schema.fields().iter()) + .zip(self.column_indices.iter()) + .map(|((array, field), column_index)| { + if column_index.is_left { + // repeat streamed `rows_to_output` times + let from = stream_batch.column(column_index.index); + let from_row = self.stream_batch.cur_row; + repeat_n_times(rows_to_output, array, from, from_row); + } else { + // output buffered start `buffered_idx`, len `rows_to_output` + + match to.data_type().to_physical_type() { + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => { + let to = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + for pos in slices { + let from = buffered_batches[pos.batch_idx] + .batch + .column(column_index.index) + .slice(pos.start_idx, pos.len); + let from = from + .as_any() + .downcast_ref::() + .unwrap(); + to.extend_trusted_len(from.iter()); + } + } + _ => todo!(), + }, + _ => todo!(), + } + } + }); + + if batch_available == 0 { + let result = make_batch(self.schema.clone(), arrays); + + if let Err(e) = self.result_sender.send(result).await { + println!("ERROR batch via inner join stream: {}", e); + }; + + arrays = new_arrays(&self.schema, target_batch_size)?; } - _ => {} + + rows_to_output = 0; } } - while self.find_next_inner_match()? {} - Ok(()) } From a71d2a6c964d19d95b5a3333c4edc013ff29ce1e Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 8 Nov 2021 22:13:05 +0800 Subject: [PATCH 04/15] wip --- .../physical_plan/joins/sort_merge_join.rs | 130 +++++++++++------- 1 file changed, 84 insertions(+), 46 deletions(-) diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index db8a136aeb79..211cf1e09157 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -474,6 +474,68 @@ fn repeat_n_times( } } +fn range_start_indices(buffered_ranges: &VecDeque>) -> Vec { + let mut idx = 0; + let mut start_indices: Vec = vec![]; + buffered_ranges + .iter() + .for_each(|r| { + start_indices.push(idx); + idx += range_len(r); + }) + .collect(); + start_indices.push(usize::MAX); + start_indices +} + +/// output buffered start `buffered_idx`, len `rows_to_output` +/// (start_pos, len) +fn slices_from_batches( + buffered_ranges: &&VecDeque>, + start_indices: &Vec, + buffered_idx: usize, + rows_to_output: usize, +) -> Vec { + let mut buffered_idx = buffered_idx; + let mut slices: Vec = vec![]; + let mut rows_remaining = rows_to_output; + let find = start_indices + .iter() + .find_position(|&&start_idx| start_idx >= buffered_idx) + .unwrap(); + let mut batch_idx = if find.1 == buffered_idx { + find.0 + } else { + find.0 - 1 + }; + + while rows_remaining > 0 { + let current_range = &buffered_ranges[batch_idx]; + let range_start_idx = start_indices[batch_idx]; + let start_row_idx = buffered_idx - range_start_idx + current_range.start; + let range_available = range_len(current_range) - (buffered_idx - range_start_idx); + + if range_available >= rows_remaining { + slices.push(Pos { + batch_idx, + start_idx: start_row_idx, + len: rows_remaining, + }); + rows_remaining = 0; + } else { + slices.push(Pos { + batch_idx, + start_idx: start_row_idx, + len: range_available, + }); + rows_remaining -= range_available; + batch_idx += 1; + buffered_idx += range_available; + } + } + slices +} + struct Pos { batch_idx: usize, start_idx: usize, @@ -511,66 +573,43 @@ impl SMJStream { async fn inner_join_driver(&mut self) -> Result<()> { let target_batch_size = self.runtime.batch_size(); - let mut batch_available = target_batch_size; + let mut output_slots_available = target_batch_size; let mut arrays = new_arrays(&self.schema, target_batch_size)?; while self.find_next_inner_match()? { - let stream_repeat = self.buffered_batches.row_num; + let output_total = self.buffered_batches.row_num; let stream_batch = self.stream_batch.batch.unwrap().batch; let buffered_batches = &self.buffered_batches.batches; let buffered_ranges = &self.buffered_batches.ranges; - let mut idx = 0; - let mut start_indices: Vec = vec![]; - let ranges: Vec = buffered_ranges - .iter() - .map(|r| { - start_indices.push(idx); - let len = range_len(r); - idx += len; - }) - .collect(); - start_indices.push(usize::MAX); - - let mut buffer_unfinished = true; + + let start_indices = range_start_indices(buffered_ranges); + + let mut unfinished = true; let mut buffered_idx = 0; let mut rows_to_output = 0; // until all buffered row output - while buffer_unfinished { - if batch_available >= stream_repeat { - buffer_unfinished = false; - rows_to_output = stream_repeat; - batch_available -= stream_repeat; + while unfinished { + if output_slots_available >= output_total { + unfinished = false; + rows_to_output = output_total; + output_slots_available -= output_total; } else { - buffered_idx += stream_repeat - batch_available; - rows_to_output = batch_available; - batch_available = 0; + ->>>>>>>>>>> buffered_idx += output_total - output_slots_available; + rows_to_output = output_slots_available; + output_slots_available = 0; } + // get slices for each buffered row batch for the current output // output buffered start `buffered_idx`, len `rows_to_output` // (start_pos, len) - let mut slices: Vec = vec![]; - let mut rows_remaining = rows_to_output; - let find = start_indices - .iter() - .find_position(|&&start_idx| start_idx >= buffered_idx) - .unwrap(); - let mut batch_idx = if find.1 == buffered_idx { - find.0 - } else { - find.0 - 1 - }; - - while rows_remaining > 0 { - // TODO here - - slices.push(Pos { - batch_idx, - start_idx: x, - len: l, - }) - } + let slices = slices_from_batches( + &buffered_ranges, + &start_indices, + buffered_idx, + rows_to_output, + ); arrays .iter_mut() @@ -584,7 +623,6 @@ impl SMJStream { repeat_n_times(rows_to_output, array, from, from_row); } else { // output buffered start `buffered_idx`, len `rows_to_output` - match to.data_type().to_physical_type() { PhysicalType::Primitive(primitive) => match primitive { PrimitiveType::Int8 => { @@ -611,7 +649,7 @@ impl SMJStream { } }); - if batch_available == 0 { + if output_slots_available == 0 { let result = make_batch(self.schema.clone(), arrays); if let Err(e) = self.result_sender.send(result).await { From 02de001dff5a414dc6d5153588c6aed5cad91d12 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 9 Nov 2021 12:42:16 +0800 Subject: [PATCH 05/15] inner join driver --- .../physical_plan/joins/sort_merge_join.rs | 564 ++++++------------ 1 file changed, 186 insertions(+), 378 deletions(-) diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 211cf1e09157..289fe13b86ee 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -164,6 +164,8 @@ impl StreamingBatch { self.cur_range += 1; self.is_new_key = true; } + } else { + self.batch = None; } } @@ -345,6 +347,7 @@ impl Stream for SMJStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + todo!() } } @@ -439,37 +442,84 @@ macro_rules! repeat_n { }}; } -/// extend mutable with cell in `from` at `idx` of `rows_to_output` times. -fn repeat_n_times( - rows_to_output: usize, - to: &mut Box, - from: &Arc, +/// repeat times of cell located by `idx` at streamed side to output +fn repeat_streamed_cell( + stream_batch: &RecordBatch, idx: usize, + times: usize, + to: &mut Box, + column_index: &ColumnIndex, ) { + let from = stream_batch.column(column_index.index); match to.data_type().to_physical_type() { PhysicalType::Boolean => { - repeat_n!(MutableBooleanArray, BooleanArray, rows_to_output) + repeat_n!(MutableBooleanArray, BooleanArray, times) } PhysicalType::Primitive(primitive) => match primitive { - PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, rows_to_output), - PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, rows_to_output), - PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, rows_to_output), - PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, rows_to_output), - PrimitiveType::Float32 => repeat_n!(Float32Vec, Float32Array, rows_to_output), - PrimitiveType::Float64 => repeat_n!(Float64Vec, Float64Array, rows_to_output), + PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, times), + PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, times), + PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, times), + PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, times), + PrimitiveType::Float32 => repeat_n!(Float32Vec, Float32Array, times), + PrimitiveType::Float64 => repeat_n!(Float64Vec, Float64Array, times), _ => todo!(), }, PhysicalType::Utf8 => { - repeat_n!(MutableUtf8Array, Utf8Array, rows_to_output) + repeat_n!(MutableUtf8Array, Utf8Array, times) } PhysicalType::Binary => { - repeat_n!(MutableBinaryArray, BinaryArray, rows_to_output) + repeat_n!(MutableBinaryArray, BinaryArray, times) + } + PhysicalType::FixedSizeBinary => { + repeat_n!(MutableFixedSizeBinaryArray, FixedSizeBinaryArray, times) + } + _ => todo!(), + } +} + +macro_rules! copy_slices { + ($TO:ty, $FROM:ty) => {{ + let to = array.as_mut_any().downcast_mut::<$TO>().unwrap(); + for pos in slices { + let from = buffered_batches[pos.batch_idx] + .batch + .column(column_index.index) + .slice(pos.start_idx, pos.len); + let from = from.as_any().downcast_ref::<$FROM>().unwrap(); + to.extend_trusted_len(from.iter()); + } + }}; +} + +fn copy_buffered_slices( + buffered_batches: &VecDeque, + slices: &Vec, + array: &mut Box, + column_index: &ColumnIndex, +) { + // output buffered start `buffered_idx`, len `rows_to_output` + match to.data_type().to_physical_type() { + PhysicalType::Boolean => { + copy_slices!(MutableBooleanArray, BooleanArray) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => copy_slices!(Int8Vec, Int8Array), + PrimitiveType::Int16 => copy_slices!(Int16Vec, Int16Array), + PrimitiveType::Int32 => copy_slices!(Int32Vec, Int32Array), + PrimitiveType::Int64 => copy_slices!(Int64Vec, Int64Array), + PrimitiveType::Float32 => copy_slices!(Float32Vec, Float32Array), + PrimitiveType::Float64 => copy_slices!(Float64Vec, Float64Array), + _ => todo!(), + }, + PhysicalType::Utf8 => { + copy_slices!(MutableUtf8Array, Utf8Array) + } + PhysicalType::Binary => { + copy_slices!(MutableBinaryArray, BinaryArray) + } + PhysicalType::FixedSizeBinary => { + copy_slices!(MutableFixedSizeBinaryArray, FixedSizeBinaryArray) } - PhysicalType::FixedSizeBinary => repeat_n!( - MutableFixedSizeBinaryArray, - FixedSizeBinaryArray, - rows_to_output - ), _ => todo!(), } } @@ -488,55 +538,52 @@ fn range_start_indices(buffered_ranges: &VecDeque>) -> Vec { start_indices } -/// output buffered start `buffered_idx`, len `rows_to_output` -/// (start_pos, len) +/// Locate buffered records start from `buffered_idx` of `len`gth +/// inside buffered batches. fn slices_from_batches( - buffered_ranges: &&VecDeque>, + buffered_ranges: &VecDeque>, start_indices: &Vec, buffered_idx: usize, - rows_to_output: usize, -) -> Vec { - let mut buffered_idx = buffered_idx; - let mut slices: Vec = vec![]; - let mut rows_remaining = rows_to_output; + len: usize, +) -> Vec { + let mut idx = buffered_idx; + let mut slices: Vec = vec![]; + let mut remaining = len; let find = start_indices .iter() - .find_position(|&&start_idx| start_idx >= buffered_idx) + .find_position(|&&start_idx| start_idx >= idx) .unwrap(); - let mut batch_idx = if find.1 == buffered_idx { - find.0 - } else { - find.0 - 1 - }; + let mut batch_idx = if find.1 == idx { find.0 } else { find.0 - 1 }; - while rows_remaining > 0 { + while remaining > 0 { let current_range = &buffered_ranges[batch_idx]; let range_start_idx = start_indices[batch_idx]; - let start_row_idx = buffered_idx - range_start_idx + current_range.start; - let range_available = range_len(current_range) - (buffered_idx - range_start_idx); + let start_idx = idx - range_start_idx + current_range.start; + let range_available = range_len(current_range) - (idx - range_start_idx); - if range_available >= rows_remaining { - slices.push(Pos { + if range_available >= remaining { + slices.push(Slice { batch_idx, - start_idx: start_row_idx, - len: rows_remaining, + start_idx, + len: remaining, }); - rows_remaining = 0; + remaining = 0; } else { - slices.push(Pos { + slices.push(Slice { batch_idx, - start_idx: start_row_idx, + start_idx, len: range_available, }); - rows_remaining -= range_available; + remaining -= range_available; batch_idx += 1; - buffered_idx += range_available; + idx += range_available; } } slices } -struct Pos { +/// Slice of batch at `batch_idx` inside BufferedBatches. +struct Slice { batch_idx: usize, start_idx: usize, len: usize, @@ -574,96 +621,110 @@ impl SMJStream { let target_batch_size = self.runtime.batch_size(); let mut output_slots_available = target_batch_size; - let mut arrays = new_arrays(&self.schema, target_batch_size)?; + let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; while self.find_next_inner_match()? { - let output_total = self.buffered_batches.row_num; - let stream_batch = self.stream_batch.batch.unwrap().batch; - - let buffered_batches = &self.buffered_batches.batches; - let buffered_ranges = &self.buffered_batches.ranges; - - let start_indices = range_start_indices(buffered_ranges); + loop { + let result = self + .join_eq_records( + target_batch_size, + output_slots_available, + output_arrays, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + self.stream_batch.advance(); + if self.stream_batch.is_new_key { + break; + } + } + } - let mut unfinished = true; - let mut buffered_idx = 0; - let mut rows_to_output = 0; + Ok(()) + } - // until all buffered row output - while unfinished { - if output_slots_available >= output_total { - unfinished = false; - rows_to_output = output_total; - output_slots_available -= output_total; - } else { - ->>>>>>>>>>> buffered_idx += output_total - output_slots_available; - rows_to_output = output_slots_available; - output_slots_available = 0; - } + async fn join_eq_records( + &mut self, + target_batch_size: usize, + output_slots_available: usize, + mut output_arrays: Vec>, + ) -> Result<(usize, Vec>)> { + let mut output_slots_available = output_slots_available; + let mut remaining = self.buffered_batches.row_num; + let stream_batch = &self.stream_batch.batch.unwrap().batch; + let stream_row = self.stream_batch.cur_row; + + let buffered_batches = &self.buffered_batches.batches; + let buffered_ranges = &self.buffered_batches.ranges; + + let mut unfinished = true; + let mut buffered_idx = 0; + let mut rows_to_output = 0; + let start_indices = range_start_indices(buffered_ranges); + + // output each buffered matching record once + while unfinished { + if output_slots_available >= remaining { + unfinished = false; + rows_to_output = remaining; + output_slots_available -= remaining; + remaining = 0; + } else { + rows_to_output = output_slots_available; + output_slots_available = 0; + remaining -= rows_to_output; + } - // get slices for each buffered row batch for the current output - // output buffered start `buffered_idx`, len `rows_to_output` - // (start_pos, len) - let slices = slices_from_batches( - &buffered_ranges, - &start_indices, - buffered_idx, - rows_to_output, - ); - - arrays - .iter_mut() - .zip(self.schema.fields().iter()) - .zip(self.column_indices.iter()) - .map(|((array, field), column_index)| { - if column_index.is_left { - // repeat streamed `rows_to_output` times - let from = stream_batch.column(column_index.index); - let from_row = self.stream_batch.cur_row; - repeat_n_times(rows_to_output, array, from, from_row); - } else { - // output buffered start `buffered_idx`, len `rows_to_output` - match to.data_type().to_physical_type() { - PhysicalType::Primitive(primitive) => match primitive { - PrimitiveType::Int8 => { - let to = array - .as_mut_any() - .downcast_mut::() - .unwrap(); - for pos in slices { - let from = buffered_batches[pos.batch_idx] - .batch - .column(column_index.index) - .slice(pos.start_idx, pos.len); - let from = from - .as_any() - .downcast_ref::() - .unwrap(); - to.extend_trusted_len(from.iter()); - } - } - _ => todo!(), - }, - _ => todo!(), - } - } - }); + // get slices for buffered side for the current output + let slices = slices_from_batches( + buffered_ranges, + &start_indices, + buffered_idx, + rows_to_output, + ); - if output_slots_available == 0 { - let result = make_batch(self.schema.clone(), arrays); + output_arrays + .iter_mut() + .zip(self.schema.fields().iter()) + .zip(self.column_indices.iter()) + .map(|((array, field), column_index)| { + if column_index.is_left { + // repeat streamed `rows_to_output` times + repeat_streamed_cell( + stream_batch, + stream_row, + rows_to_output, + array, + column_index, + ); + } else { + // copy buffered start from: `buffered_idx`, len: `rows_to_output` + copy_buffered_slices( + buffered_batches, + &slices, + array, + column_index, + ) + } + }); - if let Err(e) = self.result_sender.send(result).await { - println!("ERROR batch via inner join stream: {}", e); - }; + if output_slots_available == 0 { + let result = make_batch(self.schema.clone(), output_arrays); - arrays = new_arrays(&self.schema, target_batch_size)?; - } + if let Err(e) = self.result_sender.send(result).await { + println!("ERROR batch via inner join stream: {}", e); + }; - rows_to_output = 0; + output_arrays = new_arrays(&self.schema, target_batch_size)?; + output_slots_available = target_batch_size; } - } - Ok(()) + buffered_idx += rows_to_output; + rows_to_output = 0; + } + Ok((output_slots_available, output_arrays)) } fn find_next_inner_match(&mut self) -> Result { @@ -1178,43 +1239,6 @@ impl ExecutionPlan for SortMergeJoinExec { } } -/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`, -/// assuming that the [RecordBatch] corresponds to the `index`th -fn update_hash( - on: &[Column], - batch: &RecordBatch, - hash_map: &mut JoinHashMap, - offset: usize, - random_state: &RandomState, - hashes_buffer: &mut Vec, -) -> Result<()> { - // evaluate the keys - let keys_values = on - .iter() - .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows()))) - .collect::>>()?; - - // calculate the hash values - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - - // insert hashes to key of the hashmap - for (row, hash_value) in hash_values.iter().enumerate() { - let item = hash_map - .0 - .get_mut(*hash_value, |(hash, _)| *hash_value == *hash); - if let Some((_, indices)) = item { - indices.push((row + offset) as u64); - } else { - hash_map.0.insert( - *hash_value, - (*hash_value, smallvec![(row + offset) as u64]), - |(hash, _)| *hash, - ); - } - } - Ok(()) -} - /// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct SortMergeJoinStream { /// Input schema @@ -1273,222 +1297,6 @@ impl RecordBatchStream for SortMergeJoinStream { } } -/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. -/// The resulting batch has [Schema] `schema`. -/// # Error -/// This function errors when: -/// * -fn build_batch_from_indices( - schema: &Schema, - left: &RecordBatch, - right: &RecordBatch, - left_indices: UInt64Array, - right_indices: UInt32Array, - column_indices: &[ColumnIndex], -) -> ArrowResult<(RecordBatch, UInt64Array)> { - // build the columns of the new [RecordBatch]: - // 1. pick whether the column is from the left or right - // 2. based on the pick, `take` items from the different RecordBatches - let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); - - for column_index in column_indices { - let array = if column_index.is_left { - let array = left.column(column_index.index); - take::take(array.as_ref(), &left_indices)?.into() - } else { - let array = right.column(column_index.index); - take::take(array.as_ref(), &right_indices)?.into() - }; - columns.push(array); - } - RecordBatch::try_new(Arc::new(schema.clone()), columns).map(|x| (x, left_indices)) -} - -#[allow(clippy::too_many_arguments)] -fn build_batch( - batch: &RecordBatch, - left_data: &JoinLeftData, - on_left: &[Column], - on_right: &[Column], - join_type: JoinType, - schema: &Schema, - column_indices: &[ColumnIndex], -) -> ArrowResult<(RecordBatch, UInt64Array)> { - let (left_indices, right_indices) = - build_join_indexes(left_data, batch, join_type, on_left, on_right).unwrap(); - - if matches!(join_type, JoinType::Semi | JoinType::Anti) { - return Ok(( - RecordBatch::new_empty(Arc::new(schema.clone())), - left_indices, - )); - } - - build_batch_from_indices( - schema, - &left_data.1, - batch, - left_indices, - right_indices, - column_indices, - ) -} - -/// returns a vector with (index from left, index from right). -/// The size of this vector corresponds to the total size of a joined batch -// For a join on column A: -// left right -// batch 1 -// A B A D -// --------------- -// 1 a 3 6 -// 2 b 1 2 -// 3 c 2 4 -// batch 2 -// A B A D -// --------------- -// 1 a 5 10 -// 2 b 2 2 -// 4 d 1 1 -// indices (batch, batch_row) -// left right -// (0, 2) (0, 0) -// (0, 0) (0, 1) -// (0, 1) (0, 2) -// (1, 0) (0, 1) -// (1, 1) (0, 2) -// (0, 1) (1, 1) -// (0, 0) (1, 2) -// (1, 1) (1, 1) -// (1, 0) (1, 2) -fn build_join_indexes( - left_data: &JoinLeftData, - right: &RecordBatch, - join_type: JoinType, - left_on: &[Column], - right_on: &[Column], -) -> Result<(UInt64Array, UInt32Array)> { - let keys_values: Vec = right_on - .iter() - .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows()))) - .collect::>>()?; - let left_join_values: Vec = left_on - .iter() - .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows()))) - .collect::>>()?; - let hashes_buffer = &mut vec![0; keys_values[0].len()]; - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - let left = &left_data.0; - - match join_type { - JoinType::Inner | JoinType::Semi | JoinType::Anti => { - // Using a buffer builder to avoid slower normal builder - let mut left_indices = MutableBuffer::::new(); - let mut right_indices = MutableBuffer::::new(); - - // Visit all of the right rows - for (row, hash_value) in hash_values.iter().enumerate() { - // Get the hash and find it in the build index - - // For every item on the left and right we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Check hash collisions - if equal_rows(i as usize, row, &left_join_values, &keys_values)? { - left_indices.push(i as u64); - right_indices.push(row as u32); - } - } - } - } - - Ok(( - PrimitiveArray::::from_data( - DataType::UInt64, - left_indices.into(), - None, - ), - PrimitiveArray::::from_data( - DataType::UInt32, - right_indices.into(), - None, - ), - )) - } - JoinType::Left => { - let mut left_indices = MutableBuffer::::new(); - let mut right_indices = MutableBuffer::::new(); - - // First visit all of the rows - for (row, hash_value) in hash_values.iter().enumerate() { - if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Collision check - if equal_rows(i as usize, row, &left_join_values, &keys_values)? { - left_indices.push(i as u64); - right_indices.push(row as u32); - } - } - }; - } - Ok(( - PrimitiveArray::::from_data( - DataType::UInt64, - left_indices.into(), - None, - ), - PrimitiveArray::::from_data( - DataType::UInt32, - right_indices.into(), - None, - ), - )) - } - JoinType::Right | JoinType::Full => { - let mut left_indices = MutablePrimitiveArray::::new(); - let mut right_indices = MutablePrimitiveArray::::new(); - - for (row, hash_value) in hash_values.iter().enumerate() { - match left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) { - Some((_, indices)) => { - let mut no_match = true; - for &i in indices { - if equal_rows( - i as usize, - row, - &left_join_values, - &keys_values, - )? { - left_indices.push(Some(i as u64)); - right_indices.push(Some(row as u32)); - no_match = false; - } - } - // If no rows matched left, still must keep the right - // with all nulls for left - if no_match { - left_indices.push(None); - right_indices.push(Some(row as u32)); - } - } - None => { - // when no match, add the row with None for the left side - left_indices.push(None); - right_indices.push(Some(row as u32)); - } - } - } - Ok((left_indices.into(), right_indices.into())) - } - } -} - macro_rules! equal_rows_elem { ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); From 7e607ef88cd009f047259cb1eb17988e004c717c Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 9 Nov 2021 17:21:52 +0800 Subject: [PATCH 06/15] extract common --- .../src/physical_plan/joins/hash_join.rs | 41 +---- datafusion/src/physical_plan/joins/mod.rs | 38 +++++ .../physical_plan/joins/sort_merge_join.rs | 160 +----------------- 3 files changed, 54 insertions(+), 185 deletions(-) diff --git a/datafusion/src/physical_plan/joins/hash_join.rs b/datafusion/src/physical_plan/joins/hash_join.rs index 5aeb4db92068..6578af25f8fe 100644 --- a/datafusion/src/physical_plan/joins/hash_join.rs +++ b/datafusion/src/physical_plan/joins/hash_join.rs @@ -40,7 +40,7 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::JoinType; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn}; +use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn, column_indices_from_schema}; use crate::physical_plan::PhysicalExpr; use crate::physical_plan::{ expressions::Column, @@ -213,38 +213,6 @@ impl HashJoinExec { pub fn partition_mode(&self) -> &PartitionMode { &self.mode } - - /// Calculates column indices and left/right placement on input / output schemas and jointype - fn column_indices_from_schema(&self) -> ArrowResult> { - let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Full - | JoinType::Semi - | JoinType::Anti => (true, self.left.schema(), self.right.schema()), - JoinType::Right => (false, self.right.schema(), self.left.schema()), - }; - let mut column_indices = Vec::with_capacity(self.schema.fields().len()); - for field in self.schema.fields() { - let (is_primary, index) = match primary_schema.index_of(field.name()) { - Ok(i) => Ok((true, i)), - Err(_) => { - match secondary_schema.index_of(field.name()) { - Ok(i) => Ok((false, i)), - _ => Err(DataFusionError::Internal( - format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() - )) - } - } - }.map_err(DataFusionError::into_arrow_external_error)?; - - let is_left = - is_primary && primary_is_left || !is_primary && !primary_is_left; - column_indices.push(ColumnIndex { index, is_left }); - } - - Ok(column_indices) - } } #[async_trait] @@ -405,7 +373,12 @@ impl ExecutionPlan for HashJoinExec { let right_stream = self.right.execute(partition).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); - let column_indices = self.column_indices_from_schema()?; + let column_indices = column_indices_from_schema( + &self.join_type, + &self.left.schema(), + &self.right.schema(), + &self.schema, + )?; let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index cfd0a640e5a5..79ab69240c96 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -24,6 +24,7 @@ use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; +use std::sync::Arc; /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; @@ -102,6 +103,43 @@ pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema::new(fields) } +/// Calculates column indices and left/right placement on input / output schemas and jointype +pub fn column_indices_from_schema( + join_type: &JoinType, + left_schema: &Arc, + right_schema: &Arc, + schema: &Arc, +) -> ArrowResult> { + let (primary_is_left, primary_schema, secondary_schema) = match join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => (true, left_schema, right_schema), + JoinType::Right => (false, right_schema, left_schema), + }; + let mut column_indices = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + let (is_primary, index) = match primary_schema.index_of(field.name()) { + Ok(i) => Ok((true, i)), + Err(_) => { + match secondary_schema.index_of(field.name()) { + Ok(i) => Ok((false, i)), + _ => Err(DataFusionError::Internal( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }.map_err(DataFusionError::into_arrow_external_error)?; + + let is_left = + is_primary && primary_is_left || !is_primary && !primary_is_left; + column_indices.push(ColumnIndex { index, is_left }); + } + + Ok(column_indices) +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 289fe13b86ee..3bdb47d775b6 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -41,7 +41,7 @@ use crate::logical_plan::JoinType; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; -use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn}; +use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn, column_indices_from_schema}; use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::PhysicalExpr; use crate::physical_plan::{ @@ -998,38 +998,6 @@ impl SortMergeJoinExec { pub fn join_type(&self) -> &JoinType { &self.join_type } - - /// Calculates column indices and left/right placement on input / output schemas and jointype - fn column_indices_from_schema(&self) -> ArrowResult> { - let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Full - | JoinType::Semi - | JoinType::Anti => (true, self.left.schema(), self.right.schema()), - JoinType::Right => (false, self.right.schema(), self.left.schema()), - }; - let mut column_indices = Vec::with_capacity(self.schema.fields().len()); - for field in self.schema.fields() { - let (is_primary, index) = match primary_schema.index_of(field.name()) { - Ok(i) => Ok((true, i)), - Err(_) => { - match secondary_schema.index_of(field.name()) { - Ok(i) => Ok((false, i)), - _ => Err(DataFusionError::Internal( - format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() - )) - } - } - }.map_err(DataFusionError::into_arrow_external_error)?; - - let is_left = - is_primary && primary_is_left || !is_primary && !primary_is_left; - column_indices.push(ColumnIndex { index, is_left }); - } - - Ok(column_indices) - } } #[async_trait] @@ -1069,127 +1037,17 @@ impl ExecutionPlan for SortMergeJoinExec { async fn execute(&self, partition: usize) -> Result { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); - // we only want to compute the build side once for PartitionMode::CollectLeft - let left_data = { - match self.mode { - PartitionMode::CollectLeft => { - let mut build_side = self.build_side.lock().await; - - match build_side.as_ref() { - Some(stream) => stream.clone(), - None => { - let start = Instant::now(); - - // merge all left parts into a single stream - let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0).await?; - - // This operation performs 2 steps at once: - // 1. creates a [JoinHashMap] of all batches from the stream - // 2. stores the batches in a vector. - let initial = (0, Vec::new()); - let (num_rows, batches) = stream - .try_fold(initial, |mut acc, batch| async { - acc.0 += batch.num_rows(); - acc.1.push(batch); - Ok(acc) - }) - .await?; - let mut hashmap = - JoinHashMap(RawTable::with_capacity(num_rows)); - let mut hashes_buffer = Vec::new(); - let mut offset = 0; - for batch in batches.iter() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( - &on_left, - batch, - &mut hashmap, - offset, - &self.random_state, - &mut hashes_buffer, - )?; - offset += batch.num_rows(); - } - // Merge all batches into a single batch, so we - // can directly index into the arrays - let single_batch = - concat_batches(&self.left.schema(), &batches, num_rows)?; - - let left_side = Arc::new((hashmap, single_batch)); - - *build_side = Some(left_side.clone()); - - debug!( - "Built build-side of hash join containing {} rows in {} ms", - num_rows, - start.elapsed().as_millis() - ); - - left_side - } - } - } - PartitionMode::Partitioned => { - let start = Instant::now(); - - // Load 1 partition of left side in memory - let stream = self.left.execute(partition).await?; - - // This operation performs 2 steps at once: - // 1. creates a [JoinHashMap] of all batches from the stream - // 2. stores the batches in a vector. - let initial = (0, Vec::new()); - let (num_rows, batches) = stream - .try_fold(initial, |mut acc, batch| async { - acc.0 += batch.num_rows(); - acc.1.push(batch); - Ok(acc) - }) - .await?; - let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows)); - let mut hashes_buffer = Vec::new(); - let mut offset = 0; - for batch in batches.iter() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( - &on_left, - batch, - &mut hashmap, - offset, - &self.random_state, - &mut hashes_buffer, - )?; - offset += batch.num_rows(); - } - // Merge all batches into a single batch, so we - // can directly index into the arrays - let single_batch = - concat_batches(&self.left.schema(), &batches, num_rows)?; - - let left_side = Arc::new((hashmap, single_batch)); - - debug!( - "Built build-side {} of hash join containing {} rows in {} ms", - partition, - num_rows, - start.elapsed().as_millis() - ); - - left_side - } - } - }; + let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + let left = self.left.execute(partition).await?; + let right = self.right.execute(partition).await?; - // we have the batches and the hash map with their keys. We can how create a stream - // over the right that uses this information to issue new batches. + let column_indices = column_indices_from_schema( + &self.join_type, + &self.left.schema(), + &self.right.schema(), + &self.schema)?; - let right_stream = self.right.execute(partition).await?; - let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); - let column_indices = self.column_indices_from_schema()?; let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { From 374125d6d6201ab8937a0506f75e7a59e6234b49 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 10 Nov 2021 17:19:41 +0800 Subject: [PATCH 07/15] extract more common --- .../src/physical_plan/joins/hash_join.rs | 56 +------ datafusion/src/physical_plan/joins/mod.rs | 57 ++++++- .../physical_plan/joins/sort_merge_join.rs | 152 ++---------------- 3 files changed, 70 insertions(+), 195 deletions(-) diff --git a/datafusion/src/physical_plan/joins/hash_join.rs b/datafusion/src/physical_plan/joins/hash_join.rs index 6578af25f8fe..35ec39d3749c 100644 --- a/datafusion/src/physical_plan/joins/hash_join.rs +++ b/datafusion/src/physical_plan/joins/hash_join.rs @@ -40,7 +40,10 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::JoinType; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn, column_indices_from_schema}; +use crate::physical_plan::joins::{ + build_join_schema, check_join_is_valid, column_indices_from_schema, equal_rows, + JoinOn, +}; use crate::physical_plan::PhysicalExpr; use crate::physical_plan::{ expressions::Column, @@ -746,57 +749,6 @@ fn build_join_indexes( } } -macro_rules! equal_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ - let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); - let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); - - match (left_array.is_null($left), right_array.is_null($right)) { - (false, false) => left_array.value($left) == right_array.value($right), - _ => false, - } - }}; -} - -/// Left and right row have equal values -fn equal_rows( - left: usize, - right: usize, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], -) -> Result { - let mut err = None; - let res = left_arrays - .iter() - .zip(right_arrays) - .all(|(l, r)| match l.data_type() { - DataType::Null => true, - DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), - DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), - DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), - DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), - DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), - DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), - DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), - DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), - DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), - DataType::Timestamp(_, None) => { - equal_rows_elem!(Int64Array, l, r, left, right) - } - DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), - DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), - _ => { - // This is internal because we should have caught this before. - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - }); - - err.unwrap_or(Ok(res)) -} - // Produces a batch for left-side rows that have/have not been matched during the whole join fn produce_from_matched( visited_left_side: &[bool], diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index 79ab69240c96..04b46b175ddf 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -22,7 +22,8 @@ pub mod sort_merge_join; use crate::error::{DataFusionError, Result}; use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; -use arrow::datatypes::{Field, Schema}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema}; use std::collections::HashSet; use std::sync::Arc; @@ -132,14 +133,64 @@ pub fn column_indices_from_schema( } }.map_err(DataFusionError::into_arrow_external_error)?; - let is_left = - is_primary && primary_is_left || !is_primary && !primary_is_left; + let is_left = is_primary && primary_is_left || !is_primary && !primary_is_left; column_indices.push(ColumnIndex { index, is_left }); } Ok(column_indices) } +macro_rules! equal_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => left_array.value($left) == right_array.value($right), + _ => false, + } + }}; +} + +/// Left and right row have equal values +fn equal_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result { + let mut err = None; + let res = left_arrays + .iter() + .zip(right_arrays) + .all(|(l, r)| match l.data_type() { + DataType::Null => true, + DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), + DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), + DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), + DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), + DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), + DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), + DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), + DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), + DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), + DataType::Timestamp(_, None) => { + equal_rows_elem!(Int64Array, l, r, left, right) + } + DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), + DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), + _ => { + // This is internal because we should have caught this before. + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } + }); + + err.unwrap_or(Ok(res)) +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 3bdb47d775b6..91965a2cc603 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -41,7 +41,10 @@ use crate::logical_plan::JoinType; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; -use crate::physical_plan::joins::{build_join_schema, check_join_is_valid, JoinOn, column_indices_from_schema}; +use crate::physical_plan::joins::{ + build_join_schema, check_join_is_valid, column_indices_from_schema, equal_rows, + JoinOn, +}; use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::PhysicalExpr; use crate::physical_plan::{ @@ -866,28 +869,6 @@ impl SMJStream { } } -// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value. -// -// Note that the `u64` keys are not stored in the hashmap (hence the `()` as key), but are only used -// to put the indices in a certain bucket. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the left side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// TODO: speed up collision check and move away from using a hashbrown HashMap -// https://github.com/apache/arrow-datafusion/issues/50 -struct JoinHashMap(RawTable<(u64, SmallVec<[u64; 1]>)>); - -impl fmt::Debug for JoinHashMap { - fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { - Ok(()) - } -} - -type JoinLeftData = Arc<(JoinHashMap, RecordBatch)>; - /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. #[derive(Debug)] @@ -1045,26 +1026,17 @@ impl ExecutionPlan for SortMergeJoinExec { &self.join_type, &self.left.schema(), &self.right.schema(), - &self.schema)?; - + &self.schema, + )?; - let num_rows = left_data.1.num_rows(); - let visited_left_side = match self.join_type { - JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { - vec![false; num_rows] - } - JoinType::Inner | JoinType::Right => vec![], - }; Ok(Box::pin(SortMergeJoinStream::new( self.schema.clone(), on_left, on_right, self.join_type, - left_data, - right_stream, + left, + right, column_indices, - self.random_state.clone(), - visited_left_side, SortMergeJoinMetrics::new(partition, &self.metrics), ))) } @@ -1107,14 +1079,12 @@ struct SortMergeJoinStream { on_right: Vec, /// type of the join join_type: JoinType, - /// information from the left - left_data: JoinLeftData, + /// left + left: SendableRecordBatchStream, /// right right: SendableRecordBatchStream, /// Information of index and left / right placement of columns column_indices: Vec, - /// Keeps track of the left side rows whether they are visited - visited_left_side: Vec, // TODO: use a more memory efficient data structure, https://github.com/apache/arrow-datafusion/issues/240 /// There is nothing to process anymore and left side is processed in case of left join is_exhausted: bool, /// Metrics @@ -1128,10 +1098,9 @@ impl SortMergeJoinStream { on_left: Vec, on_right: Vec, join_type: JoinType, - left_data: JoinLeftData, + left: SendableRecordBatchStream, right: SendableRecordBatchStream, column_indices: Vec, - visited_left_side: Vec, join_metrics: SortMergeJoinMetrics, ) -> Self { SortMergeJoinStream { @@ -1139,10 +1108,9 @@ impl SortMergeJoinStream { on_left, on_right, join_type, - left_data, + left, right, column_indices, - visited_left_side, is_exhausted: false, join_metrics, } @@ -1155,102 +1123,6 @@ impl RecordBatchStream for SortMergeJoinStream { } } -macro_rules! equal_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ - let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); - let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); - - match (left_array.is_null($left), right_array.is_null($right)) { - (false, false) => left_array.value($left) == right_array.value($right), - _ => false, - } - }}; -} - -/// Left and right row have equal values -fn equal_rows( - left: usize, - right: usize, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], -) -> Result { - let mut err = None; - let res = left_arrays - .iter() - .zip(right_arrays) - .all(|(l, r)| match l.data_type() { - DataType::Null => true, - DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right), - DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right), - DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right), - DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right), - DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right), - DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right), - DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right), - DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right), - DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right), - DataType::Timestamp(_, None) => { - equal_rows_elem!(Int64Array, l, r, left, right) - } - DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right), - DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right), - _ => { - // This is internal because we should have caught this before. - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false - } - }); - - err.unwrap_or(Ok(res)) -} - -// Produces a batch for left-side rows that have/have not been matched during the whole join -fn produce_from_matched( - visited_left_side: &[bool], - schema: &SchemaRef, - column_indices: &[ColumnIndex], - left_data: &JoinLeftData, - unmatched: bool, -) -> ArrowResult { - // Find indices which didn't match any right row (are false) - let indices = if unmatched { - visited_left_side - .iter() - .enumerate() - .filter(|&(_, &value)| !value) - .map(|(index, _)| index as u64) - .collect::>() - } else { - // produce those that did match - visited_left_side - .iter() - .enumerate() - .filter(|&(_, &value)| value) - .map(|(index, _)| index as u64) - .collect::>() - }; - - // generate batches by taking values from the left side and generating columns filled with null on the right side - let indices = UInt64Array::from_data(DataType::UInt64, indices.into(), None); - - let num_rows = indices.len(); - let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); - for (idx, column_index) in column_indices.iter().enumerate() { - let array = if column_index.is_left { - let array = left_data.1.column(column_index.index); - take::take(array.as_ref(), &indices)?.into() - } else { - let datatype = schema.field(idx).data_type().clone(); - new_null_array(datatype, num_rows).into() - }; - - columns.push(array); - } - RecordBatch::try_new(schema.clone(), columns) -} - impl Stream for SortMergeJoinStream { type Item = ArrowResult; From 3393f528ddb9c451aca1926e66a21c5e64ab4142 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 10 Nov 2021 18:12:44 +0800 Subject: [PATCH 08/15] Add cmp --- datafusion/src/physical_plan/joins/mod.rs | 58 +++++++++++++++++++ .../physical_plan/joins/sort_merge_join.rs | 19 ++++-- 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index 04b46b175ddf..ec788e366d76 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -24,6 +24,7 @@ use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field, Schema}; +use std::cmp::Ordering; use std::collections::HashSet; use std::sync::Arc; @@ -191,6 +192,63 @@ fn equal_rows( err.unwrap_or(Ok(res)) } +macro_rules! cmp_rows_elem { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => { + let cmp = left_array + .value($left) + .partial_cmp(&right_array.value($right))?; + if cmp != Ordering::Equal { + res = cmp; + break; + } + } + _ => unreachable!(), + } + }}; +} + +/// compare left row with right row +fn comp_rows( + left: usize, + right: usize, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result { + let mut res = Ordering::Equal; + for (l, r) in left_arrays.iter().zip(right_arrays) { + match l.data_type() { + DataType::Null => {} + DataType::Boolean => cmp_rows_elem!(BooleanArray, l, r, left, right), + DataType::Int8 => cmp_rows_elem!(Int8Array, l, r, left, right), + DataType::Int16 => cmp_rows_elem!(Int16Array, l, r, left, right), + DataType::Int32 => cmp_rows_elem!(Int32Array, l, r, left, right), + DataType::Int64 => cmp_rows_elem!(Int64Array, l, r, left, right), + DataType::UInt8 => cmp_rows_elem!(UInt8Array, l, r, left, right), + DataType::UInt16 => cmp_rows_elem!(UInt16Array, l, r, left, right), + DataType::UInt32 => cmp_rows_elem!(UInt32Array, l, r, left, right), + DataType::UInt64 => cmp_rows_elem!(UInt64Array, l, r, left, right), + DataType::Timestamp(_, None) => { + cmp_rows_elem!(Int64Array, l, r, left, right) + } + DataType::Utf8 => cmp_rows_elem!(StringArray, l, r, left, right), + DataType::LargeUtf8 => cmp_rows_elem!(LargeStringArray, l, r, left, right), + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal( + "Unsupported data type in sort merge join comparator".to_string(), + )); + } + } + } + + Ok(res) +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 91965a2cc603..88233392e7ba 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -42,8 +42,8 @@ use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; use crate::physical_plan::joins::{ - build_join_schema, check_join_is_valid, column_indices_from_schema, equal_rows, - JoinOn, + build_join_schema, check_join_is_valid, column_indices_from_schema, comp_rows, + equal_rows, JoinOn, }; use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::PhysicalExpr; @@ -746,7 +746,7 @@ impl SMJStream { } loop { - let current_cmp = self.compare_stream_buffer(); + let current_cmp = self.compare_stream_buffer()?; match current_cmp { Ordering::Less => { let more_stream = self.advance_streamed_key_null_free()?; @@ -852,8 +852,17 @@ impl SMJStream { } } - fn compare_stream_buffer(&self) -> Ordering { - todo!() + fn compare_stream_buffer(&self) -> Result { + let stream_arrays = + join_arrays(&self.stream_batch.batch.unwrap().batch, &self.on_streamed); + let buffer_arrays = + join_arrays(&self.buffered_batches.batches[0].batch, &self.on_buffered); + comp_rows( + self.stream_batch.cur_row, + self.buffered_batches.key_idx.unwrap(), + &stream_arrays, + &buffer_arrays, + ) } fn get_stream_next(&mut self) -> Result<()> { From 27ef41cf1fc6ef0ed95613010091a3a2aa1f958d Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 12 Nov 2021 19:45:42 +0800 Subject: [PATCH 09/15] utilize driver --- .../physical_plan/joins/sort_merge_join.rs | 260 ++++++------------ 1 file changed, 80 insertions(+), 180 deletions(-) diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 88233392e7ba..2e811a0b5065 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -45,6 +45,7 @@ use crate::physical_plan::joins::{ build_join_schema, check_join_is_valid, column_indices_from_schema, comp_rows, equal_rows, JoinOn, }; +use crate::physical_plan::sorts::external_sort::ExternalSortExec; use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::PhysicalExpr; use crate::physical_plan::{ @@ -322,7 +323,7 @@ fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { on_column.iter().map(|c| rb.column(c.index())).collect() } -struct SMJStream { +struct SortMergeJoinDriver { streamed: SendableRecordBatchStream, buffered: SendableRecordBatchStream, on_streamed: Vec, @@ -332,34 +333,15 @@ struct SMJStream { column_indices: Vec, stream_batch: StreamingBatch, buffered_batches: BufferedBatches, - result_sender: Sender>, - result: SendableRecordBatchStream, runtime: Arc, } -impl RecordBatchStream for SMJStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl Stream for SMJStream { - type Item = ArrowResult; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - todo!() - } -} - macro_rules! with_match_primitive_type {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* ) => ({ macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} - use datafusion::arrow::datatypes::PrimitiveType::*; - use datafusion::arrow::types::{days_ms, months_days_ns}; + use arrow::datatypes::PrimitiveType::*; + use arrow::types::{days_ms, months_days_ns}; match $key_type { Int8 => __with_ty__! { i8 }, Int16 => __with_ty__! { i16 }, @@ -592,7 +574,7 @@ struct Slice { len: usize, } -impl SMJStream { +impl SortMergeJoinDriver { fn new( streamed: SendableRecordBatchStream, buffered: SendableRecordBatchStream, @@ -600,11 +582,10 @@ impl SMJStream { on_buffered: Vec, streamed_sort: Vec, buffered_sort: Vec, + column_indices: Vec, schema: Arc, runtime: Arc, ) -> Self { - let (tx, rx) = tokio::sync::mpsc::channel(2); - let column_indices = vec![]; Self { streamed, buffered, @@ -614,13 +595,14 @@ impl SMJStream { column_indices, stream_batch: StreamingBatch::new(on_streamed.clone(), streamed_sort), buffered_batches: BufferedBatches::new(on_buffered.clone(), buffered_sort), - result_sender: tx, - result: RecordBatchReceiverStream::create(&schema, rx), runtime, } } - async fn inner_join_driver(&mut self) -> Result<()> { + async fn inner_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { let target_batch_size = self.runtime.batch_size(); let mut output_slots_available = target_batch_size; @@ -633,6 +615,7 @@ impl SMJStream { target_batch_size, output_slots_available, output_arrays, + sender, ) .await?; output_slots_available = result.0; @@ -653,6 +636,7 @@ impl SMJStream { target_batch_size: usize, output_slots_available: usize, mut output_arrays: Vec>, + sender: &Sender>, ) -> Result<(usize, Vec>)> { let mut output_slots_available = output_slots_available; let mut remaining = self.buffered_batches.row_num; @@ -716,7 +700,7 @@ impl SMJStream { if output_slots_available == 0 { let result = make_batch(self.schema.clone(), output_arrays); - if let Err(e) = self.result_sender.send(result).await { + if let Err(e) = sender.send(result).await { println!("ERROR batch via inner join stream: {}", e); }; @@ -1038,16 +1022,71 @@ impl ExecutionPlan for SortMergeJoinExec { &self.schema, )?; - Ok(Box::pin(SortMergeJoinStream::new( - self.schema.clone(), - on_left, - on_right, - self.join_type, - left, - right, - column_indices, - SortMergeJoinMetrics::new(partition, &self.metrics), - ))) + let (tx, rx): ( + Sender>, + Receiver>, + ) = tokio::sync::mpsc::channel(2); + + let left_sort = self + .left + .as_any() + .downcast_ref::() + .unwrap() + .expr() + .iter() + .map(|s| s.clone()) + .collect::>(); + let right_sort = self + .right + .as_any() + .downcast_ref::() + .unwrap() + .expr() + .iter() + .map(|s| s.clone()) + .collect::>(); + + let mut driver = match self.join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => SortMergeJoinDriver::new( + left, + right, + on_left, + on_right, + left_sort, + right_sort, + column_indices, + self.schema.clone(), + RUNTIME_ENV.clone(), + ), + JoinType::Right => SortMergeJoinDriver::new( + right, + left, + on_right, + on_left, + right_sort, + left_sort, + column_indices, + self.schema.clone(), + RUNTIME_ENV.clone(), + ), + }; + + match self.join_type { + JoinType::Inner => driver.inner_join_driver(&tx).await?, + JoinType::Left => {} + JoinType::Right => {} + JoinType::Full => {} + JoinType::Semi => {} + JoinType::Anti => {} + } + + let result = RecordBatchReceiverStream::create(&schema, rx); + + Ok(Box::pin(result)) } fn fmt_as( @@ -1059,8 +1098,8 @@ impl ExecutionPlan for SortMergeJoinExec { DisplayFormatType::Default => { write!( f, - "SortMergeJoinExec: mode={:?}, join_type={:?}, on={:?}", - self.mode, self.join_type, self.on + "SortMergeJoinExec: join_type={:?}, on={:?}", + self.join_type, self.on ) } } @@ -1078,145 +1117,6 @@ impl ExecutionPlan for SortMergeJoinExec { } } -/// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SortMergeJoinStream { - /// Input schema - schema: Arc, - /// columns from the left - on_left: Vec, - /// columns from the right used to compute the hash - on_right: Vec, - /// type of the join - join_type: JoinType, - /// left - left: SendableRecordBatchStream, - /// right - right: SendableRecordBatchStream, - /// Information of index and left / right placement of columns - column_indices: Vec, - /// There is nothing to process anymore and left side is processed in case of left join - is_exhausted: bool, - /// Metrics - join_metrics: SortMergeJoinMetrics, -} - -#[allow(clippy::too_many_arguments)] -impl SortMergeJoinStream { - fn new( - schema: Arc, - on_left: Vec, - on_right: Vec, - join_type: JoinType, - left: SendableRecordBatchStream, - right: SendableRecordBatchStream, - column_indices: Vec, - join_metrics: SortMergeJoinMetrics, - ) -> Self { - SortMergeJoinStream { - schema, - on_left, - on_right, - join_type, - left, - right, - column_indices, - is_exhausted: false, - join_metrics, - } - } -} - -impl RecordBatchStream for SortMergeJoinStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -impl Stream for SortMergeJoinStream { - type Item = ArrowResult; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.right - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(batch)) => { - let timer = self.join_metrics.join_time.timer(); - let result = build_batch( - &batch, - &self.left_data, - &self.on_left, - &self.on_right, - self.join_type, - &self.schema, - &self.column_indices, - ); - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if let Ok((ref batch, ref left_side)) = result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - - match self.join_type { - JoinType::Left - | JoinType::Full - | JoinType::Semi - | JoinType::Anti => { - left_side.iter().flatten().for_each(|x| { - self.visited_left_side[*x as usize] = true; - }); - } - JoinType::Inner | JoinType::Right => {} - } - } - Some(result.map(|x| x.0)) - } - other => { - let timer = self.join_metrics.join_time.timer(); - // For the left join, produce rows for unmatched rows - match self.join_type { - JoinType::Left - | JoinType::Full - | JoinType::Semi - | JoinType::Anti - if !self.is_exhausted => - { - let result = produce_from_matched( - &self.visited_left_side, - &self.schema, - &self.column_indices, - &self.left_data, - self.join_type != JoinType::Semi, - ); - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - } - timer.done(); - self.is_exhausted = true; - return Some(result); - } - JoinType::Left - | JoinType::Full - | JoinType::Semi - | JoinType::Anti - | JoinType::Inner - | JoinType::Right => {} - } - - other - } - }) - } -} - #[cfg(test)] mod tests { use std::sync::Arc; From f71cb1506ef224694e40277c895bf0cc2500a783 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 16 Nov 2021 13:03:01 +0800 Subject: [PATCH 10/15] smj v1 --- .../physical_plan/joins/sort_merge_join.rs | 575 +++++++++++++++++- 1 file changed, 559 insertions(+), 16 deletions(-) diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 2e811a0b5065..e3ebc3a08b0e 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -466,8 +466,7 @@ macro_rules! copy_slices { ($TO:ty, $FROM:ty) => {{ let to = array.as_mut_any().downcast_mut::<$TO>().unwrap(); for pos in slices { - let from = buffered_batches[pos.batch_idx] - .batch + let from = batches[pos.batch_idx] .column(column_index.index) .slice(pos.start_idx, pos.len); let from = from.as_any().downcast_ref::<$FROM>().unwrap(); @@ -476,8 +475,8 @@ macro_rules! copy_slices { }}; } -fn copy_buffered_slices( - buffered_batches: &VecDeque, +fn copy_slices( + batches: &Vec<&RecordBatch>, slices: &Vec, array: &mut Box, column_index: &ColumnIndex, @@ -631,6 +630,263 @@ impl SortMergeJoinDriver { Ok(()) } + async fn semi_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let target_batch_size = self.runtime.batch_size(); + + let mut output_slots_available = target_batch_size; + let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; + + while self.find_next_inner_match()? { + let result = self.stream_copy_buffer_omit( + target_batch_size, output_slots_available, output_arrays, sender).await?; + output_slots_available = result.0; + output_arrays = result.1; + } + + Ok(()) + } + + async fn outer_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let target_batch_size = self.runtime.batch_size(); + + let mut output_slots_available = target_batch_size; + let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; + let mut buffer_ends = false; + + loop { + let OuterMatchResult { + get_match, + buffered_ended, + more_output, + } = self.find_next_outer(buffer_ends)?; + if !more_output { + break; + } + buffer_ends = buffered_ended; + if get_match { + loop { + let result = self + .join_eq_records( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + self.stream_batch.advance(); + if self.stream_batch.is_new_key { + break; + } + } + } else { + let result = self + .stream_copy_buffer_null( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + } + } + + Ok(()) + } + + async fn anti_join_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let target_batch_size = self.runtime.batch_size(); + + let mut output_slots_available = target_batch_size; + let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; + let mut buffer_ends = false; + + loop { + let OuterMatchResult { + get_match, + buffered_ended, + more_output, + } = self.find_next_outer(buffer_ends)?; + if !more_output { + break; + } + buffer_ends = buffered_ended; + if get_match { + // do nothing + } else { + let result = self + .stream_copy_buffer_omit( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + } + } + + Ok(()) + } + + async fn full_outer_driver( + &mut self, + sender: &Sender>, + ) -> Result<()> { + let target_batch_size = self.runtime.batch_size(); + + let mut output_slots_available = target_batch_size; + let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; + let mut stream_ends = false; + let mut buffer_ends = false; + let mut advance_stream = true; + let mut advance_buffer = true; + + loop { + if advance_buffer { + buffer_ends = !self.advance_buffered_key()?; + } + if advance_stream { + stream_ends = !self.advance_streamed_key()?; + } + + if stream_ends && buffer_ends { + break; + } else if stream_ends { + let result = self + .stream_null_buffer_copy( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + advance_buffer = true; + advance_stream = false; + } else if buffer_ends { + let result = self + .stream_copy_buffer_null( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + advance_stream = true; + advance_buffer = false; + } else { + if self.stream_batch.key_any_null() { + let result = self + .stream_copy_buffer_null( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + advance_stream = true; + advance_buffer = false; + continue; + } + if self.buffered_batches.key_any_null() { + let result = self + .stream_null_buffer_copy( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + advance_buffer = true; + advance_stream = false; + continue; + } + + let current_cmp = self.compare_stream_buffer()?; + match current_cmp { + Ordering::Less => { + let result = self + .stream_copy_buffer_null( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + advance_stream = true; + advance_buffer = false; + } + Ordering::Equal => { + loop { + let result = self + .join_eq_records( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + self.stream_batch.advance(); + if self.stream_batch.is_new_key { + break; + } + } + advance_stream = false; // we already reach the next key of stream + advance_buffer = true; + } + Ordering::Greater => { + let result = self + .stream_null_buffer_copy( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; + output_slots_available = result.0; + output_arrays = result.1; + + advance_buffer = true; + advance_stream = false; + } + } + } + } + Ok(()) + } + async fn join_eq_records( &mut self, target_batch_size: usize, @@ -643,7 +899,12 @@ impl SortMergeJoinDriver { let stream_batch = &self.stream_batch.batch.unwrap().batch; let stream_row = self.stream_batch.cur_row; - let buffered_batches = &self.buffered_batches.batches; + let batches = self + .buffered_batches + .batches + .iter() + .map(|prb| &prb.batch) + .collect::>(); let buffered_ranges = &self.buffered_batches.ranges; let mut unfinished = true; @@ -688,12 +949,209 @@ impl SortMergeJoinDriver { ); } else { // copy buffered start from: `buffered_idx`, len: `rows_to_output` - copy_buffered_slices( - buffered_batches, - &slices, - array, - column_index, - ) + copy_slices(&batches, &slices, array, column_index); + } + }); + + if output_slots_available == 0 { + let result = make_batch(self.schema.clone(), output_arrays); + + if let Err(e) = sender.send(result).await { + println!("ERROR batch via inner join stream: {}", e); + }; + + output_arrays = new_arrays(&self.schema, target_batch_size)?; + output_slots_available = target_batch_size; + } + + buffered_idx += rows_to_output; + rows_to_output = 0; + } + Ok((output_slots_available, output_arrays)) + } + + async fn stream_copy_buffer_null( + &mut self, + target_batch_size: usize, + output_slots_available: usize, + mut output_arrays: Vec>, + sender: &Sender>, + ) -> Result<(usize, Vec>)> { + let mut output_slots_available = output_slots_available; + let stream_batch = &self.stream_batch.batch.unwrap().batch; + let batch = vec![stream_batch]; + let stream_range = + &self.stream_batch.batch.unwrap().ranges[&self.stream_batch.cur_range]; + let mut remaining = range_len(stream_range); + + let mut unfinished = true; + let mut streamed_idx = self.stream_batch.cur_row; + let mut rows_to_output = 0; + + // output each buffered matching record once + while unfinished { + if output_slots_available >= remaining { + unfinished = false; + rows_to_output = remaining; + output_slots_available -= remaining; + remaining = 0; + } else { + rows_to_output = output_slots_available; + output_slots_available = 0; + remaining -= rows_to_output; + } + + let slice = vec![Slice { + batch_idx: 0, + start_idx: streamed_idx, + len: rows_to_output, + }]; + + output_arrays + .iter_mut() + .zip(self.schema.fields().iter()) + .zip(self.column_indices.iter()) + .map(|((array, field), column_index)| { + if column_index.is_left { + copy_slices(&batch, &slice, array, column_index); + } else { + (0..rows_to_output).for_each(array.push_null()); + } + }); + + if output_slots_available == 0 { + let result = make_batch(self.schema.clone(), output_arrays); + + if let Err(e) = sender.send(result).await { + println!("ERROR batch via inner join stream: {}", e); + }; + + output_arrays = new_arrays(&self.schema, target_batch_size)?; + output_slots_available = target_batch_size; + } + + streamed_idx += rows_to_output; + rows_to_output = 0; + } + Ok((output_slots_available, output_arrays)) + } + + async fn stream_copy_buffer_omit( + &mut self, + target_batch_size: usize, + output_slots_available: usize, + mut output_arrays: Vec>, + sender: &Sender>, + ) -> Result<(usize, Vec>)> { + let mut output_slots_available = output_slots_available; + let stream_batch = &self.stream_batch.batch.unwrap().batch; + let batch = vec![stream_batch]; + let stream_range = + &self.stream_batch.batch.unwrap().ranges[&self.stream_batch.cur_range]; + let mut remaining = range_len(stream_range); + + let mut unfinished = true; + let mut streamed_idx = self.stream_batch.cur_row; + let mut rows_to_output = 0; + + // output each buffered matching record once + while unfinished { + if output_slots_available >= remaining { + unfinished = false; + rows_to_output = remaining; + output_slots_available -= remaining; + remaining = 0; + } else { + rows_to_output = output_slots_available; + output_slots_available = 0; + remaining -= rows_to_output; + } + + let slice = vec![Slice { + batch_idx: 0, + start_idx: streamed_idx, + len: rows_to_output, + }]; + + output_arrays + .iter_mut() + .zip(self.schema.fields().iter()) + .zip(self.column_indices.iter()) + .map(|((array, field), column_index)| { + copy_slices(&batch, &slice, array, column_index); + }); + + if output_slots_available == 0 { + let result = make_batch(self.schema.clone(), output_arrays); + + if let Err(e) = sender.send(result).await { + println!("ERROR batch via inner join stream: {}", e); + }; + + output_arrays = new_arrays(&self.schema, target_batch_size)?; + output_slots_available = target_batch_size; + } + + streamed_idx += rows_to_output; + rows_to_output = 0; + } + Ok((output_slots_available, output_arrays)) + } + + async fn stream_null_buffer_copy( + &mut self, + target_batch_size: usize, + output_slots_available: usize, + mut output_arrays: Vec>, + sender: &Sender>, + ) -> Result<(usize, Vec>)> { + let mut output_slots_available = output_slots_available; + let mut remaining = self.buffered_batches.row_num; + + let batches = self + .buffered_batches + .batches + .iter() + .map(|prb| &prb.batch) + .collect::>(); + let buffered_ranges = &self.buffered_batches.ranges; + + let mut unfinished = true; + let mut buffered_idx = 0; + let mut rows_to_output = 0; + let start_indices = range_start_indices(buffered_ranges); + + // output each buffered matching record once + while unfinished { + if output_slots_available >= remaining { + unfinished = false; + rows_to_output = remaining; + output_slots_available -= remaining; + remaining = 0; + } else { + rows_to_output = output_slots_available; + output_slots_available = 0; + remaining -= rows_to_output; + } + + // get slices for buffered side for the current output + let slices = slices_from_batches( + buffered_ranges, + &start_indices, + buffered_idx, + rows_to_output, + ); + + output_arrays + .iter_mut() + .zip(self.schema.fields().iter()) + .zip(self.column_indices.iter()) + .map(|((array, field), column_index)| { + if column_index.is_left { + (0..rows_to_output).for_each(array.push_null()); + } else { + // copy buffered start from: `buffered_idx`, len: `rows_to_output` + copy_slices(&batches, &slices, array, column_index); } }); @@ -749,6 +1207,74 @@ impl SortMergeJoinDriver { } } + fn find_next_outer(&mut self, buffer_ends: bool) -> Result { + let more_stream = self.advance_streamed_key()?; + if buffer_ends { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: true, + more_output: more_stream, + }); + } else { + if !more_stream { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: false, + more_output: false, + }); + } + + if self.buffered_batches.key_any_null() { + let more_buffer = self.advance_buffered_key_null_free()?; + if !more_buffer { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: true, + more_output: true, + }); + } + } + + loop { + if self.stream_batch.key_any_null() { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: false, + more_output: true, + }); + } + + let current_cmp = self.compare_stream_buffer()?; + match current_cmp { + Ordering::Less => { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: false, + more_output: true, + }) + } + Ordering::Equal => { + return Ok(OuterMatchResult { + get_match: true, + buffered_ended: false, + more_output: true, + }) + } + Ordering::Greater => { + let more_buffer = self.advance_buffered_key_null_free()?; + if !more_buffer { + return Ok(OuterMatchResult { + get_match: false, + buffered_ended: true, + more_output: true, + }); + } + } + } + } + } + } + /// true for has next, false for ended fn advance_streamed(&mut self) -> Result { if self.stream_batch.is_finished() { @@ -806,15 +1332,26 @@ impl SortMergeJoinDriver { None => return Ok(false), Some(batch) => { self.buffered_batches.reset_batch(&batch); + if &batch.ranges.len() == 1 { + self.cumulate_same_keys()?; + } } } } Some(batch) => { self.buffered_batches.reset_batch(batch); + if batch.ranges.len() == 1 { + self.cumulate_same_keys()?; + } } } } else { self.buffered_batches.advance_in_current_batch(); + if self.buffered_batches.batches[0] + .is_last_range(&self.buffered_batches.ranges[0]) + { + self.cumulate_same_keys()?; + } } Ok(false) } @@ -880,6 +1417,12 @@ pub struct SortMergeJoinExec { metrics: ExecutionPlanMetricsSet, } +struct OuterMatchResult { + get_match: bool, + buffered_ended: bool, + more_output: bool, +} + /// Metrics for SortMergeJoinExec #[derive(Debug)] struct SortMergeJoinMetrics { @@ -1077,11 +1620,11 @@ impl ExecutionPlan for SortMergeJoinExec { match self.join_type { JoinType::Inner => driver.inner_join_driver(&tx).await?, - JoinType::Left => {} - JoinType::Right => {} - JoinType::Full => {} - JoinType::Semi => {} - JoinType::Anti => {} + JoinType::Left => driver.outer_join_driver(&tx).await?, + JoinType::Right => driver.outer_join_driver(&tx).await?, + JoinType::Full => driver.full_outer_driver(&tx).await?, + JoinType::Semi => driver.semi_join_driver(&tx).await?, + JoinType::Anti => driver.anti_join_driver(&tx).await?, } let result = RecordBatchReceiverStream::create(&schema, rx); From 4ba86bc34d70ed5497e7719649253c403c3ecbf2 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 16 Nov 2021 16:04:24 +0800 Subject: [PATCH 11/15] wip --- datafusion/src/arrow_dyn_list_array.rs | 149 +++++++++++ datafusion/src/lib.rs | 1 + .../src/physical_plan/expressions/mod.rs | 3 +- datafusion/src/physical_plan/hash_utils.rs | 5 +- .../src/physical_plan/joins/hash_join.rs | 10 +- datafusion/src/physical_plan/joins/mod.rs | 43 ++-- .../physical_plan/joins/sort_merge_join.rs | 237 +++++++++++------- datafusion/src/physical_plan/planner.rs | 2 +- datafusion/src/physical_plan/sorts/sort.rs | 6 +- 9 files changed, 334 insertions(+), 122 deletions(-) create mode 100644 datafusion/src/arrow_dyn_list_array.rs diff --git a/datafusion/src/arrow_dyn_list_array.rs b/datafusion/src/arrow_dyn_list_array.rs new file mode 100644 index 000000000000..4f2d9b77e236 --- /dev/null +++ b/datafusion/src/arrow_dyn_list_array.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DynMutableListArray from arrow/io/avro/read/nested.rs + +use arrow::array::{Array, ListArray, MutableArray, Offset}; +use arrow::bitmap::MutableBitmap; +use arrow::buffer::MutableBuffer; +use arrow::datatypes::DataType; +use arrow::error::ArrowError; +use std::sync::Arc; + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableListArray { + data_type: DataType, + offsets: MutableBuffer, + values: Box, + validity: Option, +} + +impl DynMutableListArray { + pub fn new_from( + values: Box, + data_type: DataType, + capacity: usize, + ) -> Self { + let mut offsets = MutableBuffer::::with_capacity(capacity + 1); + offsets.push(O::default()); + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&data_type); + Self { + data_type, + offsets, + values, + validity: None, + } + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_with_capacity(values: Box, capacity: usize) -> Self { + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self::new_from(values, data_type, capacity) + } + + /// The values + pub fn mut_values(&mut self) -> &mut dyn MutableArray { + self.values.as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> Result<(), ArrowError> { + let size = self.values.len(); + let size = O::from_usize(size).ok_or(ArrowError::KeyOverflowError)?; // todo: make this error + assert!(size >= *self.offsets.last().unwrap()); + + self.offsets.push(size); + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.push(self.last_offset()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + #[inline] + fn last_offset(&self) -> O { + *self.offsets.last().unwrap() + } + + fn init_validity(&mut self) { + let len = self.offsets.len() - 1; + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableListArray { + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + Box::new(ListArray::from_data( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_arc(), + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn as_arc(&mut self) -> Arc { + Arc::new(ListArray::from_data( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_arc(), + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index bb90b1703931..2539715a719a 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -230,6 +230,7 @@ pub mod variable; // re-export dependencies from arrow-rs to minimise version maintenance for crate users pub use arrow; +mod arrow_dyn_list_array; mod arrow_temporal_util; #[cfg(test)] diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 98950b9cd0fa..59d2a4a34e97 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -146,13 +146,12 @@ impl PhysicalSortExpr { pub fn exprs_to_sort_columns( batch: &RecordBatch, expr: &[PhysicalSortExpr], -) -> Result> { +) -> Result> { let columns = expr .iter() .map(|e| e.evaluate_to_sort_column(&batch)) .collect::>>() .map_err(DataFusionError::into_arrow_external_error)?; - let columns = columns.iter().map(|x| x.into()).collect::>(); Ok(columns) } diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 48cd5bcada7a..f5d20202a8c2 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -17,7 +17,6 @@ //! Functionality used both on logical and physical plans -use std::collections::HashSet; use std::sync::Arc; pub use ahash::{CallHasher, RandomState}; @@ -26,11 +25,9 @@ use arrow::array::{ Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::datatypes::{DataType, TimeUnit}; use crate::error::{DataFusionError, Result}; -use crate::logical_plan::JoinType; -use crate::physical_plan::expressions::Column; // Combines two hashes into one hash #[inline] diff --git a/datafusion/src/physical_plan/joins/hash_join.rs b/datafusion/src/physical_plan/joins/hash_join.rs index 35ec39d3749c..9b2508ab4218 100644 --- a/datafusion/src/physical_plan/joins/hash_join.rs +++ b/datafusion/src/physical_plan/joins/hash_join.rs @@ -42,7 +42,7 @@ use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::joins::{ build_join_schema, check_join_is_valid, column_indices_from_schema, equal_rows, - JoinOn, + ColumnIndex, JoinOn, }; use crate::physical_plan::PhysicalExpr; use crate::physical_plan::{ @@ -152,14 +152,6 @@ pub enum PartitionMode { CollectLeft, } -/// Information about the index and placement (left or right) of the columns -struct ColumnIndex { - /// Index of the column - index: usize, - /// Whether the column is at the left or right side - is_left: bool, -} - impl HashJoinExec { /// Tries to create a new [HashJoinExec]. /// # Error diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index ec788e366d76..aed2da6e4453 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -23,11 +23,16 @@ use crate::error::{DataFusionError, Result}; use crate::logical_plan::JoinType; use crate::physical_plan::expressions::Column; use arrow::array::ArrayRef; +use arrow::array::*; use arrow::datatypes::{DataType, Field, Schema}; +use arrow::error::Result as ArrowResult; use std::cmp::Ordering; use std::collections::HashSet; use std::sync::Arc; +type StringArray = Utf8Array; +type LargeStringArray = Utf8Array; + /// The on clause of the join, as vector of (left, right) columns. pub type JoinOn = Vec<(Column, Column)>; /// Reference for JoinOn. @@ -105,6 +110,14 @@ pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType) -> Schema::new(fields) } +/// Information about the index and placement (left or right) of the columns +struct ColumnIndex { + /// Index of the column + index: usize, + /// Whether the column is at the left or right side + is_left: bool, +} + /// Calculates column indices and left/right placement on input / output schemas and jointype pub fn column_indices_from_schema( join_type: &JoinType, @@ -193,7 +206,7 @@ fn equal_rows( } macro_rules! cmp_rows_elem { - ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{ + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $res: ident) => {{ let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); @@ -203,7 +216,7 @@ macro_rules! cmp_rows_elem { .value($left) .partial_cmp(&right_array.value($right))?; if cmp != Ordering::Equal { - res = cmp; + $res = cmp; break; } } @@ -223,20 +236,22 @@ fn comp_rows( for (l, r) in left_arrays.iter().zip(right_arrays) { match l.data_type() { DataType::Null => {} - DataType::Boolean => cmp_rows_elem!(BooleanArray, l, r, left, right), - DataType::Int8 => cmp_rows_elem!(Int8Array, l, r, left, right), - DataType::Int16 => cmp_rows_elem!(Int16Array, l, r, left, right), - DataType::Int32 => cmp_rows_elem!(Int32Array, l, r, left, right), - DataType::Int64 => cmp_rows_elem!(Int64Array, l, r, left, right), - DataType::UInt8 => cmp_rows_elem!(UInt8Array, l, r, left, right), - DataType::UInt16 => cmp_rows_elem!(UInt16Array, l, r, left, right), - DataType::UInt32 => cmp_rows_elem!(UInt32Array, l, r, left, right), - DataType::UInt64 => cmp_rows_elem!(UInt64Array, l, r, left, right), + DataType::Boolean => cmp_rows_elem!(BooleanArray, l, r, left, right, res), + DataType::Int8 => cmp_rows_elem!(Int8Array, l, r, left, right, res), + DataType::Int16 => cmp_rows_elem!(Int16Array, l, r, left, right, res), + DataType::Int32 => cmp_rows_elem!(Int32Array, l, r, left, right, res), + DataType::Int64 => cmp_rows_elem!(Int64Array, l, r, left, right, res), + DataType::UInt8 => cmp_rows_elem!(UInt8Array, l, r, left, right, res), + DataType::UInt16 => cmp_rows_elem!(UInt16Array, l, r, left, right, res), + DataType::UInt32 => cmp_rows_elem!(UInt32Array, l, r, left, right, res), + DataType::UInt64 => cmp_rows_elem!(UInt64Array, l, r, left, right, res), DataType::Timestamp(_, None) => { - cmp_rows_elem!(Int64Array, l, r, left, right) + cmp_rows_elem!(Int64Array, l, r, left, right, res) + } + DataType::Utf8 => cmp_rows_elem!(StringArray, l, r, left, right, res), + DataType::LargeUtf8 => { + cmp_rows_elem!(LargeStringArray, l, r, left, right, res) } - DataType::Utf8 => cmp_rows_elem!(StringArray, l, r, left, right), - DataType::LargeUtf8 => cmp_rows_elem!(LargeStringArray, l, r, left, right), _ => { // This is internal because we should have caught this before. return Err(DataFusionError::Internal( diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index e3ebc3a08b0e..05a3f5eec82a 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -18,53 +18,42 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. -use std::fmt; use std::iter::repeat; use std::sync::Arc; +use std::vec; use std::{any::Any, usize}; -use std::{time::Instant, vec}; -use arrow::compute::take; +use arrow::array::*; use arrow::datatypes::*; -use arrow::error::{ArrowError, Result as ArrowResult}; +use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use arrow::{array::*, buffer::MutableBuffer}; use async_trait::async_trait; -use futures::{Stream, StreamExt, TryStreamExt}; -use log::debug; -use smallvec::{smallvec, SmallVec}; -use tokio::sync::Mutex; +use futures::StreamExt; +use crate::arrow_dyn_list_array::DynMutableListArray; use crate::error::{DataFusionError, Result}; use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::runtime_env::RUNTIME_ENV; use crate::logical_plan::JoinType; -use crate::physical_plan::coalesce_batches::concat_batches; -use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; use crate::physical_plan::joins::{ build_join_schema, check_join_is_valid, column_indices_from_schema, comp_rows, - equal_rows, JoinOn, + equal_rows, ColumnIndex, JoinOn, }; use crate::physical_plan::sorts::external_sort::ExternalSortExec; use crate::physical_plan::stream::RecordBatchReceiverStream; -use crate::physical_plan::PhysicalExpr; +use crate::physical_plan::Statistics; use crate::physical_plan::{ expressions::Column, metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, }; -use crate::physical_plan::{hash_utils::create_hashes, Statistics}; use crate::physical_plan::{ - DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, + DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }; -use arrow::array::growable::GrowablePrimitive; use arrow::compute::partition::lexicographical_partition_ranges; -use arrow::compute::sort::SortOptions; use std::cmp::Ordering; use std::collections::VecDeque; use std::ops::Range; -use std::pin::Pin; -use std::task::{Context, Poll}; use tokio::sync::mpsc::{Receiver, Sender}; type StringArray = Utf8Array; @@ -84,8 +73,10 @@ impl PartitionedRecordBatch { match batch { Some(batch) => { let columns = exprs_to_sort_columns(&batch, expr)?; - let ranges = - lexicographical_partition_ranges(&columns)?.collect::>(); + let ranges = lexicographical_partition_ranges( + &columns.iter().map(|x| x.into()).collect::>(), + )? + .collect::>(); Ok(Some(Self { batch, ranges })) } None => Ok(None), @@ -103,7 +94,7 @@ struct StreamingBatch { cur_row: usize, cur_range: usize, num_rows: usize, - num_ranges: uszie, + num_ranges: usize, is_new_key: bool, on_column: Vec, sort: Vec, @@ -419,10 +410,10 @@ fn make_batch( } macro_rules! repeat_n { - ($TO:ty, $FROM:ty, $N:expr) => {{ - let to = to.as_mut_any().downcast_mut::<$TO>().unwrap(); - let from = from.as_any().downcast_ref::<$FROM>().unwrap(); - let repeat_iter = from.slice(idx, 1).iter().flat_map(|v| repeat(v).take($N)); + ($TO:ty, $FROM:ty, $N:expr, $to: ident, $from: ident, $idx: ident) => {{ + let to = $to.as_mut_any().downcast_mut::<$TO>().unwrap(); + let from = $from.as_any().downcast_ref::<$FROM>().unwrap(); + let repeat_iter = $from.slice($idx, 1).iter().flat_map(|v| repeat(v).take($N)); to.extend_trusted_len(repeat_iter); }}; } @@ -438,36 +429,54 @@ fn repeat_streamed_cell( let from = stream_batch.column(column_index.index); match to.data_type().to_physical_type() { PhysicalType::Boolean => { - repeat_n!(MutableBooleanArray, BooleanArray, times) + repeat_n!(MutableBooleanArray, BooleanArray, times, to, from, idx) } PhysicalType::Primitive(primitive) => match primitive { - PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, times), - PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, times), - PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, times), - PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, times), - PrimitiveType::Float32 => repeat_n!(Float32Vec, Float32Array, times), - PrimitiveType::Float64 => repeat_n!(Float64Vec, Float64Array, times), + PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, times, to, from, idx), + PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, times, to, from, idx), + PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, times, to, from, idx), + PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, times, to, from, idx), + PrimitiveType::Float32 => { + repeat_n!(Float32Vec, Float32Array, times, to, from, idx) + } + PrimitiveType::Float64 => { + repeat_n!(Float64Vec, Float64Array, times, to, from, idx) + } _ => todo!(), }, PhysicalType::Utf8 => { - repeat_n!(MutableUtf8Array, Utf8Array, times) + repeat_n!(MutableUtf8Array, Utf8Array, times, to, from, idx) } PhysicalType::Binary => { - repeat_n!(MutableBinaryArray, BinaryArray, times) + repeat_n!( + MutableBinaryArray, + BinaryArray, + times, + to, + from, + idx + ) } PhysicalType::FixedSizeBinary => { - repeat_n!(MutableFixedSizeBinaryArray, FixedSizeBinaryArray, times) + repeat_n!( + MutableFixedSizeBinaryArray, + FixedSizeBinaryArray, + times, + to, + from, + idx + ) } _ => todo!(), } } macro_rules! copy_slices { - ($TO:ty, $FROM:ty) => {{ - let to = array.as_mut_any().downcast_mut::<$TO>().unwrap(); - for pos in slices { - let from = batches[pos.batch_idx] - .column(column_index.index) + ($TO:ty, $FROM:ty, $array: ident, $batches: ident, $slices: ident, $column_index: ident) => {{ + let to = $array.as_mut_any().downcast_mut::<$TO>().unwrap(); + for pos in $slices { + let from = $batches[pos.batch_idx] + .column($column_index.index) .slice(pos.start_idx, pos.len); let from = from.as_any().downcast_ref::<$FROM>().unwrap(); to.extend_trusted_len(from.iter()); @@ -482,27 +491,77 @@ fn copy_slices( column_index: &ColumnIndex, ) { // output buffered start `buffered_idx`, len `rows_to_output` - match to.data_type().to_physical_type() { + match array.data_type().to_physical_type() { PhysicalType::Boolean => { - copy_slices!(MutableBooleanArray, BooleanArray) + copy_slices!( + MutableBooleanArray, + BooleanArray, + array, + batches, + slices, + column_index + ) } PhysicalType::Primitive(primitive) => match primitive { - PrimitiveType::Int8 => copy_slices!(Int8Vec, Int8Array), - PrimitiveType::Int16 => copy_slices!(Int16Vec, Int16Array), - PrimitiveType::Int32 => copy_slices!(Int32Vec, Int32Array), - PrimitiveType::Int64 => copy_slices!(Int64Vec, Int64Array), - PrimitiveType::Float32 => copy_slices!(Float32Vec, Float32Array), - PrimitiveType::Float64 => copy_slices!(Float64Vec, Float64Array), + PrimitiveType::Int8 => { + copy_slices!(Int8Vec, Int8Array, array, batches, slices, column_index) + } + PrimitiveType::Int16 => { + copy_slices!(Int16Vec, Int16Array, array, batches, slices, column_index) + } + PrimitiveType::Int32 => { + copy_slices!(Int32Vec, Int32Array, array, batches, slices, column_index) + } + PrimitiveType::Int64 => { + copy_slices!(Int64Vec, Int64Array, array, batches, slices, column_index) + } + PrimitiveType::Float32 => copy_slices!( + Float32Vec, + Float32Array, + array, + batches, + slices, + column_index + ), + PrimitiveType::Float64 => copy_slices!( + Float64Vec, + Float64Array, + array, + batches, + slices, + column_index + ), _ => todo!(), }, PhysicalType::Utf8 => { - copy_slices!(MutableUtf8Array, Utf8Array) + copy_slices!( + MutableUtf8Array, + Utf8Array, + array, + batches, + slices, + column_index + ) } PhysicalType::Binary => { - copy_slices!(MutableBinaryArray, BinaryArray) + copy_slices!( + MutableBinaryArray, + BinaryArray, + array, + batches, + slices, + column_index + ) } PhysicalType::FixedSizeBinary => { - copy_slices!(MutableFixedSizeBinaryArray, FixedSizeBinaryArray) + copy_slices!( + MutableFixedSizeBinaryArray, + FixedSizeBinaryArray, + array, + batches, + slices, + column_index + ) } _ => todo!(), } @@ -607,7 +666,7 @@ impl SortMergeJoinDriver { let mut output_slots_available = target_batch_size; let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; - while self.find_next_inner_match()? { + while self.find_next_inner_match().await? { loop { let result = self .join_eq_records( @@ -639,9 +698,15 @@ impl SortMergeJoinDriver { let mut output_slots_available = target_batch_size; let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; - while self.find_next_inner_match()? { - let result = self.stream_copy_buffer_omit( - target_batch_size, output_slots_available, output_arrays, sender).await?; + while self.find_next_inner_match().await? { + let result = self + .stream_copy_buffer_omit( + target_batch_size, + output_slots_available, + output_arrays, + sender, + ) + .await?; output_slots_available = result.0; output_arrays = result.1; } @@ -758,10 +823,10 @@ impl SortMergeJoinDriver { loop { if advance_buffer { - buffer_ends = !self.advance_buffered_key()?; + buffer_ends = !self.advance_buffered_key().await?; } if advance_stream { - stream_ends = !self.advance_streamed_key()?; + stream_ends = !self.advance_streamed_key().await?; } if stream_ends && buffer_ends { @@ -1172,16 +1237,16 @@ impl SortMergeJoinDriver { Ok((output_slots_available, output_arrays)) } - fn find_next_inner_match(&mut self) -> Result { + async fn find_next_inner_match(&mut self) -> Result { if self.stream_batch.key_any_null() { - let more_stream = self.advance_streamed_key_null_free()?; + let more_stream = self.advance_streamed_key_null_free().await?; if !more_stream { return Ok(false); } } if self.buffered_batches.key_any_null() { - let more_buffer = self.advance_buffered_key_null_free()?; + let more_buffer = self.advance_buffered_key_null_free().await?; if !more_buffer { return Ok(false); } @@ -1191,14 +1256,14 @@ impl SortMergeJoinDriver { let current_cmp = self.compare_stream_buffer()?; match current_cmp { Ordering::Less => { - let more_stream = self.advance_streamed_key_null_free()?; + let more_stream = self.advance_streamed_key_null_free().await?; if !more_stream { return Ok(false); } } Ordering::Equal => return Ok(true), Ordering::Greater => { - let more_buffer = self.advance_buffered_key_null_free()?; + let more_buffer = self.advance_buffered_key_null_free().await?; if !more_buffer { return Ok(false); } @@ -1276,9 +1341,9 @@ impl SortMergeJoinDriver { } /// true for has next, false for ended - fn advance_streamed(&mut self) -> Result { + async fn advance_streamed(&mut self) -> Result { if self.stream_batch.is_finished() { - self.get_stream_next()?; + self.get_stream_next().await?; Ok(!self.stream_batch.is_finished()) } else { self.stream_batch.advance(); @@ -1287,9 +1352,9 @@ impl SortMergeJoinDriver { } /// true for has next, false for ended - fn advance_streamed_key(&mut self) -> Result { + async fn advance_streamed_key(&mut self) -> Result { if self.stream_batch.is_finished() || self.stream_batch.is_last_key_in_batch() { - self.get_stream_next()?; + self.get_stream_next().await?; Ok(!self.stream_batch.is_finished()) } else { self.stream_batch.advance_key(); @@ -1298,11 +1363,11 @@ impl SortMergeJoinDriver { } /// true for has next, false for ended - fn advance_streamed_key_null_free(&mut self) -> Result { - let mut more_stream_keys = self.advance_streamed_key()?; + async fn advance_streamed_key_null_free(&mut self) -> Result { + let mut more_stream_keys = self.advance_streamed_key().await?; loop { if more_stream_keys && self.stream_batch.key_any_null() { - more_stream_keys = self.advance_streamed_key()?; + more_stream_keys = self.advance_streamed_key().await?; } else { break; } @@ -1310,11 +1375,11 @@ impl SortMergeJoinDriver { Ok(more_stream_keys) } - fn advance_buffered_key_null_free(&mut self) -> Result { - let mut more_buffered_keys = self.advance_buffered_key()?; + async fn advance_buffered_key_null_free(&mut self) -> Result { + let mut more_buffered_keys = self.advance_buffered_key().await?; loop { if more_buffered_keys && self.buffered_batches.key_any_null() { - more_buffered_keys = self.advance_buffered_key()?; + more_buffered_keys = self.advance_buffered_key().await?; } else { break; } @@ -1323,11 +1388,11 @@ impl SortMergeJoinDriver { } /// true for has next, false for ended - fn advance_buffered_key(&mut self) -> Result { + async fn advance_buffered_key(&mut self) -> Result { if self.buffered_batches.is_finished() { match &self.buffered_batches.next_key_batch { None => { - let batch = self.get_buffered_next()?; + let batch = self.get_buffered_next().await?; match batch { None => return Ok(false), Some(batch) => { @@ -1350,15 +1415,15 @@ impl SortMergeJoinDriver { if self.buffered_batches.batches[0] .is_last_range(&self.buffered_batches.ranges[0]) { - self.cumulate_same_keys()?; + self.cumulate_same_keys().await?; } } Ok(false) } /// true for has next, false for buffer side ended - fn cumulate_same_keys(&mut self) -> Result { - let batch = self.get_buffered_next()?; + async fn cumulate_same_keys(&mut self) -> Result { + let batch = self.get_buffered_next().await?; match batch { None => Ok(false), Some(batch) => { @@ -1386,14 +1451,14 @@ impl SortMergeJoinDriver { ) } - fn get_stream_next(&mut self) -> Result<()> { + async fn get_stream_next(&mut self) -> Result<()> { let batch = self.streamed.next().await.transpose()?; let prb = PartitionedRecordBatch::new(batch, &self.stream_batch.sort)?; self.stream_batch.rest_batch(prb); Ok(()) } - fn get_buffered_next(&mut self) -> Result> { + async fn get_buffered_next(&mut self) -> Result> { let batch = self.buffered.next().await.transpose()?; PartitionedRecordBatch::new(batch, &self.buffered_batches.sort) } @@ -1462,14 +1527,6 @@ impl SortMergeJoinMetrics { } } -/// Information about the index and placement (left or right) of the columns -struct ColumnIndex { - /// Index of the column - index: usize, - /// Whether the column is at the left or right side - is_left: bool, -} - impl SortMergeJoinExec { /// Tries to create a new [SortMergeJoinExec]. /// # Error @@ -1627,7 +1684,7 @@ impl ExecutionPlan for SortMergeJoinExec { JoinType::Anti => driver.anti_join_driver(&tx).await?, } - let result = RecordBatchReceiverStream::create(&schema, rx); + let result = RecordBatchReceiverStream::create(&self.schema, rx); Ok(Box::pin(result)) } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index d5579c1ebb24..3088d787d4dd 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -47,7 +47,7 @@ use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::udf; use crate::physical_plan::windows::WindowAggExec; -use crate::physical_plan::{hash_utils, joins, Partitioning}; +use crate::physical_plan::{joins, Partitioning}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr}; use crate::scalar::ScalarValue; use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys}; diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index 5e4c33c39472..41613920bfc5 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -26,7 +26,6 @@ use crate::physical_plan::metrics::{ use crate::physical_plan::{ common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, Statistics, }; -use arrow::compute::sort::SortColumn; pub use arrow::compute::sort::SortOptions; use arrow::compute::{sort::lexsort_to_indices, take}; use arrow::datatypes::SchemaRef; @@ -193,7 +192,10 @@ pub fn sort_batch( expr: &[PhysicalSortExpr], ) -> ArrowResult { let columns = exprs_to_sort_columns(&batch, expr)?; - let indices = lexsort_to_indices::(&columns, None)?; + let indices = lexsort_to_indices::( + &columns.iter().map(|x| x.into()).collect::>(), + None, + )?; // reorder all rows based on sorted indices RecordBatch::try_new( From aaabc2c8abfbe76dd33826a69344dff39572538c Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 16 Nov 2021 21:07:47 +0800 Subject: [PATCH 12/15] v1 --- .../src/physical_plan/expressions/mod.rs | 5 ++- datafusion/src/physical_plan/joins/mod.rs | 2 +- .../physical_plan/joins/sort_merge_join.rs | 36 +++++++++---------- datafusion/src/physical_plan/sorts/sort.rs | 3 +- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 59d2a4a34e97..6607144657d8 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -150,9 +150,8 @@ pub fn exprs_to_sort_columns( let columns = expr .iter() .map(|e| e.evaluate_to_sort_column(&batch)) - .collect::>>() - .map_err(DataFusionError::into_arrow_external_error)?; - Ok(columns) + .collect::>>(); + columns } #[cfg(test)] diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index aed2da6e4453..175cf8c88009 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -214,7 +214,7 @@ macro_rules! cmp_rows_elem { (false, false) => { let cmp = left_array .value($left) - .partial_cmp(&right_array.value($right))?; + .partial_cmp(&right_array.value($right)).unwrap(); if cmp != Ordering::Equal { $res = cmp; break; diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 05a3f5eec82a..ca72cd28757b 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -311,7 +311,7 @@ impl BufferedBatches { } fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { - on_column.iter().map(|c| rb.column(c.index())).collect() + on_column.iter().map(|c| rb.column(c.index()).clone()).collect() } struct SortMergeJoinDriver { @@ -413,8 +413,8 @@ macro_rules! repeat_n { ($TO:ty, $FROM:ty, $N:expr, $to: ident, $from: ident, $idx: ident) => {{ let to = $to.as_mut_any().downcast_mut::<$TO>().unwrap(); let from = $from.as_any().downcast_ref::<$FROM>().unwrap(); - let repeat_iter = $from.slice($idx, 1).iter().flat_map(|v| repeat(v).take($N)); - to.extend_trusted_len(repeat_iter); + let repeat_iter = $from.slice($idx, 1).iter().flat_map(|v| repeat(v).take($N)).collect::>(); + to.extend_trusted_len(repeat_iter.into_iter()); }}; } @@ -575,8 +575,7 @@ fn range_start_indices(buffered_ranges: &VecDeque>) -> Vec { .for_each(|r| { start_indices.push(idx); idx += range_len(r); - }) - .collect(); + }); start_indices.push(usize::MAX); start_indices } @@ -729,7 +728,7 @@ impl SortMergeJoinDriver { get_match, buffered_ended, more_output, - } = self.find_next_outer(buffer_ends)?; + } = self.find_next_outer(buffer_ends).await?; if !more_output { break; } @@ -1080,7 +1079,7 @@ impl SortMergeJoinDriver { if column_index.is_left { copy_slices(&batch, &slice, array, column_index); } else { - (0..rows_to_output).for_each(array.push_null()); + (0..rows_to_output).for_each(|_| array.push_null()); } }); @@ -1213,7 +1212,7 @@ impl SortMergeJoinDriver { .zip(self.column_indices.iter()) .map(|((array, field), column_index)| { if column_index.is_left { - (0..rows_to_output).for_each(array.push_null()); + (0..rows_to_output).for_each(|_| array.push_null()); } else { // copy buffered start from: `buffered_idx`, len: `rows_to_output` copy_slices(&batches, &slices, array, column_index); @@ -1272,8 +1271,8 @@ impl SortMergeJoinDriver { } } - fn find_next_outer(&mut self, buffer_ends: bool) -> Result { - let more_stream = self.advance_streamed_key()?; + async fn find_next_outer(&mut self, buffer_ends: bool) -> Result { + let more_stream = self.advance_streamed_key().await?; if buffer_ends { return Ok(OuterMatchResult { get_match: false, @@ -1290,7 +1289,7 @@ impl SortMergeJoinDriver { } if self.buffered_batches.key_any_null() { - let more_buffer = self.advance_buffered_key_null_free()?; + let more_buffer = self.advance_buffered_key_null_free().await?; if !more_buffer { return Ok(OuterMatchResult { get_match: false, @@ -1326,7 +1325,7 @@ impl SortMergeJoinDriver { }) } Ordering::Greater => { - let more_buffer = self.advance_buffered_key_null_free()?; + let more_buffer = self.advance_buffered_key_null_free().await?; if !more_buffer { return Ok(OuterMatchResult { get_match: false, @@ -1389,7 +1388,7 @@ impl SortMergeJoinDriver { /// true for has next, false for ended async fn advance_buffered_key(&mut self) -> Result { - if self.buffered_batches.is_finished() { + if self.buffered_batches.is_finished()? { match &self.buffered_batches.next_key_batch { None => { let batch = self.get_buffered_next().await?; @@ -1397,8 +1396,8 @@ impl SortMergeJoinDriver { None => return Ok(false), Some(batch) => { self.buffered_batches.reset_batch(&batch); - if &batch.ranges.len() == 1 { - self.cumulate_same_keys()?; + if batch.ranges.len() == 1 { + self.cumulate_same_keys().await?; } } } @@ -1406,7 +1405,7 @@ impl SortMergeJoinDriver { Some(batch) => { self.buffered_batches.reset_batch(batch); if batch.ranges.len() == 1 { - self.cumulate_same_keys()?; + self.cumulate_same_keys().await?; } } } @@ -1429,7 +1428,7 @@ impl SortMergeJoinDriver { Some(batch) => { let more_batches = self.buffered_batches.running_key(&batch)?; if more_batches { - self.cumulate_same_keys() + self.cumulate_same_keys().await } else { // reach end of current key, but the stream continues Ok(true) @@ -1686,7 +1685,7 @@ impl ExecutionPlan for SortMergeJoinExec { let result = RecordBatchReceiverStream::create(&self.schema, rx); - Ok(Box::pin(result)) + Ok(result) } fn fmt_as( @@ -1730,6 +1729,7 @@ mod tests { }; use super::*; + use crate::physical_plan::PhysicalExpr; fn build_table( a: (&str, &Vec), diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index 41613920bfc5..48c72fb0026d 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -191,7 +191,8 @@ pub fn sort_batch( schema: SchemaRef, expr: &[PhysicalSortExpr], ) -> ArrowResult { - let columns = exprs_to_sort_columns(&batch, expr)?; + let columns = exprs_to_sort_columns(&batch, expr) + .map_err(DataFusionError::into_arrow_external_error)?; let indices = lexsort_to_indices::( &columns.iter().map(|x| x.into()).collect::>(), None, From faaa7f5868349a4087763bd598d7344cd4951abb Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 17 Nov 2021 14:18:25 +0800 Subject: [PATCH 13/15] minor --- datafusion/src/physical_plan/joins/mod.rs | 28 ++++- .../physical_plan/joins/sort_merge_join.rs | 106 +++++++++--------- 2 files changed, 78 insertions(+), 56 deletions(-) diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index 175cf8c88009..c448205a182d 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -214,7 +214,29 @@ macro_rules! cmp_rows_elem { (false, false) => { let cmp = left_array .value($left) - .partial_cmp(&right_array.value($right)).unwrap(); + .partial_cmp(&right_array.value($right)) + .unwrap(); + if cmp != Ordering::Equal { + $res = cmp; + break; + } + } + _ => unreachable!(), + } + }}; +} + +macro_rules! cmp_rows_elem_str { + ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $res: ident) => {{ + let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap(); + let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap(); + + match (left_array.is_null($left), right_array.is_null($right)) { + (false, false) => { + let cmp = left_array + .value($left) + .partial_cmp(right_array.value($right)) + .unwrap(); if cmp != Ordering::Equal { $res = cmp; break; @@ -248,9 +270,9 @@ fn comp_rows( DataType::Timestamp(_, None) => { cmp_rows_elem!(Int64Array, l, r, left, right, res) } - DataType::Utf8 => cmp_rows_elem!(StringArray, l, r, left, right, res), + DataType::Utf8 => cmp_rows_elem_str!(StringArray, l, r, left, right, res), DataType::LargeUtf8 => { - cmp_rows_elem!(LargeStringArray, l, r, left, right, res) + cmp_rows_elem_str!(LargeStringArray, l, r, left, right, res) } _ => { // This is internal because we should have caught this before. diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index ca72cd28757b..28ee9dd5025f 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -311,7 +311,50 @@ impl BufferedBatches { } fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { - on_column.iter().map(|c| rb.column(c.index()).clone()).collect() + on_column + .iter() + .map(|c| rb.column(c.index()).clone()) + .collect() +} + +struct OutputBuffer { + arrays: Vec>, + target_batch_size: usize, + slots_available: usize, + schema: Arc, +} + +impl OutputBuffer { + fn new(target_batch_size: usize, schema: Arc) -> Result { + let arrays = new_arrays(&schema, target_batch_size)?; + Ok(Self { + arrays, + target_batch_size, + slots_available: target_batch_size, + schema, + }) + } + + fn output_and_reset(&mut self) -> Option> { + if self.is_full() { + let result = make_batch(self.schema.clone(), output_arrays); + self.arrays = new_arrays(&self.schema, self.target_batch_size)?; + self.slots_available = self.target_batch_size; + Some(result) + } else { + None + } + } + + fn append(&mut self, size: usize) { + assert!(size <= self.slots_available); + self.slots_available -= size; + } + + #[inline] + fn is_full(&self) -> bool { + self.slots_available == 0 + } } struct SortMergeJoinDriver { @@ -413,7 +456,11 @@ macro_rules! repeat_n { ($TO:ty, $FROM:ty, $N:expr, $to: ident, $from: ident, $idx: ident) => {{ let to = $to.as_mut_any().downcast_mut::<$TO>().unwrap(); let from = $from.as_any().downcast_ref::<$FROM>().unwrap(); - let repeat_iter = $from.slice($idx, 1).iter().flat_map(|v| repeat(v).take($N)).collect::>(); + let repeat_iter = from + .slice($idx, 1) + .iter() + .flat_map(|v| repeat(v).take($N)) + .collect::>(); to.extend_trusted_len(repeat_iter.into_iter()); }}; } @@ -447,26 +494,6 @@ fn repeat_streamed_cell( PhysicalType::Utf8 => { repeat_n!(MutableUtf8Array, Utf8Array, times, to, from, idx) } - PhysicalType::Binary => { - repeat_n!( - MutableBinaryArray, - BinaryArray, - times, - to, - from, - idx - ) - } - PhysicalType::FixedSizeBinary => { - repeat_n!( - MutableFixedSizeBinaryArray, - FixedSizeBinaryArray, - times, - to, - from, - idx - ) - } _ => todo!(), } } @@ -543,26 +570,6 @@ fn copy_slices( column_index ) } - PhysicalType::Binary => { - copy_slices!( - MutableBinaryArray, - BinaryArray, - array, - batches, - slices, - column_index - ) - } - PhysicalType::FixedSizeBinary => { - copy_slices!( - MutableFixedSizeBinaryArray, - FixedSizeBinaryArray, - array, - batches, - slices, - column_index - ) - } _ => todo!(), } } @@ -570,12 +577,10 @@ fn copy_slices( fn range_start_indices(buffered_ranges: &VecDeque>) -> Vec { let mut idx = 0; let mut start_indices: Vec = vec![]; - buffered_ranges - .iter() - .for_each(|r| { - start_indices.push(idx); - idx += range_len(r); - }); + buffered_ranges.iter().for_each(|r| { + start_indices.push(idx); + idx += range_len(r); + }); start_indices.push(usize::MAX); start_indices } @@ -660,11 +665,6 @@ impl SortMergeJoinDriver { &mut self, sender: &Sender>, ) -> Result<()> { - let target_batch_size = self.runtime.batch_size(); - - let mut output_slots_available = target_batch_size; - let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; - while self.find_next_inner_match().await? { loop { let result = self From e41224c75e62bcc5cfaf710d378472c80de156dc Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 17 Nov 2021 19:25:26 +0800 Subject: [PATCH 14/15] v2 wip --- datafusion/src/physical_plan/joins/mod.rs | 1 + .../src/physical_plan/joins/smj_utils.rs | 553 ++++++++++++++++++ .../physical_plan/joins/sort_merge_join.rs | 369 +----------- 3 files changed, 555 insertions(+), 368 deletions(-) create mode 100644 datafusion/src/physical_plan/joins/smj_utils.rs diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index c448205a182d..73e0991dfd90 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod smj_utils; pub mod cross_join; pub mod hash_join; pub mod sort_merge_join; diff --git a/datafusion/src/physical_plan/joins/smj_utils.rs b/datafusion/src/physical_plan/joins/smj_utils.rs new file mode 100644 index 000000000000..ad167e853820 --- /dev/null +++ b/datafusion/src/physical_plan/joins/smj_utils.rs @@ -0,0 +1,553 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::DataFusionError; +use crate::physical_plan::expressions::{ + exprs_to_sort_columns, Column, PhysicalSortExpr, +}; +use crate::physical_plan::joins::equal_rows; +use crate::physical_plan::SendableRecordBatchStream; +use arrow::array::ArrayRef; +use arrow::compute::partition::lexicographical_partition_ranges; +use arrow::error::ArrowError; +use arrow::error::Result as ArrowResult; +use arrow::record_batch::RecordBatch; +use futures::StreamExt; +use std::collections::VecDeque; +use std::ops::Range; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { + on_column + .iter() + .map(|c| rb.column(c.index()).clone()) + .collect() +} + +#[derive(Clone)] +struct PartitionedRecordBatch { + batch: RecordBatch, + ranges: Vec>, +} + +impl PartitionedRecordBatch { + fn new( + batch: Option, + expr: &[PhysicalSortExpr], + ) -> ArrowResult> { + match batch { + Some(batch) => { + let columns = exprs_to_sort_columns(&batch, expr) + .map_err(DataFusionError::into_arrow_external_error)?; + let ranges = lexicographical_partition_ranges( + &columns.iter().map(|x| x.into()).collect::>(), + )? + .collect::>(); + Ok(Some(Self { batch, ranges })) + } + None => Ok(None), + } + } + + #[inline] + pub fn is_last_range(&self, range: &Range) -> bool { + range.end == self.batch.num_rows() + } +} + +struct StreamingSideBuffer { + batch: Option, + cur_row: usize, + cur_range: usize, + num_rows: usize, + num_ranges: usize, + is_new_key: bool, + on_column: Vec, + sort: Vec, +} + +impl StreamingSideBuffer { + fn new(on_column: Vec, sort: Vec) -> Self { + Self { + batch: None, + cur_row: 0, + cur_range: 0, + num_rows: 0, + num_ranges: 0, + is_new_key: true, + on_column, + sort, + } + } + + fn reset(&mut self, prb: Option) { + self.batch = prb; + if let Some(prb) = &self.batch { + self.cur_row = 0; + self.cur_range = 0; + self.num_rows = prb.batch.num_rows(); + self.num_ranges = prb.ranges.len(); + self.is_new_key = true; + }; + } + + fn key_any_null(&self) -> bool { + match &self.batch { + None => return true, + Some(batch) => { + for c in self.on_column { + let array = batch.batch.column(c.index()); + if array.is_null(self.cur_row) { + return true; + } + } + false + } + } + } + + #[inline] + fn is_finished(&self) -> bool { + self.batch.is_none() || self.num_rows == self.cur_row + 1 + } + + #[inline] + fn is_last_key_in_batch(&self) -> bool { + self.batch.is_none() || self.num_ranges == self.cur_range + 1 + } + + fn advance(&mut self) { + self.cur_row += 1; + self.is_new_key = false; + if !self.is_last_key_in_batch() { + let ranges = self.batch.unwrap().ranges; + if self.cur_row == ranges[self.cur_range + 1].start { + self.cur_range += 1; + self.is_new_key = true; + } + } + } + + fn advance_key(&mut self) { + let ranges = self.batch.unwrap().ranges; + self.cur_range += 1; + self.cur_row = ranges[self.cur_range].start; + self.is_new_key = true; + } +} + +struct StreamingSideStream { + input: SendableRecordBatchStream, + output: StreamingSideBuffer, + input_is_finished: bool, + sort: Vec, +} + +impl StreamingSideStream { + fn inner_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.input_is_finished { + Poll::Ready(None) + } else { + match self.input.poll_next_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => { + self.input_is_finished = true; + Poll::Ready(None) + } + batch => { + let batch = batch.transpose()?; + let prb = PartitionedRecordBatch::new(batch, &self.sort)?; + self.output.reset(prb); + Poll::Ready(Some(Ok(()))) + } + }, + } + } + } + + fn advance( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.output.is_finished() && self.input_is_finished { + Poll::Ready(None) + } else { + if self.output.is_finished() { + match self.inner_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(x) => { + x?; + if self.output.is_finished() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(()))) + } + } + }, + } + } else { + self.output.advance(); + Poll::Ready(Some(Ok(()))) + } + } + } + + fn advance_key( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.output.is_finished() && self.input_is_finished { + Poll::Ready(None) + } else { + if self.output.is_finished() || self.output.is_last_key_in_batch() { + match self.inner_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(x) => { + x?; + if self.output.is_finished() { + Poll::Ready(None) + } else { + Poll::Ready(Some(Ok(()))) + } + } + }, + } + } else { + self.output.advance_key(); + Poll::Ready(Some(Ok(()))) + } + } + } + + fn advance_key_skip_null( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.output.is_finished() && self.input_is_finished { + Poll::Ready(None) + } else { + loop { + match self.advance_key(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(x) => { + x?; + if !self.output.key_any_null() { + return Poll::Ready(Some(Ok(()))); + } + } + }, + Poll::Pending => return Poll::Pending, + } + } + } + } +} + +/// Holding ranges for same key over several bathes +struct BufferingSideBuffer { + /// batches that contains the current key + /// TODO: make this spillable as well for skew on join key at buffer side + batches: VecDeque, + /// ranges in each PartitionedRecordBatch that contains the current key + ranges: VecDeque>, + /// row index in first batch to the record that starts this batch + key_idx: Option, + /// total number of rows for the current key + row_num: usize, + /// hold found but not currently used batch, to continue iteration + next_key_batch: Option, + /// Join on column + on_column: Vec, + sort: Vec, +} + +#[inline] +pub(crate) fn range_len(range: &Range) -> usize { + range.end - range.start +} + +impl BufferingSideBuffer { + fn new(on_column: Vec, sort: Vec) -> Self { + Self { + batches: VecDeque::new(), + ranges: VecDeque::new(), + key_idx: None, + row_num: 0, + next_key_batch: None, + on_column, + sort, + } + } + + fn key_any_null(&self) -> bool { + match &self.key_idx { + None => return true, + Some(key_idx) => { + let first_batch = &self.batches[0].batch; + for c in self.on_column { + let array = first_batch.column(c.index()); + if array.is_null(*key_idx) { + return true; + } + } + false + } + } + } + + fn is_finished(&self) -> ArrowResult { + match self.key_idx { + None => Ok(true), + Some(_) => match (self.batches.back(), self.ranges.back()) { + (Some(batch), Some(range)) => Ok(batch.is_last_range(range)), + _ => Err(ArrowError::Other(format!( + "Batches length {} not equal to ranges length {}", + self.batches.len(), + self.ranges.len() + ))), + }, + } + } + + /// Whether the running key ends at the current batch `prb`, true for continues, false for ends. + fn running_key(&mut self, prb: &PartitionedRecordBatch) -> ArrowResult { + let first_range = &prb.ranges[0]; + let range_len = range_len(first_range); + let current_batch = &prb.batch; + let single_range = prb.ranges.len() == 1; + + // compare the first record in batch with the current key pointed by key_idx + match self.key_idx { + None => { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.row_num += range_len; + Ok(single_range) + } + Some(key_idx) => { + let key_arrays = join_arrays(&self.batches[0].batch, &self.on_column); + let current_arrays = join_arrays(current_batch, &self.on_column); + let equal = equal_rows(key_idx, 0, &key_arrays, ¤t_arrays) + .map_err(DataFusionError::into_arrow_external_error)?; + if equal { + self.batches.push_back(prb.clone()); + self.ranges.push_back(first_range.clone()); + self.row_num += range_len; + Ok(single_range) + } else { + self.next_key_batch = Some(prb.clone()); + Ok(false) // running key ends + } + } + } + } + + fn cleanup(&mut self) { + self.batches.drain(..); + self.ranges.drain(..); + self.next_key_batch = None; + } + + fn reset(&mut self, prb: &PartitionedRecordBatch) { + self.cleanup(); + self.batches.push_back(prb.clone()); + let first_range = &prb.ranges[0]; + self.ranges.push_back(first_range.clone()); + self.key_idx = Some(0); + self.row_num = range_len(first_range); + } + + /// Advance the cursor to the next key seen by this buffer + fn advance_key(&mut self) { + assert_eq!(self.batches.len(), self.ranges.len()); + if self.batches.len() > 1 { + self.batches.drain(0..(self.batches.len() - 1)); + self.ranges.drain(0..(self.batches.len() - 1)); + } + + if let Some(batch) = self.batches.pop_back() { + let tail_range = self.ranges.pop_back().unwrap(); + self.batches.push_back(batch); + let next_range_idx = batch + .ranges + .iter() + .enumerate() + .find(|(idx, range)| range.start == tail_range.start) + .unwrap() + .0; + self.key_idx = Some(tail_range.end); + self.ranges.push_back(batch.ranges[next_range_idx].clone()); + self.row_num = range_len(&batch.ranges[next_range_idx]); + } + } +} + +struct BufferingSideStream { + input: SendableRecordBatchStream, + output: BufferingSideBuffer, + input_is_finished: bool, + cumulating: bool, + sort: Vec, +} + +impl BufferingSideStream { + fn inner_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.input_is_finished { + Poll::Ready(None) + } else { + match self.input.poll_next_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => { + self.input_is_finished = true; + Poll::Ready(None) + } + batch => { + let batch = batch.transpose()?; + let prb = + PartitionedRecordBatch::new(batch, &self.sort).transpose(); + Poll::Ready(prb) + } + }, + } + } + } + + fn advance_key( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.output.is_finished()? && self.input_is_finished { + return Poll::Ready(None); + } else { + if self.cumulating { + match self.cumulate_same_keys(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => match x { + Some(x) => { + x?; + return Poll::Ready(Some(Ok(()))); + } + None => unreachable!(), + }, + } + } + + if self.output.is_finished()? { + return match &self.output.next_key_batch { + None => match self.inner_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(x) => { + let prb = x?; + self.output.reset(&prb); + if prb.ranges.len() == 1 { + self.cumulating = true; + Poll::Pending + } else { + Poll::Ready(Some(Ok(()))) + } + } + }, + }, + Some(batch) => { + self.output.reset(batch); + if batch.ranges.len() == 1 { + self.cumulating = true; + Poll::Pending + } else { + Poll::Ready(Some(Ok(()))) + } + } + }; + } else { + self.output.advance_key(); + if self.output.batches[0].is_last_range(&self.output.ranges[0]) { + self.cumulating = true; + } else { + return Poll::Ready(Some(Ok(()))); + } + } + } + + unreachable!() + } + + fn advance_key_skip_null( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.output.is_finished()? && self.input_is_finished { + Poll::Ready(None) + } else { + loop { + match self.advance_key(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(x) => { + x?; + if !self.output.key_any_null() { + return Poll::Ready(Some(Ok(()))); + } + } + }, + Poll::Pending => return Poll::Pending, + } + } + } + } + + fn cumulate_same_keys( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.inner_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => match x { + None => { + self.cumulating = false; + return Poll::Ready(Some(Ok(()))); + } + Some(x) => { + let prb = x?; + let buffer_more = self.output.running_key(&prb)?; + if !buffer_more { + self.cumulating = false; + return Poll::Ready(Some(Ok(()))); + } + } + }, + } + } + } +} diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 28ee9dd5025f..82c26c1acce1 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -36,6 +36,7 @@ use crate::execution::runtime_env::RuntimeEnv; use crate::execution::runtime_env::RUNTIME_ENV; use crate::logical_plan::JoinType; use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; +use crate::physical_plan::joins::smj_utils::{join_arrays, range_len}; use crate::physical_plan::joins::{ build_join_schema, check_join_is_valid, column_indices_from_schema, comp_rows, equal_rows, ColumnIndex, JoinOn, @@ -59,264 +60,6 @@ use tokio::sync::mpsc::{Receiver, Sender}; type StringArray = Utf8Array; type LargeStringArray = Utf8Array; -#[derive(Clone)] -struct PartitionedRecordBatch { - batch: RecordBatch, - ranges: Vec>, -} - -impl PartitionedRecordBatch { - fn new( - batch: Option, - expr: &[PhysicalSortExpr], - ) -> Result> { - match batch { - Some(batch) => { - let columns = exprs_to_sort_columns(&batch, expr)?; - let ranges = lexicographical_partition_ranges( - &columns.iter().map(|x| x.into()).collect::>(), - )? - .collect::>(); - Ok(Some(Self { batch, ranges })) - } - None => Ok(None), - } - } - - #[inline] - fn is_last_range(&self, range: &Range) -> bool { - range.end == self.batch.num_rows() - } -} - -struct StreamingBatch { - batch: Option, - cur_row: usize, - cur_range: usize, - num_rows: usize, - num_ranges: usize, - is_new_key: bool, - on_column: Vec, - sort: Vec, -} - -impl StreamingBatch { - fn new(on_column: Vec, sort: Vec) -> Self { - Self { - batch: None, - cur_row: 0, - cur_range: 0, - num_rows: 0, - num_ranges: 0, - is_new_key: true, - on_column, - sort, - } - } - - fn rest_batch(&mut self, prb: Option) { - self.batch = prb; - if let Some(prb) = &self.batch { - self.cur_row = 0; - self.cur_range = 0; - self.num_rows = prb.batch.num_rows(); - self.num_ranges = prb.ranges.len(); - self.is_new_key = true; - }; - } - - fn key_any_null(&self) -> bool { - match &self.batch { - None => return true, - Some(batch) => { - for c in self.on_column { - let array = batch.batch.column(c.index()); - if array.is_null(self.cur_row) { - return true; - } - } - false - } - } - } - - #[inline] - fn is_finished(&self) -> bool { - self.batch.is_none() || self.num_rows == self.cur_row + 1 - } - - #[inline] - fn is_last_key_in_batch(&self) -> bool { - self.batch.is_none() || self.num_ranges == self.cur_range + 1 - } - - fn advance(&mut self) { - self.cur_row += 1; - self.is_new_key = false; - if !self.is_last_key_in_batch() { - let ranges = self.batch.unwrap().ranges; - if self.cur_row == ranges[self.cur_range + 1].start { - self.cur_range += 1; - self.is_new_key = true; - } - } else { - self.batch = None; - } - } - - fn advance_key(&mut self) { - let ranges = self.batch.unwrap().ranges; - self.cur_range += 1; - self.cur_row = ranges[self.cur_range].start; - self.is_new_key = true; - } -} - -/// Holding ranges for same key over several bathes -struct BufferedBatches { - /// batches that contains the current key - /// TODO: make this spillable as well for skew on join key at buffer side - batches: VecDeque, - /// ranges in each PartitionedRecordBatch that contains the current key - ranges: VecDeque>, - /// row index in first batch to the record that starts this batch - key_idx: Option, - /// total number of rows for the current key - row_num: usize, - /// hold found but not currently used batch, to continue iteration - next_key_batch: Option, - /// Join on column - on_column: Vec, - sort: Vec, -} - -#[inline] -fn range_len(range: &Range) -> usize { - range.end - range.start -} - -impl BufferedBatches { - fn new(on_column: Vec, sort: Vec) -> Self { - Self { - batches: VecDeque::new(), - ranges: VecDeque::new(), - key_idx: None, - row_num: 0, - next_key_batch: None, - on_column, - sort, - } - } - - fn key_any_null(&self) -> bool { - match &self.key_idx { - None => return true, - Some(key_idx) => { - let first_batch = &self.batches[0].batch; - for c in self.on_column { - let array = first_batch.column(c.index()); - if array.is_null(*key_idx) { - return true; - } - } - false - } - } - } - - fn is_finished(&self) -> Result { - match self.key_idx { - None => Ok(true), - Some(_) => match (self.batches.back(), self.ranges.back()) { - (Some(batch), Some(range)) => Ok(batch.is_last_range(range)), - _ => Err(DataFusionError::Execution(format!( - "Batches length {} not equal to ranges length {}", - self.batches.len(), - self.ranges.len() - ))), - }, - } - } - - /// Whether the running key ends at the current batch `prb`, true for continues, false for ends. - fn running_key(&mut self, prb: &PartitionedRecordBatch) -> Result { - let first_range = &prb.ranges[0]; - let range_len = range_len(first_range); - let current_batch = &prb.batch; - let single_range = prb.ranges.len() == 1; - - // compare the first record in batch with the current key pointed by key_idx - match self.key_idx { - None => { - self.batches.push_back(prb.clone()); - self.ranges.push_back(first_range.clone()); - self.key_idx = Some(0); - self.row_num += range_len; - Ok(single_range) - } - Some(key_idx) => { - let key_arrays = join_arrays(&self.batches[0].batch, &self.on_column); - let current_arrays = join_arrays(current_batch, &self.on_column); - let equal = equal_rows(key_idx, 0, &key_arrays, ¤t_arrays)?; - if equal { - self.batches.push_back(prb.clone()); - self.ranges.push_back(first_range.clone()); - self.row_num += range_len; - Ok(single_range) - } else { - self.next_key_batch = Some(prb.clone()); - Ok(false) // running key ends - } - } - } - } - - fn cleanup(&mut self) { - self.batches.drain(..); - self.ranges.drain(..); - self.next_key_batch = None; - } - - fn reset_batch(&mut self, prb: &PartitionedRecordBatch) { - self.cleanup(); - self.batches.push_back(prb.clone()); - let first_range = &prb.ranges[0]; - self.ranges.push_back(first_range.clone()); - self.key_idx = Some(0); - self.row_num = range_len(first_range); - } - - /// Advance the cursor to the next key seen by this buffer - fn advance_in_current_batch(&mut self) { - assert_eq!(self.batches.len(), self.ranges.len()); - if self.batches.len() > 1 { - self.batches.drain(0..(self.batches.len() - 1)); - self.ranges.drain(0..(self.batches.len() - 1)); - } - - if let Some(batch) = self.batches.pop_back() { - let tail_range = self.ranges.pop_back().unwrap(); - self.batches.push_back(batch); - let next_range_idx = batch - .ranges - .iter() - .find_position(|x| x.start == tail_range.start) - .unwrap() - .0; - self.key_idx = Some(tail_range.end); - self.ranges.push_back(batch.ranges[next_range_idx].clone()); - self.row_num = range_len(&batch.ranges[next_range_idx]); - } - } -} - -fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { - on_column - .iter() - .map(|c| rb.column(c.index()).clone()) - .collect() -} - struct OutputBuffer { arrays: Vec>, target_batch_size: usize, @@ -1339,104 +1082,6 @@ impl SortMergeJoinDriver { } } - /// true for has next, false for ended - async fn advance_streamed(&mut self) -> Result { - if self.stream_batch.is_finished() { - self.get_stream_next().await?; - Ok(!self.stream_batch.is_finished()) - } else { - self.stream_batch.advance(); - Ok(true) - } - } - - /// true for has next, false for ended - async fn advance_streamed_key(&mut self) -> Result { - if self.stream_batch.is_finished() || self.stream_batch.is_last_key_in_batch() { - self.get_stream_next().await?; - Ok(!self.stream_batch.is_finished()) - } else { - self.stream_batch.advance_key(); - Ok(true) - } - } - - /// true for has next, false for ended - async fn advance_streamed_key_null_free(&mut self) -> Result { - let mut more_stream_keys = self.advance_streamed_key().await?; - loop { - if more_stream_keys && self.stream_batch.key_any_null() { - more_stream_keys = self.advance_streamed_key().await?; - } else { - break; - } - } - Ok(more_stream_keys) - } - - async fn advance_buffered_key_null_free(&mut self) -> Result { - let mut more_buffered_keys = self.advance_buffered_key().await?; - loop { - if more_buffered_keys && self.buffered_batches.key_any_null() { - more_buffered_keys = self.advance_buffered_key().await?; - } else { - break; - } - } - Ok(more_buffered_keys) - } - - /// true for has next, false for ended - async fn advance_buffered_key(&mut self) -> Result { - if self.buffered_batches.is_finished()? { - match &self.buffered_batches.next_key_batch { - None => { - let batch = self.get_buffered_next().await?; - match batch { - None => return Ok(false), - Some(batch) => { - self.buffered_batches.reset_batch(&batch); - if batch.ranges.len() == 1 { - self.cumulate_same_keys().await?; - } - } - } - } - Some(batch) => { - self.buffered_batches.reset_batch(batch); - if batch.ranges.len() == 1 { - self.cumulate_same_keys().await?; - } - } - } - } else { - self.buffered_batches.advance_in_current_batch(); - if self.buffered_batches.batches[0] - .is_last_range(&self.buffered_batches.ranges[0]) - { - self.cumulate_same_keys().await?; - } - } - Ok(false) - } - - /// true for has next, false for buffer side ended - async fn cumulate_same_keys(&mut self) -> Result { - let batch = self.get_buffered_next().await?; - match batch { - None => Ok(false), - Some(batch) => { - let more_batches = self.buffered_batches.running_key(&batch)?; - if more_batches { - self.cumulate_same_keys().await - } else { - // reach end of current key, but the stream continues - Ok(true) - } - } - } - } - fn compare_stream_buffer(&self) -> Result { let stream_arrays = join_arrays(&self.stream_batch.batch.unwrap().batch, &self.on_streamed); @@ -1449,18 +1094,6 @@ impl SortMergeJoinDriver { &buffer_arrays, ) } - - async fn get_stream_next(&mut self) -> Result<()> { - let batch = self.streamed.next().await.transpose()?; - let prb = PartitionedRecordBatch::new(batch, &self.stream_batch.sort)?; - self.stream_batch.rest_batch(prb); - Ok(()) - } - - async fn get_buffered_next(&mut self) -> Result> { - let batch = self.buffered.next().await.transpose()?; - PartitionedRecordBatch::new(batch, &self.buffered_batches.sort) - } } /// join execution plan executes partitions in parallel and combines them into a set of From de9b1981519dfc3eaae2a897e20a43736f3b8f38 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 19 Nov 2021 18:00:55 +0800 Subject: [PATCH 15/15] v2 wip --- datafusion/src/physical_plan/joins/mod.rs | 2 +- .../src/physical_plan/joins/smj_utils.rs | 733 ++++++++++- .../physical_plan/joins/sort_merge_join.rs | 1084 +---------------- 3 files changed, 708 insertions(+), 1111 deletions(-) diff --git a/datafusion/src/physical_plan/joins/mod.rs b/datafusion/src/physical_plan/joins/mod.rs index 73e0991dfd90..bbf47cb21625 100644 --- a/datafusion/src/physical_plan/joins/mod.rs +++ b/datafusion/src/physical_plan/joins/mod.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -mod smj_utils; pub mod cross_join; pub mod hash_join; +mod smj_utils; pub mod sort_merge_join; use crate::error::{DataFusionError, Result}; diff --git a/datafusion/src/physical_plan/joins/smj_utils.rs b/datafusion/src/physical_plan/joins/smj_utils.rs index ad167e853820..4614f9442dca 100644 --- a/datafusion/src/physical_plan/joins/smj_utils.rs +++ b/datafusion/src/physical_plan/joins/smj_utils.rs @@ -15,21 +15,29 @@ // specific language governing permissions and limitations // under the License. -use crate::error::DataFusionError; +use crate::arrow_dyn_list_array::DynMutableListArray; +use crate::error::{DataFusionError, Result}; +use crate::execution::runtime_env::RuntimeEnv; +use crate::logical_plan::JoinType; use crate::physical_plan::expressions::{ exprs_to_sort_columns, Column, PhysicalSortExpr, }; -use crate::physical_plan::joins::equal_rows; -use crate::physical_plan::SendableRecordBatchStream; -use arrow::array::ArrayRef; +use crate::physical_plan::joins::{comp_rows, equal_rows, ColumnIndex}; +use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; +use arrow::array::*; +use arrow::array::{ArrayRef, MutableArray, MutableBooleanArray}; use arrow::compute::partition::lexicographical_partition_ranges; +use arrow::datatypes::*; use arrow::error::ArrowError; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use futures::StreamExt; +use futures::{Stream, StreamExt}; +use std::cmp::Ordering; use std::collections::VecDeque; +use std::iter::repeat; use std::ops::Range; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; pub(crate) fn join_arrays(rb: &RecordBatch, on_column: &Vec) -> Vec { @@ -70,6 +78,10 @@ impl PartitionedRecordBatch { } } +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Streaming Side +/////////////////////////////////////////////////////////////////////////////////////////////////// + struct StreamingSideBuffer { batch: Option, cur_row: usize, @@ -78,11 +90,10 @@ struct StreamingSideBuffer { num_ranges: usize, is_new_key: bool, on_column: Vec, - sort: Vec, } impl StreamingSideBuffer { - fn new(on_column: Vec, sort: Vec) -> Self { + fn new(on_column: Vec) -> Self { Self { batch: None, cur_row: 0, @@ -91,10 +102,13 @@ impl StreamingSideBuffer { num_ranges: 0, is_new_key: true, on_column, - sort, } } + fn join_arrays(&self) -> Vec { + join_arrays(&self.batch.unwrap().batch, &self.on_column) + } + fn reset(&mut self, prb: Option) { self.batch = prb; if let Some(prb) = &self.batch { @@ -149,17 +163,57 @@ impl StreamingSideBuffer { self.cur_row = ranges[self.cur_range].start; self.is_new_key = true; } + + fn repeat_cell( + &self, + times: usize, + to: &mut Box, + column_index: &ColumnIndex, + ) { + repeat_cell( + &self.batch.unwrap().batch, + self.cur_row, + times, + to, + column_index, + ); + } + + fn copy_slices( + &self, + slice: &Slice, + array: &mut Box, + column_index: &ColumnIndex, + ) { + let batches = vec![&self.batch.unwrap().batch]; + let slices = vec![slice.clone()]; + copy_slices(&batches, &slices, array, column_index); + } } struct StreamingSideStream { input: SendableRecordBatchStream, - output: StreamingSideBuffer, + buffer: StreamingSideBuffer, input_is_finished: bool, sort: Vec, } impl StreamingSideStream { - fn inner_next( + fn new( + input: SendableRecordBatchStream, + on: Vec, + sort: Vec, + ) -> Self { + let buffer = StreamingSideBuffer::new(on); + Self { + input, + buffer, + input_is_finished: false, + sort, + } + } + + fn input_next( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { @@ -176,7 +230,7 @@ impl StreamingSideStream { batch => { let batch = batch.transpose()?; let prb = PartitionedRecordBatch::new(batch, &self.sort)?; - self.output.reset(prb); + self.buffer.reset(prb); Poll::Ready(Some(Ok(()))) } }, @@ -188,17 +242,17 @@ impl StreamingSideStream { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - if self.output.is_finished() && self.input_is_finished { + if self.buffer.is_finished() && self.input_is_finished { Poll::Ready(None) } else { - if self.output.is_finished() { - match self.inner_next(cx) { + if self.buffer.is_finished() { + match self.input_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(x) => match x { None => Poll::Ready(None), Some(x) => { x?; - if self.output.is_finished() { + if self.buffer.is_finished() { Poll::Ready(None) } else { Poll::Ready(Some(Ok(()))) @@ -207,7 +261,7 @@ impl StreamingSideStream { }, } } else { - self.output.advance(); + self.buffer.advance(); Poll::Ready(Some(Ok(()))) } } @@ -217,17 +271,17 @@ impl StreamingSideStream { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - if self.output.is_finished() && self.input_is_finished { + if self.buffer.is_finished() && self.input_is_finished { Poll::Ready(None) } else { - if self.output.is_finished() || self.output.is_last_key_in_batch() { - match self.inner_next(cx) { + if self.buffer.is_finished() || self.buffer.is_last_key_in_batch() { + match self.input_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(x) => match x { None => Poll::Ready(None), Some(x) => { x?; - if self.output.is_finished() { + if self.buffer.is_finished() { Poll::Ready(None) } else { Poll::Ready(Some(Ok(()))) @@ -236,7 +290,7 @@ impl StreamingSideStream { }, } } else { - self.output.advance_key(); + self.buffer.advance_key(); Poll::Ready(Some(Ok(()))) } } @@ -246,7 +300,7 @@ impl StreamingSideStream { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - if self.output.is_finished() && self.input_is_finished { + if self.buffer.is_finished() && self.input_is_finished { Poll::Ready(None) } else { loop { @@ -255,7 +309,7 @@ impl StreamingSideStream { None => return Poll::Ready(None), Some(x) => { x?; - if !self.output.key_any_null() { + if !self.buffer.key_any_null() { return Poll::Ready(Some(Ok(()))); } } @@ -267,6 +321,10 @@ impl StreamingSideStream { } } +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Buffering Side +/////////////////////////////////////////////////////////////////////////////////////////////////// + /// Holding ranges for same key over several bathes struct BufferingSideBuffer { /// batches that contains the current key @@ -282,16 +340,10 @@ struct BufferingSideBuffer { next_key_batch: Option, /// Join on column on_column: Vec, - sort: Vec, -} - -#[inline] -pub(crate) fn range_len(range: &Range) -> usize { - range.end - range.start } impl BufferingSideBuffer { - fn new(on_column: Vec, sort: Vec) -> Self { + fn new(on_column: Vec) -> Self { Self { batches: VecDeque::new(), ranges: VecDeque::new(), @@ -299,10 +351,13 @@ impl BufferingSideBuffer { row_num: 0, next_key_batch: None, on_column, - sort, } } + fn join_arrays(&self) -> Vec { + join_arrays(&self.batches[0].batch, &self.on_column) + } + fn key_any_null(&self) -> bool { match &self.key_idx { None => return true, @@ -336,7 +391,7 @@ impl BufferingSideBuffer { /// Whether the running key ends at the current batch `prb`, true for continues, false for ends. fn running_key(&mut self, prb: &PartitionedRecordBatch) -> ArrowResult { let first_range = &prb.ranges[0]; - let range_len = range_len(first_range); + let range_len = first_range.len(); let current_batch = &prb.batch; let single_range = prb.ranges.len() == 1; @@ -379,7 +434,7 @@ impl BufferingSideBuffer { let first_range = &prb.ranges[0]; self.ranges.push_back(first_range.clone()); self.key_idx = Some(0); - self.row_num = range_len(first_range); + self.row_num = first_range.len(); } /// Advance the cursor to the next key seen by this buffer @@ -402,21 +457,115 @@ impl BufferingSideBuffer { .0; self.key_idx = Some(tail_range.end); self.ranges.push_back(batch.ranges[next_range_idx].clone()); - self.row_num = range_len(&batch.ranges[next_range_idx]); + self.row_num = batch.ranges[next_range_idx].len(); } } + + /// Locate the starting idx for each of the ranges in the current buffer. + fn range_start_indices(&self) -> Vec { + let mut idx = 0; + let mut start_indices: Vec = vec![]; + self.ranges.iter().for_each(|r| { + start_indices.push(idx); + idx += r.len(); + }); + start_indices.push(usize::MAX); + start_indices + } + + /// Locate buffered records start from `buffered_idx` of `len`gth + fn slices_from_batches( + &self, + start_indices: &Vec, + buffered_idx: usize, + len: usize, + ) -> Vec { + let ranges = &self.ranges; + let mut idx = buffered_idx; + let mut slices: Vec = vec![]; + let mut remaining = len; + let find = start_indices + .iter() + .enumerate() + .find(|(i, start_idx)| **start_idx >= idx) + .unwrap(); + let mut batch_idx = if *find.1 == idx { find.0 } else { find.0 - 1 }; + + while remaining > 0 { + let current_range = &ranges[batch_idx]; + let range_start_idx = start_indices[batch_idx]; + let start_idx = idx - range_start_idx + current_range.start; + let range_available = current_range.len() - (idx - range_start_idx); + + if range_available >= remaining { + slices.push(Slice { + batch_idx, + start_idx, + len: remaining, + }); + remaining = 0; + } else { + slices.push(Slice { + batch_idx, + start_idx, + len: range_available, + }); + remaining -= range_available; + batch_idx += 1; + idx += range_available; + } + } + slices + } + + fn copy_slices( + &self, + slices: &Vec, + array: &mut Box, + column_index: &ColumnIndex, + ) { + let batches = self + .batches + .iter() + .map(|prb| &prb.batch) + .collect::>(); + copy_slices(&batches, slices, array, column_index); + } +} + +/// Slice of batch at `batch_idx` inside BufferingSideBuffer. +#[derive(Copy, Clone)] +struct Slice { + batch_idx: usize, + start_idx: usize, + len: usize, } struct BufferingSideStream { input: SendableRecordBatchStream, - output: BufferingSideBuffer, + buffer: BufferingSideBuffer, input_is_finished: bool, cumulating: bool, sort: Vec, } impl BufferingSideStream { - fn inner_next( + fn new( + input: SendableRecordBatchStream, + on: Vec, + sort: Vec, + ) -> Self { + let buffer = BufferingSideBuffer::new(on); + Self { + input, + buffer, + input_is_finished: false, + cumulating: false, + sort, + } + } + + fn input_next( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { @@ -445,7 +594,7 @@ impl BufferingSideStream { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - if self.output.is_finished()? && self.input_is_finished { + if self.buffer.is_finished()? && self.input_is_finished { return Poll::Ready(None); } else { if self.cumulating { @@ -461,15 +610,15 @@ impl BufferingSideStream { } } - if self.output.is_finished()? { - return match &self.output.next_key_batch { - None => match self.inner_next(cx) { + if self.buffer.is_finished()? { + return match &self.buffer.next_key_batch { + None => match self.input_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(x) => match x { None => Poll::Ready(None), Some(x) => { let prb = x?; - self.output.reset(&prb); + self.buffer.reset(&prb); if prb.ranges.len() == 1 { self.cumulating = true; Poll::Pending @@ -480,7 +629,7 @@ impl BufferingSideStream { }, }, Some(batch) => { - self.output.reset(batch); + self.buffer.reset(batch); if batch.ranges.len() == 1 { self.cumulating = true; Poll::Pending @@ -490,8 +639,8 @@ impl BufferingSideStream { } }; } else { - self.output.advance_key(); - if self.output.batches[0].is_last_range(&self.output.ranges[0]) { + self.buffer.advance_key(); + if self.buffer.batches[0].is_last_range(&self.buffer.ranges[0]) { self.cumulating = true; } else { return Poll::Ready(Some(Ok(()))); @@ -506,7 +655,7 @@ impl BufferingSideStream { self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - if self.output.is_finished()? && self.input_is_finished { + if self.buffer.is_finished()? && self.input_is_finished { Poll::Ready(None) } else { loop { @@ -515,7 +664,7 @@ impl BufferingSideStream { None => return Poll::Ready(None), Some(x) => { x?; - if !self.output.key_any_null() { + if !self.buffer.key_any_null() { return Poll::Ready(Some(Ok(()))); } } @@ -531,7 +680,7 @@ impl BufferingSideStream { cx: &mut Context<'_>, ) -> Poll>> { loop { - match self.inner_next(cx) { + match self.input_next(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(x) => match x { None => { @@ -540,7 +689,7 @@ impl BufferingSideStream { } Some(x) => { let prb = x?; - let buffer_more = self.output.running_key(&prb)?; + let buffer_more = self.buffer.running_key(&prb)?; if !buffer_more { self.cumulating = false; return Poll::Ready(Some(Ok(()))); @@ -551,3 +700,491 @@ impl BufferingSideStream { } } } + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Output +/////////////////////////////////////////////////////////////////////////////////////////////////// + +struct OutputBuffer { + arrays: Vec>, + target_batch_size: usize, + slots_available: usize, + schema: Arc, +} + +impl OutputBuffer { + fn new(target_batch_size: usize, schema: Arc) -> ArrowResult { + let arrays = new_arrays(&schema, target_batch_size)?; + Ok(Self { + arrays, + target_batch_size, + slots_available: target_batch_size, + schema, + }) + } + + fn output_and_reset(mut self) -> ArrowResult { + let result = make_batch(self.schema.clone(), self.arrays); + self.arrays = new_arrays(&self.schema, self.target_batch_size)?; + self.slots_available = self.target_batch_size; + result + } + + fn append(&mut self, size: usize) { + assert!(size <= self.slots_available); + self.slots_available -= size; + } + + #[inline] + fn is_full(&self) -> bool { + self.slots_available == 0 + } +} + +fn new_arrays( + schema: &Arc, + batch_size: usize, +) -> ArrowResult>> { + let arrays: Vec> = schema + .fields() + .iter() + .map(|field| { + let dt = field.data_type.to_logical_type(); + make_mutable(dt, batch_size) + }) + .collect::>()?; + Ok(arrays) +} + +fn make_batch( + schema: Arc, + mut arrays: Vec>, +) -> ArrowResult { + let columns = arrays.iter_mut().map(|array| array.as_arc()).collect(); + RecordBatch::try_new(schema, columns) +} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use arrow::datatypes::PrimitiveType::*; + use arrow::types::{days_ms, months_days_ns}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +fn make_mutable( + data_type: &DataType, + capacity: usize, +) -> ArrowResult> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => Box::new(MutableBooleanArray::with_capacity(capacity)) + as Box, + PhysicalType::Primitive(primitive) => { + with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }) + } + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) + as Box + } + PhysicalType::Utf8 => Box::new(MutableUtf8Array::::with_capacity(capacity)) + as Box, + _ => match data_type { + DataType::List(inner) => { + let values = make_mutable(inner.data_type(), 0)?; + Box::new(DynMutableListArray::::new_with_capacity( + values, capacity, + )) as Box + } + DataType::FixedSizeBinary(size) => Box::new( + MutableFixedSizeBinaryArray::with_capacity(*size as usize, capacity), + ) as Box, + other => { + return Err(ArrowError::NotYetImplemented(format!( + "making mutable of type {} is not implemented yet", + data_type + ))) + } + }, + }) +} + +macro_rules! repeat_n { + ($TO:ty, $FROM:ty, $N:expr, $to: ident, $from: ident, $idx: ident) => {{ + let to = $to.as_mut_any().downcast_mut::<$TO>().unwrap(); + let from = $from.as_any().downcast_ref::<$FROM>().unwrap(); + let repeat_iter = from + .slice($idx, 1) + .iter() + .flat_map(|v| repeat(v).take($N)) + .collect::>(); + to.extend_trusted_len(repeat_iter.into_iter()); + }}; +} + +/// repeat times of cell located by `idx` at streamed side to output +fn repeat_cell( + batch: &RecordBatch, + idx: usize, + times: usize, + to: &mut Box, + column_index: &ColumnIndex, +) { + let from = batch.column(column_index.index); + match to.data_type().to_physical_type() { + PhysicalType::Boolean => { + repeat_n!(MutableBooleanArray, BooleanArray, times, to, from, idx) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, times, to, from, idx), + PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, times, to, from, idx), + PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, times, to, from, idx), + PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, times, to, from, idx), + PrimitiveType::Float32 => { + repeat_n!(Float32Vec, Float32Array, times, to, from, idx) + } + PrimitiveType::Float64 => { + repeat_n!(Float64Vec, Float64Array, times, to, from, idx) + } + _ => todo!(), + }, + PhysicalType::Utf8 => { + repeat_n!(MutableUtf8Array, Utf8Array, times, to, from, idx) + } + _ => todo!(), + } +} + +macro_rules! copy_slices { + ($TO:ty, $FROM:ty, $array: ident, $batches: ident, $slices: ident, $column_index: ident) => {{ + let to = $array.as_mut_any().downcast_mut::<$TO>().unwrap(); + for pos in $slices { + let from = $batches[pos.batch_idx] + .column($column_index.index) + .slice(pos.start_idx, pos.len); + let from = from.as_any().downcast_ref::<$FROM>().unwrap(); + to.extend_trusted_len(from.iter()); + } + }}; +} + +fn copy_slices( + batches: &Vec<&RecordBatch>, + slices: &Vec, + array: &mut Box, + column_index: &ColumnIndex, +) { + // output buffered start `buffered_idx`, len `rows_to_output` + match array.data_type().to_physical_type() { + PhysicalType::Boolean => { + copy_slices!( + MutableBooleanArray, + BooleanArray, + array, + batches, + slices, + column_index + ) + } + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int8 => { + copy_slices!(Int8Vec, Int8Array, array, batches, slices, column_index) + } + PrimitiveType::Int16 => { + copy_slices!(Int16Vec, Int16Array, array, batches, slices, column_index) + } + PrimitiveType::Int32 => { + copy_slices!(Int32Vec, Int32Array, array, batches, slices, column_index) + } + PrimitiveType::Int64 => { + copy_slices!(Int64Vec, Int64Array, array, batches, slices, column_index) + } + PrimitiveType::Float32 => copy_slices!( + Float32Vec, + Float32Array, + array, + batches, + slices, + column_index + ), + PrimitiveType::Float64 => copy_slices!( + Float64Vec, + Float64Array, + array, + batches, + slices, + column_index + ), + _ => todo!(), + }, + PhysicalType::Utf8 => { + copy_slices!( + MutableUtf8Array, + Utf8Array, + array, + batches, + slices, + column_index + ) + } + _ => todo!(), + } +} + +pub struct SortMergeJoinCommon { + streamed: StreamingSideStream, + buffered: BufferingSideStream, + schema: Arc, + /// Information of index and left / right placement of columns + column_indices: Vec, + join_type: JoinType, + runtime: Arc, + output: OutputBuffer, +} + +impl SortMergeJoinCommon { + pub fn new( + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec, + on_buffered: Vec, + streamed_sort: Vec, + buffered_sort: Vec, + column_indices: Vec, + schema: Arc, + join_type: JoinType, + runtime: Arc, + ) -> Result { + let streamed = StreamingSideStream::new(streamed, on_streamed, streamed_sort); + let buffered = BufferingSideStream::new(buffered, on_buffered, buffered_sort); + let output = OutputBuffer::new(runtime.batch_size(), schema.clone()) + .map_err(DataFusionError::ArrowError)?; + Ok(Self { + streamed, + buffered, + schema, + column_indices, + join_type, + runtime, + output, + }) + } + + fn compare_stream_buffer(&self) -> Result { + let stream_arrays = &self.streamed.buffer.join_arrays(); + let buffer_arrays = &self.buffered.buffer.join_arrays(); + comp_rows( + self.streamed.buffer.cur_row, + self.buffered.buffer.key_idx.unwrap(), + stream_arrays, + buffer_arrays, + ) + } +} + +pub struct InnerJoiner { + inner: SortMergeJoinCommon, + matched: bool, + buffered_idx: usize, + start_indices: Vec, + buffer_remaining: usize, + advance_stream: bool, + advance_buffer: bool, + continues_match: bool, +} + +impl InnerJoiner { + pub fn new(inner: SortMergeJoinCommon) -> Self { + Self { + inner, + matched: false, + buffered_idx: 0, + start_indices: vec![], + buffer_remaining: 0, + advance_stream: true, + advance_buffer: true, + continues_match: false, + } + } + + fn find_next( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + if self.continues_match { + match Pin::new(&mut self.inner.streamed).advance(cx) { + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(y) => { + y?; + if self.inner.streamed.buffer.is_new_key { + self.continues_match = false; + self.advance_stream = false; + self.advance_buffer = true; + self.matched = false; + self.find_next(cx) + } else { + self.continues_match = true; + self.advance_stream = false; + self.advance_buffer = false; + self.matched = false; + Poll::Ready(Some(Ok(()))) + } + } + }, + Poll::Pending => Poll::Pending, + } + } else { + if self.advance_stream { + match Pin::new(&mut self.inner.streamed).advance_key_skip_null(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(y) => { + y?; + self.continues_match = true; + } + }, + Poll::Pending => return Poll::Pending, + } + } + + if self.advance_buffer { + match Pin::new(&mut self.inner.streamed).advance_key_skip_null(cx) { + Poll::Ready(x) => match x { + None => return Poll::Ready(None), + Some(y) => { + y?; + } + }, + Poll::Pending => return Poll::Pending, + } + } + + let cmp = self + .inner + .compare_stream_buffer() + .map_err(DataFusionError::into_arrow_external_error)?; + + match cmp { + Ordering::Less => { + self.advance_stream = true; + self.advance_buffer = false; + self.find_next(cx) + } + Ordering::Equal => { + self.advance_stream = true; + self.advance_buffer = true; + self.matched = true; + self.buffered_idx = 0; + self.start_indices = self.inner.buffered.buffer.range_start_indices(); + self.buffer_remaining = self.inner.buffered.buffer.row_num; + Poll::Ready(Some(Ok(()))) + } + Ordering::Greater => { + self.advance_stream = false; + self.advance_buffer = true; + self.find_next(cx) + } + } + } + } + + fn fill_output( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let output = &mut self.inner.output; + let streamed = &self.inner.streamed.buffer; + let buffered = &self.inner.buffered.buffer; + + let slots_available = output.slots_available; + let mut rows_to_output = 0; + if slots_available >= self.buffer_remaining { + self.matched = false; + rows_to_output = self.buffer_remaining; + self.buffer_remaining = 0; + } else { + rows_to_output = slots_available; + self.buffer_remaining -= rows_to_output; + } + + let slices = buffered.slices_from_batches( + &self.start_indices, + self.buffered_idx, + rows_to_output, + ); + + output + .arrays + .iter_mut() + .zip(self.inner.schema.fields().iter()) + .zip(self.inner.column_indices.iter()) + .map(|((array, field), column_index)| { + if column_index.is_left { + // repeat streamed `rows_to_output` times + streamed.repeat_cell(rows_to_output, array, column_index); + } else { + // copy buffered start from: `buffered_idx`, len: `rows_to_output` + buffered.copy_slices(&slices, array, column_index); + } + }); + + self.inner.output.append(rows_to_output); + self.buffered_idx += rows_to_output; + + if self.inner.output.is_full() { + let result = output.output_and_reset(); + Poll::Ready(Some(result)) + } else { + self.poll_next(cx) + } + } +} + +impl Stream for InnerJoiner { + type Item = ArrowResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !self.matched { + let find = self.find_next(cx); + match find { + Poll::Ready(x) => match x { + None => Poll::Ready(None), + Some(y) => match y { + Ok(_) => self.fill_output(cx), + Err(err) => Poll::Ready(Some(Err(ArrowError::External( + "Failed while finding next match for inner join".to_owned(), + err.into(), + )))), + }, + }, + Poll::Pending => Poll::Pending, + } + } else { + self.fill_output(cx) + } + } +} + +impl RecordBatchStream for InnerJoiner { + fn schema(&self) -> SchemaRef { + self.inner.schema.clone() + } +} diff --git a/datafusion/src/physical_plan/joins/sort_merge_join.rs b/datafusion/src/physical_plan/joins/sort_merge_join.rs index 82c26c1acce1..b1b0292c3afe 100644 --- a/datafusion/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/src/physical_plan/joins/sort_merge_join.rs @@ -25,10 +25,10 @@ use std::{any::Any, usize}; use arrow::array::*; use arrow::datatypes::*; -use arrow::error::Result as ArrowResult; +use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use futures::StreamExt; +use futures::{Stream, StreamExt}; use crate::arrow_dyn_list_array::DynMutableListArray; use crate::error::{DataFusionError, Result}; @@ -36,7 +36,9 @@ use crate::execution::runtime_env::RuntimeEnv; use crate::execution::runtime_env::RUNTIME_ENV; use crate::logical_plan::JoinType; use crate::physical_plan::expressions::{exprs_to_sort_columns, PhysicalSortExpr}; -use crate::physical_plan::joins::smj_utils::{join_arrays, range_len}; +use crate::physical_plan::joins::smj_utils::{ + join_arrays, InnerJoiner, SortMergeJoinCommon, +}; use crate::physical_plan::joins::{ build_join_schema, check_join_is_valid, column_indices_from_schema, comp_rows, equal_rows, ColumnIndex, JoinOn, @@ -55,1047 +57,13 @@ use arrow::compute::partition::lexicographical_partition_ranges; use std::cmp::Ordering; use std::collections::VecDeque; use std::ops::Range; +use std::pin::Pin; +use std::task::{Context, Poll}; use tokio::sync::mpsc::{Receiver, Sender}; type StringArray = Utf8Array; type LargeStringArray = Utf8Array; -struct OutputBuffer { - arrays: Vec>, - target_batch_size: usize, - slots_available: usize, - schema: Arc, -} - -impl OutputBuffer { - fn new(target_batch_size: usize, schema: Arc) -> Result { - let arrays = new_arrays(&schema, target_batch_size)?; - Ok(Self { - arrays, - target_batch_size, - slots_available: target_batch_size, - schema, - }) - } - - fn output_and_reset(&mut self) -> Option> { - if self.is_full() { - let result = make_batch(self.schema.clone(), output_arrays); - self.arrays = new_arrays(&self.schema, self.target_batch_size)?; - self.slots_available = self.target_batch_size; - Some(result) - } else { - None - } - } - - fn append(&mut self, size: usize) { - assert!(size <= self.slots_available); - self.slots_available -= size; - } - - #[inline] - fn is_full(&self) -> bool { - self.slots_available == 0 - } -} - -struct SortMergeJoinDriver { - streamed: SendableRecordBatchStream, - buffered: SendableRecordBatchStream, - on_streamed: Vec, - on_buffered: Vec, - schema: Arc, - /// Information of index and left / right placement of columns - column_indices: Vec, - stream_batch: StreamingBatch, - buffered_batches: BufferedBatches, - runtime: Arc, -} - -macro_rules! with_match_primitive_type {( - $key_type:expr, | $_:tt $T:ident | $($body:tt)* -) => ({ - macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} - use arrow::datatypes::PrimitiveType::*; - use arrow::types::{days_ms, months_days_ns}; - match $key_type { - Int8 => __with_ty__! { i8 }, - Int16 => __with_ty__! { i16 }, - Int32 => __with_ty__! { i32 }, - Int64 => __with_ty__! { i64 }, - Int128 => __with_ty__! { i128 }, - DaysMs => __with_ty__! { days_ms }, - MonthDayNano => __with_ty__! { months_days_ns }, - UInt8 => __with_ty__! { u8 }, - UInt16 => __with_ty__! { u16 }, - UInt32 => __with_ty__! { u32 }, - UInt64 => __with_ty__! { u64 }, - Float32 => __with_ty__! { f32 }, - Float64 => __with_ty__! { f64 }, - } -})} - -fn make_mutable(data_type: &DataType, capacity: usize) -> Result> { - Ok(match data_type.to_physical_type() { - PhysicalType::Boolean => Box::new(MutableBooleanArray::with_capacity(capacity)) - as Box, - PhysicalType::Primitive(primitive) => { - with_match_primitive_type!(primitive, |$T| { - Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) - as Box - }) - } - PhysicalType::Binary => { - Box::new(MutableBinaryArray::::with_capacity(capacity)) - as Box - } - PhysicalType::Utf8 => Box::new(MutableUtf8Array::::with_capacity(capacity)) - as Box, - _ => match data_type { - DataType::List(inner) => { - let values = make_mutable(inner.data_type(), 0)?; - Box::new(DynMutableListArray::::new_with_capacity( - values, capacity, - )) as Box - } - DataType::FixedSizeBinary(size) => Box::new( - MutableFixedSizeBinaryArray::with_capacity(*size as usize, capacity), - ) as Box, - other => { - return Err(DataFusionError::Execution(format!( - "making mutable of type {} is not implemented yet", - data_type - ))) - } - }, - }) -} - -fn new_arrays( - schema: &Arc, - batch_size: usize, -) -> Result>> { - let arrays: Vec> = schema - .fields() - .iter() - .map(|field| { - let dt = field.data_type.to_logical_type(); - make_mutable(dt, batch_size) - }) - .collect::>()?; - Ok(arrays) -} - -fn make_batch( - schema: Arc, - mut arrays: Vec>, -) -> ArrowResult { - let columns = arrays.iter_mut().map(|array| array.as_arc()).collect(); - RecordBatch::try_new(schema, columns) -} - -macro_rules! repeat_n { - ($TO:ty, $FROM:ty, $N:expr, $to: ident, $from: ident, $idx: ident) => {{ - let to = $to.as_mut_any().downcast_mut::<$TO>().unwrap(); - let from = $from.as_any().downcast_ref::<$FROM>().unwrap(); - let repeat_iter = from - .slice($idx, 1) - .iter() - .flat_map(|v| repeat(v).take($N)) - .collect::>(); - to.extend_trusted_len(repeat_iter.into_iter()); - }}; -} - -/// repeat times of cell located by `idx` at streamed side to output -fn repeat_streamed_cell( - stream_batch: &RecordBatch, - idx: usize, - times: usize, - to: &mut Box, - column_index: &ColumnIndex, -) { - let from = stream_batch.column(column_index.index); - match to.data_type().to_physical_type() { - PhysicalType::Boolean => { - repeat_n!(MutableBooleanArray, BooleanArray, times, to, from, idx) - } - PhysicalType::Primitive(primitive) => match primitive { - PrimitiveType::Int8 => repeat_n!(Int8Vec, Int8Array, times, to, from, idx), - PrimitiveType::Int16 => repeat_n!(Int16Vec, Int16Array, times, to, from, idx), - PrimitiveType::Int32 => repeat_n!(Int32Vec, Int32Array, times, to, from, idx), - PrimitiveType::Int64 => repeat_n!(Int64Vec, Int64Array, times, to, from, idx), - PrimitiveType::Float32 => { - repeat_n!(Float32Vec, Float32Array, times, to, from, idx) - } - PrimitiveType::Float64 => { - repeat_n!(Float64Vec, Float64Array, times, to, from, idx) - } - _ => todo!(), - }, - PhysicalType::Utf8 => { - repeat_n!(MutableUtf8Array, Utf8Array, times, to, from, idx) - } - _ => todo!(), - } -} - -macro_rules! copy_slices { - ($TO:ty, $FROM:ty, $array: ident, $batches: ident, $slices: ident, $column_index: ident) => {{ - let to = $array.as_mut_any().downcast_mut::<$TO>().unwrap(); - for pos in $slices { - let from = $batches[pos.batch_idx] - .column($column_index.index) - .slice(pos.start_idx, pos.len); - let from = from.as_any().downcast_ref::<$FROM>().unwrap(); - to.extend_trusted_len(from.iter()); - } - }}; -} - -fn copy_slices( - batches: &Vec<&RecordBatch>, - slices: &Vec, - array: &mut Box, - column_index: &ColumnIndex, -) { - // output buffered start `buffered_idx`, len `rows_to_output` - match array.data_type().to_physical_type() { - PhysicalType::Boolean => { - copy_slices!( - MutableBooleanArray, - BooleanArray, - array, - batches, - slices, - column_index - ) - } - PhysicalType::Primitive(primitive) => match primitive { - PrimitiveType::Int8 => { - copy_slices!(Int8Vec, Int8Array, array, batches, slices, column_index) - } - PrimitiveType::Int16 => { - copy_slices!(Int16Vec, Int16Array, array, batches, slices, column_index) - } - PrimitiveType::Int32 => { - copy_slices!(Int32Vec, Int32Array, array, batches, slices, column_index) - } - PrimitiveType::Int64 => { - copy_slices!(Int64Vec, Int64Array, array, batches, slices, column_index) - } - PrimitiveType::Float32 => copy_slices!( - Float32Vec, - Float32Array, - array, - batches, - slices, - column_index - ), - PrimitiveType::Float64 => copy_slices!( - Float64Vec, - Float64Array, - array, - batches, - slices, - column_index - ), - _ => todo!(), - }, - PhysicalType::Utf8 => { - copy_slices!( - MutableUtf8Array, - Utf8Array, - array, - batches, - slices, - column_index - ) - } - _ => todo!(), - } -} - -fn range_start_indices(buffered_ranges: &VecDeque>) -> Vec { - let mut idx = 0; - let mut start_indices: Vec = vec![]; - buffered_ranges.iter().for_each(|r| { - start_indices.push(idx); - idx += range_len(r); - }); - start_indices.push(usize::MAX); - start_indices -} - -/// Locate buffered records start from `buffered_idx` of `len`gth -/// inside buffered batches. -fn slices_from_batches( - buffered_ranges: &VecDeque>, - start_indices: &Vec, - buffered_idx: usize, - len: usize, -) -> Vec { - let mut idx = buffered_idx; - let mut slices: Vec = vec![]; - let mut remaining = len; - let find = start_indices - .iter() - .find_position(|&&start_idx| start_idx >= idx) - .unwrap(); - let mut batch_idx = if find.1 == idx { find.0 } else { find.0 - 1 }; - - while remaining > 0 { - let current_range = &buffered_ranges[batch_idx]; - let range_start_idx = start_indices[batch_idx]; - let start_idx = idx - range_start_idx + current_range.start; - let range_available = range_len(current_range) - (idx - range_start_idx); - - if range_available >= remaining { - slices.push(Slice { - batch_idx, - start_idx, - len: remaining, - }); - remaining = 0; - } else { - slices.push(Slice { - batch_idx, - start_idx, - len: range_available, - }); - remaining -= range_available; - batch_idx += 1; - idx += range_available; - } - } - slices -} - -/// Slice of batch at `batch_idx` inside BufferedBatches. -struct Slice { - batch_idx: usize, - start_idx: usize, - len: usize, -} - -impl SortMergeJoinDriver { - fn new( - streamed: SendableRecordBatchStream, - buffered: SendableRecordBatchStream, - on_streamed: Vec, - on_buffered: Vec, - streamed_sort: Vec, - buffered_sort: Vec, - column_indices: Vec, - schema: Arc, - runtime: Arc, - ) -> Self { - Self { - streamed, - buffered, - on_streamed, - on_buffered, - schema, - column_indices, - stream_batch: StreamingBatch::new(on_streamed.clone(), streamed_sort), - buffered_batches: BufferedBatches::new(on_buffered.clone(), buffered_sort), - runtime, - } - } - - async fn inner_join_driver( - &mut self, - sender: &Sender>, - ) -> Result<()> { - while self.find_next_inner_match().await? { - loop { - let result = self - .join_eq_records( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - self.stream_batch.advance(); - if self.stream_batch.is_new_key { - break; - } - } - } - - Ok(()) - } - - async fn semi_join_driver( - &mut self, - sender: &Sender>, - ) -> Result<()> { - let target_batch_size = self.runtime.batch_size(); - - let mut output_slots_available = target_batch_size; - let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; - - while self.find_next_inner_match().await? { - let result = self - .stream_copy_buffer_omit( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - } - - Ok(()) - } - - async fn outer_join_driver( - &mut self, - sender: &Sender>, - ) -> Result<()> { - let target_batch_size = self.runtime.batch_size(); - - let mut output_slots_available = target_batch_size; - let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; - let mut buffer_ends = false; - - loop { - let OuterMatchResult { - get_match, - buffered_ended, - more_output, - } = self.find_next_outer(buffer_ends).await?; - if !more_output { - break; - } - buffer_ends = buffered_ended; - if get_match { - loop { - let result = self - .join_eq_records( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - self.stream_batch.advance(); - if self.stream_batch.is_new_key { - break; - } - } - } else { - let result = self - .stream_copy_buffer_null( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - } - } - - Ok(()) - } - - async fn anti_join_driver( - &mut self, - sender: &Sender>, - ) -> Result<()> { - let target_batch_size = self.runtime.batch_size(); - - let mut output_slots_available = target_batch_size; - let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; - let mut buffer_ends = false; - - loop { - let OuterMatchResult { - get_match, - buffered_ended, - more_output, - } = self.find_next_outer(buffer_ends)?; - if !more_output { - break; - } - buffer_ends = buffered_ended; - if get_match { - // do nothing - } else { - let result = self - .stream_copy_buffer_omit( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - } - } - - Ok(()) - } - - async fn full_outer_driver( - &mut self, - sender: &Sender>, - ) -> Result<()> { - let target_batch_size = self.runtime.batch_size(); - - let mut output_slots_available = target_batch_size; - let mut output_arrays = new_arrays(&self.schema, target_batch_size)?; - let mut stream_ends = false; - let mut buffer_ends = false; - let mut advance_stream = true; - let mut advance_buffer = true; - - loop { - if advance_buffer { - buffer_ends = !self.advance_buffered_key().await?; - } - if advance_stream { - stream_ends = !self.advance_streamed_key().await?; - } - - if stream_ends && buffer_ends { - break; - } else if stream_ends { - let result = self - .stream_null_buffer_copy( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - advance_buffer = true; - advance_stream = false; - } else if buffer_ends { - let result = self - .stream_copy_buffer_null( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - advance_stream = true; - advance_buffer = false; - } else { - if self.stream_batch.key_any_null() { - let result = self - .stream_copy_buffer_null( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - advance_stream = true; - advance_buffer = false; - continue; - } - if self.buffered_batches.key_any_null() { - let result = self - .stream_null_buffer_copy( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - advance_buffer = true; - advance_stream = false; - continue; - } - - let current_cmp = self.compare_stream_buffer()?; - match current_cmp { - Ordering::Less => { - let result = self - .stream_copy_buffer_null( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - advance_stream = true; - advance_buffer = false; - } - Ordering::Equal => { - loop { - let result = self - .join_eq_records( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - self.stream_batch.advance(); - if self.stream_batch.is_new_key { - break; - } - } - advance_stream = false; // we already reach the next key of stream - advance_buffer = true; - } - Ordering::Greater => { - let result = self - .stream_null_buffer_copy( - target_batch_size, - output_slots_available, - output_arrays, - sender, - ) - .await?; - output_slots_available = result.0; - output_arrays = result.1; - - advance_buffer = true; - advance_stream = false; - } - } - } - } - Ok(()) - } - - async fn join_eq_records( - &mut self, - target_batch_size: usize, - output_slots_available: usize, - mut output_arrays: Vec>, - sender: &Sender>, - ) -> Result<(usize, Vec>)> { - let mut output_slots_available = output_slots_available; - let mut remaining = self.buffered_batches.row_num; - let stream_batch = &self.stream_batch.batch.unwrap().batch; - let stream_row = self.stream_batch.cur_row; - - let batches = self - .buffered_batches - .batches - .iter() - .map(|prb| &prb.batch) - .collect::>(); - let buffered_ranges = &self.buffered_batches.ranges; - - let mut unfinished = true; - let mut buffered_idx = 0; - let mut rows_to_output = 0; - let start_indices = range_start_indices(buffered_ranges); - - // output each buffered matching record once - while unfinished { - if output_slots_available >= remaining { - unfinished = false; - rows_to_output = remaining; - output_slots_available -= remaining; - remaining = 0; - } else { - rows_to_output = output_slots_available; - output_slots_available = 0; - remaining -= rows_to_output; - } - - // get slices for buffered side for the current output - let slices = slices_from_batches( - buffered_ranges, - &start_indices, - buffered_idx, - rows_to_output, - ); - - output_arrays - .iter_mut() - .zip(self.schema.fields().iter()) - .zip(self.column_indices.iter()) - .map(|((array, field), column_index)| { - if column_index.is_left { - // repeat streamed `rows_to_output` times - repeat_streamed_cell( - stream_batch, - stream_row, - rows_to_output, - array, - column_index, - ); - } else { - // copy buffered start from: `buffered_idx`, len: `rows_to_output` - copy_slices(&batches, &slices, array, column_index); - } - }); - - if output_slots_available == 0 { - let result = make_batch(self.schema.clone(), output_arrays); - - if let Err(e) = sender.send(result).await { - println!("ERROR batch via inner join stream: {}", e); - }; - - output_arrays = new_arrays(&self.schema, target_batch_size)?; - output_slots_available = target_batch_size; - } - - buffered_idx += rows_to_output; - rows_to_output = 0; - } - Ok((output_slots_available, output_arrays)) - } - - async fn stream_copy_buffer_null( - &mut self, - target_batch_size: usize, - output_slots_available: usize, - mut output_arrays: Vec>, - sender: &Sender>, - ) -> Result<(usize, Vec>)> { - let mut output_slots_available = output_slots_available; - let stream_batch = &self.stream_batch.batch.unwrap().batch; - let batch = vec![stream_batch]; - let stream_range = - &self.stream_batch.batch.unwrap().ranges[&self.stream_batch.cur_range]; - let mut remaining = range_len(stream_range); - - let mut unfinished = true; - let mut streamed_idx = self.stream_batch.cur_row; - let mut rows_to_output = 0; - - // output each buffered matching record once - while unfinished { - if output_slots_available >= remaining { - unfinished = false; - rows_to_output = remaining; - output_slots_available -= remaining; - remaining = 0; - } else { - rows_to_output = output_slots_available; - output_slots_available = 0; - remaining -= rows_to_output; - } - - let slice = vec![Slice { - batch_idx: 0, - start_idx: streamed_idx, - len: rows_to_output, - }]; - - output_arrays - .iter_mut() - .zip(self.schema.fields().iter()) - .zip(self.column_indices.iter()) - .map(|((array, field), column_index)| { - if column_index.is_left { - copy_slices(&batch, &slice, array, column_index); - } else { - (0..rows_to_output).for_each(|_| array.push_null()); - } - }); - - if output_slots_available == 0 { - let result = make_batch(self.schema.clone(), output_arrays); - - if let Err(e) = sender.send(result).await { - println!("ERROR batch via inner join stream: {}", e); - }; - - output_arrays = new_arrays(&self.schema, target_batch_size)?; - output_slots_available = target_batch_size; - } - - streamed_idx += rows_to_output; - rows_to_output = 0; - } - Ok((output_slots_available, output_arrays)) - } - - async fn stream_copy_buffer_omit( - &mut self, - target_batch_size: usize, - output_slots_available: usize, - mut output_arrays: Vec>, - sender: &Sender>, - ) -> Result<(usize, Vec>)> { - let mut output_slots_available = output_slots_available; - let stream_batch = &self.stream_batch.batch.unwrap().batch; - let batch = vec![stream_batch]; - let stream_range = - &self.stream_batch.batch.unwrap().ranges[&self.stream_batch.cur_range]; - let mut remaining = range_len(stream_range); - - let mut unfinished = true; - let mut streamed_idx = self.stream_batch.cur_row; - let mut rows_to_output = 0; - - // output each buffered matching record once - while unfinished { - if output_slots_available >= remaining { - unfinished = false; - rows_to_output = remaining; - output_slots_available -= remaining; - remaining = 0; - } else { - rows_to_output = output_slots_available; - output_slots_available = 0; - remaining -= rows_to_output; - } - - let slice = vec![Slice { - batch_idx: 0, - start_idx: streamed_idx, - len: rows_to_output, - }]; - - output_arrays - .iter_mut() - .zip(self.schema.fields().iter()) - .zip(self.column_indices.iter()) - .map(|((array, field), column_index)| { - copy_slices(&batch, &slice, array, column_index); - }); - - if output_slots_available == 0 { - let result = make_batch(self.schema.clone(), output_arrays); - - if let Err(e) = sender.send(result).await { - println!("ERROR batch via inner join stream: {}", e); - }; - - output_arrays = new_arrays(&self.schema, target_batch_size)?; - output_slots_available = target_batch_size; - } - - streamed_idx += rows_to_output; - rows_to_output = 0; - } - Ok((output_slots_available, output_arrays)) - } - - async fn stream_null_buffer_copy( - &mut self, - target_batch_size: usize, - output_slots_available: usize, - mut output_arrays: Vec>, - sender: &Sender>, - ) -> Result<(usize, Vec>)> { - let mut output_slots_available = output_slots_available; - let mut remaining = self.buffered_batches.row_num; - - let batches = self - .buffered_batches - .batches - .iter() - .map(|prb| &prb.batch) - .collect::>(); - let buffered_ranges = &self.buffered_batches.ranges; - - let mut unfinished = true; - let mut buffered_idx = 0; - let mut rows_to_output = 0; - let start_indices = range_start_indices(buffered_ranges); - - // output each buffered matching record once - while unfinished { - if output_slots_available >= remaining { - unfinished = false; - rows_to_output = remaining; - output_slots_available -= remaining; - remaining = 0; - } else { - rows_to_output = output_slots_available; - output_slots_available = 0; - remaining -= rows_to_output; - } - - // get slices for buffered side for the current output - let slices = slices_from_batches( - buffered_ranges, - &start_indices, - buffered_idx, - rows_to_output, - ); - - output_arrays - .iter_mut() - .zip(self.schema.fields().iter()) - .zip(self.column_indices.iter()) - .map(|((array, field), column_index)| { - if column_index.is_left { - (0..rows_to_output).for_each(|_| array.push_null()); - } else { - // copy buffered start from: `buffered_idx`, len: `rows_to_output` - copy_slices(&batches, &slices, array, column_index); - } - }); - - if output_slots_available == 0 { - let result = make_batch(self.schema.clone(), output_arrays); - - if let Err(e) = sender.send(result).await { - println!("ERROR batch via inner join stream: {}", e); - }; - - output_arrays = new_arrays(&self.schema, target_batch_size)?; - output_slots_available = target_batch_size; - } - - buffered_idx += rows_to_output; - rows_to_output = 0; - } - Ok((output_slots_available, output_arrays)) - } - - async fn find_next_inner_match(&mut self) -> Result { - if self.stream_batch.key_any_null() { - let more_stream = self.advance_streamed_key_null_free().await?; - if !more_stream { - return Ok(false); - } - } - - if self.buffered_batches.key_any_null() { - let more_buffer = self.advance_buffered_key_null_free().await?; - if !more_buffer { - return Ok(false); - } - } - - loop { - let current_cmp = self.compare_stream_buffer()?; - match current_cmp { - Ordering::Less => { - let more_stream = self.advance_streamed_key_null_free().await?; - if !more_stream { - return Ok(false); - } - } - Ordering::Equal => return Ok(true), - Ordering::Greater => { - let more_buffer = self.advance_buffered_key_null_free().await?; - if !more_buffer { - return Ok(false); - } - } - } - } - } - - async fn find_next_outer(&mut self, buffer_ends: bool) -> Result { - let more_stream = self.advance_streamed_key().await?; - if buffer_ends { - return Ok(OuterMatchResult { - get_match: false, - buffered_ended: true, - more_output: more_stream, - }); - } else { - if !more_stream { - return Ok(OuterMatchResult { - get_match: false, - buffered_ended: false, - more_output: false, - }); - } - - if self.buffered_batches.key_any_null() { - let more_buffer = self.advance_buffered_key_null_free().await?; - if !more_buffer { - return Ok(OuterMatchResult { - get_match: false, - buffered_ended: true, - more_output: true, - }); - } - } - - loop { - if self.stream_batch.key_any_null() { - return Ok(OuterMatchResult { - get_match: false, - buffered_ended: false, - more_output: true, - }); - } - - let current_cmp = self.compare_stream_buffer()?; - match current_cmp { - Ordering::Less => { - return Ok(OuterMatchResult { - get_match: false, - buffered_ended: false, - more_output: true, - }) - } - Ordering::Equal => { - return Ok(OuterMatchResult { - get_match: true, - buffered_ended: false, - more_output: true, - }) - } - Ordering::Greater => { - let more_buffer = self.advance_buffered_key_null_free().await?; - if !more_buffer { - return Ok(OuterMatchResult { - get_match: false, - buffered_ended: true, - more_output: true, - }); - } - } - } - } - } - } - - fn compare_stream_buffer(&self) -> Result { - let stream_arrays = - join_arrays(&self.stream_batch.batch.unwrap().batch, &self.on_streamed); - let buffer_arrays = - join_arrays(&self.buffered_batches.batches[0].batch, &self.on_buffered); - comp_rows( - self.stream_batch.cur_row, - self.buffered_batches.key_idx.unwrap(), - &stream_arrays, - &buffer_arrays, - ) - } -} - /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. #[derive(Debug)] @@ -1254,11 +222,6 @@ impl ExecutionPlan for SortMergeJoinExec { &self.schema, )?; - let (tx, rx): ( - Sender>, - Receiver>, - ) = tokio::sync::mpsc::channel(2); - let left_sort = self .left .as_any() @@ -1278,12 +241,12 @@ impl ExecutionPlan for SortMergeJoinExec { .map(|s| s.clone()) .collect::>(); - let mut driver = match self.join_type { + let join = match self.join_type { JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Semi - | JoinType::Anti => SortMergeJoinDriver::new( + | JoinType::Anti => SortMergeJoinCommon::new( left, right, on_left, @@ -1292,9 +255,10 @@ impl ExecutionPlan for SortMergeJoinExec { right_sort, column_indices, self.schema.clone(), + self.join_type, RUNTIME_ENV.clone(), - ), - JoinType::Right => SortMergeJoinDriver::new( + )?, + JoinType::Right => SortMergeJoinCommon::new( right, left, on_right, @@ -1303,22 +267,18 @@ impl ExecutionPlan for SortMergeJoinExec { left_sort, column_indices, self.schema.clone(), + self.join_type, RUNTIME_ENV.clone(), - ), + )?, }; - - match self.join_type { - JoinType::Inner => driver.inner_join_driver(&tx).await?, - JoinType::Left => driver.outer_join_driver(&tx).await?, - JoinType::Right => driver.outer_join_driver(&tx).await?, - JoinType::Full => driver.full_outer_driver(&tx).await?, - JoinType::Semi => driver.semi_join_driver(&tx).await?, - JoinType::Anti => driver.anti_join_driver(&tx).await?, - } - - let result = RecordBatchReceiverStream::create(&self.schema, rx); - - Ok(result) + Ok(match self.join_type { + JoinType::Inner => Box::pin(InnerJoiner::new(join)), + JoinType::Left => todo!(), + JoinType::Right => todo!(), + JoinType::Full => todo!(), + JoinType::Semi => todo!(), + JoinType::Anti => todo!(), + }) } fn fmt_as(