diff --git a/crates/benchmarks/src/bin/merge.rs b/crates/benchmarks/src/bin/merge.rs index d3acb80c0a..affae8b7dd 100644 --- a/crates/benchmarks/src/bin/merge.rs +++ b/crates/benchmarks/src/bin/merge.rs @@ -265,6 +265,14 @@ async fn benchmark_merge_tpcds( .object_store() .delete(&Path::parse("_delta_log/00000000000000000002.json")?) .await?; + table + .object_store() + .delete(&Path::parse("_delta_log/00000000000000000003.json")?) + .await?; + let _ = table + .object_store() + .delete(&Path::parse("_delta_log/00000000000000000004.json")?) + .await; Ok((duration, metrics)) } diff --git a/crates/deltalake-core/src/delta_datafusion/logical.rs b/crates/deltalake-core/src/delta_datafusion/logical.rs index 7b05dd57d9..75ed53d1b1 100644 --- a/crates/deltalake-core/src/delta_datafusion/logical.rs +++ b/crates/deltalake-core/src/delta_datafusion/logical.rs @@ -1,5 +1,7 @@ //! Logical Operations for DataFusion +use std::collections::HashSet; + use datafusion_expr::{LogicalPlan, UserDefinedLogicalNodeCore}; // Metric Observer is used to update DataFusion metrics from a record batch. @@ -10,6 +12,7 @@ pub(crate) struct MetricObserver { // id is preserved during conversion to physical node pub id: String, pub input: LogicalPlan, + pub enable_pushdown: bool, } impl UserDefinedLogicalNodeCore for MetricObserver { @@ -35,6 +38,18 @@ impl UserDefinedLogicalNodeCore for MetricObserver { write!(f, "MetricObserver id={}", &self.id) } + fn prevent_predicate_push_down_columns(&self) -> HashSet { + if self.enable_pushdown { + HashSet::new() + } else { + self.schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() + } + } + fn from_template( &self, _exprs: &[datafusion_expr::Expr], @@ -43,6 +58,7 @@ impl UserDefinedLogicalNodeCore for MetricObserver { MetricObserver { id: self.id.clone(), input: inputs[0].clone(), + enable_pushdown: self.enable_pushdown, } } } diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 5890401e67..17d04c692a 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -32,7 +32,7 @@ use arrow::datatypes::{DataType as ArrowDataType, Schema as ArrowSchema, SchemaR use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use arrow_array::types::UInt16Type; -use arrow_array::{Array, DictionaryArray, StringArray}; +use arrow_array::{Array, DictionaryArray, StringArray, TypedDictionaryArray}; use arrow_cast::display::array_value_to_string; use arrow_schema::Field; @@ -132,6 +132,21 @@ fn get_scalar_value(value: Option<&ColumnValueStat>, field: &Arc) -> Prec } } +pub(crate) fn get_path_column<'a>( + batch: &'a RecordBatch, + path_column: &str, +) -> DeltaResult> { + let err = || DeltaTableError::Generic("Unable to obtain Delta-rs path column".to_string()); + batch + .column_by_name(path_column) + .unwrap() + .as_any() + .downcast_ref::>() + .ok_or_else(err)? + .downcast_dict::() + .ok_or_else(err) +} + impl DeltaTableState { /// Provide table level statistics to Datafusion pub fn datafusion_table_statistics(&self) -> DataFusionResult { @@ -1362,31 +1377,20 @@ fn join_batches_with_add_actions( let mut files = Vec::with_capacity(batches.iter().map(|batch| batch.num_rows()).sum()); for batch in batches { - let array = batch.column_by_name(path_column).ok_or_else(|| { - DeltaTableError::Generic(format!("Unable to find column {}", path_column)) - })?; - - let iter: Box>> = - if dict_array { - let array = array - .as_any() - .downcast_ref::>() - .ok_or(DeltaTableError::Generic(format!( - "Unable to downcast column {}", - path_column - )))? - .downcast_dict::() - .ok_or(DeltaTableError::Generic(format!( - "Unable to downcast column {}", - path_column - )))?; - Box::new(array.into_iter()) - } else { - let array = array.as_any().downcast_ref::().ok_or( - DeltaTableError::Generic(format!("Unable to downcast column {}", path_column)), - )?; - Box::new(array.into_iter()) - }; + let err = || DeltaTableError::Generic("Unable to obtain Delta-rs path column".to_string()); + + let iter: Box>> = if dict_array { + let array = get_path_column(&batch, path_column)?; + Box::new(array.into_iter()) + } else { + let array = batch + .column_by_name(path_column) + .ok_or_else(err)? + .as_any() + .downcast_ref::() + .ok_or_else(err)?; + Box::new(array.into_iter()) + }; for path in iter { let path = path.ok_or(DeltaTableError::Generic(format!( diff --git a/crates/deltalake-core/src/operations/merge/barrier.rs b/crates/deltalake-core/src/operations/merge/barrier.rs new file mode 100644 index 0000000000..6883f61253 --- /dev/null +++ b/crates/deltalake-core/src/operations/merge/barrier.rs @@ -0,0 +1,675 @@ +//! Merge Barrier determines which files have modifications during the merge operation +//! +//! For every unique path in the input stream, a barrier is established. If any +//! single record for a file contains any delete, update, or insert operations +//! then the barrier for the file is opened and can be sent downstream. +//! To determine if a file contains zero changes, the input stream is +//! exhausted. Afterwards, records are then dropped. +//! +//! Bookkeeping is maintained to determine which files have modifications so +//! they can be removed from the delta log. + +use std::{ + collections::{HashMap, HashSet}, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use arrow_array::{builder::UInt64Builder, ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, +}; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNodeCore}; +use datafusion_physical_expr::{Distribution, PhysicalExpr}; +use futures::{Stream, StreamExt}; + +use crate::{ + delta_datafusion::get_path_column, + operations::merge::{TARGET_DELETE_COLUMN, TARGET_INSERT_COLUMN, TARGET_UPDATE_COLUMN}, + DeltaTableError, +}; + +pub(crate) type BarrierSurvivorSet = Arc>>; + +#[derive(Debug)] +/// Physical Node for the MergeBarrier +/// Batches to this node must be repartitioned on col('deleta_rs_path'). +/// Each record batch then undergoes further partitioning based on the file column to it's corresponding barrier +pub struct MergeBarrierExec { + input: Arc, + file_column: Arc, + survivors: BarrierSurvivorSet, + expr: Arc, +} + +impl MergeBarrierExec { + /// Create a new MergeBarrierExec Node + pub fn new( + input: Arc, + file_column: Arc, + expr: Arc, + ) -> Self { + MergeBarrierExec { + input, + file_column, + survivors: Arc::new(Mutex::new(HashSet::new())), + expr, + } + } + + /// Files that have modifications to them and need to removed from the delta log + pub fn survivors(&self) -> BarrierSurvivorSet { + self.survivors.clone() + } +} + +impl ExecutionPlan for MergeBarrierExec { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + self.input.schema() + } + + fn output_partitioning(&self) -> datafusion_physical_expr::Partitioning { + self.input.output_partitioning() + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::HashPartitioned(vec![self.expr.clone()]); 1] + } + + fn output_ordering(&self) -> Option<&[datafusion_physical_expr::PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec> { + vec![self.input.clone()] + } + + fn with_new_children( + self: std::sync::Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(MergeBarrierExec::new( + children[0].clone(), + self.file_column.clone(), + self.expr.clone(), + ))) + } + + fn execute( + &self, + partition: usize, + context: std::sync::Arc, + ) -> datafusion_common::Result { + let input = self.input.execute(partition, context)?; + Ok(Box::pin(MergeBarrierStream::new( + input, + self.schema(), + self.survivors.clone(), + self.file_column.clone(), + ))) + } +} + +impl DisplayAs for MergeBarrierExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "MergeBarrier",)?; + Ok(()) + } + } + } +} + +#[derive(Debug)] +enum State { + Feed, + Drain, + Finalize, + Abort, + Done, +} + +#[derive(Debug)] +enum PartitionBarrierState { + Closed, + Open, +} + +#[derive(Debug)] +struct MergeBarrierPartition { + state: PartitionBarrierState, + buffer: Vec, + file_name: Option, +} + +impl MergeBarrierPartition { + pub fn new(file_name: Option) -> Self { + MergeBarrierPartition { + state: PartitionBarrierState::Closed, + buffer: Vec::new(), + file_name, + } + } + + pub fn feed(&mut self, batch: RecordBatch) -> DataFusionResult<()> { + match self.state { + PartitionBarrierState::Closed => { + let delete_count = get_count(&batch, TARGET_DELETE_COLUMN)?; + let update_count = get_count(&batch, TARGET_UPDATE_COLUMN)?; + let insert_count = get_count(&batch, TARGET_INSERT_COLUMN)?; + self.buffer.push(batch); + + if insert_count > 0 || update_count > 0 || delete_count > 0 { + self.state = PartitionBarrierState::Open; + } + } + PartitionBarrierState::Open => { + self.buffer.push(batch); + } + } + Ok(()) + } + + pub fn drain(&mut self) -> Option { + match self.state { + PartitionBarrierState::Closed => None, + PartitionBarrierState::Open => self.buffer.pop(), + } + } +} + +struct MergeBarrierStream { + schema: SchemaRef, + state: State, + input: SendableRecordBatchStream, + file_column: Arc, + survivors: BarrierSurvivorSet, + map: HashMap, + file_partitions: Vec, +} + +impl MergeBarrierStream { + pub fn new( + input: SendableRecordBatchStream, + schema: SchemaRef, + survivors: BarrierSurvivorSet, + file_column: Arc, + ) -> Self { + // Always allocate for a null bucket at index 0; + let file_partitions = vec![MergeBarrierPartition::new(None)]; + + MergeBarrierStream { + schema, + state: State::Feed, + input, + file_column, + survivors, + file_partitions, + map: HashMap::new(), + } + } +} + +fn get_count(batch: &RecordBatch, column: &str) -> DataFusionResult { + batch + .column_by_name(column) + .map(|array| array.null_count()) + .ok_or_else(|| { + DataFusionError::External(Box::new(DeltaTableError::Generic( + "Required operation column is missing".to_string(), + ))) + }) +} + +impl Stream for MergeBarrierStream { + type Item = DataFusionResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state { + State::Feed => { + match self.input.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + let file_dictionary = get_path_column(&batch, &self.file_column)?; + + // For each record batch, the key for a file path is not stable. + // We can iterate through the dictionary and lookup the correspond string for each record and then lookup the correct `file_partition` for that value. + // However this approach exposes the cost of hashing so we want to minimize that as much as possible. + // A map from an arrow dictionary key to the correct index of `file_partition` is created for each batch that's processed. + // This ensures we only need to hash each file path at most once per batch. + let mut key_map = Vec::new(); + + for file_name in file_dictionary.values().into_iter() { + let key = match file_name { + Some(name) => { + if !self.map.contains_key(name) { + let key = self.file_partitions.len(); + let part_stream = + MergeBarrierPartition::new(Some(name.to_string())); + self.file_partitions.push(part_stream); + self.map.insert(name.to_string(), key); + } + // Safe unwrap due to the above + *self.map.get(name).unwrap() + } + None => 0, + }; + key_map.push(key) + } + + let mut indices: Vec<_> = (0..(self.file_partitions.len())) + .map(|_| UInt64Builder::with_capacity(batch.num_rows())) + .collect(); + + for (idx, key) in file_dictionary.keys().iter().enumerate() { + match key { + Some(value) => { + indices[key_map[value as usize]].append_value(idx as u64) + } + None => indices[0].append_value(idx as u64), + } + } + + let batches: Vec> = + indices + .into_iter() + .enumerate() + .filter_map(|(partition, mut indices)| { + let indices = indices.finish(); + (!indices.is_empty()).then_some((partition, indices)) + }) + .map(move |(partition, indices)| { + // Produce batches based on indices + let columns = batch + .columns() + .iter() + .map(|c| { + arrow::compute::take(c.as_ref(), &indices, None) + .map_err(DataFusionError::ArrowError) + }) + .collect::>>()?; + + // This unwrap is safe since the processed batched has the same schema + let batch = + RecordBatch::try_new(batch.schema(), columns).unwrap(); + + Ok((partition, batch)) + }) + .collect(); + + for batch in batches { + match batch { + Ok((partition, batch)) => { + self.file_partitions[partition].feed(batch)?; + } + Err(err) => { + self.state = State::Abort; + return Poll::Ready(Some(Err(err))); + } + } + } + + self.state = State::Drain; + continue; + } + Poll::Ready(Some(Err(err))) => { + self.state = State::Abort; + return Poll::Ready(Some(Err(err))); + } + Poll::Ready(None) => { + self.state = State::Finalize; + continue; + } + Poll::Pending => return Poll::Pending, + } + } + State::Drain => { + for part in &mut self.file_partitions { + if let Some(batch) = part.drain() { + return Poll::Ready(Some(Ok(batch))); + } + } + + self.state = State::Feed; + continue; + } + State::Finalize => { + for part in &mut self.file_partitions { + if let Some(batch) = part.drain() { + return Poll::Ready(Some(Ok(batch))); + } + } + + { + let mut lock = self.survivors.lock().map_err(|_| { + DataFusionError::External(Box::new(DeltaTableError::Generic( + "MergeBarrier mutex is poisoned".to_string(), + ))) + })?; + for part in &self.file_partitions { + match part.state { + PartitionBarrierState::Closed => {} + PartitionBarrierState::Open => { + if let Some(file_name) = &part.file_name { + lock.insert(file_name.to_owned()); + } + } + } + } + } + + self.state = State::Done; + continue; + } + State::Abort => return Poll::Ready(None), + State::Done => return Poll::Ready(None), + } + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, self.input.size_hint().1) + } +} + +impl RecordBatchStream for MergeBarrierStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub(crate) struct MergeBarrier { + pub input: LogicalPlan, + pub expr: Expr, + pub file_column: Arc, +} + +impl UserDefinedLogicalNodeCore for MergeBarrier { + fn name(&self) -> &str { + "MergeBarrier" + } + + fn inputs(&self) -> Vec<&datafusion_expr::LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &datafusion_common::DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![self.expr.clone()] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "MergeBarrier") + } + + fn from_template( + &self, + exprs: &[datafusion_expr::Expr], + inputs: &[datafusion_expr::LogicalPlan], + ) -> Self { + MergeBarrier { + input: inputs[0].clone(), + file_column: self.file_column.clone(), + expr: exprs[0].clone(), + } + } +} + +pub(crate) fn find_barrier_node(parent: &Arc) -> Option> { + //! Used to locate the physical Barrier Node after the planner converts the logical node + if parent.as_any().downcast_ref::().is_some() { + return Some(parent.to_owned()); + } + + for child in &parent.children() { + let res = find_barrier_node(child); + if res.is_some() { + return res; + } + } + + None +} + +#[cfg(test)] +mod tests { + use crate::operations::merge::MergeBarrierExec; + use crate::operations::merge::{ + TARGET_DELETE_COLUMN, TARGET_INSERT_COLUMN, TARGET_UPDATE_COLUMN, + }; + use arrow::datatypes::Schema as ArrowSchema; + use arrow_array::RecordBatch; + use arrow_array::StringArray; + use arrow_array::{DictionaryArray, UInt16Array}; + use arrow_schema::DataType as ArrowDataType; + use arrow_schema::Field; + use datafusion::assert_batches_sorted_eq; + use datafusion::execution::TaskContext; + use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion::physical_plan::ExecutionPlan; + use datafusion_physical_expr::expressions::Column; + use futures::StreamExt; + use std::sync::Arc; + + use super::BarrierSurvivorSet; + + #[tokio::test] + async fn test_barrier() { + // Validate that files without modifications are dropped and that files with changes passthrough + // File 0: No Changes + // File 1: Contains an update + // File 2: Contains a delete + // null (id: 3): is a insert + + let schema = get_schema(); + let keys = UInt16Array::from(vec![Some(0), Some(1), Some(2), None]); + let values = StringArray::from(vec![Some("file0"), Some("file1"), Some("file2")]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["0", "1", "2", "3"])), + Arc::new(dict), + //insert column + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + Some(false), + Some(false), + None, + ])), + //update column + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + None, + Some(false), + Some(false), + ])), + //delete column + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + Some(false), + None, + Some(false), + ])), + ], + ) + .unwrap(); + + let (actual, survivors) = execute(vec![batch]).await; + let expected = vec![ + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + "| id | __delta_rs_path | __delta_rs_target_insert | __delta_rs_target_update | __delta_rs_target_delete |", + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + "| 1 | file1 | false | | false |", + "| 2 | file2 | false | false | |", + "| 3 | | | false | false |", + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + ]; + assert_batches_sorted_eq!(&expected, &actual); + + let s = survivors.lock().unwrap(); + assert!(!s.contains(&"file0".to_string())); + assert!(s.contains(&"file1".to_string())); + assert!(s.contains(&"file2".to_string())); + assert_eq!(s.len(), 2); + } + + #[tokio::test] + async fn test_barrier_changing_indicies() { + // Validate implementation can handle different dictionary indicies between batches + + let schema = get_schema(); + let mut batches = vec![]; + + // Batch 1 + let keys = UInt16Array::from(vec![Some(0), Some(1)]); + let values = StringArray::from(vec![Some("file0"), Some("file1")]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["0", "1"])), + Arc::new(dict), + //insert column + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + Some(false), + ])), + //update column + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + Some(false), + ])), + //delete column + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + Some(false), + ])), + ], + ) + .unwrap(); + batches.push(batch); + // Batch 2 + + let keys = UInt16Array::from(vec![Some(0), Some(1)]); + let values = StringArray::from(vec![Some("file1"), Some("file0")]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["2", "3"])), + Arc::new(dict), + //insert column + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + Some(false), + ])), + //update column + Arc::new(arrow::array::BooleanArray::from(vec![None, Some(false)])), + //delete column + Arc::new(arrow::array::BooleanArray::from(vec![Some(false), None])), + ], + ) + .unwrap(); + batches.push(batch); + + let (actual, _survivors) = execute(batches).await; + let expected = vec! + [ + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + "| id | __delta_rs_path | __delta_rs_target_insert | __delta_rs_target_update | __delta_rs_target_delete |", + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + "| 0 | file0 | false | false | false |", + "| 1 | file1 | false | false | false |", + "| 2 | file1 | false | | false |", + "| 3 | file0 | false | false | |", + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + ]; + assert_batches_sorted_eq!(&expected, &actual); + } + + #[tokio::test] + async fn test_barrier_null_paths() { + // Arrow dictionaries are interesting since a null value can be either in the keys of the dict or in the values. + // Validate they can be processed without issue + + let schema = get_schema(); + let keys = UInt16Array::from(vec![Some(0), None, Some(1)]); + let values = StringArray::from(vec![Some("file1"), None]); + let dict = DictionaryArray::new(keys, Arc::new(values)); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["1", "2", "3"])), + Arc::new(dict), + Arc::new(arrow::array::BooleanArray::from(vec![ + Some(false), + None, + None, + ])), + Arc::new(arrow::array::BooleanArray::from(vec![false, false, false])), + Arc::new(arrow::array::BooleanArray::from(vec![false, false, false])), + ], + ) + .unwrap(); + + let (actual, _) = execute(vec![batch]).await; + let expected = vec![ + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + "| id | __delta_rs_path | __delta_rs_target_insert | __delta_rs_target_update | __delta_rs_target_delete |", + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + "| 2 | | | false | false |", + "| 3 | | | false | false |", + "+----+-----------------+--------------------------+--------------------------+--------------------------+", + ]; + assert_batches_sorted_eq!(&expected, &actual); + } + + async fn execute(input: Vec) -> (Vec, BarrierSurvivorSet) { + let schema = get_schema(); + let repartition = Arc::new(Column::new("__delta_rs_path", 2)); + let exec = Arc::new(MemoryExec::try_new(&[input], schema.clone(), None).unwrap()); + + let task_ctx = Arc::new(TaskContext::default()); + let merge = + MergeBarrierExec::new(exec, Arc::new("__delta_rs_path".to_string()), repartition); + + let survivors = merge.survivors(); + let coalsece = CoalesceBatchesExec::new(Arc::new(merge), 100); + let mut stream = coalsece.execute(0, task_ctx).unwrap(); + (vec![stream.next().await.unwrap().unwrap()], survivors) + } + + fn get_schema() -> Arc { + Arc::new(ArrowSchema::new(vec![ + Field::new("id", ArrowDataType::Utf8, true), + Field::new( + "__delta_rs_path", + ArrowDataType::Dictionary( + Box::new(ArrowDataType::UInt16), + Box::new(ArrowDataType::Utf8), + ), + true, + ), + Field::new(TARGET_INSERT_COLUMN, ArrowDataType::Boolean, true), + Field::new(TARGET_UPDATE_COLUMN, ArrowDataType::Boolean, true), + Field::new(TARGET_DELETE_COLUMN, ArrowDataType::Boolean, true), + ])) + } +} diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge/mod.rs similarity index 96% rename from crates/deltalake-core/src/operations/merge.rs rename to crates/deltalake-core/src/operations/merge/mod.rs index 0f0da1c21f..7cb752dc21 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge/mod.rs @@ -36,6 +36,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::{Instant, SystemTime, UNIX_EPOCH}; +use arrow_schema::Schema as ArrowSchema; use async_trait::async_trait; use datafusion::datasource::provider_as_source; use datafusion::error::Result as DataFusionResult; @@ -64,31 +65,36 @@ use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; +use self::barrier::{MergeBarrier, MergeBarrierExec}; + use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::{commit, PROTOCOL}; use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression}; use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; use crate::delta_datafusion::{ - execute_plan_to_batch, register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, + execute_plan_to_batch, register_store, DeltaColumn, DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider, }; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; +use crate::operations::merge::barrier::find_barrier_node; use crate::operations::write::write_execution_plan; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; +mod barrier; + const SOURCE_COLUMN: &str = "__delta_rs_source"; const TARGET_COLUMN: &str = "__delta_rs_target"; const OPERATION_COLUMN: &str = "__delta_rs_operation"; const DELETE_COLUMN: &str = "__delta_rs_delete"; -const TARGET_INSERT_COLUMN: &str = "__delta_rs_target_insert"; -const TARGET_UPDATE_COLUMN: &str = "__delta_rs_target_update"; -const TARGET_DELETE_COLUMN: &str = "__delta_rs_target_delete"; -const TARGET_COPY_COLUMN: &str = "__delta_rs_target_copy"; +pub(crate) const TARGET_INSERT_COLUMN: &str = "__delta_rs_target_insert"; +pub(crate) const TARGET_UPDATE_COLUMN: &str = "__delta_rs_target_update"; +pub(crate) const TARGET_DELETE_COLUMN: &str = "__delta_rs_target_delete"; +pub(crate) const TARGET_COPY_COLUMN: &str = "__delta_rs_target_copy"; const SOURCE_COUNT_METRIC: &str = "num_source_rows"; const TARGET_COUNT_METRIC: &str = "num_target_rows"; @@ -580,11 +586,11 @@ struct MergeMetricExtensionPlanner {} impl ExtensionPlanner for MergeMetricExtensionPlanner { async fn plan_extension( &self, - _planner: &dyn PhysicalPlanner, + planner: &dyn PhysicalPlanner, node: &dyn UserDefinedLogicalNode, _logical_inputs: &[&LogicalPlan], physical_inputs: &[Arc], - _session_state: &SessionState, + session_state: &SessionState, ) -> DataFusionResult>> { if let Some(metric_observer) = node.as_any().downcast_ref::() { if metric_observer.id.eq(SOURCE_COUNT_ID) { @@ -653,6 +659,16 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { } } + if let Some(barrier) = node.as_any().downcast_ref::() { + let schema = barrier.input.schema(); + let exec_schema: ArrowSchema = schema.as_ref().to_owned().into(); + return Ok(Some(Arc::new(MergeBarrierExec::new( + physical_inputs.get(0).unwrap().clone(), + barrier.file_column.clone(), + planner.create_physical_expr(&barrier.expr, schema, &exec_schema, session_state)?, + )))); + } + Ok(None) } } @@ -945,13 +961,20 @@ async fn execute( node: Arc::new(MetricObserver { id: SOURCE_COUNT_ID.into(), input: source, + enable_pushdown: false, }), }); + let scan_config = DeltaScanConfigBuilder::default() + .with_file_column(true) + .build(snapshot)?; + + let file_column = Arc::new(scan_config.file_column_name.clone().unwrap()); + let target_provider = Arc::new(DeltaTableProvider::try_new( snapshot.clone(), log_store.clone(), - DeltaScanConfig::default(), + scan_config, )?); let target_provider = provider_as_source(target_provider); @@ -968,7 +991,7 @@ async fn execute( let state = state.with_query_planner(Arc::new(MergePlanner {})); - let (target, files) = { + let target = { // Attempt to construct an early filter that we can apply to the Add action list and the delta scan. // In the case where there are partition columns in the join predicate, we can scan the source table // to get the distinct list of partitions affected and constrain the search to those. @@ -976,7 +999,7 @@ async fn execute( if !not_match_source_operations.is_empty() { // It's only worth trying to create an early filter where there are no `when_not_matched_source` operators, since // that implies a full scan - (target, snapshot.files().iter().collect_vec()) + target } else if let Some(filter) = try_construct_early_filter( predicate.clone(), snapshot, @@ -987,35 +1010,23 @@ async fn execute( ) .await? { - let file_filter = filter - .clone() - .transform(&|expr| match expr { - Expr::Column(c) => Ok(Transformed::Yes(Expr::Column(Column { - relation: None, // the file filter won't be looking at columns like `target.partition`, it'll just be `partition` - name: c.name, - }))), - expr => Ok(Transformed::No(expr)), - }) - .unwrap(); - let files = snapshot - .files_matching_predicate(&[file_filter])? - .collect_vec(); - - let new_target = LogicalPlan::Filter(Filter::try_new(filter, target.into())?); - (new_target, files) + LogicalPlan::Filter(Filter::try_new(filter, target.into())?) } else { - (target, snapshot.files().iter().collect_vec()) + target } }; let source = DataFrame::new(state.clone(), source); let source = source.with_column(SOURCE_COLUMN, lit(true))?; - // TODO: This is here to prevent predicate pushdowns. In the future we can replace this node to allow pushdowns depending on which operations are being used. + // Not match operations imply a full scan of the target table is required + let enable_pushdown = + not_match_source_operations.is_empty() && not_match_target_operations.is_empty(); let target = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { id: TARGET_COUNT_ID.into(), input: target, + enable_pushdown, }), }); let target = DataFrame::new(state.clone(), target); @@ -1272,11 +1283,23 @@ async fn execute( )?; new_columns = new_columns.with_column(TARGET_COPY_COLUMN, build_case(copy_when, copy_then)?)?; - let new_columns = new_columns.into_optimized_plan()?; + let new_columns = new_columns.into_unoptimized_plan(); + + let distrbute_expr = col(file_column.as_str()); + + let merge_barrier = LogicalPlan::Extension(Extension { + node: Arc::new(MergeBarrier { + input: new_columns, + expr: distrbute_expr, + file_column, + }), + }); + let operation_count = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { id: OUTPUT_COUNT_ID.into(), - input: new_columns, + input: merge_barrier, + enable_pushdown: false, }), }); @@ -1284,13 +1307,14 @@ async fn execute( let filtered = operation_count.filter(col(DELETE_COLUMN).is_false())?; let project = filtered.select(write_projection)?; - let optimized = &project.into_optimized_plan()?; + let merge_final = &project.into_unoptimized_plan(); - let write = state.create_physical_plan(optimized).await?; + let write = state.create_physical_plan(merge_final).await?; let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); let source_count = find_metric_node(SOURCE_COUNT_ID, &write).ok_or_else(err)?; let op_count = find_metric_node(OUTPUT_COUNT_ID, &write).ok_or_else(err)?; + let barrier = find_barrier_node(&write).ok_or_else(err)?; // write projected records let table_partition_cols = current_metadata.partition_columns.clone(); @@ -1320,20 +1344,31 @@ async fn execute( let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); metrics.num_target_files_added = actions.len(); - for action in files { - metrics.num_target_files_removed += 1; - actions.push(Action::Remove(Remove { - path: action.path.clone(), - deletion_timestamp: Some(deletion_timestamp), - data_change: true, - extended_file_metadata: Some(true), - partition_values: Some(action.partition_values.clone()), - deletion_vector: action.deletion_vector.clone(), - size: Some(action.size), - tags: None, - base_row_id: action.base_row_id, - default_row_commit_version: action.default_row_commit_version, - })) + let survivors = barrier + .as_any() + .downcast_ref::() + .unwrap() + .survivors(); + + { + let lock = survivors.lock().unwrap(); + for action in snapshot.files() { + if lock.contains(&action.path) { + metrics.num_target_files_removed += 1; + actions.push(Action::Remove(Remove { + path: action.path.clone(), + deletion_timestamp: Some(deletion_timestamp), + data_change: true, + extended_file_metadata: Some(true), + partition_values: Some(action.partition_values.clone()), + deletion_vector: action.deletion_vector.clone(), + size: Some(action.size), + tags: None, + base_row_id: action.base_row_id, + default_row_commit_version: action.default_row_commit_version, + })) + } + } } let mut version = snapshot.version(); @@ -1506,6 +1541,8 @@ mod tests { .merge(merge_source(schema), col("target.id").eq(col("source.id"))) .with_source_alias("source") .with_target_alias("target") + .when_not_matched_by_source_delete(|delete| delete) + .unwrap() .await .expect_err("Remove action is included when Delta table is append-only. Should error"); } @@ -2004,7 +2041,7 @@ mod tests { assert_eq!(table.version(), 2); assert!(table.get_file_uris().count() >= 2); - assert!(metrics.num_target_files_added >= 2); + assert_eq!(metrics.num_target_files_added, 2); assert_eq!(metrics.num_target_files_removed, 2); assert_eq!(metrics.num_target_rows_copied, 2); assert_eq!(metrics.num_target_rows_updated, 0); @@ -2068,13 +2105,13 @@ mod tests { assert_eq!(table.version(), 2); assert!(table.get_file_uris().count() >= 2); - assert!(metrics.num_target_files_added >= 2); - assert_eq!(metrics.num_target_files_removed, 2); - assert_eq!(metrics.num_target_rows_copied, 3); + assert_eq!(metrics.num_target_files_added, 1); + assert_eq!(metrics.num_target_files_removed, 1); + assert_eq!(metrics.num_target_rows_copied, 1); assert_eq!(metrics.num_target_rows_updated, 0); assert_eq!(metrics.num_target_rows_inserted, 0); assert_eq!(metrics.num_target_rows_deleted, 1); - assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_output_rows, 1); assert_eq!(metrics.num_source_rows, 3); let commit_info = table.history(None).await.unwrap(); @@ -2201,13 +2238,13 @@ mod tests { .unwrap(); assert_eq!(table.version(), 2); - assert!(metrics.num_target_files_added >= 2); - assert_eq!(metrics.num_target_files_removed, 2); - assert_eq!(metrics.num_target_rows_copied, 3); + assert!(metrics.num_target_files_added == 1); + assert_eq!(metrics.num_target_files_removed, 1); + assert_eq!(metrics.num_target_rows_copied, 1); assert_eq!(metrics.num_target_rows_updated, 0); assert_eq!(metrics.num_target_rows_inserted, 0); assert_eq!(metrics.num_target_rows_deleted, 1); - assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_output_rows, 1); assert_eq!(metrics.num_source_rows, 3); let commit_info = table.history(None).await.unwrap();