Skip to content

Commit

Permalink
fix: Ensure base model compatibility with DeclarativeBase subclasses.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Sep 25, 2023
1 parent 2236bbe commit 623393d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest-mock-resources"
version = "2.9.1"
version = "2.9.2"
description = "A pytest plugin for easily instantiating reproducible mock resources."
authors = [
"Omar Khan <[email protected]>",
Expand Down
10 changes: 10 additions & 0 deletions src/pytest_mock_resources/compat/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import sqlalchemy
import sqlalchemy.engine.url
from sqlalchemy.schema import MetaData

from pytest_mock_resources.compat.import_ import ImportAdaptor

version = getattr(sqlalchemy, "__version__", "")


if version.startswith("1.4") or version.startswith("2."):
from sqlalchemy.ext import asyncio
from sqlalchemy.orm import declarative_base, DeclarativeMeta
Expand All @@ -29,6 +31,14 @@ def _select(*args, **kwargs):
select = _select


def extract_model_base_metadata(base) -> sqlalchemy.MetaData | None:
metadata = getattr(base, "metadata", None)
if isinstance(metadata, MetaData):
return metadata

return None


__all__ = [
"asyncio",
"declarative_base",
Expand Down
6 changes: 3 additions & 3 deletions src/pytest_mock_resources/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def normalize_actions(
unique_metadata: Set[MetaData] = set()
normalized_actions: List[T] = []
for action in ordered_actions:
if isinstance(action, compat.sqlalchemy.DeclarativeMeta):
action = action.metadata
normalized_actions.append(action)
metadata = compat.sqlalchemy.extract_model_base_metadata(action)
if metadata:
normalized_actions.append(metadata)

elif isinstance(action, Rows):
new_metadata = {row.metadata for row in action.rows} - unique_metadata
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pytest_mock_resources.compat import sqlalchemy

is_at_least_sqlalchemy2 = sqlalchemy.version.startswith("2.")
is_sqlalchemy2 = sqlalchemy.version.startswith("1.4") or sqlalchemy.version.startswith("2.")
skip_if_sqlalchemy2 = pytest.mark.skipif(
is_sqlalchemy2,
Expand Down
Empty file.
21 changes: 21 additions & 0 deletions tests/fixture/postgres/test_sqlalchemy2_base_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from tests import is_at_least_sqlalchemy2

if is_at_least_sqlalchemy2:
from sqlalchemy import Column, Integer
from sqlalchemy.orm import DeclarativeBase

from pytest_mock_resources import create_postgres_fixture

class Base(DeclarativeBase):
...

class Thing(Base):
__tablename__ = "thing"

id = Column(Integer, autoincrement=True, primary_key=True)

pg = create_postgres_fixture(Base, session=True)

def test_creates_ddl(pg):
rows = pg.query(Thing).all()
assert len(rows) == 0

0 comments on commit 623393d

Please sign in to comment.