From aa6058433f8ff35d837415e6368a6436fe96da26 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 17 Dec 2024 06:45:45 -0500 Subject: [PATCH] refactor(register): remove deprecated register method (#10545) BREAKING CHANGE: The deprecated `register` method has been removed. Please use the file-specific `read_*` methods instead. For in-memory objects, pass them to `ibis.memtable` or `create_table`. --- ibis/backends/datafusion/__init__.py | 147 +++--- .../backends/datafusion/tests/test_connect.py | 6 +- .../datafusion/tests/test_register.py | 72 --- ibis/backends/duckdb/__init__.py | 72 --- .../tests/{test_register.py => test_io.py} | 14 - ibis/backends/polars/__init__.py | 69 +-- ibis/backends/pyspark/__init__.py | 60 --- .../tests/{test_register.py => test_io.py} | 483 ++++++------------ 8 files changed, 215 insertions(+), 708 deletions(-) delete mode 100644 ibis/backends/datafusion/tests/test_register.py rename ibis/backends/duckdb/tests/{test_register.py => test_io.py} (96%) rename ibis/backends/tests/{test_register.py => test_io.py} (60%) diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 733d6a772b48..1a3af336e534 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -9,7 +9,6 @@ import datafusion as df import pyarrow as pa -import pyarrow.dataset as ds import pyarrow_hotfix # noqa: F401 import sqlglot as sg import sqlglot.expressions as sge @@ -28,7 +27,7 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowSchema, PyArrowType -from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames +from ibis.util import gen_name, normalize_filename, normalize_filenames, warn_deprecated try: from datafusion import ExecutionContext as SessionContext @@ -88,37 +87,30 @@ def do_connect( Parameters ---------- config - Mapping of table names to files or a `SessionContext` + Mapping of table names to files (deprecated in 10.0) or a `SessionContext` instance. Examples -------- + >>> from datafusion import SessionContext + >>> ctx = SessionContext() + >>> _ = ctx.from_pydict({"a": [1, 2, 3]}, "mytable") >>> import ibis - >>> config = { - ... "astronauts": "ci/ibis-testing-data/parquet/astronauts.parquet", - ... "diamonds": "ci/ibis-testing-data/csv/diamonds.csv", - ... } - >>> con = ibis.datafusion.connect(config) + >>> con = ibis.datafusion.connect(ctx) >>> con.list_tables() - ['astronauts', 'diamonds'] - >>> con.table("diamonds") - DatabaseTable: diamonds - carat float64 - cut string - color string - clarity string - depth float64 - table float64 - price int64 - x float64 - y float64 - z float64 + ['mytable'] """ if isinstance(config, SessionContext): (self.con, config) = (config, None) else: if config is not None and not isinstance(config, Mapping): raise TypeError("Input to ibis.datafusion.connect must be a mapping") + elif config is not None and config: # warn if dict is not empty + warn_deprecated( + "Passing a mapping of tables names to files", + as_of="10.0", + instead="Please use the explicit `read_*` methods for the files you would like to load instead.", + ) if SessionConfig is not None: df_config = SessionConfig( {"datafusion.sql_parser.dialect": "PostgreSQL"} @@ -178,6 +170,57 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: return PyArrowSchema.to_ibis(df.schema()) + def _register( + self, + source: str | Path | pa.Table | pa.RecordBatch | pa.Dataset | pd.DataFrame, + table_name: str | None = None, + **kwargs: Any, + ) -> ir.Table: + import pandas as pd + import pyarrow.dataset as ds + + if isinstance(source, (str, Path)): + first = str(source) + elif isinstance(source, pa.Table): + self.con.deregister_table(table_name) + self.con.register_record_batches(table_name, [source.to_batches()]) + return self.table(table_name) + elif isinstance(source, pa.RecordBatch): + self.con.deregister_table(table_name) + self.con.register_record_batches(table_name, [[source]]) + return self.table(table_name) + elif isinstance(source, ds.Dataset): + self.con.deregister_table(table_name) + self.con.register_dataset(table_name, source) + return self.table(table_name) + elif isinstance(source, pd.DataFrame): + return self.register(pa.Table.from_pandas(source), table_name, **kwargs) + else: + raise ValueError("`source` must be either a string or a pathlib.Path") + + if first.startswith(("parquet://", "parq://")) or first.endswith( + ("parq", "parquet") + ): + return self.read_parquet(source, table_name=table_name, **kwargs) + elif first.startswith(("csv://", "txt://")) or first.endswith( + ("csv", "tsv", "txt") + ): + return self.read_csv(source, table_name=table_name, **kwargs) + else: + self._register_failure() + return None + + def _register_failure(self): + import inspect + + msg = ", ".join( + m[0] for m in inspect.getmembers(self) if m[0].startswith("read_") + ) + raise ValueError( + f"Cannot infer appropriate read function for input, " + f"please call one of {msg} directly" + ) + def _register_builtin_udfs(self): from ibis.backends.datafusion import udfs @@ -345,68 +388,6 @@ def get_schema( table = database.table(table_name) return sch.schema(table.schema) - @deprecated( - as_of="9.1", - instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.", - ) - def register( - self, - source: str | Path | pa.Table | pa.RecordBatch | pa.Dataset | pd.DataFrame, - table_name: str | None = None, - **kwargs: Any, - ) -> ir.Table: - return self._register(source, table_name, **kwargs) - - def _register( - self, - source: str | Path | pa.Table | pa.RecordBatch | pa.Dataset | pd.DataFrame, - table_name: str | None = None, - **kwargs: Any, - ) -> ir.Table: - import pandas as pd - - if isinstance(source, (str, Path)): - first = str(source) - elif isinstance(source, pa.Table): - self.con.deregister_table(table_name) - self.con.register_record_batches(table_name, [source.to_batches()]) - return self.table(table_name) - elif isinstance(source, pa.RecordBatch): - self.con.deregister_table(table_name) - self.con.register_record_batches(table_name, [[source]]) - return self.table(table_name) - elif isinstance(source, ds.Dataset): - self.con.deregister_table(table_name) - self.con.register_dataset(table_name, source) - return self.table(table_name) - elif isinstance(source, pd.DataFrame): - return self.register(pa.Table.from_pandas(source), table_name, **kwargs) - else: - raise ValueError("`source` must be either a string or a pathlib.Path") - - if first.startswith(("parquet://", "parq://")) or first.endswith( - ("parq", "parquet") - ): - return self.read_parquet(source, table_name=table_name, **kwargs) - elif first.startswith(("csv://", "txt://")) or first.endswith( - ("csv", "tsv", "txt") - ): - return self.read_csv(source, table_name=table_name, **kwargs) - else: - self._register_failure() - return None - - def _register_failure(self): - import inspect - - msg = ", ".join( - m[0] for m in inspect.getmembers(self) if m[0].startswith("read_") - ) - raise ValueError( - f"Cannot infer appropriate read function for input, " - f"please call one of {msg} directly" - ) - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: # self.con.register_table is broken, so we do this roundabout thing # of constructing a datafusion DataFrame, which has a side effect diff --git a/ibis/backends/datafusion/tests/test_connect.py b/ibis/backends/datafusion/tests/test_connect.py index 6b3773f8370f..ff0ea2cf482a 100644 --- a/ibis/backends/datafusion/tests/test_connect.py +++ b/ibis/backends/datafusion/tests/test_connect.py @@ -25,13 +25,15 @@ def test_none_config(): def test_str_config(name_to_path): config = {name: str(path) for name, path in name_to_path.items()} - conn = ibis.datafusion.connect(config) + with pytest.warns(FutureWarning): + conn = ibis.datafusion.connect(config) assert sorted(conn.list_tables()) == sorted(name_to_path) def test_path_config(name_to_path): config = name_to_path - conn = ibis.datafusion.connect(config) + with pytest.warns(FutureWarning): + conn = ibis.datafusion.connect(config) assert sorted(conn.list_tables()) == sorted(name_to_path) diff --git a/ibis/backends/datafusion/tests/test_register.py b/ibis/backends/datafusion/tests/test_register.py deleted file mode 100644 index de8bc971da47..000000000000 --- a/ibis/backends/datafusion/tests/test_register.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -import pathlib - -import pandas as pd -import pyarrow as pa -import pytest - -import ibis - - -@pytest.fixture -def conn(): - return ibis.datafusion.connect() - - -def test_read_csv(conn, data_dir): - t = conn.read_csv(data_dir / "csv" / "functional_alltypes.csv") - assert t.count().execute() - - -@pytest.mark.parametrize( - "function", - [pathlib.Path, str], -) -def test_read_csv_path_list(conn, data_dir, function): - path = data_dir / "csv" / "functional_alltypes.csv" - - t = conn.read_csv(path, table_name="alltypes1") - t2 = conn.read_csv([function(path), function(path)], table_name="alltypes2") - - assert t2.schema() == t.schema() - assert t2.count().execute() == 2 * t.count().execute() - - -def test_read_parquet(conn, data_dir): - t = conn.read_parquet(data_dir / "parquet" / "functional_alltypes.parquet") - assert t.count().execute() - - -def test_register_table(conn): - tab = pa.table({"x": [1, 2, 3]}) - conn.create_table("my_table", tab) - assert conn.table("my_table").x.sum().execute() == 6 - - -def test_register_pandas(conn): - df = pd.DataFrame({"x": [1, 2, 3]}) - conn.create_table("my_table", df) - assert conn.table("my_table").x.sum().execute() == 6 - - -def test_register_batches(conn): - batch = pa.record_batch([pa.array([1, 2, 3])], names=["x"]) - conn.create_table("my_table", batch) - assert conn.table("my_table").x.sum().execute() == 6 - - -def test_register_dataset(conn): - import pyarrow.dataset as ds - - tab = pa.table({"x": [1, 2, 3]}) - dataset = ds.InMemoryDataset(tab) - with pytest.warns(FutureWarning, match="v9.1"): - conn.register(dataset, "my_table") - assert conn.table("my_table").x.sum().execute() == 6 - - -def test_create_table_with_uppercase_name(conn): - tab = pa.table({"x": [1, 2, 3]}) - conn.create_table("MY_TABLE", tab) - assert conn.table("MY_TABLE").x.sum().execute() == 6 diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 03c463907443..fc93cc53b309 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -31,7 +31,6 @@ from ibis.backends.sql.compilers.base import STAR, AlterTable, C, RenameTable from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType -from ibis.util import deprecated if TYPE_CHECKING: from collections.abc import Iterable, Mapping, MutableMapping, Sequence @@ -483,77 +482,6 @@ def drop_database( with self._safe_raw_sql(sge.Drop(this=name, kind="SCHEMA", replace=force)): pass - @deprecated( - as_of="9.1", - instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.", - ) - def register( - self, - source: str | Path | Any, - table_name: str | None = None, - **kwargs: Any, - ) -> ir.Table: - """Register a data source as a table in the current database. - - Parameters - ---------- - source - The data source(s). May be a path to a file or directory of - parquet/csv files, an iterable of parquet or CSV files, a pandas - dataframe, a pyarrow table or dataset, or a postgres URI. - table_name - An optional name to use for the created table. This defaults to a - sequentially generated name. - **kwargs - Additional keyword arguments passed to DuckDB loading functions for - CSV or parquet. See https://duckdb.org/docs/data/csv and - https://duckdb.org/docs/data/parquet for more information. - - Returns - ------- - ir.Table - The just-registered table - - """ - - if isinstance(source, (str, Path)): - first = str(source) - elif isinstance(source, (list, tuple)): - first = source[0] - else: - try: - return self.read_in_memory(source, table_name=table_name, **kwargs) - except (duckdb.InvalidInputException, NameError): - self._register_failure() - - if first.startswith(("parquet://", "parq://")) or first.endswith( - ("parq", "parquet") - ): - return self.read_parquet(source, table_name=table_name, **kwargs) - elif first.startswith( - ("csv://", "csv.gz://", "txt://", "txt.gz://") - ) or first.endswith(("csv", "csv.gz", "tsv", "tsv.gz", "txt", "txt.gz")): - return self.read_csv(source, table_name=table_name, **kwargs) - elif first.startswith(("postgres://", "postgresql://")): - return self.read_postgres(source, table_name=table_name, **kwargs) - elif first.startswith("sqlite://"): - return self.read_sqlite( - first[len("sqlite://") :], table_name=table_name, **kwargs - ) - else: - self._register_failure() # noqa: RET503 - - def _register_failure(self): - import inspect - - msg = ", ".join( - name for name, _ in inspect.getmembers(self) if name.startswith("read_") - ) - raise ValueError( - f"Cannot infer appropriate read function for input, " - f"please call one of {msg} directly" - ) - @util.experimental def read_json( self, diff --git a/ibis/backends/duckdb/tests/test_register.py b/ibis/backends/duckdb/tests/test_io.py similarity index 96% rename from ibis/backends/duckdb/tests/test_register.py rename to ibis/backends/duckdb/tests/test_io.py index 0bb07c1d1f42..7410807daa38 100644 --- a/ibis/backends/duckdb/tests/test_register.py +++ b/ibis/backends/duckdb/tests/test_io.py @@ -229,20 +229,6 @@ def test_read_sqlite_no_table_name(con, tmp_path): con.read_sqlite(path) -@pytest.mark.xfail( - LINUX and SANDBOXED, - reason="nix on linux cannot download duckdb extensions or data due to sandboxing", - raises=duckdb.IOException, -) -def test_register_sqlite(con, tmp_path): - path = tmp_path / "test.db" - sqlite_con = sqlite3.connect(str(path)) - sqlite_con.execute("CREATE TABLE t AS SELECT 1 a UNION SELECT 2 UNION SELECT 3") - with pytest.warns(FutureWarning, match="v9.1"): - ft = con.register(f"sqlite://{path}", "t") - assert ft.count().execute() - - # Because we create a new connection and the test requires loading/installing a # DuckDB extension, we need to xfail these on Nix. @pytest.mark.xfail( diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 9bbedd6987ae..0a533541e309 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -19,7 +19,7 @@ from ibis.common.dispatch import lazy_singledispatch from ibis.expr.rewrites import lower_stringslice, replace_parameter from ibis.formats.polars import PolarsSchema -from ibis.util import deprecated, gen_name, normalize_filename, normalize_filenames +from ibis.util import gen_name, normalize_filename, normalize_filenames if TYPE_CHECKING: from collections.abc import Iterable @@ -106,73 +106,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: def _finalize_memtable(self, name: str) -> None: self.drop_table(name, force=True) - @deprecated( - as_of="9.1", - instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.", - ) - def register( - self, - source: str | Path | Any, - table_name: str | None = None, - **kwargs: Any, - ) -> ir.Table: - """Register a data source as a table in the current database. - - Parameters - ---------- - source - The data source(s). May be a path to a file, a parquet directory, or a pandas - dataframe. - table_name - An optional name to use for the created table. This defaults to - a sequentially generated name. - **kwargs - Additional keyword arguments passed to Polars loading functions for - CSV or parquet. - See https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_csv.html - and https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_parquet.html - for more information - - Returns - ------- - ir.Table - The just-registered table - - """ - - if isinstance(source, (str, Path)): - first = str(source) - elif isinstance(source, (list, tuple)): - first = str(source[0]) - else: - try: - return self.read_pandas(source, table_name=table_name, **kwargs) - except ValueError: - self._register_failure() - - if first.startswith(("parquet://", "parq://")) or first.endswith( - ("parq", "parquet") - ): - return self.read_parquet(source, table_name=table_name, **kwargs) - elif first.startswith( - ("csv://", "csv.gz://", "txt://", "txt.gz://") - ) or first.endswith(("csv", "csv.gz", "tsv", "tsv.gz", "txt", "txt.gz")): - return self.read_csv(source, table_name=table_name, **kwargs) - else: - self._register_failure() - return None - - def _register_failure(self): - import inspect - - msg = ", ".join( - m[0] for m in inspect.getmembers(self) if m[0].startswith("read_") - ) - raise ValueError( - f"Cannot infer appropriate read function for input, " - f"please call one of {msg} directly" - ) - def _add_table(self, name: str, obj: pl.LazyFrame | pl.DataFrame) -> None: if isinstance(obj, pl.DataFrame): obj = obj.lazy() diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index f8d9cfffb3c6..519b85d7d026 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -28,7 +28,6 @@ from ibis.backends.sql.compilers.base import AlterTable, RenameTable from ibis.expr.operations.udf import InputType from ibis.legacy.udf.vectorized import _coerce_to_series -from ibis.util import deprecated try: from pyspark.errors import ParseException @@ -912,65 +911,6 @@ def read_json( spark_df.createOrReplaceTempView(table_name) return self.table(table_name) - @deprecated( - as_of="9.1", - instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.", - ) - def register( - self, - source: str | Path | Any, - table_name: str | None = None, - **kwargs: Any, - ) -> ir.Table: - """Register a data source as a table in the current database. - - Parameters - ---------- - source - The data source(s). May be a path to a file or directory of - parquet/csv files, or an iterable of CSV files. - table_name - An optional name to use for the created table. This defaults to - a random generated name. - **kwargs - Additional keyword arguments passed to PySpark loading functions for - CSV or parquet. - - Returns - ------- - ir.Table - The just-registered table - - """ - if isinstance(source, (str, Path)): - first = str(source) - elif isinstance(source, (list, tuple)): - first = source[0] - else: - self._register_failure() - - if first.startswith(("parquet://", "parq://")) or first.endswith( - ("parq", "parquet") - ): - return self.read_parquet(source, table_name=table_name, **kwargs) - elif first.startswith( - ("csv://", "csv.gz://", "txt://", "txt.gz://") - ) or first.endswith(("csv", "csv.gz", "tsv", "tsv.gz", "txt", "txt.gz")): - return self.read_csv(source, table_name=table_name, **kwargs) - else: - self._register_failure() # noqa: RET503 - - def _register_failure(self): - import inspect - - msg = ", ".join( - name for name, _ in inspect.getmembers(self) if name.startswith("read_") - ) - raise ValueError( - f"Cannot infer appropriate read function for input, " - f"please call one of {msg} directly" - ) - @util.experimental def to_delta( self, diff --git a/ibis/backends/tests/test_register.py b/ibis/backends/tests/test_io.py similarity index 60% rename from ibis/backends/tests/test_register.py rename to ibis/backends/tests/test_io.py index 1ca96eb42221..7223db704116 100644 --- a/ibis/backends/tests/test_register.py +++ b/ibis/backends/tests/test_io.py @@ -48,79 +48,78 @@ def gzip_csv(data_dir, tmp_path): return str(f.absolute()) -# TODO: rewrite or delete test when register api is removed +@pytest.fixture(scope="module") +def num_diamonds(data_dir): + with open(data_dir / "csv" / "diamonds.csv") as f: + # subtract 1 for the header + return sum(1 for _ in f) - 1 + + +@pytest.fixture(scope="module") +def ft_data(data_dir): + pq = pytest.importorskip("pyarrow.parquet") + nrows = 5 + table = pq.read_table(data_dir.joinpath("parquet", "functional_alltypes.parquet")) + return table.slice(0, nrows) + + +DIAMONDS_COLUMN_TYPES = { + # snowflake's `INFER_SCHEMA` returns this for the diamonds CSV `price` + # column type + "snowflake": { + "carat": "decimal(3, 2)", + "depth": "decimal(3, 1)", + "table": "decimal(3, 1)", + "x": "decimal(4, 2)", + "y": "decimal(4, 2)", + "z": "decimal(4, 2)", + }, + "pyspark": {"price": "int32"}, +} + + @pytest.mark.parametrize( - ("fname", "in_table_name", "out_table_name"), - [ - param("diamonds.csv", None, "ibis_read_csv_", id="default"), - param( - "csv://diamonds.csv", - "Diamonds2", - "Diamonds2", - id="csv_name", - marks=pytest.mark.notyet( - ["pyspark"], reason="pyspark lowercases view names" - ), - ), - param( - "file://diamonds.csv", - "fancy_stones", - "fancy_stones", - id="file_name", - ), - param( - "file://diamonds.csv", - "fancy stones", - "fancy stones", - id="file_atypical_name", - marks=pytest.mark.notyet( - ["pyspark"], reason="no spaces allowed in view names" - ), - ), - param( - ["file://diamonds.csv", "diamonds.csv"], - "fancy_stones2", - "fancy_stones2", - id="multi_csv", - marks=pytest.mark.notyet( - ["datafusion"], - reason="doesn't accept multiple files to scan or read", - ), - ), - ], + "in_table_name", + [param(None, id="default"), param("fancy_stones", id="file_name")], ) @pytest.mark.notyet( - [ - "bigquery", - "clickhouse", - "flink", - "impala", - "mssql", - "mysql", - "postgres", - "risingwave", - "snowflake", - "sqlite", - "trino", - "databricks", - ] + ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] ) -def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): +def test_read_csv(con, data_dir, in_table_name, num_diamonds): + fname = "diamonds.csv" with pushd(data_dir / "csv"): - with pytest.warns(FutureWarning, match="v9.1"): - table = con.register(fname, table_name=in_table_name) + if con.name == "pyspark": + # pyspark doesn't respect CWD + fname = str(Path(fname).absolute()) + table = con.read_csv(fname, table_name=in_table_name) - assert any(out_table_name in t for t in con.list_tables()) - if con.name != "datafusion": - table.count().execute() + if in_table_name is not None: + assert table.op().name == in_table_name + + special_types = DIAMONDS_COLUMN_TYPES.get(con.name, {}) + + assert table.schema() == ibis.schema( + { + "carat": "float64", + "cut": "string", + "color": "string", + "clarity": "string", + "depth": "float64", + "table": "float64", + "price": "int64", + "x": "float64", + "y": "float64", + "z": "float64", + **special_types, + } + ) + assert table.count().execute() == num_diamonds -# TODO: rewrite or delete test when register api is removed @pytest.mark.notimpl(["datafusion"]) @pytest.mark.notyet( [ "bigquery", - "clickhouse", "flink", "impala", "mssql", @@ -133,19 +132,16 @@ def test_register_csv(con, data_dir, fname, in_table_name, out_table_name): "databricks", ] ) -def test_register_csv_gz(con, data_dir, gzip_csv): +def test_read_csv_gz(con, data_dir, gzip_csv): with pushd(data_dir): - with pytest.warns(FutureWarning, match="v9.1"): - table = con.register(gzip_csv) + table = con.read_csv(gzip_csv) assert table.count().execute() -# TODO: rewrite or delete test when register api is removed @pytest.mark.notyet( [ "bigquery", - "clickhouse", "flink", "impala", "mssql", @@ -157,193 +153,21 @@ def test_register_csv_gz(con, data_dir, gzip_csv): "trino", ] ) -def test_register_with_dotted_name(con, data_dir, tmp_path): +def test_read_csv_with_dotted_name(con, data_dir, tmp_path): basename = "foo.bar.baz/diamonds.csv" f = tmp_path.joinpath(basename) f.parent.mkdir() data = data_dir.joinpath("csv", "diamonds.csv").read_bytes() f.write_bytes(data) - with pytest.warns(FutureWarning, match="v9.1"): - table = con.register(str(f.absolute())) - - if con.name != "datafusion": - table.count().execute() - - -def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: - """For each csv `names` in `data_dir` return a `pyarrow.Table`.""" - pac = pytest.importorskip("pyarrow.csv") - - table_name = path.stem - schema = TEST_TABLES[table_name] - convert_options = pac.ConvertOptions( - column_types={name: typ.to_pyarrow() for name, typ in schema.items()} - ) - data_dir = path.parent - return pac.read_csv(data_dir / f"{table_name}.csv", convert_options=convert_options) - - -# TODO: rewrite or delete test when register api is removed -@pytest.mark.parametrize( - ("fname", "in_table_name", "out_table_name"), - [ - param( - "parquet://functional_alltypes.parquet", None, "ibis_read_parquet", id="url" - ), - param("functional_alltypes.parquet", "funk_all", "funk_all", id="basename"), - param( - "parquet://functional_alltypes.parq", "funk_all", "funk_all", id="url_parq" - ), - param( - "parquet://functional_alltypes", None, "ibis_read_parquet", id="url_no_ext" - ), - ], -) -@pytest.mark.notyet( - [ - "bigquery", - "clickhouse", - "flink", - "impala", - "mssql", - "mysql", - "postgres", - "risingwave", - "snowflake", - "sqlite", - "trino", - ] -) -def test_register_parquet( - con, tmp_path, data_dir, fname, in_table_name, out_table_name -): - pq = pytest.importorskip("pyarrow.parquet") - - fname = Path(fname) - table = read_table(data_dir / "csv" / fname.name) - - pq.write_table(table, tmp_path / fname.name) - - with pushd(tmp_path): - with pytest.warns(FutureWarning, match="v9.1"): - table = con.register(f"parquet://{fname.name}", table_name=in_table_name) - - assert any(out_table_name in t for t in con.list_tables()) + table = con.read_csv(str(f.absolute())) if con.name != "datafusion": table.count().execute() -# TODO: rewrite or delete test when register api is removed -@pytest.mark.notyet( - [ - "bigquery", - "clickhouse", - "datafusion", - "flink", - "impala", - "mssql", - "mysql", - "postgres", - "risingwave", - "pyspark", - "snowflake", - "sqlite", - "trino", - ] -) -def test_register_iterator_parquet( - con, - tmp_path, - data_dir, -): - pq = pytest.importorskip("pyarrow.parquet") - - table = read_table(data_dir / "csv" / "functional_alltypes.csv") - - pq.write_table(table, tmp_path / "functional_alltypes.parquet") - - with pushd(tmp_path): - with pytest.warns(FutureWarning, match="v9.1"): - table = con.register( - [ - "parquet://functional_alltypes.parquet", - "functional_alltypes.parquet", - ], - table_name=None, - ) - - assert any("ibis_read_parquet" in t for t in con.list_tables()) - assert table.count().execute() - - -# TODO: remove entirely when `register` is removed -# This same functionality is implemented across all backends -# via `create_table` and tested in `test_client.py` -@pytest.mark.notimpl(["datafusion"]) -@pytest.mark.notyet( - [ - "bigquery", - "clickhouse", - "flink", - "impala", - "mssql", - "mysql", - "postgres", - "risingwave", - "pyspark", - "snowflake", - "sqlite", - "trino", - ] -) -def test_register_pandas(con): - pd = pytest.importorskip("pandas") - df = pd.DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]}) - - with pytest.warns(FutureWarning, match="v9.1"): - t = con.register(df) - assert t.x.sum().execute() == 6 - - with pytest.warns(FutureWarning, match="v9.1"): - t = con.register(df, "my_table") - assert t.op().name == "my_table" - assert t.x.sum().execute() == 6 - - -# TODO: remove entirely when `register` is removed -# This same functionality is implemented across all backends -# via `create_table` and tested in `test_client.py` -@pytest.mark.notimpl(["datafusion", "polars"]) -@pytest.mark.notyet( - [ - "bigquery", - "clickhouse", - "flink", - "impala", - "mssql", - "mysql", - "postgres", - "risingwave", - "pyspark", - "snowflake", - "sqlite", - "trino", - ] -) -def test_register_pyarrow_tables(con): - pa = pytest.importorskip("pyarrow") - pa_t = pa.Table.from_pydict({"x": [1, 2, 3], "y": ["a", "b", "c"]}) - - with pytest.warns(FutureWarning, match="v9.1"): - t = con.register(pa_t) - assert t.x.sum().execute() == 6 - - @pytest.mark.notyet( [ "bigquery", - "clickhouse", "flink", "impala", "mssql", @@ -355,7 +179,7 @@ def test_register_pyarrow_tables(con): "trino", ] ) -def test_csv_reregister_schema(con, tmp_path): +def test_read_csv_schema(con, tmp_path): foo = tmp_path.joinpath("foo.csv") with foo.open("w", newline="") as csvfile: csv.writer(csvfile, delimiter=",").writerows( @@ -367,9 +191,8 @@ def test_csv_reregister_schema(con, tmp_path): ] ) - with pytest.warns(FutureWarning, match="v9.1"): - # For a full file scan, expect correct schema based on final row - foo_table = con.register(foo, table_name="same") + # For a full file scan, expect correct schema based on final row + foo_table = con.read_csv(foo, table_name="same") result_schema = foo_table.schema() assert result_schema.names == ("cola", "colb", "colc") @@ -378,36 +201,32 @@ def test_csv_reregister_schema(con, tmp_path): assert result_schema["colc"].is_string() -@pytest.mark.notimpl( +@pytest.mark.notyet( [ - "bigquery", - "clickhouse", - "datafusion", "flink", "impala", - "mysql", "mssql", - "polars", + "mysql", "postgres", "risingwave", - "pyspark", - "snowflake", "sqlite", "trino", ] ) -def test_register_garbage(con, monkeypatch): - # monkeypatch to avoid downloading extensions in tests - monkeypatch.setattr(con, "_load_extensions", lambda x: True) +def test_read_csv_glob(con, tmp_path, ft_data): + pc = pytest.importorskip("pyarrow.csv") - duckdb = pytest.importorskip("duckdb") - with pytest.raises( - duckdb.IOException, match="No files found that match the pattern" - ): - con.read_csv("garbage_notafile") + nrows = len(ft_data) + ntables = 2 + ext = "csv" - with pytest.raises((FileNotFoundError, duckdb.IOException)): - con.read_parquet("garbage_notafile") + fnames = [f"data{i}.{ext}" for i in range(ntables)] + for fname in fnames: + pc.write_csv(ft_data, tmp_path / fname) + + table = con.read_csv(tmp_path / f"*.{ext}") + + assert table.count().execute() == nrows * ntables @pytest.mark.parametrize( @@ -440,40 +259,58 @@ def test_read_parquet(con, tmp_path, data_dir, fname, in_table_name): assert table.count().execute() -@pytest.fixture(scope="module") -def ft_data(data_dir): - pq = pytest.importorskip("pyarrow.parquet") - nrows = 5 - table = pq.read_table(data_dir.joinpath("parquet", "functional_alltypes.parquet")) - return table.slice(0, nrows) +def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]: + """For each csv `names` in `data_dir` return a `pyarrow.Table`.""" + pac = pytest.importorskip("pyarrow.csv") + + table_name = path.stem + schema = TEST_TABLES[table_name] + convert_options = pac.ConvertOptions( + column_types={name: typ.to_pyarrow() for name, typ in schema.items()} + ) + data_dir = path.parent + return pac.read_csv(data_dir / f"{table_name}.csv", convert_options=convert_options) @pytest.mark.notyet( [ + "bigquery", + "clickhouse", + "datafusion", "flink", "impala", "mssql", "mysql", "postgres", "risingwave", + "pyspark", + "snowflake", "sqlite", "trino", ] ) -def test_read_parquet_glob(con, tmp_path, ft_data): +def test_read_parquet_iterator( + con, + tmp_path, + data_dir, +): pq = pytest.importorskip("pyarrow.parquet") - nrows = len(ft_data) - ntables = 2 - ext = "parquet" + table = read_table(data_dir / "csv" / "functional_alltypes.csv") - fnames = [f"data{i}.{ext}" for i in range(ntables)] - for fname in fnames: - pq.write_table(ft_data, tmp_path / fname) + pq.write_table(table, tmp_path / "functional_alltypes.parquet") - table = con.read_parquet(tmp_path / f"*.{ext}") + with pushd(tmp_path): + table = con.read_parquet( + [ + "parquet://functional_alltypes.parquet", + "functional_alltypes.parquet", + ], + table_name=None, + ) - assert table.count().execute() == nrows * ntables + assert any("ibis_read_parquet" in t for t in con.list_tables()) + assert table.count().execute() @pytest.mark.notyet( @@ -488,18 +325,18 @@ def test_read_parquet_glob(con, tmp_path, ft_data): "trino", ] ) -def test_read_csv_glob(con, tmp_path, ft_data): - pc = pytest.importorskip("pyarrow.csv") +def test_read_parquet_glob(con, tmp_path, ft_data): + pq = pytest.importorskip("pyarrow.parquet") nrows = len(ft_data) ntables = 2 - ext = "csv" + ext = "parquet" fnames = [f"data{i}.{ext}" for i in range(ntables)] for fname in fnames: - pc.write_csv(ft_data, tmp_path / fname) + pq.write_table(ft_data, tmp_path / fname) - table = con.read_csv(tmp_path / f"*.{ext}") + table = con.read_parquet(tmp_path / f"*.{ext}") assert table.count().execute() == nrows * ntables @@ -539,61 +376,33 @@ def test_read_json_glob(con, tmp_path, ft_data): assert table.count().execute() == nrows * ntables -@pytest.fixture(scope="module") -def num_diamonds(data_dir): - with open(data_dir / "csv" / "diamonds.csv") as f: - # subtract 1 for the header - return sum(1 for _ in f) - 1 - - -DIAMONDS_COLUMN_TYPES = { - # snowflake's `INFER_SCHEMA` returns this for the diamonds CSV `price` - # column type - "snowflake": { - "carat": "decimal(3, 2)", - "depth": "decimal(3, 1)", - "table": "decimal(3, 1)", - "x": "decimal(4, 2)", - "y": "decimal(4, 2)", - "z": "decimal(4, 2)", - }, - "pyspark": {"price": "int32"}, -} - - -@pytest.mark.parametrize( - "in_table_name", - [param(None, id="default"), param("fancy_stones", id="file_name")], -) -@pytest.mark.notyet( - ["flink", "impala", "mssql", "mysql", "postgres", "risingwave", "sqlite", "trino"] +@pytest.mark.notimpl( + [ + "bigquery", + "clickhouse", + "datafusion", + "flink", + "impala", + "mysql", + "mssql", + "polars", + "postgres", + "risingwave", + "pyspark", + "snowflake", + "sqlite", + "trino", + ] ) -def test_read_csv(con, data_dir, in_table_name, num_diamonds): - fname = "diamonds.csv" - with pushd(data_dir / "csv"): - if con.name == "pyspark": - # pyspark doesn't respect CWD - fname = str(Path(fname).absolute()) - table = con.read_csv(fname, table_name=in_table_name) - - if in_table_name is not None: - assert table.op().name == in_table_name +def test_read_garbage(con, monkeypatch): + # monkeypatch to avoid downloading extensions in tests + monkeypatch.setattr(con, "_load_extensions", lambda x: True) - special_types = DIAMONDS_COLUMN_TYPES.get(con.name, {}) + duckdb = pytest.importorskip("duckdb") + with pytest.raises( + duckdb.IOException, match="No files found that match the pattern" + ): + con.read_csv("garbage_notafile") - assert table.schema() == ibis.schema( - { - "carat": "float64", - "cut": "string", - "color": "string", - "clarity": "string", - "depth": "float64", - "table": "float64", - "price": "int64", - "x": "float64", - "y": "float64", - "z": "float64", - **special_types, - } - ) - assert table.count().execute() == num_diamonds + with pytest.raises((FileNotFoundError, duckdb.IOException)): + con.read_parquet("garbage_notafile")