Skip to content

Commit

Permalink
refactor(duckdb): use .sql instead of .execute in performance-sen…
Browse files Browse the repository at this point in the history
…itive locations (#8669)
  • Loading branch information
jcrist authored Mar 17, 2024
1 parent 5e10d17 commit aa6aa0c
Showing 1 changed file with 54 additions and 40 deletions.
94 changes: 54 additions & 40 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,27 @@ def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:

super()._run_pre_execute_hooks(expr)

def _to_duckdb_relation(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
):
"""Preprocess the expr, and return a ``duckdb.DuckDBPyRelation`` object.
When retrieving in-memory results, it's faster to use `duckdb_con.sql`
than `duckdb_con.execute`, as the query planner can take advantage of
knowing the output type. Since the relation objects aren't compatible
with the dbapi, we choose to only use them in select internal methods
where performance might matter, and use the standard
`duckdb_con.execute` everywhere else.
"""
self._run_pre_execute_hooks(expr)
table_expr = expr.as_table()
sql = self.compile(table_expr, limit=limit, params=params)
return self.con.sql(sql)

def to_pyarrow_batches(
self,
expr: ir.Expr,
Expand Down Expand Up @@ -1220,12 +1241,7 @@ def to_pyarrow_batches(
def batch_producer(cur):
yield from cur.fetch_record_batch(rows_per_batch=chunk_size)

# TODO: check that this is still handled correctly
# batch_producer keeps the `self.con` member alive long enough to
# exhaust the record batch reader, even if the backend or connection
# have gone out of scope in the caller
result = self.raw_sql(sql)

return pa.RecordBatchReader.from_batches(
expr.as_table().schema().to_pyarrow(), batch_producer(result)
)
Expand All @@ -1238,14 +1254,40 @@ def to_pyarrow(
limit: int | str | None = None,
**_: Any,
) -> pa.Table:
self._run_pre_execute_hooks(expr)
table = expr.as_table()
sql = self.compile(table, limit=limit, params=params)
table = self._to_duckdb_relation(expr, params=params, limit=limit).arrow()
return expr.__pyarrow_result__(table)

with self._safe_raw_sql(sql) as cur:
table = cur.fetch_arrow_table()
def execute(
self,
expr: ir.Expr,
params: Mapping | None = None,
limit: str | None = "default",
**_: Any,
) -> Any:
"""Execute an expression."""
import pandas as pd
import pyarrow.types as pat

return expr.__pyarrow_result__(table)
table = self._to_duckdb_relation(expr, params=params, limit=limit).arrow()

df = pd.DataFrame(
{
name: (
col.to_pylist()
if (
pat.is_nested(col.type)
or
# pyarrow / duckdb type null literals columns as int32?
# but calling `to_pylist()` will render it as None
col.null_count
)
else col.to_pandas(timestamp_as_object=True)
)
for name, col in zip(table.column_names, table.columns)
}
)
df = DuckDBPandasData.convert_table(df, expr.as_table().schema())
return expr.__pandas_result__(df)

@util.experimental
def to_torch(
Expand Down Expand Up @@ -1275,9 +1317,7 @@ def to_torch(
A dictionary of torch tensors, keyed by column name.
"""
compiled = self.compile(expr, limit=limit, params=params, **kwargs)
with self._safe_raw_sql(compiled) as cur:
return cur.torch()
return self._to_duckdb_relation(expr, params=params, limit=limit).torch()

@util.experimental
def to_parquet(
Expand Down Expand Up @@ -1377,32 +1417,6 @@ def to_csv(
with self._safe_raw_sql(copy_cmd):
pass

def _fetch_from_cursor(
self, cursor: duckdb.DuckDBPyConnection, schema: sch.Schema
) -> pd.DataFrame:
import pandas as pd
import pyarrow.types as pat

table = cursor.fetch_arrow_table()

df = pd.DataFrame(
{
name: (
col.to_pylist()
if (
pat.is_nested(col.type)
or
# pyarrow / duckdb type null literals columns as int32?
# but calling `to_pylist()` will render it as None
col.null_count
)
else col.to_pandas(timestamp_as_object=True)
)
for name, col in zip(table.column_names, table.columns)
}
)
return DuckDBPandasData.convert_table(df, schema)

def _get_schema_using_query(self, query: str) -> sch.Schema:
with self._safe_raw_sql(f"DESCRIBE {query}") as cur:
rows = cur.fetch_arrow_table()
Expand Down

0 comments on commit aa6aa0c

Please sign in to comment.