Skip to content

Commit

Permalink
test: add unit test and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jitingxu1 committed Sep 24, 2024
1 parent e62925b commit 96ff701
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 3 deletions.
15 changes: 12 additions & 3 deletions ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,8 +1277,7 @@ def read_csv(
"""Register a CSV file as a table in the current backend.
This function reads a CSV file and registers it as a table in the current
backend. Note that for Impala and Trino backends, CSV read performance
may be suboptimal.
backend. Note that for Impala and Trino backends, the performance may be suboptimal.
Parameters
----------
Expand All @@ -1289,6 +1288,11 @@ def read_csv(
a sequentially generated name.
**kwargs
Additional keyword arguments passed to the backend loading function.
Common options are skip_rows, column_names, delimiter, and include_columns.
More details could be found:
https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html
https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html
https://arrow.apache.org/docs/python/generated/pyarrow.csv.ConvertOptions.html
Returns
-------
Expand Down Expand Up @@ -1317,6 +1321,11 @@ def read_csv(
>>> table = con.read_csv("s3://bucket/path/to/file.csv")
Read csv file with custom pyarrow options:
>>> table = con.read_csv(
... "path/to/file.csv", delimiter=",", include_columns=["col1", "col3"]
... )
"""
pa = self._import_pyarrow()
import pyarrow.csv as pcsv
Expand All @@ -1342,7 +1351,7 @@ def read_csv(
read_options = pcsv.ReadOptions(**read_options_args)
parse_options = pcsv.ParseOptions(**parse_options_args)
convert_options = pcsv.ConvertOptions(**convert_options_args)
if memory_pool:
if not memory_pool:
memory_pool = pa.default_memory_pool()

path = str(path)
Expand Down
70 changes: 70 additions & 0 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,73 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds):
}
)
assert table.count().execute() == num_diamonds


@pytest.mark.parametrize(
("skip_rows", "new_column_names", "delimiter", "include_columns"),
[
param(True, True, False, False, id="skip_rows_with_column_names"),
param(False, False, False, True, id="include_columns"),
param(False, False, True, False, id="delimiter"),
],
)
@pytest.mark.notyet(["flink"])
@pytest.mark.notimpl(["druid"])
@pytest.mark.never(
[
"duckdb",
"polars",
"bigquery",
"clickhouse",
"datafusion",
"snowflake",
"pyspark",
],
reason="backend implements its own read_csv",
)
@pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError)
@pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError)
def test_read_csv_pyarrow_options(
con, tmp_path, ft_data, skip_rows, new_column_names, delimiter, include_columns
):
pc = pytest.importorskip("pyarrow.csv")

if con.name in (
"duckdb",
"polars",
"bigquery",
"clickhouse",
"datafusion",
"snowflake",
"pyspark",
):
pytest.skip(f"{con.name} implements its own `read_parquet`")

column_names = ft_data.column_names
num_rows = ft_data.num_rows
fname = "tmp.csv"
pc.write_csv(ft_data, tmp_path / fname)

options = {}
if skip_rows:
options["skip_rows"] = 2
num_rows = num_rows - options["skip_rows"] + 1
if new_column_names:
column_names = [f"col_{i}" for i in range(ft_data.num_columns)]
options["column_names"] = column_names
if delimiter:
new_delimiter = "*"
options["delimiter"] = new_delimiter
pc.write_csv(
ft_data, tmp_path / fname, pc.WriteOptions(delimiter=new_delimiter)
)
if include_columns:
# try to include all types here
# pick the first 12 columns
column_names = column_names[:12]
options["include_columns"] = column_names

table = con.read_csv(tmp_path / fname, **options)

assert set(table.columns) == set(column_names)
assert table.count().execute() == num_rows

0 comments on commit 96ff701

Please sign in to comment.