Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
some refactoring + hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
connortsui20 committed Feb 25, 2024
1 parent 7af6454 commit 6066b2d
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 25 deletions.
99 changes: 77 additions & 22 deletions eggstrain/src/execution/operators/hash_join.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,89 @@
use super::{BinaryOperator, Operator};
use crate::execution::record_table::RecordTable;
use arrow::array::ArrayRef;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExprRef;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_common::Result;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;

/// TODO docs
pub struct HashJoin {
_todo: bool,
children: Vec<Arc<dyn ExecutionPlan>>,
schema: SchemaRef,
left_schema: SchemaRef,
right_schema: SchemaRef,
equate_on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
children: Vec<Arc<dyn ExecutionPlan>>,
}

/// TODO docs
impl HashJoin {
pub(crate) fn new(
schema: SchemaRef,
left_schema: SchemaRef,
right_schema: SchemaRef,
equate_on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Self {
Self {
_todo: false,
children: vec![],
schema,
children,
left_schema,
right_schema,
equate_on,
}
}

/// Given a [`RecordBatch`]`, hashes based on the input physical expressions
/// Given a [`RecordBatch`]`, hashes based on the input physical expressions.
///
/// TODO docs
fn hash_batch(&self, batch: &RecordBatch) -> Vec<usize> {
todo!("use self.equate_on to hash each of the tuple keys in the batch")
fn hash_batch<const LEFT: bool>(&self, batch: &RecordBatch) -> Result<Vec<u64>> {
let rows = batch.num_rows();

// A vector of columns, horizontally these are the join keys
let mut column_vals = Vec::with_capacity(self.equate_on.len());

for (left_eq, right_eq) in self.equate_on.iter() {
let eq = if LEFT { left_eq } else { right_eq };

// Extract a single column
let col_val = eq.evaluate(batch)?;
match col_val {
ColumnarValue::Array(arr) => column_vals.push(arr),
ColumnarValue::Scalar(s) => {
return Err(DataFusionError::NotImplemented(format!(
"Join physical expression scalar condition on {:#?} not implemented",
s
)));
}
}
}

let mut hashes = Vec::with_capacity(column_vals.len());
create_hashes(&column_vals, &Default::default(), &mut hashes)?;
assert_eq!(hashes.len(), rows);

Ok(hashes)
}

/// Builds the Hash Table from the [`RecordBatch`]es coming from the left child.
///
/// TODO docs
async fn build_table(&self, mut rx_left: broadcast::Receiver<RecordBatch>) -> RecordTable {
async fn build_table(
&self,
mut rx_left: broadcast::Receiver<RecordBatch>,
) -> Result<RecordTable> {
// Take in all of the record batches from the left and create a hash table
let mut record_table = RecordTable::new(self.schema.clone());
let mut record_table = RecordTable::new(self.left_schema.clone());

loop {
match rx_left.recv().await {
Ok(batch) => {
// TODO gather N batches and use rayon to insert all at once
let hashes = self.hash_batch(&batch);
let hashes = self.hash_batch::<true>(&batch)?;
record_table.insert_batch(batch, hashes);
}
Err(e) => match e {
Expand All @@ -60,38 +93,55 @@ impl HashJoin {
}
}

record_table
Ok(record_table)
}

/// Given a single batch (coming from the right child), probes the hash table and outputs a
/// [`RecordBatch`] for every tuple on the right that gets matched with a tuple in the hash table.
///
/// Note: This is super inefficient since its possible that we could emit a bunch of
/// [`RecordBatch`]es that have just 1 tuple in them. This is a place for easy optimization.
/// [`RecordBatch`]es that have just 1 tuple in them.
///
/// TODO This is a place for easy optimization.
///
/// TODO only implements an inner join
async fn probe_table(
&self,
table: &RecordTable,
right_batch: RecordBatch,
tx: &broadcast::Sender<RecordBatch>,
) {
let hashes = self.hash_batch(&right_batch);
) -> Result<()> {
let hashes = self.hash_batch::<false>(&right_batch)?;

let left_column_count = self.left_schema.fields().size();
let right_column_count = self.right_schema.fields().size();
let output_columns = left_column_count + right_column_count - self.equate_on.len();

for (right_row, &hash) in hashes.iter().enumerate() {
// Construct a RecordBatch for each tuple that might get joined with tuples in the hash table
let mut out_columns: Vec<(String, ArrayRef)> = Vec::with_capacity(output_columns);

// For each of these hashes, check if it is in the table
let Some(records) = table.get_records(hash) else {
return;
return Ok(());
};
assert!(!records.is_empty());

// There are records associated with this hash value, so we need to emit things
for &record in records {
let (left_batch, left_row) = table.get(record).unwrap();

todo!("Join tuples together and then send through `tx`");
// Left tuple is in `left_batch` at `left_row` offset
// Right tuple is in `right_batch` at `right_row` offset
}

let joined_batch = RecordBatch::try_from_iter(out_columns)?;

tx.send(joined_batch)
.expect("Unable to send the projected batch");
}

Ok(())
}
}

Expand Down Expand Up @@ -124,13 +174,18 @@ impl BinaryOperator for HashJoin {
) {
// Phase 1: Build Phase
// TODO assign to its own tokio task
let record_table = self.build_table(rx_left).await;
let record_table = self
.build_table(rx_left)
.await
.expect("Unable to build hash table");

// Phase 2: Probe Phase
loop {
match rx_right.recv().await {
Ok(batch) => {
self.probe_table(&record_table, batch, &tx).await;
self.probe_table(&record_table, batch, &tx)
.await
.expect("Unable to probe hash table");
}
Err(e) => match e {
RecvError::Closed => break,
Expand Down
6 changes: 3 additions & 3 deletions eggstrain/src/execution/record_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::collections::HashMap; // TODO replace with a raw table

pub struct RecordTable {
/// Maps a Hash value to a `RecordIndex` into the `RecordBuffer`
inner: HashMap<usize, Vec<RecordIndex>>,
inner: HashMap<u64, Vec<RecordIndex>>,
buffer: RecordBuffer,
}

Expand All @@ -23,7 +23,7 @@ impl RecordTable {
}
}

pub fn insert_batch(&mut self, batch: RecordBatch, hashes: Vec<usize>) {
pub fn insert_batch(&mut self, batch: RecordBatch, hashes: Vec<u64>) {
assert_eq!(
batch.num_rows(),
hashes.len(),
Expand All @@ -44,7 +44,7 @@ impl RecordTable {
}
}

pub fn get_records(&self, hash: usize) -> Option<&Vec<RecordIndex>> {
pub fn get_records(&self, hash: u64) -> Option<&Vec<RecordIndex>> {
self.inner.get(&hash)
}

Expand Down

0 comments on commit 6066b2d

Please sign in to comment.