Skip to content

Commit

Permalink
Enhance to_sql, read_sql
Browse files Browse the repository at this point in the history
  • Loading branch information
MJuddBooth committed Aug 7, 2023
1 parent 861b2fb commit dbdba6f
Showing 1 changed file with 84 additions and 3 deletions.
87 changes: 84 additions & 3 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def read_sql_table(
columns: list[str] | None = ...,
chunksize: None = ...,
dtype_backend: DtypeBackend | lib.NoDefault = ...,
infer_index: bool = False,
) -> DataFrame:
...

Expand All @@ -255,6 +256,7 @@ def read_sql_table(
columns: list[str] | None = ...,
chunksize: int = ...,
dtype_backend: DtypeBackend | lib.NoDefault = ...,
infer_index: bool = False,
) -> Iterator[DataFrame]:
...

Expand All @@ -269,6 +271,7 @@ def read_sql_table(
columns: list[str] | None = None,
chunksize: int | None = None,
dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default,
infer_index: bool = False,
) -> DataFrame | Iterator[DataFrame]:
"""
Read SQL database table into a DataFrame.
Expand Down Expand Up @@ -315,6 +318,9 @@ def read_sql_table(
DataFrame.
.. versionadded:: 2.0
infer_index : bool, default False
if True and the table has a primary key, those columns will be
used to set the index on the returned dataframe.
Returns
-------
Expand Down Expand Up @@ -353,6 +359,7 @@ def read_sql_table(
columns=columns,
chunksize=chunksize,
dtype_backend=dtype_backend,
infer_index=infer_index,
)

if table is not None:
Expand Down Expand Up @@ -507,6 +514,7 @@ def read_sql(
chunksize: None = ...,
dtype_backend: DtypeBackend | lib.NoDefault = ...,
dtype: DtypeArg | None = None,
infer_index: bool = False,
) -> DataFrame:
...

Expand All @@ -523,6 +531,7 @@ def read_sql(
chunksize: int = ...,
dtype_backend: DtypeBackend | lib.NoDefault = ...,
dtype: DtypeArg | None = None,
infer_index: bool = False,
) -> Iterator[DataFrame]:
...

Expand All @@ -538,6 +547,7 @@ def read_sql(
chunksize: int | None = None,
dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default,
dtype: DtypeArg | None = None,
infer_index: bool = False,
) -> DataFrame | Iterator[DataFrame]:
"""
Read SQL query or database table into a DataFrame.
Expand Down Expand Up @@ -601,6 +611,9 @@ def read_sql(
The argument is ignored if a table is passed instead of a query.
.. versionadded:: 2.0.0
infer_index : bool, default: False
If true and reading from a table, infer the index columns from
any primary keys that are present on the table.
Returns
-------
Expand Down Expand Up @@ -677,6 +690,7 @@ def read_sql(
columns=columns,
chunksize=chunksize,
dtype_backend=dtype_backend,
infer_index=infer_index,
)
else:
return pandas_sql.read_query(
Expand All @@ -702,6 +716,7 @@ def to_sql(
chunksize: int | None = None,
dtype: DtypeArg | None = None,
method: Literal["multi"] | Callable | None = None,
create_pk: bool = False,
engine: str = "auto",
**engine_kwargs,
) -> int | None:
Expand Down Expand Up @@ -748,6 +763,9 @@ def to_sql(
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
create_pk : bool. default False
if True, a primary key will be created on the table composed of the
index columns.
engine : {'auto', 'sqlalchemy'}, default 'auto'
SQL engine library to use. If 'auto', then the option
``io.sql.engine`` is used. The default ``io.sql.engine``
Expand Down Expand Up @@ -795,6 +813,7 @@ def to_sql(
chunksize=chunksize,
dtype=dtype,
method=method,
create_pk=create_pk,
engine=engine,
**engine_kwargs,
)
Expand Down Expand Up @@ -877,7 +896,7 @@ def __init__(
pandas_sql_engine,
frame=None,
index: bool | str | list[str] | None = True,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
prefix: str = "pandas",
index_label=None,
schema=None,
Expand Down Expand Up @@ -928,6 +947,8 @@ def create(self) -> None:
if self.if_exists == "replace":
self.pd_sql.drop_table(self.name, self.schema)
self._execute_create()
elif self.if_exists == "truncate":
self.pd_sql.truncate_table(self.name, self.schema)
elif self.if_exists == "append":
pass
else:
Expand Down Expand Up @@ -1074,6 +1095,7 @@ def _query_iterator(
coerce_float: bool = True,
parse_dates=None,
dtype_backend: DtypeBackend | Literal["numpy"] = "numpy",
infer_index: bool = False,
):
"""Return generator through chunked result set."""
has_read_data = False
Expand All @@ -1098,7 +1120,10 @@ def _query_iterator(

if self.index is not None:
self.frame.set_index(self.index, inplace=True)

elif infer_index:
index = [c.name for c in self.table.columns if c.primary_key]
if index:
self.frame.set_index(index, inplace=True)
yield self.frame

def read(
Expand All @@ -1109,6 +1134,7 @@ def read(
columns=None,
chunksize: int | None = None,
dtype_backend: DtypeBackend | Literal["numpy"] = "numpy",
infer_index: bool = False,
) -> DataFrame | Iterator[DataFrame]:
from sqlalchemy import select

Expand All @@ -1132,6 +1158,7 @@ def read(
coerce_float=coerce_float,
parse_dates=parse_dates,
dtype_backend=dtype_backend,
infer_index=infer_index,
)
else:
data = result.fetchall()
Expand All @@ -1145,6 +1172,10 @@ def read(

if self.index is not None:
self.frame.set_index(self.index, inplace=True)
elif infer_index:
index = [c.name for c in self.table.columns if c.primary_key]
if index:
self.frame.set_index(index, inplace=True)

return self.frame

Expand Down Expand Up @@ -1401,6 +1432,7 @@ def read_table(
schema: str | None = None,
chunksize: int | None = None,
dtype_backend: DtypeBackend | Literal["numpy"] = "numpy",
infer_index: bool = False,
) -> DataFrame | Iterator[DataFrame]:
raise NotImplementedError

Expand Down Expand Up @@ -1430,6 +1462,7 @@ def to_sql(
chunksize: int | None = None,
dtype: DtypeArg | None = None,
method: Literal["multi"] | Callable | None = None,
create_pk: bool = False,
engine: str = "auto",
**engine_kwargs,
) -> int | None:
Expand Down Expand Up @@ -1609,6 +1642,7 @@ def read_table(
schema: str | None = None,
chunksize: int | None = None,
dtype_backend: DtypeBackend | Literal["numpy"] = "numpy",
infer_index: bool = False,
) -> DataFrame | Iterator[DataFrame]:
"""
Read SQL database table into a DataFrame.
Expand Down Expand Up @@ -1651,6 +1685,10 @@ def read_table(
DataFrame.
.. versionadded:: 2.0
infer_index : bool, default False
if True and the table has a primary key, those columns will be
used to set the index on the returned dataframe.
Returns
-------
Expand All @@ -1673,6 +1711,7 @@ def read_table(
columns=columns,
chunksize=chunksize,
dtype_backend=dtype_backend,
infer_index=infer_index,
)

@staticmethod
Expand Down Expand Up @@ -1808,11 +1847,12 @@ def prep_table(
self,
frame,
name: str,
if_exists: Literal["fail", "replace", "append"] = "fail",
if_exists: Literal["fail", "replace", "append", "truncate"] = "fail",
index: bool | str | list[str] | None = True,
index_label=None,
schema=None,
dtype: DtypeArg | None = None,
create_pk: bool = False,
) -> SQLTable:
"""
Prepares table in the database for data insertion. Creates it if needed, etc.
Expand Down Expand Up @@ -1849,6 +1889,23 @@ def prep_table(
schema=schema,
dtype=dtype,
)

if create_pk:
# somewhat wasteful, but re-create table with keys set from
# the index. Then we can use table.index, otherwise we need
# recreate a bunch of logic
keys = table.index()
table = SQLTable(
name,
self,
frame=frame,
index=index,
if_exists=if_exists,
index_label=index_label,
schema=schema,
keys=keys,
dtype=dtype,
)
table.create()
return table

Expand Down Expand Up @@ -1892,6 +1949,7 @@ def to_sql(
chunksize: int | None = None,
dtype: DtypeArg | None = None,
method: Literal["multi"] | Callable | None = None,
create_pk: bool = False,
engine: str = "auto",
**engine_kwargs,
) -> int | None:
Expand Down Expand Up @@ -1933,6 +1991,8 @@ def to_sql(
Details and a sample callable implementation can be found in the
section :ref:`insert method <io.sql.method>`.
create_pk : bool, default False
if True, us the index column(s) to create a primary key on the table
engine : {'auto', 'sqlalchemy'}, default 'auto'
SQL engine library to use. If 'auto', then the option
``io.sql.engine`` is used. The default ``io.sql.engine``
Expand All @@ -1953,6 +2013,7 @@ def to_sql(
index_label=index_label,
schema=schema,
dtype=dtype,
create_pk=create_pk,
)

total_inserted = sql_engine.insert_records(
Expand Down Expand Up @@ -2003,6 +2064,22 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
self.get_table(table_name, schema).drop(bind=self.con)
self.meta.clear()

def truncate_table(self, table_name: str, schema: str | None = None) -> None:
schema = schema or self.meta.schema
if self.has_table(table_name, schema):
self.meta.reflect(
bind=self.con, only=[table_name], schema=schema, views=True
)
with self.run_transaction():
table = self.get_table(table_name, schema)
# FIXME: There should be a more natural way to properly
# quote table and schema
fullname = ".".join([_get_valid_sqlite_name(x)
for x in [table.schema, table.name]])
with table.bind.begin() as conn:
conn.execute("TRUNCATE TABLE {} RESTART IDENTITY".format(fullname))
self.meta.clear()

def _create_sql_schema(
self,
frame: DataFrame,
Expand Down Expand Up @@ -2463,6 +2540,10 @@ def drop_table(self, name: str, schema: str | None = None) -> None:
drop_sql = f"DROP TABLE {_get_valid_sqlite_name(name)}"
self.execute(drop_sql)

def truncate_table(self, name: str, schema: str | None = None) -> None:
delete_sql = f"DELETE FROM {_get_valid_sqlite_name(name)}"
self.execute(delete_sql)

def _create_sql_schema(
self,
frame,
Expand Down

0 comments on commit dbdba6f

Please sign in to comment.