From 8351668689f548621fb65fca483b12d053469c42 Mon Sep 17 00:00:00 2001 From: Marko Grujic Date: Mon, 27 Feb 2023 16:11:53 +0100 Subject: [PATCH] Enable passing Datafusion session state to WriteBuilder to accompany the input plan --- rust/src/operations/write.rs | 19 ++++++++++-- rust/tests/datafusion_test.rs | 58 ++++++++++++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/rust/src/operations/write.rs b/rust/src/operations/write.rs index 15ca82d6fd..ad8094b385 100644 --- a/rust/src/operations/write.rs +++ b/rust/src/operations/write.rs @@ -35,7 +35,7 @@ use crate::writer::utils::PartitionPath; use arrow::datatypes::{DataType, SchemaRef as ArrowSchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion::execution::context::{SessionContext, TaskContext}; +use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan}; use futures::future::BoxFuture; use futures::StreamExt; @@ -78,6 +78,8 @@ impl From for DeltaTableError { pub struct WriteBuilder { /// The input plan input: Option>, + /// Datafusion session state relevant for executing the input plan + state: Option, /// Location where the table is stored location: Option, /// SaveMode defines how to treat data already written to table location @@ -109,6 +111,7 @@ impl WriteBuilder { pub fn new() -> Self { Self { input: None, + state: None, location: None, mode: SaveMode::Append, partition_columns: None, @@ -156,6 +159,12 @@ impl WriteBuilder { self } + /// A session state accompanying a given input plan, containing e.g. registered object stores + pub fn with_input_session_state(mut self, state: SessionState) -> Self { + self.state = Some(state); + self + } + /// Execution plan that produces the data to be written to the delta table pub fn with_input_batches(mut self, batches: impl IntoIterator) -> Self { self.batches = Some(batches.into_iter().collect()); @@ -346,8 +355,12 @@ impl std::future::IntoFuture for WriteBuilder { let mut tasks = vec![]; for i in 0..plan.output_partitioning().partition_count() { let inner_plan = plan.clone(); - let state = SessionContext::new(); - let task_ctx = Arc::new(TaskContext::from(&state)); + let task_ctx = Arc::from(if let Some(state) = this.state.clone() { + TaskContext::from(&state) + } else { + let ctx = SessionContext::new(); + TaskContext::from(&ctx) + }); let config = WriterConfig::new( inner_plan.schema(), partition_columns.clone(), diff --git a/rust/tests/datafusion_test.rs b/rust/tests/datafusion_test.rs index 894a2dd86c..e3092aee67 100644 --- a/rust/tests/datafusion_test.rs +++ b/rust/tests/datafusion_test.rs @@ -19,7 +19,11 @@ use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::Expr; use deltalake::action::SaveMode; -use deltalake::{operations::DeltaOps, DeltaTable, Schema}; +use deltalake::operations::create::CreateBuilder; +use deltalake::{ + operations::{write::WriteBuilder, DeltaOps}, + DeltaTable, Schema, +}; mod common; @@ -138,6 +142,58 @@ async fn test_datafusion_simple_query_partitioned() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_datafusion_write_from_delta_scan() -> Result<()> { + let ctx = SessionContext::new(); + let state = ctx.state(); + + // Build an execution plan for scanning a DeltaTable + let source_table = deltalake::open_table("./tests/data/delta-0.8.0-date").await?; + let source_scan = source_table.scan(&state, None, &[], None).await?; + + // Create target Delta Table + let target_table = CreateBuilder::new() + .with_location("memory://target") + .with_columns(source_table.schema().unwrap().get_fields().clone()) + .with_table_name("target") + .await?; + + // Trying to execute the write by providing only the Datafusion plan and not the session state + // results in an error due to missing object store in the runtime registry. + assert!(WriteBuilder::new() + .with_input_execution_plan(source_scan.clone()) + .with_object_store(target_table.object_store()) + .await + .unwrap_err() + .to_string() + .contains("No suitable object store found for delta-rs://")); + + // Execute write to the target table with the proper state + let target_table = WriteBuilder::new() + .with_input_execution_plan(source_scan) + .with_input_session_state(state) + .with_object_store(target_table.object_store()) + .await?; + ctx.register_table("target", Arc::new(target_table))?; + + let batches = ctx.sql("SELECT * FROM target").await?.collect().await?; + + let expected = vec![ + "+------------+-----------+", + "| date | dayOfYear |", + "+------------+-----------+", + "| 2021-01-01 | 1 |", + "| 2021-01-02 | 2 |", + "| 2021-01-03 | 3 |", + "| 2021-01-04 | 4 |", + "| 2021-01-05 | 5 |", + "+------------+-----------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) +} + #[tokio::test] async fn test_datafusion_date_column() -> Result<()> { let ctx = SessionContext::new();