Skip to content

Commit

Permalink
[FEAT] connect: read/write -> csv, write -> json
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 21, 2024
1 parent 7827d10 commit 1036a19
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 20 deletions.
20 changes: 10 additions & 10 deletions src/daft-connect/src/op/execute/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ use eyre::{bail, WrapErr};
use futures::stream;
use spark_connect::{
write_operation::{SaveMode, SaveType},
ExecutePlanResponse, Relation, WriteOperation,
ExecutePlanResponse, WriteOperation,
};
use tokio_util::sync::CancellationToken;
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
use tracing::warn;

use crate::{
invalid_argument_err,
op::execute::{ExecuteStream, PlanIds},
session::Session,
translation,
Expand Down Expand Up @@ -100,9 +99,12 @@ impl Session {
bail!("Source is required");
};

if source != "parquet" {
bail!("Unsupported source: {source}; only parquet is supported");
}
let file_format = match &*source {
"parquet" => FileFormat::Parquet,
"csv" => FileFormat::Csv,
"json" => FileFormat::Json,
_ => bail!("Unsupported source: {source}; only parquet and csv are supported"),
};

let Ok(mode) = SaveMode::try_from(mode) else {
bail!("Invalid save mode: {mode}");
Expand Down Expand Up @@ -146,7 +148,7 @@ impl Session {
}

let Some(save_type) = save_type else {
return bail!("Save type is required");
bail!("Save type is required");
};

let path = match save_type {
Expand All @@ -160,7 +162,7 @@ impl Session {
let plan = translation::to_logical_plan(input)?;

let plan = plan
.table_write(&path, FileFormat::Parquet, None, None, None)
.table_write(&path, file_format, None, None, None)
.wrap_err("Failed to create table write plan")?;

let logical_plan = plan.build();
Expand All @@ -177,9 +179,7 @@ impl Session {
CancellationToken::new(), // todo: maybe implement cancelling
)?;

for _ignored in iterator {

}
for _ignored in iterator {}

// this is so we make sure the operation is actually done
// before we return
Expand Down
30 changes: 20 additions & 10 deletions src/daft-connect/src/translation/logical_plan/read/data_source.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use daft_logical_plan::LogicalPlanBuilder;
use daft_scan::builder::ParquetScanBuilder;
use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder};
use eyre::{bail, ensure, WrapErr};
use tracing::warn;

pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result<LogicalPlanBuilder> {
pub fn data_source(
data_source: spark_connect::read::DataSource,
) -> eyre::Result<LogicalPlanBuilder> {
let spark_connect::read::DataSource {
format,
schema,
Expand All @@ -16,10 +18,6 @@ pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result
bail!("Format is required");
};

if format != "parquet" {
bail!("Unsupported format: {format}; only parquet is supported");
}

ensure!(!paths.is_empty(), "Paths are required");

if let Some(schema) = schema {
Expand All @@ -34,9 +32,21 @@ pub fn data_source(data_source: spark_connect::read::DataSource) -> eyre::Result
warn!("Ignoring predicates: {predicates:?}; not yet implemented");
}

let builder = ParquetScanBuilder::new(paths)
.finish()
.wrap_err("Failed to create parquet scan builder")?;
let plan = match &*format {
"parquet" => ParquetScanBuilder::new(paths)
.finish()
.wrap_err("Failed to create parquet scan builder")?,
"csv" => CsvScanBuilder::new(paths)
.finish()
.wrap_err("Failed to create csv scan builder")?,
"json" => {
// todo(completeness): implement json reading
bail!("json reading is not yet implemented");
}
other => {
bail!("Unsupported format: {other}; only parquet and csv are supported");
}
};

Ok(builder)
Ok(plan)
}
36 changes: 36 additions & 0 deletions tests/connect/test_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import tempfile
import shutil
import os


def test_write_csv(spark_session):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
try:
# Create DataFrame from range(10)
df = spark_session.range(10)

# Write DataFrame to CSV directory
csv_dir = os.path.join(temp_dir, "test.csv")
df.write.csv(csv_dir)

# List all files in the CSV directory
csv_files = [f for f in os.listdir(csv_dir) if f.endswith('.csv')]
print(f"CSV files in directory: {csv_files}")

# Assert there is at least one CSV file
assert len(csv_files) > 0, "Expected at least one CSV file to be written"

# Read back from the CSV directory (not specific file)
df_read = spark_session.read.csv(csv_dir)

# Verify the data is unchanged
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read"

finally:
# Clean up temp directory
shutil.rmtree(temp_dir)

0 comments on commit 1036a19

Please sign in to comment.