Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable passing Datafusion session state to WriteBuilder #1187

Merged
merged 1 commit into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions rust/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::writer::utils::PartitionPath;

use arrow::datatypes::{DataType, SchemaRef as ArrowSchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion::execution::context::{SessionContext, TaskContext};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan};
use futures::future::BoxFuture;
use futures::StreamExt;
Expand Down Expand Up @@ -78,6 +78,8 @@ impl From<WriteError> for DeltaTableError {
pub struct WriteBuilder {
/// The input plan
input: Option<Arc<dyn ExecutionPlan>>,
/// Datafusion session state relevant for executing the input plan
state: Option<SessionState>,
/// Location where the table is stored
location: Option<String>,
/// SaveMode defines how to treat data already written to table location
Expand Down Expand Up @@ -109,6 +111,7 @@ impl WriteBuilder {
pub fn new() -> Self {
Self {
input: None,
state: None,
location: None,
mode: SaveMode::Append,
partition_columns: None,
Expand Down Expand Up @@ -156,6 +159,12 @@ impl WriteBuilder {
self
}

/// A session state accompanying a given input plan, containing e.g. registered object stores
pub fn with_input_session_state(mut self, state: SessionState) -> Self {
self.state = Some(state);
self
}

/// Execution plan that produces the data to be written to the delta table
pub fn with_input_batches(mut self, batches: impl IntoIterator<Item = RecordBatch>) -> Self {
self.batches = Some(batches.into_iter().collect());
Expand Down Expand Up @@ -346,8 +355,12 @@ impl std::future::IntoFuture for WriteBuilder {
let mut tasks = vec![];
for i in 0..plan.output_partitioning().partition_count() {
let inner_plan = plan.clone();
let state = SessionContext::new();
let task_ctx = Arc::new(TaskContext::from(&state));
let task_ctx = Arc::from(if let Some(state) = this.state.clone() {
TaskContext::from(&state)
} else {
let ctx = SessionContext::new();
TaskContext::from(&ctx)
});
let config = WriterConfig::new(
inner_plan.schema(),
partition_columns.clone(),
Expand Down
58 changes: 57 additions & 1 deletion rust/tests/datafusion_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ use datafusion_common::{Column, DataFusionError, Result};
use datafusion_expr::Expr;

use deltalake::action::SaveMode;
use deltalake::{operations::DeltaOps, DeltaTable, Schema};
use deltalake::operations::create::CreateBuilder;
use deltalake::{
operations::{write::WriteBuilder, DeltaOps},
DeltaTable, Schema,
};

mod common;

Expand Down Expand Up @@ -138,6 +142,58 @@ async fn test_datafusion_simple_query_partitioned() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_datafusion_write_from_delta_scan() -> Result<()> {
let ctx = SessionContext::new();
let state = ctx.state();

// Build an execution plan for scanning a DeltaTable
let source_table = deltalake::open_table("./tests/data/delta-0.8.0-date").await?;
let source_scan = source_table.scan(&state, None, &[], None).await?;

// Create target Delta Table
let target_table = CreateBuilder::new()
.with_location("memory://target")
.with_columns(source_table.schema().unwrap().get_fields().clone())
.with_table_name("target")
.await?;

// Trying to execute the write by providing only the Datafusion plan and not the session state
// results in an error due to missing object store in the runtime registry.
assert!(WriteBuilder::new()
.with_input_execution_plan(source_scan.clone())
.with_object_store(target_table.object_store())
.await
.unwrap_err()
.to_string()
.contains("No suitable object store found for delta-rs://"));

// Execute write to the target table with the proper state
let target_table = WriteBuilder::new()
.with_input_execution_plan(source_scan)
.with_input_session_state(state)
.with_object_store(target_table.object_store())
.await?;
ctx.register_table("target", Arc::new(target_table))?;

let batches = ctx.sql("SELECT * FROM target").await?.collect().await?;

let expected = vec![
"+------------+-----------+",
"| date | dayOfYear |",
"+------------+-----------+",
"| 2021-01-01 | 1 |",
"| 2021-01-02 | 2 |",
"| 2021-01-03 | 3 |",
"| 2021-01-04 | 4 |",
"| 2021-01-05 | 5 |",
"+------------+-----------+",
];
assert_batches_sorted_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn test_datafusion_date_column() -> Result<()> {
let ctx = SessionContext::new();
Expand Down