diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index fb9ce68453aa4..c64f6251cd3ee 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -5,6 +5,7 @@ import contextlib import re import warnings +import weakref from functools import cached_property from operator import itemgetter from typing import TYPE_CHECKING, Any @@ -255,18 +256,19 @@ def drop_database(self, name: str, force: bool = False) -> None: def begin(self): con = self.con cur = con.cursor() + autocommit = con.autocommit_mode - if not con.autocommit_mode: + if not autocommit: con.begin() try: yield cur except Exception: - if not con.autocommit_mode: + if not autocommit: con.rollback() raise else: - if not con.autocommit_mode: + if not autocommit: con.commit() finally: cur.close() @@ -283,20 +285,22 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: query = query.sql(dialect=self.name) con = self.con + autocommit = con.autocommit_mode + cursor = con.cursor() - if not con.autocommit_mode: + if not autocommit: con.begin() try: cursor.execute(query, **kwargs) except Exception: - if not con.autocommit_mode: + if not autocommit: con.rollback() cursor.close() raise else: - if not con.autocommit_mode: + if not autocommit: con.commit() return cursor @@ -504,6 +508,30 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: if not df.empty: cur.executemany(sql, data) + def _register_memtable_finalizer(self, op: ops.InMemoryTable): + weakref.finalize(op, self.drop_table, op.name, force=True, temp=True) + + def drop_table( + self, + name: str, + database: tuple[str, str] | str | None = None, + force: bool = False, + temp: bool = False, + ) -> None: + table_loc = self._warn_and_create_table_loc(database, None) + catalog, db = self._to_catalog_db_tuple(table_loc) + + drop_stmt = sg.exp.Drop( + kind="TABLE", + this=sg.table(name, db=db, catalog=catalog, quoted=self.compiler.quoted), + exists=force, + temporary=temp, + ) + drop_stmt_sql = drop_stmt.sql(self.dialect) + + with self.con.cursor() as cur: + cur.execute(drop_stmt_sql) + @util.experimental def to_pyarrow_batches( self,