Skip to content

Commit

Permalink
test(python): refactored write_database unit tests to properly sepa…
Browse files Browse the repository at this point in the history
…rate concerns (#10773)
  • Loading branch information
alexander-beedie authored Aug 29, 2023
1 parent d7a7c3e commit 4963a43
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 62 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3408,7 +3408,7 @@ def write_database(
if engine == "adbc":
if if_exists == "fail":
raise NotImplementedError(
"`if_exists` not yet supported with engine ADBC"
"`if_exists = 'fail'` not supported for ADBC engine"
)
elif if_exists == "replace":
mode = "create"
Expand Down
185 changes: 124 additions & 61 deletions py-polars/tests/unit/io/test_database_write.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING
from contextlib import suppress
from typing import TYPE_CHECKING, Any

import pytest

Expand All @@ -11,93 +12,155 @@
if TYPE_CHECKING:
from pathlib import Path

from polars.type_aliases import DbWriteEngine, DbWriteMode
from polars.type_aliases import DbWriteEngine


@pytest.fixture()
def sample_df() -> pl.DataFrame:
return pl.DataFrame(
def adbc_sqlite_driver_version(*args: Any, **kwargs: Any) -> str:
with suppress(ModuleNotFoundError): # not available on 3.8/windows
import adbc_driver_sqlite

return getattr(adbc_driver_sqlite, "__version__", "n/a")
return "n/a"


@pytest.mark.write_disk()
@pytest.mark.parametrize("engine", ["adbc", "sqlalchemy"])
@pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available below Python 3.9 / on Windows",
)
def test_write_database_create(engine: DbWriteEngine, tmp_path: Path) -> None:
df = pl.DataFrame(
{
"id": [1, 2],
"id": [1234, 5678],
"name": ["misc", "other"],
"value": [100.0, -99.0],
"date": ["2020-01-01", "2021-12-31"],
"value": [1000.0, -9999.0],
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
table_name = "test_create"

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="replace",
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
assert_frame_equal(result, df)


@pytest.mark.write_disk()
@pytest.mark.parametrize("engine", ["adbc", "sqlalchemy"])
@pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available below Python 3.9 / on Windows",
)
def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None:
df = pl.DataFrame(
{
"key": ["xx", "yy", "zz"],
"value": [123, None, 789],
"other": [5.5, 7.0, None],
}
)

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
table_name = "test_append"

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="replace",
engine=engine,
)

ExpectedError = NotImplementedError if engine == "adbc" else ValueError
with pytest.raises(ExpectedError):
df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="fail",
engine=engine,
)

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="append",
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
assert_frame_equal(result, pl.concat([df, df]))


@pytest.mark.write_disk()
@pytest.mark.parametrize(
("engine", "mode"),
"engine",
[
pytest.param(
"adbc",
"create",
id="adbc_create",
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available below Python 3.9 / on Windows",
),
),
pytest.param(
"adbc",
"append",
id="adbc_append",
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available below Python 3.9 / on Windows",
# see: https://github.com/apache/arrow-adbc/issues/1000
adbc_sqlite_driver_version() == "0.6.0",
reason="ADBC SQLite driver v0.6.0 has a bug with quoted/qualified table names",
),
),
pytest.param(
"sqlalchemy",
"create",
id="sa_create",
),
pytest.param(
"sqlalchemy",
"append",
id="sa_append",
),
"sqlalchemy",
],
)
def test_write_database(
engine: DbWriteEngine, mode: DbWriteMode, sample_df: pl.DataFrame, tmp_path: Path
@pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available below Python 3.9 / on Windows",
)
def test_write_database_create_quoted_tablename(
engine: DbWriteEngine, tmp_path: Path
) -> None:
df = pl.DataFrame({"col x": [100, 200, 300], "col y": ["a", "b", "c"]})

tmp_path.mkdir(exist_ok=True)
tmp_db = f"test_{engine}.db"
test_db = str(tmp_path / tmp_db)
test_db = str(tmp_path / f"test_{engine}.db")

# note: test a table name that requires quotes to ensure that we handle
# it correctly (also supply an explicit db schema with/without quotes)
table_name = "test_data"
# table name requires quoting, and is qualified with the implicit 'main' schema
table_name = 'main."test-append"'

sample_df.write_database(
df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="replace",
engine=engine,
)
if mode == "append":
sample_df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="append",
engine=engine,
)
sample_df = pl.concat([sample_df, sample_df])

result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
sample_df = sample_df.with_columns(pl.col("date").cast(pl.Utf8))
assert_frame_equal(sample_df, result)
assert_frame_equal(result, df)


# check that some invalid parameters raise errors
for invalid_params in (
{"table_name": "w.x.y.z"},
{"if_exists": "crunk", "table_name": table_name},
def test_write_database_errors() -> None:
# confirm that invalid parameter values raise errors
df = pl.DataFrame({"colx": [1, 2, 3]})

with pytest.raises(
ValueError, match="`table_name` appears to be invalid: 'w.x.y.z'"
):
df.write_database(
connection="sqlite:///:memory:", table_name="w.x.y.z", engine="sqlalchemy"
)

with pytest.raises(
NotImplementedError, match="`if_exists = 'fail'` not supported for ADBC engine"
):
with pytest.raises((ValueError, NotImplementedError)):
sample_df.write_database(
connection=f"sqlite:///{test_db}",
engine=engine,
**invalid_params, # type: ignore[arg-type]
)
df.write_database(
connection="sqlite:///:memory:",
table_name="test_errs",
if_exists="fail",
engine="adbc",
)

with pytest.raises(ValueError, match="'do_something' is not valid for if_exists"):
df.write_database(
connection="sqlite:///:memory:",
table_name="main.test_errs",
if_exists="do_something", # type: ignore[arg-type]
engine="sqlalchemy",
)

0 comments on commit 4963a43

Please sign in to comment.