Skip to content

Commit

Permalink
[FEAT] connect: read/write -> csv, write -> json
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent ea8f8bd commit 59c7803
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 14 deletions.
6 changes: 2 additions & 4 deletions src/daft-connect/src/op/execute/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ impl Session {
bail!("Source is required");
};

if source != "parquet" {
bail!("Unsupported source: {source}; only parquet is supported");
}
let file_format: FileFormat = source.parse()?;

let Ok(mode) = SaveMode::try_from(mode) else {
bail!("Invalid save mode: {mode}");
Expand Down Expand Up @@ -115,7 +113,7 @@ impl Session {
let plan = translator.to_logical_plan(input).await?;

let plan = plan
.table_write(&path, FileFormat::Parquet, None, None, None)
.table_write(&path, file_format, None, None, None)
.wrap_err("Failed to create table write plan")?;

let optimized_plan = plan.optimize()?;
Expand Down
29 changes: 19 additions & 10 deletions src/daft-connect/src/translation/logical_plan/read/data_source.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use daft_logical_plan::LogicalPlanBuilder;
use daft_scan::builder::ParquetScanBuilder;
use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder};
use eyre::{bail, ensure, WrapErr};
use tracing::warn;

Expand All @@ -18,10 +18,6 @@ pub async fn data_source(
bail!("Format is required");
};

if format != "parquet" {
bail!("Unsupported format: {format}; only parquet is supported");
}

ensure!(!paths.is_empty(), "Paths are required");

if let Some(schema) = schema {
Expand All @@ -36,10 +32,23 @@ pub async fn data_source(
warn!("Ignoring predicates: {predicates:?}; not yet implemented");
}

let builder = ParquetScanBuilder::new(paths)
.finish()
.await
.wrap_err("Failed to create parquet scan builder")?;
let plan = match &*format {
"parquet" => ParquetScanBuilder::new(paths)
.finish()
.await
.wrap_err("Failed to create parquet scan builder")?,
"csv" => CsvScanBuilder::new(paths)
.finish()
.await
.wrap_err("Failed to create csv scan builder")?,
"json" => {

Check warning on line 44 in src/daft-connect/src/translation/logical_plan/read/data_source.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/read/data_source.rs#L44

Added line #L44 was not covered by tests
// todo(completeness): implement json reading
bail!("json reading is not yet implemented");

Check warning on line 46 in src/daft-connect/src/translation/logical_plan/read/data_source.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/read/data_source.rs#L46

Added line #L46 was not covered by tests
}
other => {
bail!("Unsupported format: {other}; only parquet and csv are supported");

Check warning on line 49 in src/daft-connect/src/translation/logical_plan/read/data_source.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/read/data_source.rs#L48-L49

Added lines #L48 - L49 were not covered by tests
}
};

Ok(builder)
Ok(plan)
}
88 changes: 88 additions & 0 deletions tests/connect/test_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import os

import pytest


def test_write_csv_basic(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.csv(csv_dir)

csv_files = [f for f in os.listdir(csv_dir) if f.endswith(".csv")]
assert len(csv_files) > 0, "Expected at least one CSV file to be written"

df_read = spark_session.read.csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read"


def test_write_csv_with_header(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("header", True).csv(csv_dir)

df_read = spark_session.read.option("header", True).csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])


def test_write_csv_with_delimiter(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("sep", "|").csv(csv_dir)

df_read = spark_session.read.option("sep", "|").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])


def test_write_csv_with_quote(spark_session, tmp_path):
df = spark_session.createDataFrame([("a,b",), ("c'd",)], ["text"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("quote", "'").csv(csv_dir)

df_read = spark_session.read.option("quote", "'").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["text"].equals(df_read_pandas["text"])


def test_write_csv_with_escape(spark_session, tmp_path):
df = spark_session.createDataFrame([("a'b",), ("c'd",)], ["text"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("escape", "\\").csv(csv_dir)

df_read = spark_session.read.option("escape", "\\").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["text"].equals(df_read_pandas["text"])


@pytest.mark.skip(
reason="https://github.com/Eventual-Inc/Daft/issues/3609: CSV null value handling not yet implemented"
)
def test_write_csv_with_null_value(spark_session, tmp_path):
df = spark_session.createDataFrame([(1, None), (2, "test")], ["id", "value"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("nullValue", "NULL").csv(csv_dir)

df_read = spark_session.read.option("nullValue", "NULL").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["value"].isna().equals(df_read_pandas["value"].isna())


def test_write_csv_with_compression(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("compression", "gzip").csv(csv_dir)

df_read = spark_session.read.csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])

0 comments on commit 59c7803

Please sign in to comment.