-
Notifications
You must be signed in to change notification settings - Fork 604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support read_csv for backends with no native support #9908
Changes from all commits
8b87686
fedd4de
38f91dd
773cfb5
c7aea6e
6547ae3
69b4e39
6152533
0214160
7bb6f96
520fe5f
e023025
f1b42f8
902bb47
79885a9
ce2c8cf
9acda5c
e62925b
96ff701
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||||||||||||||||||||||||||||
import collections.abc | ||||||||||||||||||||||||||||||||
import contextlib | ||||||||||||||||||||||||||||||||
import functools | ||||||||||||||||||||||||||||||||
import glob | ||||||||||||||||||||||||||||||||
import importlib.metadata | ||||||||||||||||||||||||||||||||
import keyword | ||||||||||||||||||||||||||||||||
import re | ||||||||||||||||||||||||||||||||
|
@@ -1269,6 +1270,117 @@ | |||||||||||||||||||||||||||||||
f"{cls.name} backend has not implemented `has_operation` API" | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
@util.experimental | ||||||||||||||||||||||||||||||||
def read_csv( | ||||||||||||||||||||||||||||||||
self, path: str | Path, table_name: str | None = None, **kwargs: Any | ||||||||||||||||||||||||||||||||
) -> ir.Table: | ||||||||||||||||||||||||||||||||
"""Register a CSV file as a table in the current backend. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
This function reads a CSV file and registers it as a table in the current | ||||||||||||||||||||||||||||||||
backend. Note that for Impala and Trino backends, the performance may be suboptimal. | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Parameters | ||||||||||||||||||||||||||||||||
---------- | ||||||||||||||||||||||||||||||||
path | ||||||||||||||||||||||||||||||||
The data source. A string or Path to the CSV file. | ||||||||||||||||||||||||||||||||
table_name | ||||||||||||||||||||||||||||||||
An optional name to use for the created table. This defaults to | ||||||||||||||||||||||||||||||||
a sequentially generated name. | ||||||||||||||||||||||||||||||||
**kwargs | ||||||||||||||||||||||||||||||||
Additional keyword arguments passed to the backend loading function. | ||||||||||||||||||||||||||||||||
Common options are skip_rows, column_names, delimiter, and include_columns. | ||||||||||||||||||||||||||||||||
More details could be found: | ||||||||||||||||||||||||||||||||
https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html | ||||||||||||||||||||||||||||||||
https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html | ||||||||||||||||||||||||||||||||
https://arrow.apache.org/docs/python/generated/pyarrow.csv.ConvertOptions.html | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Returns | ||||||||||||||||||||||||||||||||
------- | ||||||||||||||||||||||||||||||||
ir.Table | ||||||||||||||||||||||||||||||||
The just-registered table | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Examples | ||||||||||||||||||||||||||||||||
-------- | ||||||||||||||||||||||||||||||||
Connect to a SQLite database: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
>>> con = ibis.sqlite.connect() | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Read a single csv file: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
>>> table = con.read_csv("path/to/file.csv") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Comment on lines
+1308
to
+1311
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
Read all csv files in a directory: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
>>> table = con.read_parquet("path/to/csv_directory/*") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Read all csv files with a glob pattern: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
>>> table = con.read_csv("path/to/csv_directory/test_*.csv") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Read csv file from s3: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
>>> table = con.read_csv("s3://bucket/path/to/file.csv") | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Read csv file with custom pyarrow options: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
>>> table = con.read_csv( | ||||||||||||||||||||||||||||||||
... "path/to/file.csv", delimiter=",", include_columns=["col1", "col3"] | ||||||||||||||||||||||||||||||||
... ) | ||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||
pa = self._import_pyarrow() | ||||||||||||||||||||||||||||||||
import pyarrow.csv as pcsv | ||||||||||||||||||||||||||||||||
from pyarrow import fs | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
read_options_args = {} | ||||||||||||||||||||||||||||||||
parse_options_args = {} | ||||||||||||||||||||||||||||||||
convert_options_args = {} | ||||||||||||||||||||||||||||||||
memory_pool = None | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
for key, value in kwargs.items(): | ||||||||||||||||||||||||||||||||
if hasattr(pcsv.ReadOptions, key): | ||||||||||||||||||||||||||||||||
read_options_args[key] = value | ||||||||||||||||||||||||||||||||
elif hasattr(pcsv.ParseOptions, key): | ||||||||||||||||||||||||||||||||
parse_options_args[key] = value | ||||||||||||||||||||||||||||||||
elif hasattr(pcsv.ConvertOptions, key): | ||||||||||||||||||||||||||||||||
convert_options_args[key] = value | ||||||||||||||||||||||||||||||||
elif key == "memory_pool": | ||||||||||||||||||||||||||||||||
memory_pool = value | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
raise ValueError(f"Invalid args: {key!r}") | ||||||||||||||||||||||||||||||||
jitingxu1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
read_options = pcsv.ReadOptions(**read_options_args) | ||||||||||||||||||||||||||||||||
parse_options = pcsv.ParseOptions(**parse_options_args) | ||||||||||||||||||||||||||||||||
convert_options = pcsv.ConvertOptions(**convert_options_args) | ||||||||||||||||||||||||||||||||
if not memory_pool: | ||||||||||||||||||||||||||||||||
memory_pool = pa.default_memory_pool() | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
path = str(path) | ||||||||||||||||||||||||||||||||
file_system, path = fs.FileSystem.from_uri(path) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if isinstance(file_system, fs.LocalFileSystem): | ||||||||||||||||||||||||||||||||
paths = glob.glob(path) | ||||||||||||||||||||||||||||||||
if not paths: | ||||||||||||||||||||||||||||||||
raise FileNotFoundError(f"No files found at {path!r}") | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
paths = [path] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
pyarrow_tables = [] | ||||||||||||||||||||||||||||||||
for path in paths: | ||||||||||||||||||||||||||||||||
with file_system.open_input_file(path) as f: | ||||||||||||||||||||||||||||||||
pyarrow_table = pcsv.read_csv( | ||||||||||||||||||||||||||||||||
f, | ||||||||||||||||||||||||||||||||
read_options=read_options, | ||||||||||||||||||||||||||||||||
parse_options=parse_options, | ||||||||||||||||||||||||||||||||
convert_options=convert_options, | ||||||||||||||||||||||||||||||||
memory_pool=memory_pool, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
pyarrow_tables.append(pyarrow_table) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
pyarrow_table = pa.concat_tables(pyarrow_tables) | ||||||||||||||||||||||||||||||||
table_name = table_name or util.gen_name("read_csv") | ||||||||||||||||||||||||||||||||
self.create_table(table_name, pyarrow_table) | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, I think this should probably be a temp table or a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||||||||||
return self.table(table_name) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def _transpile_sql(self, query: str, *, dialect: str | None = None) -> str: | ||||||||||||||||||||||||||||||||
# only transpile if dialect was passed | ||||||||||||||||||||||||||||||||
if dialect is None: | ||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,7 +12,11 @@ | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import ibis | ||||||||||||||||||||||||||
from ibis.backends.conftest import TEST_TABLES | ||||||||||||||||||||||||||
from ibis.backends.tests.errors import PySparkAnalysisException | ||||||||||||||||||||||||||
from ibis.backends.tests.errors import ( | ||||||||||||||||||||||||||
MySQLOperationalError, | ||||||||||||||||||||||||||
PyODBCProgrammingError, | ||||||||||||||||||||||||||
PySparkAnalysisException, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
from ibis.conftest import IS_SPARK_REMOTE | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||||||||||
|
@@ -21,9 +25,10 @@ | |||||||||||||||||||||||||
import pyarrow as pa | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
pytestmark = [ | ||||||||||||||||||||||||||
pytest.mark.notimpl(["druid", "exasol", "oracle"]), | ||||||||||||||||||||||||||
pytest.mark.notyet( | ||||||||||||||||||||||||||
["pyspark"], condition=IS_SPARK_REMOTE, raises=PySparkAnalysisException | ||||||||||||||||||||||||||
["pyspark"], | ||||||||||||||||||||||||||
condition=IS_SPARK_REMOTE, | ||||||||||||||||||||||||||
raises=PySparkAnalysisException, | ||||||||||||||||||||||||||
Comment on lines
+29
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like an unrelated formatting change |
||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -103,6 +108,7 @@ def gzip_csv(data_dir, tmp_path): | |||||||||||||||||||||||||
"trino", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): | ||||||||||||||||||||||||||
with pushd(data_dir / "csv"): | ||||||||||||||||||||||||||
with pytest.warns(FutureWarning, match="v9.1"): | ||||||||||||||||||||||||||
|
@@ -114,7 +120,7 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# TODO: rewrite or delete test when register api is removed | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["datafusion"]) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["datafusion", "druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
@pytest.mark.notyet( | ||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||
"bigquery", | ||||||||||||||||||||||||||
|
@@ -154,6 +160,7 @@ def test_register_csv_gz(con, data_dir, gzip_csv): | |||||||||||||||||||||||||
"trino", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_register_with_dotted_name(con, data_dir, tmp_path): | ||||||||||||||||||||||||||
basename = "foo.bar.baz/diamonds.csv" | ||||||||||||||||||||||||||
f = tmp_path.joinpath(basename) | ||||||||||||||||||||||||||
|
@@ -211,6 +218,7 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: | |||||||||||||||||||||||||
"trino", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_register_parquet( | ||||||||||||||||||||||||||
con, tmp_path, data_dir, fname, in_table_name, out_table_name | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
|
@@ -249,6 +257,7 @@ def test_register_parquet( | |||||||||||||||||||||||||
"trino", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_register_iterator_parquet( | ||||||||||||||||||||||||||
con, | ||||||||||||||||||||||||||
tmp_path, | ||||||||||||||||||||||||||
|
@@ -277,7 +286,7 @@ def test_register_iterator_parquet( | |||||||||||||||||||||||||
# TODO: remove entirely when `register` is removed | ||||||||||||||||||||||||||
# This same functionality is implemented across all backends | ||||||||||||||||||||||||||
# via `create_table` and tested in `test_client.py` | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["datafusion"]) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["datafusion", "druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
@pytest.mark.notyet( | ||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||
"bigquery", | ||||||||||||||||||||||||||
|
@@ -311,7 +320,7 @@ def test_register_pandas(con): | |||||||||||||||||||||||||
# TODO: remove entirely when `register` is removed | ||||||||||||||||||||||||||
# This same functionality is implemented across all backends | ||||||||||||||||||||||||||
# via `create_table` and tested in `test_client.py` | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["datafusion", "polars"]) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["datafusion", "polars", "druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
@pytest.mark.notyet( | ||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||
"bigquery", | ||||||||||||||||||||||||||
|
@@ -352,6 +361,7 @@ def test_register_pyarrow_tables(con): | |||||||||||||||||||||||||
"trino", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_csv_reregister_schema(con, tmp_path): | ||||||||||||||||||||||||||
foo = tmp_path.joinpath("foo.csv") | ||||||||||||||||||||||||||
with foo.open("w", newline="") as csvfile: | ||||||||||||||||||||||||||
|
@@ -380,10 +390,13 @@ def test_csv_reregister_schema(con, tmp_path): | |||||||||||||||||||||||||
"bigquery", | ||||||||||||||||||||||||||
"clickhouse", | ||||||||||||||||||||||||||
"datafusion", | ||||||||||||||||||||||||||
"druid", | ||||||||||||||||||||||||||
"exasol", | ||||||||||||||||||||||||||
"flink", | ||||||||||||||||||||||||||
"impala", | ||||||||||||||||||||||||||
"mysql", | ||||||||||||||||||||||||||
"mssql", | ||||||||||||||||||||||||||
"oracle", | ||||||||||||||||||||||||||
"polars", | ||||||||||||||||||||||||||
"postgres", | ||||||||||||||||||||||||||
"risingwave", | ||||||||||||||||||||||||||
|
@@ -417,6 +430,7 @@ def test_register_garbage(con, monkeypatch): | |||||||||||||||||||||||||
@pytest.mark.notyet( | ||||||||||||||||||||||||||
["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): | ||||||||||||||||||||||||||
pq = pytest.importorskip("pyarrow.parquet") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -457,6 +471,7 @@ def ft_data(data_dir): | |||||||||||||||||||||||||
"trino", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_read_parquet_glob(con, tmp_path, ft_data): | ||||||||||||||||||||||||||
pq = pytest.importorskip("pyarrow.parquet") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -473,18 +488,10 @@ def test_read_parquet_glob(con, tmp_path, ft_data): | |||||||||||||||||||||||||
assert table.count().execute() == nrows * ntables | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@pytest.mark.notyet( | ||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||
"flink", | ||||||||||||||||||||||||||
"impala", | ||||||||||||||||||||||||||
"mssql", | ||||||||||||||||||||||||||
"mysql", | ||||||||||||||||||||||||||
"postgres", | ||||||||||||||||||||||||||
"risingwave", | ||||||||||||||||||||||||||
"sqlite", | ||||||||||||||||||||||||||
"trino", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notyet(["flink"]) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid"]) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError) | ||||||||||||||||||||||||||
def test_read_csv_glob(con, tmp_path, ft_data): | ||||||||||||||||||||||||||
pc = pytest.importorskip("pyarrow.csv") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -519,6 +526,7 @@ def test_read_csv_glob(con, tmp_path, ft_data): | |||||||||||||||||||||||||
raises=ValueError, | ||||||||||||||||||||||||||
reason="read_json() missing required argument: 'schema'", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_read_json_glob(con, tmp_path, ft_data): | ||||||||||||||||||||||||||
nrows = len(ft_data) | ||||||||||||||||||||||||||
ntables = 2 | ||||||||||||||||||||||||||
|
@@ -562,14 +570,26 @@ def num_diamonds(data_dir): | |||||||||||||||||||||||||
"in_table_name", | ||||||||||||||||||||||||||
[param(None, id="default"), param("fancy_stones", id="file_name")], | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notyet( | ||||||||||||||||||||||||||
["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notyet(["flink"]) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid", "exasol", "oracle"]) | ||||||||||||||||||||||||||
def test_read_csv(con, data_dir, in_table_name, num_diamonds): | ||||||||||||||||||||||||||
if con.name in ("trino", "impala"): | ||||||||||||||||||||||||||
# TODO: remove after trino and impala have efficient insertion | ||||||||||||||||||||||||||
pytest.skip( | ||||||||||||||||||||||||||
"Both Impala and Trino lack efficient data insertion methods from Python." | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
fname = "diamonds.csv" | ||||||||||||||||||||||||||
with pushd(data_dir / "csv"): | ||||||||||||||||||||||||||
if con.name == "pyspark": | ||||||||||||||||||||||||||
# pyspark doesn't respect CWD | ||||||||||||||||||||||||||
if con.name in ( | ||||||||||||||||||||||||||
"pyspark", | ||||||||||||||||||||||||||
"sqlite", | ||||||||||||||||||||||||||
"mysql", | ||||||||||||||||||||||||||
"postgres", | ||||||||||||||||||||||||||
"risingwave", | ||||||||||||||||||||||||||
"mssql", | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
jitingxu1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
# pyspark backend doesn't respect CWD | ||||||||||||||||||||||||||
# backends using pyarrow implementation need absolute path | ||||||||||||||||||||||||||
fname = str(Path(fname).absolute()) | ||||||||||||||||||||||||||
table = con.read_csv(fname, table_name=in_table_name) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -594,3 +614,73 @@ def test_read_csv(con, data_dir, in_table_name, num_diamonds): | |||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
assert table.count().execute() == num_diamonds | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
@pytest.mark.parametrize( | ||||||||||||||||||||||||||
("skip_rows", "new_column_names", "delimiter", "include_columns"), | ||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||
param(True, True, False, False, id="skip_rows_with_column_names"), | ||||||||||||||||||||||||||
param(False, False, False, True, id="include_columns"), | ||||||||||||||||||||||||||
param(False, False, True, False, id="delimiter"), | ||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
@pytest.mark.notyet(["flink"]) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["druid"]) | ||||||||||||||||||||||||||
@pytest.mark.never( | ||||||||||||||||||||||||||
[ | ||||||||||||||||||||||||||
"duckdb", | ||||||||||||||||||||||||||
"polars", | ||||||||||||||||||||||||||
"bigquery", | ||||||||||||||||||||||||||
"clickhouse", | ||||||||||||||||||||||||||
"datafusion", | ||||||||||||||||||||||||||
"snowflake", | ||||||||||||||||||||||||||
"pyspark", | ||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||
reason="backend implements its own read_csv", | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
Comment on lines
+629
to
+640
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
You can remove this since you are skipping them inside the test body |
||||||||||||||||||||||||||
@pytest.mark.notimpl(["mssql"], raises=PyODBCProgrammingError) | ||||||||||||||||||||||||||
@pytest.mark.notimpl(["mysql"], raises=MySQLOperationalError) | ||||||||||||||||||||||||||
def test_read_csv_pyarrow_options( | ||||||||||||||||||||||||||
con, tmp_path, ft_data, skip_rows, new_column_names, delimiter, include_columns | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
pc = pytest.importorskip("pyarrow.csv") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if con.name in ( | ||||||||||||||||||||||||||
"duckdb", | ||||||||||||||||||||||||||
"polars", | ||||||||||||||||||||||||||
"bigquery", | ||||||||||||||||||||||||||
"clickhouse", | ||||||||||||||||||||||||||
"datafusion", | ||||||||||||||||||||||||||
"snowflake", | ||||||||||||||||||||||||||
"pyspark", | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
pytest.skip(f"{con.name} implements its own `read_parquet`") | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these backends have their own implementation, some of these options still could pass this test, so I skip these backends. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
column_names = ft_data.column_names | ||||||||||||||||||||||||||
num_rows = ft_data.num_rows | ||||||||||||||||||||||||||
fname = "tmp.csv" | ||||||||||||||||||||||||||
pc.write_csv(ft_data, tmp_path / fname) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
options = {} | ||||||||||||||||||||||||||
if skip_rows: | ||||||||||||||||||||||||||
options["skip_rows"] = 2 | ||||||||||||||||||||||||||
num_rows = num_rows - options["skip_rows"] + 1 | ||||||||||||||||||||||||||
if new_column_names: | ||||||||||||||||||||||||||
column_names = [f"col_{i}" for i in range(ft_data.num_columns)] | ||||||||||||||||||||||||||
options["column_names"] = column_names | ||||||||||||||||||||||||||
if delimiter: | ||||||||||||||||||||||||||
new_delimiter = "*" | ||||||||||||||||||||||||||
options["delimiter"] = new_delimiter | ||||||||||||||||||||||||||
pc.write_csv( | ||||||||||||||||||||||||||
ft_data, tmp_path / fname, pc.WriteOptions(delimiter=new_delimiter) | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
if include_columns: | ||||||||||||||||||||||||||
# try to include all types here | ||||||||||||||||||||||||||
# pick the first 12 columns | ||||||||||||||||||||||||||
column_names = column_names[:12] | ||||||||||||||||||||||||||
options["include_columns"] = column_names | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
table = con.read_csv(tmp_path / fname, **options) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
assert set(table.columns) == set(column_names) | ||||||||||||||||||||||||||
assert table.count().execute() == num_rows |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.