Skip to content

Commit

Permalink
feat: Improve read_database typing (#19444)
Browse files Browse the repository at this point in the history
  • Loading branch information
wakabame authored Oct 27, 2024
1 parent dc47e92 commit 687811d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
10 changes: 6 additions & 4 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]

from sqlalchemy.sql.elements import TextClause


@overload
def read_database(
query: str | Selectable,
query: str | TextClause | Selectable,
connection: ConnectionOrCursor | str,
*,
iter_batches: Literal[False] = ...,
Expand All @@ -41,7 +43,7 @@ def read_database(

@overload
def read_database(
query: str | Selectable,
query: str | TextClause | Selectable,
connection: ConnectionOrCursor | str,
*,
iter_batches: Literal[True],
Expand All @@ -54,7 +56,7 @@ def read_database(

@overload
def read_database(
query: str | Selectable,
query: str | TextClause | Selectable,
connection: ConnectionOrCursor | str,
*,
iter_batches: bool,
Expand All @@ -66,7 +68,7 @@ def read_database(


def read_database(
query: str | Selectable,
query: str | TextClause | Selectable,
connection: ConnectionOrCursor | str,
*,
iter_batches: bool = False,
Expand Down
35 changes: 34 additions & 1 deletion py-polars/tests/unit/io/database/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pyarrow as pa
import pytest
import sqlalchemy
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import cast as alchemy_cast

Expand Down Expand Up @@ -383,6 +383,39 @@ def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:
assert_frame_equal(batches[0], expected)


def test_read_database_alchemy_textclause(tmp_sqlite_db: Path) -> None:
# various flavours of alchemy connection
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()

# establish sqlalchemy "textclause" and validate usage
textclause_query = text("""
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
FROM test_data
WHERE value < 0
""")

expected = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})

for conn in (alchemy_session, alchemy_engine, alchemy_conn):
assert_frame_equal(
pl.read_database(textclause_query, connection=conn),
expected,
)

batches = list(
pl.read_database(
textclause_query,
connection=conn,
iter_batches=True,
batch_size=1,
)
)
assert len(batches) == 1
assert_frame_equal(batches[0], expected)


def test_read_database_parameterised(tmp_sqlite_db: Path) -> None:
# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
Expand Down

0 comments on commit 687811d

Please sign in to comment.