Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/sqlalchemy_ddl_dialect_condition…
Browse files Browse the repository at this point in the history
…al' into sqlalchemy-2.0-fixes
  • Loading branch information
martinburchell committed Jan 23, 2025
2 parents fd81680 + 69d9b64 commit 955dae4
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 7 deletions.
8 changes: 5 additions & 3 deletions cardinal_pythonlib/sqlalchemy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,11 @@ def execute_ddl(
Previously we would use DDL(sql, bind=engine).execute(), but this has gone
in SQLAlchemy 2.0.
If you want dialect-conditional execution, create the DDL object with e.g.
ddl = DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER), and pass that
DDL object to this function.
Note that creating the DDL object with e.g. ddl =
DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER), and passing that
DDL object to this function, does NOT make execution condition; it executes
regardless. The execute_if() construct is used for listeners; see
https://docs.sqlalchemy.org/en/20/core/ddl.html#sqlalchemy.schema.ExecutableDDLElement.execute_if
"""
assert bool(sql) ^ (ddl is not None) # one or the other.
if sql:
Expand Down
8 changes: 7 additions & 1 deletion cardinal_pythonlib/sqlalchemy/sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sqlalchemy.schema import DDL

from cardinal_pythonlib.sqlalchemy.dialect import (
get_dialect_name,
quote_identifier,
SqlaDialectName,
)
Expand All @@ -49,7 +50,12 @@ def _exec_ddl_if_sqlserver(engine: Engine, sql: str) -> None:
"""
Execute DDL only if we are running on Microsoft SQL Server.
"""
ddl = DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER)
# DO NOT USE DDL(sql).execute_if(dialect=SqlaDialectName.SQLSERVER).
# IT IS NOT EXECUTED CONDITIONALLY VIA THIS METHOD.
dialect_name = get_dialect_name(engine)
if dialect_name != SqlaDialectName.SQLSERVER:
return
ddl = DDL(sql)
execute_ddl(engine, ddl=ddl)


Expand Down
82 changes: 80 additions & 2 deletions cardinal_pythonlib/sqlalchemy/tests/schema_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,19 @@

import logging
import unittest
import sys

from sqlalchemy import event, inspect, select
from sqlalchemy.dialects.mssql.base import MSDialect, DECIMAL as MS_DECIMAL
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.engine import create_engine
from sqlalchemy.exc import NoSuchTableError, OperationalError
from sqlalchemy.ext import compiler
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import declarative_base, Session, sessionmaker
from sqlalchemy.schema import (
Column,
CreateTable,
DDL,
DDLElement,
Index,
MetaData,
Expand All @@ -59,6 +61,10 @@
Time,
)

from cardinal_pythonlib.sqlalchemy.engine_func import (
get_dialect_name,
SqlaDialectName,
)
from cardinal_pythonlib.sqlalchemy.schema import (
add_index,
column_creation_ddl,
Expand Down Expand Up @@ -98,6 +104,9 @@
view_exists,
)
from cardinal_pythonlib.sqlalchemy.session import SQLITE_MEMORY_URL
from cardinal_pythonlib.sqlalchemy.sqlserver import (
if_sqlserver_disable_constraints,
)

Base = declarative_base()
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -485,7 +494,7 @@ def test_mssql_transaction_count(self) -> None:


class YetMoreSchemaTests(unittest.TestCase):
def __init__(self, *args, echo: bool = False, **kwargs) -> None:
def __init__(self, *args, echo: bool = True, **kwargs) -> None:
self.echo = echo
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -631,6 +640,75 @@ def test_execute_ddl(self) -> None:
with self.assertRaises(AssertionError):
execute_ddl(self.engine) # neither

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Dialect conditionality for DDL
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@staticmethod
def _present_in_log_output(_cm, msg: str) -> bool:
"""
Detects whether a string is present, INCLUDING AS A SUBSTRING, in
log output captured from an assertLogs() context manager.
"""
return any(msg in line for line in _cm.output)

def test_ddl_dialect_conditionality_1(self) -> None:
self.engine.echo = True # will write to log at INFO level

# 1. Check that logging capture works, and our _present_in_log_output
# function.
with self.assertLogs(level=logging.INFO) as cm:
log.info("dummy call")
self.assertTrue(self._present_in_log_output(cm, "dummy"))

# 2. Check our dialect is as expected: SQLite.
dialect_name = get_dialect_name(self.engine)
self.assertEqual(dialect_name, SqlaDialectName.SQLITE)

# 3. Seeing if DDL built with execute_if() will execute "directly" when
# set to execute-if-SQLite. It executes - but not conditionally!
ddl_yes = DDL("CREATE TABLE yesplease (a INT)").execute_if(
dialect=SqlaDialectName.SQLITE
)
with self.assertLogs(level=logging.INFO) as cm:
execute_ddl(self.engine, ddl=ddl_yes)
self.assertTrue(
self._present_in_log_output(cm, "CREATE TABLE yesplease")
)

# 4. Seeing if DDL built with execute_if() will execute "directly" when
# set to execute-if-MySQL. It executes - therefore not conditionally!
# I'd misunderstood this: it is NOT conditionally executed.
ddl_no = DDL("CREATE TABLE nothanks (a INT)").execute_if(
dialect=SqlaDialectName.MYSQL
)
with self.assertLogs(level=logging.INFO) as cm:
execute_ddl(self.engine, ddl=ddl_no)
self.assertTrue(
self._present_in_log_output(cm, "CREATE TABLE nothanks")
)
# I'd thought this would be false, but it is true.

def test_ddl_dialect_conditionality_2(self) -> None:
# Therefore:
self.engine.echo = True # will write to log at INFO level
# The test above (test_ddl_dialect_conditionality_1) proves that
# this code will log something if SQL is emitted.

session = sessionmaker(
bind=self.engine, future=True
)() # type: Session

if sys.version_info < (3, 10):
log.warning(
"Unable to use unittest.TestCase.assertNoLogs; "
"needs Python 3.10; skipping test"
)
return
with self.assertNoLogs(level=logging.INFO):
with if_sqlserver_disable_constraints(session, tablename="person"):
pass
# Should do nothing, therefore emit no logs.


class SchemaAbstractTests(unittest.TestCase):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion cardinal_pythonlib/version_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@
"""

VERSION_STRING = "2.0.0"
VERSION_STRING = "2.0.1"
# Use semantic versioning: https://semver.org/
6 changes: 6 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,9 @@ Quick links:
was apply filters (if required) and execute.

- Multiple internal changes to support SQLAlchemy 2.

**2.0.1 (2025-01-22)**

- Bugfix to ``cardinal_pythonlib.sqlalchemy.sqlserver`` functions as they
were executing unconditionally, regardless of SQLAlchemy dialect (they should
have been conditional to SQL Server).

0 comments on commit 955dae4

Please sign in to comment.