Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Nov 7, 2023
1 parent 97e09dc commit 02651f6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 95 deletions.
35 changes: 17 additions & 18 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ impl ExecutionPlan for HashJoinExec {
reservation,
batch_size,
probe_batch: None,
state: HashJoinStreamState::default(),
output_state: HashJoinOutputState::default(),
}))
}

Expand Down Expand Up @@ -756,20 +756,20 @@ where
// State for storing left/right side indices used for partial batch output
// & producing ranges for adjusting indices
#[derive(Debug, Default)]
pub(crate) struct HashJoinStreamState {
pub(crate) struct HashJoinOutputState {
// total rows in current probe batch
probe_rows: usize,
// saved probe-build indices to resume matching from
last_matched_indices: Option<(usize, usize)>,
// current iteration has been updated
matched_indices_updated: bool,
// tracking last joined probe side index seen for further indices adjustment
// last probe side index, joined during current iteration
last_joined_probe_index: Option<usize>,
// tracking last joined probe side index seen for further indices adjustment
// last probe side index, joined during previous iteration
prev_joined_probe_index: Option<usize>,
}

impl HashJoinStreamState {
impl HashJoinOutputState {
// set total probe rows to process
pub(crate) fn set_probe_rows(&mut self, probe_rows: usize) {
self.probe_rows = probe_rows;
Expand Down Expand Up @@ -882,7 +882,7 @@ struct HashJoinStream {
/// (cross-join due to key duplication on build side) `HashJoinStream` saves its state
/// and emits result batch to upstream operator.
/// On next poll these indices are used to skip already matched rows and adjusted probe-side indices.
state: HashJoinStreamState,
output_state: HashJoinOutputState,
}

impl RecordBatchStream for HashJoinStream {
Expand Down Expand Up @@ -936,7 +936,7 @@ pub(crate) fn build_equal_condition_join_indices<T: JoinHashMapType>(
build_side: JoinSide,
deleted_offset: Option<usize>,
output_limit: usize,
state: &mut HashJoinStreamState,
state: &mut HashJoinOutputState,
) -> Result<(UInt64Array, UInt32Array)> {
let keys_values = probe_on
.iter()
Expand Down Expand Up @@ -1177,16 +1177,15 @@ impl HashJoinStream {

// Fetch next probe batch
if self.probe_batch.is_none() {
match self.right.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(batch))) => {
self.state.set_probe_rows(batch.num_rows());
match ready!(self.right.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
self.output_state.set_probe_rows(batch.num_rows());
self.probe_batch = Some(batch);
}
Poll::Ready(None) => {
None => {
self.probe_batch = None;
}
Poll::Ready(Some(err)) => return Poll::Ready(Some(err)),
Poll::Pending => return Poll::Pending,
Some(err) => return Poll::Ready(Some(err)),
}
}

Expand All @@ -1211,7 +1210,7 @@ impl HashJoinStream {
JoinSide::Left,
None,
self.batch_size,
&mut self.state,
&mut self.output_state,
);

let result = match left_right_indices {
Expand All @@ -1229,7 +1228,7 @@ impl HashJoinStream {
let (left_side, right_side) = adjust_indices_by_join_type(
left_side,
right_side,
self.state.adjust_range(),
self.output_state.adjust_range(),
self.join_type,
);

Expand All @@ -1245,9 +1244,9 @@ impl HashJoinStream {
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());

if self.state.is_completed() {
if self.output_state.is_completed() {
self.probe_batch = None;
self.state.reset_state();
self.output_state.reset_state();
}

Some(result)
Expand Down Expand Up @@ -2732,7 +2731,7 @@ mod tests {
JoinSide::Left,
None,
64,
&mut HashJoinStreamState::default(),
&mut HashJoinOutputState::default(),
)?;

let mut left_ids = UInt64Builder::with_capacity(0);
Expand Down
12 changes: 6 additions & 6 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ use crate::coalesce_batches::concat_batches;
use crate::joins::utils::{
append_right_indices, apply_join_filter_to_indices, build_batch_from_indices,
build_join_schema, check_join_is_valid, estimate_join_statistics, get_anti_indices,
get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices,
get_semi_u64_indices, partitioned_join_output_partitioning, BuildProbeJoinMetrics,
ColumnIndex, JoinFilter, OnceAsync, OnceFut,
get_final_indices_from_bit_map, get_semi_indices,
partitioned_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
OnceAsync, OnceFut,
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::{
Expand Down Expand Up @@ -649,20 +649,20 @@ fn adjust_indices_by_join_type(
// matched
// unmatched left row will be produced in this batch
let left_unmatched_indices =
get_anti_u64_indices(0..count_left_batch, &left_indices);
get_anti_indices(0..count_left_batch, &left_indices);
// combine the matched and unmatched left result together
append_left_indices(left_indices, right_indices, left_unmatched_indices)
}
JoinType::LeftSemi => {
// need to remove the duplicated record in the left side
let left_indices = get_semi_u64_indices(0..count_left_batch, &left_indices);
let left_indices = get_semi_indices(0..count_left_batch, &left_indices);
// the right_indices will not be used later for the `left semi` join
(left_indices, right_indices)
}
JoinType::LeftAnti => {
// need to remove the duplicated record in the left side
// get the anti index for the left side
let left_indices = get_anti_u64_indices(0..count_left_batch, &left_indices);
let left_indices = get_anti_indices(0..count_left_batch, &left_indices);
// the right_indices will not be used later for the `left anti` join
(left_indices, right_indices)
}
Expand Down
18 changes: 9 additions & 9 deletions datafusion/physical-plan/src/joins/symmetric_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ use std::vec;
use std::{any::Any, usize};

use crate::common::SharedMemoryReservation;
use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
use crate::joins::hash_join::{
build_equal_condition_join_indices, update_hash, HashJoinOutputState,
};
use crate::joins::hash_join_utils::{
calculate_filter_expr_intervals, combine_two_batches,
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
Expand Down Expand Up @@ -72,8 +74,6 @@ use futures::{Stream, StreamExt};
use hashbrown::HashSet;
use parking_lot::Mutex;

use super::hash_join::HashJoinStreamState;

const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;

/// A symmetric hash join with range conditions is when both streams are hashed on the
Expand Down Expand Up @@ -553,7 +553,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
null_equals_null: self.null_equals_null,
final_result: false,
reservation,
state: HashJoinStreamState::default(),
output_state: HashJoinOutputState::default(),
}))
}
}
Expand Down Expand Up @@ -591,7 +591,7 @@ struct SymmetricHashJoinStream {
/// Flag indicating whether there is nothing to process anymore
final_result: bool,
/// Stream state for compatibility with HashJoinExec
state: HashJoinStreamState,
output_state: HashJoinOutputState,
}

impl RecordBatchStream for SymmetricHashJoinStream {
Expand Down Expand Up @@ -820,7 +820,7 @@ pub(crate) fn join_with_probe_batch(
column_indices: &[ColumnIndex],
random_state: &RandomState,
null_equals_null: bool,
hash_join_stream_state: &mut HashJoinStreamState,
output_state: &mut HashJoinOutputState,
) -> Result<Option<RecordBatch>> {
if build_hash_joiner.input_buffer.num_rows() == 0 || probe_batch.num_rows() == 0 {
return Ok(None);
Expand All @@ -838,11 +838,11 @@ pub(crate) fn join_with_probe_batch(
build_hash_joiner.build_side,
Some(build_hash_joiner.deleted_offset),
usize::MAX,
hash_join_stream_state,
output_state,
)?;

// Resetting state to avoid potential overflows
hash_join_stream_state.reset_state();
output_state.reset_state();

if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
record_visited_indices(
Expand Down Expand Up @@ -1110,7 +1110,7 @@ impl SymmetricHashJoinStream {
&self.column_indices,
&self.random_state,
self.null_equals_null,
&mut self.state,
&mut self.output_state,
)?;
// Increment the offset for the probe hash joiner:
probe_hash_joiner.offset += probe_batch.num_rows();
Expand Down
88 changes: 26 additions & 62 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ use arrow::array::{
use arrow::compute;
use arrow::datatypes::{Field, Schema, SchemaBuilder};
use arrow::record_batch::{RecordBatch, RecordBatchOptions};
use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray};
use arrow_buffer::ArrowNativeType;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::stats::Precision;
use datafusion_common::{
Expand Down Expand Up @@ -920,17 +922,20 @@ pub(crate) fn append_right_indices(
}
}

/// Get unmatched and deduplicated indices for specified range of indices
pub(crate) fn get_anti_indices(
/// Returns `range` indices which are not present in `input_indices`
pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
range: Range<usize>,
input_indices: &UInt32Array,
) -> UInt32Array {
input_indices: &PrimitiveArray<T>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices
.iter()
.flatten()
.map(|v| v as usize)
.map(|v| v.as_usize())
.filter(|v| range.contains(v))
.for_each(|v| {
bitmap.set_bit(v - range.start, true);
Expand All @@ -940,69 +945,26 @@ pub(crate) fn get_anti_indices(

// get the anti index
(range)
.filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u32))
.collect::<UInt32Array>()
.filter_map(|idx| {
(!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
})
.collect::<PrimitiveArray<T>>()
}

/// Get unmatched and deduplicated indices
pub(crate) fn get_anti_u64_indices(
/// Returns intersection of `range` and `input_indices` omitting duplicates
pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
range: Range<usize>,
input_indices: &UInt64Array,
) -> UInt64Array {
input_indices: &PrimitiveArray<T>,
) -> PrimitiveArray<T>
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices
.iter()
.flatten()
.map(|v| v as usize)
.filter(|v| range.contains(v))
.for_each(|v| {
bitmap.set_bit(v - range.start, true);
});

let offset = range.start;

// get the anti index
(range)
.filter_map(|idx| (!bitmap.get_bit(idx - offset)).then_some(idx as u64))
.collect::<UInt64Array>()
}

/// Get matched and deduplicated indices for specified range of indices
pub(crate) fn get_semi_indices(
range: Range<usize>,
input_indices: &UInt32Array,
) -> UInt32Array {
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices
.iter()
.flatten()
.map(|v| v as usize)
.filter(|v| range.contains(v))
.for_each(|v| {
bitmap.set_bit(v - range.start, true);
});

let offset = range.start;

// get the semi index
(range)
.filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u32))
.collect::<UInt32Array>()
}

/// Get matched and deduplicated indices
pub(crate) fn get_semi_u64_indices(
range: Range<usize>,
input_indices: &UInt64Array,
) -> UInt64Array {
let mut bitmap = BooleanBufferBuilder::new(range.len());
bitmap.append_n(range.len(), false);
input_indices
.iter()
.flatten()
.map(|v| v as usize)
.map(|v| v.as_usize())
.filter(|v| range.contains(v))
.for_each(|v| {
bitmap.set_bit(v - range.start, true);
Expand All @@ -1012,8 +974,10 @@ pub(crate) fn get_semi_u64_indices(

// get the semi index
(range)
.filter_map(|idx| (bitmap.get_bit(idx - offset)).then_some(idx as u64))
.collect::<UInt64Array>()
.filter_map(|idx| {
(bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx))
})
.collect::<PrimitiveArray<T>>()
}

/// Metrics for build & probe joins
Expand Down

0 comments on commit 02651f6

Please sign in to comment.