From 427a4fe0ab1cf930a68ef771f9f123159e9ee49d Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Thu, 16 Nov 2023 13:44:45 -0800 Subject: [PATCH] [BUG] Roundtrip tests for CSVs and Parquet (#1616) Tests round-trip code (Daft write into a Daft read) for CSV and Parquet --------- Co-authored-by: Jay Chia --- tests/io/test_csv_roundtrip.py | 56 +++++++++++++++++++ tests/io/test_parquet_roundtrip.py | 89 ++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 tests/io/test_csv_roundtrip.py create mode 100644 tests/io/test_parquet_roundtrip.py diff --git a/tests/io/test_csv_roundtrip.py b/tests/io/test_csv_roundtrip.py new file mode 100644 index 0000000000..364ce40c91 --- /dev/null +++ b/tests/io/test_csv_roundtrip.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import datetime + +import pyarrow as pa +import pytest + +import daft +from daft import DataType, TimeUnit + +PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0) + + +@pytest.mark.skipif( + not PYARROW_GE_11_0_0, + reason="PyArrow writing to CSV does not have good coverage for all types for versions <11.0.0", +) +@pytest.mark.parametrize( + ["data", "pa_type", "expected_dtype", "expected_inferred_dtype"], + [ + ([1, 2, None], pa.int64(), DataType.int64(), DataType.int64()), + (["a", "b", ""], pa.large_string(), DataType.string(), DataType.string()), + ([b"a", b"b", b""], pa.large_binary(), DataType.binary(), DataType.string()), + ([True, False, None], pa.bool_(), DataType.bool(), DataType.bool()), + ([None, None, None], pa.null(), DataType.null(), DataType.null()), + # TODO: This is broken, needs more investigation into why + # ([decimal.Decimal("1.23"), decimal.Decimal("1.24"), None], pa.decimal128(16, 8), DataType.decimal128(16, 8), DataType.float64()), + ([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date32(), DataType.date(), DataType.date()), + ( + [datetime.datetime(1994, 1, 1), datetime.datetime(1995, 1, 1), None], + pa.timestamp("ms"), + DataType.timestamp(TimeUnit.ms()), + # NOTE: Seems like the inferred type is seconds because it's written with seconds resolution + DataType.timestamp(TimeUnit.s()), + ), + ( + [datetime.timedelta(days=1), datetime.timedelta(days=2), None], + pa.duration("ms"), + DataType.duration(TimeUnit.ms()), + # NOTE: Duration ends up being written as int64 + DataType.int64(), + ), + # TODO: Verify that these types throw an error when we write dataframes with them + # ([[1, 2, 3], [], None], pa.large_list(pa.int64()), DataType.list(DataType.int64())), + # ([[1, 2, 3], [4, 5, 6], None], pa.list_(pa.int64(), list_size=3), DataType.fixed_size_list(DataType.int64(), 3)), + # ([{"bar": 1}, {"bar": None}, None], pa.struct({"bar": pa.int64()}), DataType.struct({"bar": DataType.int64()})), + ], +) +def test_roundtrip_simple_arrow_types(tmp_path, data, pa_type, expected_dtype, expected_inferred_dtype): + before = daft.from_arrow(pa.table({"id": pa.array(range(3)), "foo": pa.array(data, type=pa_type)})) + before = before.concat(before) + before.write_csv(str(tmp_path)) + after = daft.read_csv(str(tmp_path)) + assert before.schema()["foo"].dtype == expected_dtype + assert after.schema()["foo"].dtype == expected_inferred_dtype + assert before.to_arrow() == after.with_column("foo", after["foo"].cast(expected_dtype)).to_arrow() diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py new file mode 100644 index 0000000000..c625dd9c70 --- /dev/null +++ b/tests/io/test_parquet_roundtrip.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import datetime +import decimal + +import numpy as np +import pyarrow as pa +import pytest + +import daft +from daft import DataType, Series, TimeUnit + +PYARROW_GE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (8, 0, 0) + + +@pytest.mark.skipif( + not PYARROW_GE_8_0_0, + reason="PyArrow writing to Parquet does not have good coverage for all types for versions <8.0.0", +) +@pytest.mark.parametrize( + ["data", "pa_type", "expected_dtype"], + [ + ([1, 2, None], pa.int64(), DataType.int64()), + (["a", "b", None], pa.large_string(), DataType.string()), + ([True, False, None], pa.bool_(), DataType.bool()), + ([b"a", b"b", None], pa.large_binary(), DataType.binary()), + ([None, None, None], pa.null(), DataType.null()), + ([decimal.Decimal("1.23"), decimal.Decimal("1.24"), None], pa.decimal128(16, 8), DataType.decimal128(16, 8)), + ([datetime.date(1994, 1, 1), datetime.date(1995, 1, 1), None], pa.date32(), DataType.date()), + ( + [datetime.datetime(1994, 1, 1), datetime.datetime(1995, 1, 1), None], + pa.timestamp("ms"), + DataType.timestamp(TimeUnit.ms()), + ), + ( + [datetime.timedelta(days=1), datetime.timedelta(days=2), None], + pa.duration("ms"), + DataType.duration(TimeUnit.ms()), + ), + ([[1, 2, 3], [], None], pa.large_list(pa.int64()), DataType.list(DataType.int64())), + # TODO: Crashes when parsing fixed size lists + # ([[1, 2, 3], [4, 5, 6], None], pa.list_(pa.int64(), list_size=3), DataType.fixed_size_list(DataType.int64(), 3)), + ([{"bar": 1}, {"bar": None}, None], pa.struct({"bar": pa.int64()}), DataType.struct({"bar": DataType.int64()})), + ], +) +def test_roundtrip_simple_arrow_types(tmp_path, data, pa_type, expected_dtype): + before = daft.from_arrow(pa.table({"foo": pa.array(data, type=pa_type)})) + before = before.concat(before) + before.write_parquet(str(tmp_path)) + after = daft.read_parquet(str(tmp_path)) + assert before.schema()["foo"].dtype == expected_dtype + assert after.schema()["foo"].dtype == expected_dtype + assert before.to_arrow() == after.to_arrow() + + +@pytest.mark.parametrize( + ["data", "pa_type", "expected_dtype"], + [ + # TODO: Fails and seems to just fall-back onto +00:00 + # ([datetime.datetime(1994, 1, 1), datetime.datetime(1995, 1, 1), None], pa.timestamp("ms", "UTC"), DataType.timestamp(TimeUnit.ms(), "UTC")), + ], +) +def test_roundtrip_temporal_arrow_types(tmp_path, data, pa_type, expected_dtype): + before = daft.from_arrow(pa.table({"foo": pa.array(data, type=pa_type)})) + before = before.concat(before) + before.write_parquet(str(tmp_path)) + after = daft.read_parquet(str(tmp_path)) + assert before.schema()["foo"].dtype == expected_dtype + assert after.schema()["foo"].dtype == expected_dtype + assert before.to_arrow() == after.to_arrow() + + +@pytest.mark.skip(reason="Currently fails when reading multiple Parquet files with tensor types") +def test_roundtrip_tensor_types(tmp_path): + expected_dtype = DataType.tensor(DataType.int64()) + data = [np.array([[1, 2], [3, 4]]), None, None] + before = daft.from_pydict({"foo": Series.from_pylist(data)}) + before = before.concat(before) + before.write_parquet(str(tmp_path)) + after = daft.read_parquet(str(tmp_path)) + assert before.schema()["foo"].dtype == expected_dtype + assert after.schema()["foo"].dtype == expected_dtype + assert before.to_arrow() == after.to_arrow() + + +# TODO: reading/writing: +# 1. Embedding type +# 2. Image type +# 3. Extension type?