From ced7e6ec218a6cda25d3d5ecbe632ebba021df73 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 02:33:17 -0800 Subject: [PATCH] [FEAT] connect: read/write -> csv, write -> json --- src/daft-connect/src/op/execute/write.rs | 20 +++++------ .../logical_plan/read/data_source.rs | 30 ++++++++++------ tests/connect/test_csv.py | 36 +++++++++++++++++++ 3 files changed, 66 insertions(+), 20 deletions(-) create mode 100644 tests/connect/test_csv.py diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs index 7566b33273..27755309f0 100644 --- a/src/daft-connect/src/op/execute/write.rs +++ b/src/daft-connect/src/op/execute/write.rs @@ -6,14 +6,13 @@ use eyre::{bail, WrapErr}; use futures::stream; use spark_connect::{ write_operation::{SaveMode, SaveType}, - ExecutePlanResponse, Relation, WriteOperation, + ExecutePlanResponse, WriteOperation, }; use tokio_util::sync::CancellationToken; use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; use tracing::warn; use crate::{ - invalid_argument_err, op::execute::{ExecuteStream, PlanIds}, session::Session, translation, @@ -100,9 +99,12 @@ impl Session { bail!("Source is required"); }; - if source != "parquet" { - bail!("Unsupported source: {source}; only parquet is supported"); - } + let file_format = match &*source { + "parquet" => FileFormat::Parquet, + "csv" => FileFormat::Csv, + "json" => FileFormat::Json, + _ => bail!("Unsupported source: {source}; only parquet and csv are supported"), + }; let Ok(mode) = SaveMode::try_from(mode) else { bail!("Invalid save mode: {mode}"); @@ -146,7 +148,7 @@ impl Session { } let Some(save_type) = save_type else { - return bail!("Save type is required"); + bail!("Save type is required"); }; let path = match save_type { @@ -160,7 +162,7 @@ impl Session { let plan = translation::to_logical_plan(input)?; let plan = plan - .table_write(&path, FileFormat::Parquet, None, None, None) + .table_write(&path, file_format, None, None, None) .wrap_err("Failed to create table write plan")?; let logical_plan = plan.build(); @@ -177,9 +179,7 @@ impl Session { CancellationToken::new(), // todo: maybe implement cancelling )?; - for _ignored in iterator { - - } + for _ignored in iterator {} // this is so we make sure the operation is actually done // before we return diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs index 0a9a14c494..a878ec1c1f 100644 --- a/src/daft-connect/src/translation/logical_plan/read/data_source.rs +++ b/src/daft-connect/src/translation/logical_plan/read/data_source.rs @@ -1,9 +1,11 @@ use daft_logical_plan::LogicalPlanBuilder; -use daft_scan::builder::ParquetScanBuilder; +use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder}; use eyre::{bail, ensure, WrapErr}; use tracing::warn; -pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result { +pub fn data_source( + data_source: spark_connect::read::DataSource, +) -> eyre::Result { let spark_connect::read::DataSource { format, schema, @@ -16,10 +18,6 @@ pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result bail!("Format is required"); }; - if format != "parquet" { - bail!("Unsupported format: {format}; only parquet is supported"); - } - ensure!(!paths.is_empty(), "Paths are required"); if let Some(schema) = schema { @@ -34,9 +32,21 @@ pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result warn!("Ignoring predicates: {predicates:?}; not yet implemented"); } - let builder = ParquetScanBuilder::new(paths) - .finish() - .wrap_err("Failed to create parquet scan builder")?; + let plan = match &*format { + "parquet" => ParquetScanBuilder::new(paths) + .finish() + .wrap_err("Failed to create parquet scan builder")?, + "csv" => CsvScanBuilder::new(paths) + .finish() + .wrap_err("Failed to create csv scan builder")?, + "json" => { + // todo(completeness): implement json reading + bail!("json reading is not yet implemented"); + } + other => { + bail!("Unsupported format: {other}; only parquet and csv are supported"); + } + }; - Ok(builder) + Ok(plan) } diff --git a/tests/connect/test_csv.py b/tests/connect/test_csv.py new file mode 100644 index 0000000000..18e986ccdf --- /dev/null +++ b/tests/connect/test_csv.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import tempfile +import shutil +import os + + +def test_write_csv(spark_session): + # Create a temporary directory + temp_dir = tempfile.mkdtemp() + try: + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Write DataFrame to CSV directory + csv_dir = os.path.join(temp_dir, "test.csv") + df.write.csv(csv_dir) + + # List all files in the CSV directory + csv_files = [f for f in os.listdir(csv_dir) if f.endswith('.csv')] + print(f"CSV files in directory: {csv_files}") + + # Assert there is at least one CSV file + assert len(csv_files) > 0, "Expected at least one CSV file to be written" + + # Read back from the CSV directory (not specific file) + df_read = spark_session.read.csv(csv_dir) + + # Verify the data is unchanged + df_pandas = df.toPandas() + df_read_pandas = df_read.toPandas() + assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read" + + finally: + # Clean up temp directory + shutil.rmtree(temp_dir)