diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 3511c6da2c..0c75de9be6 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -334,44 +334,80 @@ pub fn physical_plan_to_pipeline( // 2. Join type. Different join types have different requirements for which side can build the probe table. let left_stats_state = left.get_stats_state(); let right_stats_state = right.get_stats_state(); - let build_on_left = match (left_stats_state, right_stats_state) { - (StatsState::Materialized(left_stats), StatsState::Materialized(right_stats)) => { - left_stats.approx_stats.upper_bound_bytes - <= right_stats.approx_stats.upper_bound_bytes - } - // If stats are only available on the right side of the join, and the upper bound bytes on the - // right are under the broadcast join size threshold, we build on the right instead of the left. - (StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => right_stats - .approx_stats - .upper_bound_bytes - .map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold), - // If stats are not available, we fall back and build on the left by default. - _ => true, - }; - - // TODO(desmond): We might potentially want to flip the probe table side for - // left/right outer joins if one side is significantly larger. Needs to be tuned. - // - // In greater detail, consider a right outer join where the left side is several orders - // of magnitude larger than the right. An extreme example might have 1B rows on the left, - // and 10 rows on the right. - // - // Typically we would build the probe table on the left, then stream rows from the right - // to match against the probe table. But in this case we would have a giant intermediate - // probe table. - // - // An alternative 2-pass algorithm would be to: - // 1. Build the probe table on the right, but add a second data structure to keep track of - // which rows on the right have been matched. - // 2. Stream rows on the left until all rows have been seen. - // 3. Finally, emit all unmatched rows from the right. let build_on_left = match join_type { - JoinType::Inner => build_on_left, - JoinType::Outer => build_on_left, - // For left outer joins, we build on right so we can stream the left side. - JoinType::Left => false, - // For right outer joins, we build on left so we can stream the right side. - JoinType::Right => true, + // Inner and outer joins can build on either side. If stats are available, choose the smaller side. + // Else, default to building on the left. + JoinType::Inner | JoinType::Outer => match (left_stats_state, right_stats_state) { + ( + StatsState::Materialized(left_stats), + StatsState::Materialized(right_stats), + ) => { + let left_size = left_stats.approx_stats.upper_bound_bytes; + let right_size = right_stats.approx_stats.upper_bound_bytes; + left_size.zip(right_size).map_or(true, |(l, r)| l <= r) + } + // If stats are only available on the right side of the join, and the upper bound bytes on the + // right are under the broadcast join size threshold, we build on the right instead of the left. + (StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => { + right_stats + .approx_stats + .upper_bound_bytes + .map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold) + } + _ => true, + }, + // Left joins can build on the left side, but prefer building on the right because building on left requires keeping track + // of used indices in a bitmap. If stats are available, only select the left side if its smaller than the right side by a factor of 1.5. + JoinType::Left => match (left_stats_state, right_stats_state) { + ( + StatsState::Materialized(left_stats), + StatsState::Materialized(right_stats), + ) => { + let left_size = left_stats.approx_stats.upper_bound_bytes; + let right_size = right_stats.approx_stats.upper_bound_bytes; + left_size + .zip(right_size) + .map_or(false, |(l, r)| (r as f64) >= ((l as f64) * 1.5)) + } + // If stats are only available on the left side of the join, and the upper bound bytes on the left + // are under the broadcast join size threshold, we build on the left instead of the right. + (StatsState::Materialized(left_stats), StatsState::NotMaterialized) => { + left_stats + .approx_stats + .upper_bound_bytes + .map_or(false, |size| { + size <= cfg.broadcast_join_size_bytes_threshold + }) + } + _ => false, + }, + // Right joins can build on the right side, but prefer building on the left because building on right requires keeping track + // of used indices in a bitmap. If stats are available, only select the right side if its smaller than the left side by a factor of 1.5. + JoinType::Right => match (left_stats_state, right_stats_state) { + ( + StatsState::Materialized(left_stats), + StatsState::Materialized(right_stats), + ) => { + let left_size = left_stats.approx_stats.upper_bound_bytes; + let right_size = right_stats.approx_stats.upper_bound_bytes; + left_size + .zip(right_size) + .map_or(true, |(l, r)| (r as f64) < ((l as f64) * 1.5)) + } + // If stats are only available on the right side of the join, and the upper bound bytes on the + // right are under the broadcast join size threshold, we build on the right instead of the left. + (StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => { + right_stats + .approx_stats + .upper_bound_bytes + .map_or(false, |size| { + size <= cfg.broadcast_join_size_bytes_threshold + }) + } + _ => false, + }, + + // Anti and semi joins always build on the right JoinType::Anti | JoinType::Semi => false, }; let (build_on, probe_on, build_child, probe_child) = match build_on_left { diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 4c3729c150..fd7e381ee6 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -122,6 +122,7 @@ struct OuterHashJoinParams { common_join_keys: Vec, left_non_join_columns: Vec, right_non_join_columns: Vec, + left_non_join_schema: SchemaRef, right_non_join_schema: SchemaRef, join_type: JoinType, build_on_left: bool, @@ -129,6 +130,7 @@ struct OuterHashJoinParams { pub(crate) struct OuterHashJoinProbeSink { params: Arc, + needs_bitmap: bool, output_schema: SchemaRef, probe_state_bridge: BroadcastStateBridgeRef, } @@ -145,17 +147,23 @@ impl OuterHashJoinProbeSink { output_schema: &SchemaRef, probe_state_bridge: BroadcastStateBridgeRef, ) -> Self { + let needs_bitmap = join_type == JoinType::Outer + || join_type == JoinType::Right && !build_on_left + || join_type == JoinType::Left && build_on_left; // For outer joins, we need to swap the left and right schemas if we are building on the right. let (left_schema, right_schema) = match (join_type, build_on_left) { (JoinType::Outer, false) => (right_schema, left_schema), _ => (left_schema, right_schema), }; - let left_non_join_columns = left_schema + let left_non_join_fields = left_schema .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) + .values() + .filter(|f| !common_join_keys.contains(&f.name)) .cloned() .collect(); + let left_non_join_schema = + Arc::new(Schema::new(left_non_join_fields).expect("left schema should be valid")); + let left_non_join_columns = left_non_join_schema.fields.keys().cloned().collect(); let right_non_join_fields = right_schema .fields .values() @@ -172,15 +180,85 @@ impl OuterHashJoinProbeSink { common_join_keys, left_non_join_columns, right_non_join_columns, + left_non_join_schema, right_non_join_schema, join_type, build_on_left, }), + needs_bitmap, output_schema: output_schema.clone(), probe_state_bridge, } } + fn probe_left_right_with_bitmap( + input: &Arc, + bitmap_builder: &mut IndexBitmapBuilder, + probe_state: &ProbeState, + join_type: JoinType, + probe_on: &[ExprRef], + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_columns: &[String], + ) -> DaftResult> { + let probe_table = probe_state.get_probeable(); + let tables = probe_state.get_tables(); + + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + false, + tables.iter().map(|t| t.len()).sum(), + )?; + + let input_tables = input.get_tables()?; + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; + + drop(_growables); + { + let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + let join_keys = table.eval_expression_list(probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + bitmap_builder + .mark_used(build_side_table_idx as usize, build_row_idx as usize); + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let final_table = if join_type == JoinType::Left { + let join_table = build_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; + let right = probe_side_table.get_columns(right_non_join_columns)?; + join_table.union(&left)?.union(&right)? + } else { + let join_table = build_side_table.get_columns(common_join_keys)?; + let left = probe_side_table.get_columns(left_non_join_columns)?; + let right = build_side_table.get_columns(right_non_join_columns)?; + join_table.union(&left)?.union(&right)? + }; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } + fn probe_left_right( input: &Arc, probe_state: &ProbeState, @@ -190,8 +268,8 @@ impl OuterHashJoinProbeSink { left_non_join_columns: &[String], right_non_join_columns: &[String], ) -> DaftResult> { - let probe_table = probe_state.get_probeable().clone(); - let tables = probe_state.get_tables().clone(); + let probe_table = probe_state.get_probeable(); + let tables = probe_state.get_tables(); let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); let mut build_side_growable = GrowableTable::new( @@ -261,8 +339,8 @@ impl OuterHashJoinProbeSink { right_non_join_columns: &[String], build_on_left: bool, ) -> DaftResult> { - let probe_table = probe_state.get_probeable().clone(); - let tables = probe_state.get_tables().clone(); + let probe_table = probe_state.get_probeable(); + let tables = probe_state.get_tables(); let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); // Need to set use_validity to true here because we add nulls to the build side @@ -320,13 +398,9 @@ impl OuterHashJoinProbeSink { ))) } - async fn finalize_outer( + async fn merge_bitmaps_and_construct_null_table( mut states: Vec>, - common_join_keys: &[String], - left_non_join_columns: &[String], - right_non_join_schema: &SchemaRef, - build_on_left: bool, - ) -> DaftResult>> { + ) -> DaftResult { let mut states_iter = states.iter_mut(); let first_state = states_iter .next() @@ -376,8 +450,17 @@ impl OuterHashJoinProbeSink { .map(|(bitmap, table)| table.mask_filter(&bitmap.into_series())) .collect::>>()?; - let build_side_table = Table::concat(&leftovers)?; + Table::concat(&leftovers) + } + async fn finalize_outer( + states: Vec>, + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_schema: &SchemaRef, + build_on_left: bool, + ) -> DaftResult>> { + let build_side_table = Self::merge_bitmaps_and_construct_null_table(states).await?; let join_table = build_side_table.get_columns(common_join_keys)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = { @@ -401,6 +484,60 @@ impl OuterHashJoinProbeSink { None, )))) } + + async fn finalize_left( + states: Vec>, + common_join_keys: &[String], + left_non_join_columns: &[String], + right_non_join_schema: &SchemaRef, + ) -> DaftResult>> { + let build_side_table = Self::merge_bitmaps_and_construct_null_table(states).await?; + let join_table = build_side_table.get_columns(common_join_keys)?; + let left = build_side_table.get_columns(left_non_join_columns)?; + let right = { + let columns = right_non_join_schema + .fields + .values() + .map(|field| Series::full_null(&field.name, &field.dtype, left.len())) + .collect::>(); + Table::new_unchecked(right_non_join_schema.clone(), columns, left.len()) + }; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Some(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + )))) + } + + async fn finalize_right( + states: Vec>, + common_join_keys: &[String], + right_non_join_columns: &[String], + left_non_join_schema: &SchemaRef, + ) -> DaftResult>> { + let build_side_table = Self::merge_bitmaps_and_construct_null_table(states).await?; + let join_table = build_side_table.get_columns(common_join_keys)?; + let left = { + let columns = left_non_join_schema + .fields + .values() + .map(|field| Series::full_null(&field.name, &field.dtype, build_side_table.len())) + .collect::>(); + Table::new_unchecked( + left_non_join_schema.clone(), + columns, + build_side_table.len(), + ) + }; + let right = build_side_table.get_columns(right_non_join_columns)?; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Some(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + )))) + } } impl StreamingSink for OuterHashJoinProbeSink { @@ -416,6 +553,7 @@ impl StreamingSink for OuterHashJoinProbeSink { return Ok((state, StreamingSinkOutput::NeedMoreInput(Some(empty)))).into(); } + let needs_bitmap = self.needs_bitmap; let params = self.params.clone(); runtime_ref .spawn(async move { @@ -425,6 +563,22 @@ impl StreamingSink for OuterHashJoinProbeSink { .expect("OuterHashJoinProbeSink should have OuterHashJoinProbeState"); let probe_state = outer_join_state.get_or_build_probe_state().await; let out = match params.join_type { + JoinType::Left | JoinType::Right if needs_bitmap => { + Self::probe_left_right_with_bitmap( + &input, + outer_join_state + .get_or_build_bitmap() + .await + .as_mut() + .expect("bitmap should be set"), + &probe_state, + params.join_type, + ¶ms.probe_on, + ¶ms.common_join_keys, + ¶ms.left_non_join_columns, + ¶ms.right_non_join_columns, + ) + } JoinType::Left | JoinType::Right => Self::probe_left_right( &input, &probe_state, @@ -467,7 +621,7 @@ impl StreamingSink for OuterHashJoinProbeSink { fn make_state(&self) -> Box { Box::new(OuterHashJoinState::Building( self.probe_state_bridge.clone(), - self.params.join_type == JoinType::Outer, + self.needs_bitmap, )) } @@ -476,18 +630,37 @@ impl StreamingSink for OuterHashJoinProbeSink { states: Vec>, runtime_ref: &RuntimeRef, ) -> StreamingSinkFinalizeResult { - if self.params.join_type == JoinType::Outer { + if self.needs_bitmap { let params = self.params.clone(); runtime_ref .spawn(async move { - Self::finalize_outer( - states, - ¶ms.common_join_keys, - ¶ms.left_non_join_columns, - ¶ms.right_non_join_schema, - params.build_on_left, - ) - .await + match params.join_type { + JoinType::Left => Self::finalize_left( + states, + ¶ms.common_join_keys, + ¶ms.left_non_join_columns, + ¶ms.right_non_join_schema, + ) + .await, + JoinType::Right => Self::finalize_right( + states, + ¶ms.common_join_keys, + ¶ms.right_non_join_columns, + ¶ms.left_non_join_schema, + ) + .await, + JoinType::Outer => Self::finalize_outer( + states, + ¶ms.common_join_keys, + ¶ms.left_non_join_columns, + ¶ms.right_non_join_schema, + params.build_on_left, + ) + .await, + _ => unreachable!( + "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" + ), + } }) .into() } else { diff --git a/tests/sql/test_joins.py b/tests/sql/test_joins.py index b5c9d29ee3..590d450522 100644 --- a/tests/sql/test_joins.py +++ b/tests/sql/test_joins.py @@ -98,7 +98,8 @@ def test_joins_with_duplicate_columns(): """ SELECT * FROM table1 t1 - LEFT JOIN table2 t2 on t2.id = t1.id; + LEFT JOIN table2 t2 on t2.id = t1.id + ORDER BY t1.id; """, catalog, ).collect()