diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 82fb98615100f..11f95ff104767 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -1,20 +1,3 @@ -"""SQL io tests - -The SQL tests are broken down in different classes: - -- `PandasSQLTest`: base class with common methods for all test classes -- Tests for the public API (only tests with sqlite3) - - `_TestSQLApi` base class - - `TestSQLApi`: test the public API with sqlalchemy engine - - `TestSQLiteFallbackApi`: test the public API with a sqlite DBAPI - connection -- Tests for the different SQL flavors (flavor specific type conversions) - - Tests for the sqlalchemy mode: `_TestSQLAlchemy` is the base class with - common methods. The different tested flavors (sqlite3, MySQL, - PostgreSQL) derive from the base class - - Tests for the fallback mode (`TestSQLiteFallback`) - -""" from __future__ import annotations import contextlib @@ -29,6 +12,7 @@ from io import StringIO from pathlib import Path import sqlite3 +from typing import TYPE_CHECKING import uuid import numpy as np @@ -69,13 +53,9 @@ read_sql_table, ) -try: +if TYPE_CHECKING: import sqlalchemy - SQLALCHEMY_INSTALLED = True -except ImportError: - SQLALCHEMY_INSTALLED = False - @pytest.fixture def sql_strings(): @@ -426,13 +406,16 @@ def drop_table( conn.commit() else: with conn.begin() as con: - sql.SQLDatabase(con).drop_table(table_name) + with sql.SQLDatabase(con) as db: + db.drop_table(table_name) def drop_view( view_name: str, conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ): + import sqlalchemy + if isinstance(conn, sqlite3.Connection): conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}") conn.commit() @@ -1151,10 +1134,9 @@ def test_read_sql_iris_parameter(conn, request, sql_strings, flavor): conn = request.getfixturevalue(conn) query = sql_strings["read_parameters"][flavor(conn_name)] params = ("Iris-setosa", 5.1) - pandasSQL = pandasSQL_builder(conn) - - with pandasSQL.run_transaction(): - iris_frame = pandasSQL.read_query(query, params=params) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=params) check_iris_frame(iris_frame) @@ -1179,9 +1161,9 @@ def test_read_sql_iris_no_parameter_with_percent(conn, request, sql_strings, fla conn = request.getfixturevalue(conn) query = sql_strings["read_no_parameters_with_percent"][flavor(conn_name)] - pandasSQL = pandasSQL_builder(conn) - with pandasSQL.run_transaction(): - iris_frame = pandasSQL.read_query(query, params=None) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=None) check_iris_frame(iris_frame) @@ -1914,7 +1896,8 @@ def test_warning_case_insensitive_table_name(conn, request, test_frame1): r"sensitivity issues. Consider using lower case table names." ), ): - sql.SQLDatabase(conn).check_case_sensitive("TABLE1", "") + with sql.SQLDatabase(conn) as db: + db.check_case_sensitive("TABLE1", "") # Test that the warning is certainly NOT triggered in a normal case. with tm.assert_produces_warning(None): @@ -1930,10 +1913,10 @@ def test_sqlalchemy_type_mapping(conn, request): df = DataFrame( {"time": to_datetime(["2014-12-12 01:54", "2014-12-11 02:54"], utc=True)} ) - db = sql.SQLDatabase(conn) - table = sql.SQLTable("test_type", db, frame=df) - # GH 9086: TIMESTAMP is the suggested type for datetimes with timezones - assert isinstance(table.table.c["time"].type, TIMESTAMP) + with sql.SQLDatabase(conn) as db: + table = sql.SQLTable("test_type", db, frame=df) + # GH 9086: TIMESTAMP is the suggested type for datetimes with timezones + assert isinstance(table.table.c["time"].type, TIMESTAMP) @pytest.mark.parametrize("conn", sqlalchemy_connectable) @@ -1961,10 +1944,10 @@ def test_sqlalchemy_integer_mapping(conn, request, integer, expected): # GH35076 Map pandas integer to optimal SQLAlchemy integer type conn = request.getfixturevalue(conn) df = DataFrame([0, 1], columns=["a"], dtype=integer) - db = sql.SQLDatabase(conn) - table = sql.SQLTable("test_type", db, frame=df) + with sql.SQLDatabase(conn) as db: + table = sql.SQLTable("test_type", db, frame=df) - result = str(table.table.c.a.type) + result = str(table.table.c.a.type) assert result == expected @@ -1974,16 +1957,16 @@ def test_sqlalchemy_integer_overload_mapping(conn, request, integer): conn = request.getfixturevalue(conn) # GH35076 Map pandas integer to optimal SQLAlchemy integer type df = DataFrame([0, 1], columns=["a"], dtype=integer) - db = sql.SQLDatabase(conn) - with pytest.raises( - ValueError, match="Unsigned 64 bit integer datatype is not supported" - ): - sql.SQLTable("test_type", db, frame=df) + with sql.SQLDatabase(conn) as db: + with pytest.raises( + ValueError, match="Unsigned 64 bit integer datatype is not supported" + ): + sql.SQLTable("test_type", db, frame=df) -@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="fails without SQLAlchemy") @pytest.mark.parametrize("conn", all_connectable) def test_database_uri_string(conn, request, test_frame1): + pytest.importorskip("sqlalchemy") conn = request.getfixturevalue(conn) # Test read_sql and .to_sql method with a database URI (GH10654) # db_uri = 'sqlite:///:memory:' # raises @@ -2003,9 +1986,9 @@ def test_database_uri_string(conn, request, test_frame1): @td.skip_if_installed("pg8000") -@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="fails without SQLAlchemy") @pytest.mark.parametrize("conn", all_connectable) def test_pg8000_sqlalchemy_passthrough_error(conn, request): + pytest.importorskip("sqlalchemy") conn = request.getfixturevalue(conn) # using driver that will not be installed on CI to trigger error # in sqlalchemy.create_engine -> test passing of this error to user @@ -2076,7 +2059,7 @@ def test_sql_open_close(test_frame3): tm.assert_frame_equal(test_frame3, result) -@pytest.mark.skipif(SQLALCHEMY_INSTALLED, reason="SQLAlchemy is installed") +@td.skip_if_installed("sqlalchemy") def test_con_string_import_error(): conn = "mysql://root@localhost/pandas" msg = "Using URI string without sqlalchemy installed" @@ -2084,7 +2067,7 @@ def test_con_string_import_error(): sql.read_sql("SELECT * FROM iris", conn) -@pytest.mark.skipif(SQLALCHEMY_INSTALLED, reason="SQLAlchemy is installed") +@td.skip_if_installed("sqlalchemy") def test_con_unknown_dbapi2_class_does_not_error_without_sql_alchemy_installed(): class MockSqliteConnection: def __init__(self, *args, **kwargs) -> None: @@ -2167,20 +2150,20 @@ def test_drop_table(conn, request): from sqlalchemy import inspect temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) - pandasSQL = sql.SQLDatabase(conn) - with pandasSQL.run_transaction(): - assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 + with sql.SQLDatabase(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 - insp = inspect(conn) - assert insp.has_table("temp_frame") + insp = inspect(conn) + assert insp.has_table("temp_frame") - with pandasSQL.run_transaction(): - pandasSQL.drop_table("temp_frame") - try: - insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior - except AttributeError: - pass - assert not insp.has_table("temp_frame") + with pandasSQL.run_transaction(): + pandasSQL.drop_table("temp_frame") + try: + insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior + except AttributeError: + pass + assert not insp.has_table("temp_frame") @pytest.mark.parametrize("conn", all_connectable) @@ -2206,9 +2189,10 @@ def test_roundtrip(conn, request, test_frame1): def test_execute_sql(conn, request): conn = request.getfixturevalue(conn) pandasSQL = pandasSQL_builder(conn) - with pandasSQL.run_transaction(): - iris_results = pandasSQL.execute("SELECT * FROM iris") - row = iris_results.fetchone() + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_results = pandasSQL.execute("SELECT * FROM iris") + row = iris_results.fetchone() tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"]) @@ -2616,10 +2600,10 @@ def test_to_sql_save_index(conn, request): [(1, 2.1, "line1"), (2, 1.5, "line2")], columns=["A", "B", "C"], index=["A"] ) - pandasSQL = pandasSQL_builder(conn) tbl_name = "test_to_sql_saves_index" - with pandasSQL.run_transaction(): - assert pandasSQL.to_sql(df, tbl_name) == 2 + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(df, tbl_name) == 2 if conn_name in {"sqlite_buildin", "sqlite_str"}: ixs = sql.read_sql_query( @@ -2648,55 +2632,55 @@ def test_transactions(conn, request): conn = request.getfixturevalue(conn) stmt = "CREATE TABLE test_trans (A INT, B TEXT)" - pandasSQL = pandasSQL_builder(conn) if conn_name != "sqlite_buildin": from sqlalchemy import text stmt = text(stmt) - with pandasSQL.run_transaction() as trans: - trans.execute(stmt) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction() as trans: + trans.execute(stmt) @pytest.mark.parametrize("conn", all_connectable) def test_transaction_rollback(conn, request): conn = request.getfixturevalue(conn) - pandasSQL = pandasSQL_builder(conn) - with pandasSQL.run_transaction() as trans: - stmt = "CREATE TABLE test_trans (A INT, B TEXT)" - if isinstance(pandasSQL, SQLiteDatabase): - trans.execute(stmt) - else: - from sqlalchemy import text + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction() as trans: + stmt = "CREATE TABLE test_trans (A INT, B TEXT)" + if isinstance(pandasSQL, SQLiteDatabase): + trans.execute(stmt) + else: + from sqlalchemy import text - stmt = text(stmt) - trans.execute(stmt) + stmt = text(stmt) + trans.execute(stmt) - class DummyException(Exception): - pass + class DummyException(Exception): + pass - # Make sure when transaction is rolled back, no rows get inserted - ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')" - if isinstance(pandasSQL, SQLDatabase): - from sqlalchemy import text + # Make sure when transaction is rolled back, no rows get inserted + ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')" + if isinstance(pandasSQL, SQLDatabase): + from sqlalchemy import text + + ins_sql = text(ins_sql) + try: + with pandasSQL.run_transaction() as trans: + trans.execute(ins_sql) + raise DummyException("error") + except DummyException: + # ignore raised exception + pass + with pandasSQL.run_transaction(): + res = pandasSQL.read_query("SELECT * FROM test_trans") + assert len(res) == 0 - ins_sql = text(ins_sql) - try: + # Make sure when transaction is committed, rows do get inserted with pandasSQL.run_transaction() as trans: trans.execute(ins_sql) - raise DummyException("error") - except DummyException: - # ignore raised exception - pass - with pandasSQL.run_transaction(): - res = pandasSQL.read_query("SELECT * FROM test_trans") - assert len(res) == 0 - - # Make sure when transaction is committed, rows do get inserted - with pandasSQL.run_transaction() as trans: - trans.execute(ins_sql) - res2 = pandasSQL.read_query("SELECT * FROM test_trans") - assert len(res2) == 1 + res2 = pandasSQL.read_query("SELECT * FROM test_trans") + assert len(res2) == 1 @pytest.mark.parametrize("conn", sqlalchemy_connectable) @@ -2975,9 +2959,9 @@ def test_invalid_engine(conn, request, test_frame1): conn = request.getfixturevalue(conn) msg = "engine must be one of 'auto', 'sqlalchemy'" - pandasSQL = pandasSQL_builder(conn) - with pytest.raises(ValueError, match=msg): - pandasSQL.to_sql(test_frame1, "test_frame1", engine="bad_engine") + with pandasSQL_builder(conn) as pandasSQL: + with pytest.raises(ValueError, match=msg): + pandasSQL.to_sql(test_frame1, "test_frame1", engine="bad_engine") @pytest.mark.parametrize("conn", all_connectable) @@ -2985,10 +2969,10 @@ def test_to_sql_with_sql_engine(conn, request, test_frame1): """`to_sql` with the `engine` param""" # mostly copied from this class's `_to_sql()` method conn = request.getfixturevalue(conn) - pandasSQL = pandasSQL_builder(conn) - with pandasSQL.run_transaction(): - assert pandasSQL.to_sql(test_frame1, "test_frame1", engine="auto") == 4 - assert pandasSQL.has_table("test_frame1") + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1", engine="auto") == 4 + assert pandasSQL.has_table("test_frame1") num_entries = len(test_frame1) num_rows = count_rows(conn, "test_frame1") @@ -3000,10 +2984,10 @@ def test_options_sqlalchemy(conn, request, test_frame1): # use the set option conn = request.getfixturevalue(conn) with pd.option_context("io.sql.engine", "sqlalchemy"): - pandasSQL = pandasSQL_builder(conn) - with pandasSQL.run_transaction(): - assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 - assert pandasSQL.has_table("test_frame1") + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 + assert pandasSQL.has_table("test_frame1") num_entries = len(test_frame1) num_rows = count_rows(conn, "test_frame1") @@ -3015,18 +2999,18 @@ def test_options_auto(conn, request, test_frame1): # use the set option conn = request.getfixturevalue(conn) with pd.option_context("io.sql.engine", "auto"): - pandasSQL = pandasSQL_builder(conn) - with pandasSQL.run_transaction(): - assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 - assert pandasSQL.has_table("test_frame1") + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 + assert pandasSQL.has_table("test_frame1") num_entries = len(test_frame1) num_rows = count_rows(conn, "test_frame1") assert num_rows == num_entries -@pytest.mark.skipif(not SQLALCHEMY_INSTALLED, reason="fails without SQLAlchemy") def test_options_get_engine(): + pytest.importorskip("sqlalchemy") assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) with pd.option_context("io.sql.engine", "sqlalchemy"): @@ -3463,17 +3447,16 @@ def test_self_join_date_columns(postgresql_psycopg2_engine): def test_create_and_drop_table(sqlite_sqlalchemy_memory_engine): conn = sqlite_sqlalchemy_memory_engine temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) - pandasSQL = sql.SQLDatabase(conn) - - with pandasSQL.run_transaction(): - assert pandasSQL.to_sql(temp_frame, "drop_test_frame") == 4 + with sql.SQLDatabase(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(temp_frame, "drop_test_frame") == 4 - assert pandasSQL.has_table("drop_test_frame") + assert pandasSQL.has_table("drop_test_frame") - with pandasSQL.run_transaction(): - pandasSQL.drop_table("drop_test_frame") + with pandasSQL.run_transaction(): + pandasSQL.drop_table("drop_test_frame") - assert not pandasSQL.has_table("drop_test_frame") + assert not pandasSQL.has_table("drop_test_frame") def test_sqlite_datetime_date(sqlite_buildin):