Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent 8c02c10 commit 87c8e55
Showing 1 changed file with 72 additions and 11 deletions.
83 changes: 72 additions & 11 deletions tests/connect/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,87 @@

import os

import pytest

def test_write_csv(spark_session, tmp_path):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Write DataFrame to CSV directory
def test_write_csv_basic(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.csv(csv_dir)

# List all files in the CSV directory
csv_files = [f for f in os.listdir(csv_dir) if f.endswith(".csv")]
print(f"CSV files in directory: {csv_files}")

# Assert there is at least one CSV file
assert len(csv_files) > 0, "Expected at least one CSV file to be written"

# Read back from the CSV directory (not specific file)
df_read = spark_session.read.csv(str(csv_dir))

# Verify the data is unchanged
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read"


def test_write_csv_with_header(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("header", True).csv(csv_dir)

df_read = spark_session.read.option("header", True).csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])


def test_write_csv_with_delimiter(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("sep", "|").csv(csv_dir)

df_read = spark_session.read.option("sep", "|").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])


def test_write_csv_with_quote(spark_session, tmp_path):
df = spark_session.createDataFrame([("a,b",), ("c'd",)], ["text"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("quote", "'").csv(csv_dir)

df_read = spark_session.read.option("quote", "'").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["text"].equals(df_read_pandas["text"])


def test_write_csv_with_escape(spark_session, tmp_path):
df = spark_session.createDataFrame([("a'b",), ("c'd",)], ["text"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("escape", "\\").csv(csv_dir)

df_read = spark_session.read.option("escape", "\\").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["text"].equals(df_read_pandas["text"])


@pytest.mark.skip(
reason="https://github.com/Eventual-Inc/Daft/issues/3609: CSV null value handling not yet implemented"
)
def test_write_csv_with_null_value(spark_session, tmp_path):
df = spark_session.createDataFrame([(1, None), (2, "test")], ["id", "value"])
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("nullValue", "NULL").csv(csv_dir)

df_read = spark_session.read.option("nullValue", "NULL").csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["value"].isna().equals(df_read_pandas["value"].isna())


def test_write_csv_with_compression(spark_session, tmp_path):
df = spark_session.range(10)
csv_dir = os.path.join(tmp_path, "csv")
df.write.option("compression", "gzip").csv(csv_dir)

df_read = spark_session.read.csv(str(csv_dir))
df_pandas = df.toPandas()
df_read_pandas = df_read.toPandas()
assert df_pandas["id"].equals(df_read_pandas["id"])

0 comments on commit 87c8e55

Please sign in to comment.