From 623393df1b1ff133d9673a0042ef26f5abf33dfe Mon Sep 17 00:00:00 2001 From: DanCardin Date: Mon, 25 Sep 2023 16:38:42 -0400 Subject: [PATCH] fix: Ensure base model compatibility with DeclarativeBase subclasses. --- pyproject.toml | 2 +- .../compat/sqlalchemy.py | 10 +++++++++ src/pytest_mock_resources/sqlalchemy.py | 6 +++--- tests/__init__.py | 1 + tests/fixture/postgres/__init__.py | 0 .../postgres/test_sqlalchemy2_base_class.py | 21 +++++++++++++++++++ 6 files changed, 36 insertions(+), 4 deletions(-) create mode 100644 tests/fixture/postgres/__init__.py create mode 100644 tests/fixture/postgres/test_sqlalchemy2_base_class.py diff --git a/pyproject.toml b/pyproject.toml index 5ee24cf5..6450c8df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pytest-mock-resources" -version = "2.9.1" +version = "2.9.2" description = "A pytest plugin for easily instantiating reproducible mock resources." authors = [ "Omar Khan ", diff --git a/src/pytest_mock_resources/compat/sqlalchemy.py b/src/pytest_mock_resources/compat/sqlalchemy.py index 889caa22..f8c27365 100644 --- a/src/pytest_mock_resources/compat/sqlalchemy.py +++ b/src/pytest_mock_resources/compat/sqlalchemy.py @@ -1,10 +1,12 @@ import sqlalchemy import sqlalchemy.engine.url +from sqlalchemy.schema import MetaData from pytest_mock_resources.compat.import_ import ImportAdaptor version = getattr(sqlalchemy, "__version__", "") + if version.startswith("1.4") or version.startswith("2."): from sqlalchemy.ext import asyncio from sqlalchemy.orm import declarative_base, DeclarativeMeta @@ -29,6 +31,14 @@ def _select(*args, **kwargs): select = _select +def extract_model_base_metadata(base) -> sqlalchemy.MetaData | None: + metadata = getattr(base, "metadata", None) + if isinstance(metadata, MetaData): + return metadata + + return None + + __all__ = [ "asyncio", "declarative_base", diff --git a/src/pytest_mock_resources/sqlalchemy.py b/src/pytest_mock_resources/sqlalchemy.py index ff46689e..2f4b9f68 100644 --- a/src/pytest_mock_resources/sqlalchemy.py +++ b/src/pytest_mock_resources/sqlalchemy.py @@ -247,9 +247,9 @@ def normalize_actions( unique_metadata: Set[MetaData] = set() normalized_actions: List[T] = [] for action in ordered_actions: - if isinstance(action, compat.sqlalchemy.DeclarativeMeta): - action = action.metadata - normalized_actions.append(action) + metadata = compat.sqlalchemy.extract_model_base_metadata(action) + if metadata: + normalized_actions.append(metadata) elif isinstance(action, Rows): new_metadata = {row.metadata for row in action.rows} - unique_metadata diff --git a/tests/__init__.py b/tests/__init__.py index d47f203f..e6ca806e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,6 +4,7 @@ from pytest_mock_resources.compat import sqlalchemy +is_at_least_sqlalchemy2 = sqlalchemy.version.startswith("2.") is_sqlalchemy2 = sqlalchemy.version.startswith("1.4") or sqlalchemy.version.startswith("2.") skip_if_sqlalchemy2 = pytest.mark.skipif( is_sqlalchemy2, diff --git a/tests/fixture/postgres/__init__.py b/tests/fixture/postgres/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixture/postgres/test_sqlalchemy2_base_class.py b/tests/fixture/postgres/test_sqlalchemy2_base_class.py new file mode 100644 index 00000000..0bc6490b --- /dev/null +++ b/tests/fixture/postgres/test_sqlalchemy2_base_class.py @@ -0,0 +1,21 @@ +from tests import is_at_least_sqlalchemy2 + +if is_at_least_sqlalchemy2: + from sqlalchemy import Column, Integer + from sqlalchemy.orm import DeclarativeBase + + from pytest_mock_resources import create_postgres_fixture + + class Base(DeclarativeBase): + ... + + class Thing(Base): + __tablename__ = "thing" + + id = Column(Integer, autoincrement=True, primary_key=True) + + pg = create_postgres_fixture(Base, session=True) + + def test_creates_ddl(pg): + rows = pg.query(Thing).all() + assert len(rows) == 0