diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 81731cfeb9d9..c5f2f1f1f743 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -3408,7 +3408,7 @@ def write_database( if engine == "adbc": if if_exists == "fail": raise NotImplementedError( - "`if_exists` not yet supported with engine ADBC" + "`if_exists = 'fail'` not supported for ADBC engine" ) elif if_exists == "replace": mode = "create" diff --git a/py-polars/tests/unit/io/test_database_write.py b/py-polars/tests/unit/io/test_database_write.py index 8f89c5b414de..871a0990e141 100644 --- a/py-polars/tests/unit/io/test_database_write.py +++ b/py-polars/tests/unit/io/test_database_write.py @@ -1,7 +1,8 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING +from contextlib import suppress +from typing import TYPE_CHECKING, Any import pytest @@ -11,93 +12,155 @@ if TYPE_CHECKING: from pathlib import Path - from polars.type_aliases import DbWriteEngine, DbWriteMode + from polars.type_aliases import DbWriteEngine -@pytest.fixture() -def sample_df() -> pl.DataFrame: - return pl.DataFrame( +def adbc_sqlite_driver_version(*args: Any, **kwargs: Any) -> str: + with suppress(ModuleNotFoundError): # not available on 3.8/windows + import adbc_driver_sqlite + + return getattr(adbc_driver_sqlite, "__version__", "n/a") + return "n/a" + + +@pytest.mark.write_disk() +@pytest.mark.parametrize("engine", ["adbc", "sqlalchemy"]) +@pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc_driver_sqlite not available below Python 3.9 / on Windows", +) +def test_write_database_create(engine: DbWriteEngine, tmp_path: Path) -> None: + df = pl.DataFrame( { - "id": [1, 2], + "id": [1234, 5678], "name": ["misc", "other"], - "value": [100.0, -99.0], - "date": ["2020-01-01", "2021-12-31"], + "value": [1000.0, -9999.0], } ) + tmp_path.mkdir(exist_ok=True) + test_db = str(tmp_path / f"test_{engine}.db") + table_name = "test_create" + + df.write_database( + table_name=table_name, + connection=f"sqlite:///{test_db}", + if_exists="replace", + engine=engine, + ) + result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}") + assert_frame_equal(result, df) + + +@pytest.mark.write_disk() +@pytest.mark.parametrize("engine", ["adbc", "sqlalchemy"]) +@pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc_driver_sqlite not available below Python 3.9 / on Windows", +) +def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None: + df = pl.DataFrame( + { + "key": ["xx", "yy", "zz"], + "value": [123, None, 789], + "other": [5.5, 7.0, None], + } + ) + + tmp_path.mkdir(exist_ok=True) + test_db = str(tmp_path / f"test_{engine}.db") + table_name = "test_append" + + df.write_database( + table_name=table_name, + connection=f"sqlite:///{test_db}", + if_exists="replace", + engine=engine, + ) + + ExpectedError = NotImplementedError if engine == "adbc" else ValueError + with pytest.raises(ExpectedError): + df.write_database( + table_name=table_name, + connection=f"sqlite:///{test_db}", + if_exists="fail", + engine=engine, + ) + + df.write_database( + table_name=table_name, + connection=f"sqlite:///{test_db}", + if_exists="append", + engine=engine, + ) + result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}") + assert_frame_equal(result, pl.concat([df, df])) @pytest.mark.write_disk() @pytest.mark.parametrize( - ("engine", "mode"), + "engine", [ pytest.param( "adbc", - "create", - id="adbc_create", - marks=pytest.mark.skipif( - sys.version_info < (3, 9) or sys.platform == "win32", - reason="adbc_driver_sqlite not available below Python 3.9 / on Windows", - ), - ), - pytest.param( - "adbc", - "append", - id="adbc_append", marks=pytest.mark.skipif( - sys.version_info < (3, 9) or sys.platform == "win32", - reason="adbc_driver_sqlite not available below Python 3.9 / on Windows", + # see: https://github.com/apache/arrow-adbc/issues/1000 + adbc_sqlite_driver_version() == "0.6.0", + reason="ADBC SQLite driver v0.6.0 has a bug with quoted/qualified table names", ), ), - pytest.param( - "sqlalchemy", - "create", - id="sa_create", - ), - pytest.param( - "sqlalchemy", - "append", - id="sa_append", - ), + "sqlalchemy", ], ) -def test_write_database( - engine: DbWriteEngine, mode: DbWriteMode, sample_df: pl.DataFrame, tmp_path: Path +@pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc_driver_sqlite not available below Python 3.9 / on Windows", +) +def test_write_database_create_quoted_tablename( + engine: DbWriteEngine, tmp_path: Path ) -> None: + df = pl.DataFrame({"col x": [100, 200, 300], "col y": ["a", "b", "c"]}) + tmp_path.mkdir(exist_ok=True) - tmp_db = f"test_{engine}.db" - test_db = str(tmp_path / tmp_db) + test_db = str(tmp_path / f"test_{engine}.db") - # note: test a table name that requires quotes to ensure that we handle - # it correctly (also supply an explicit db schema with/without quotes) - table_name = "test_data" + # table name requires quoting, and is qualified with the implicit 'main' schema + table_name = 'main."test-append"' - sample_df.write_database( + df.write_database( table_name=table_name, connection=f"sqlite:///{test_db}", if_exists="replace", engine=engine, ) - if mode == "append": - sample_df.write_database( - table_name=table_name, - connection=f"sqlite:///{test_db}", - if_exists="append", - engine=engine, - ) - sample_df = pl.concat([sample_df, sample_df]) - result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}") - sample_df = sample_df.with_columns(pl.col("date").cast(pl.Utf8)) - assert_frame_equal(sample_df, result) + assert_frame_equal(result, df) + - # check that some invalid parameters raise errors - for invalid_params in ( - {"table_name": "w.x.y.z"}, - {"if_exists": "crunk", "table_name": table_name}, +def test_write_database_errors() -> None: + # confirm that invalid parameter values raise errors + df = pl.DataFrame({"colx": [1, 2, 3]}) + + with pytest.raises( + ValueError, match="`table_name` appears to be invalid: 'w.x.y.z'" + ): + df.write_database( + connection="sqlite:///:memory:", table_name="w.x.y.z", engine="sqlalchemy" + ) + + with pytest.raises( + NotImplementedError, match="`if_exists = 'fail'` not supported for ADBC engine" ): - with pytest.raises((ValueError, NotImplementedError)): - sample_df.write_database( - connection=f"sqlite:///{test_db}", - engine=engine, - **invalid_params, # type: ignore[arg-type] - ) + df.write_database( + connection="sqlite:///:memory:", + table_name="test_errs", + if_exists="fail", + engine="adbc", + ) + + with pytest.raises(ValueError, match="'do_something' is not valid for if_exists"): + df.write_database( + connection="sqlite:///:memory:", + table_name="main.test_errs", + if_exists="do_something", # type: ignore[arg-type] + engine="sqlalchemy", + )