Skip to content

Commit

Permalink
update typing for sqlalchemy 2
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism committed Jan 29, 2023
1 parent 3b9965a commit c6e76b8
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 43 deletions.
13 changes: 1 addition & 12 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ coverage = [
mypy = [
"mypy",
"pytest",
"sqlalchemy[mypy]",
"sqlalchemy",
]
docs = [
"sphinx",
Expand Down
33 changes: 18 additions & 15 deletions src/flask_sqlalchemy/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sqlalchemy.event
import sqlalchemy.exc
import sqlalchemy.orm
import sqlalchemy.pool
from flask import abort
from flask import current_app
from flask import Flask
Expand All @@ -22,6 +21,7 @@
from .query import Query
from .session import _app_ctx_id
from .session import Session
from .table import _Table


class SQLAlchemy:
Expand Down Expand Up @@ -126,8 +126,8 @@ def __init__(
*,
metadata: sa.MetaData | None = None,
session_options: dict[str, t.Any] | None = None,
query_class: t.Type[Query] = Query,
model_class: t.Type[Model] | sa.orm.DeclarativeMeta = Model,
query_class: type[Query] = Query,
model_class: type[Model] | sa.orm.DeclarativeMeta = Model,
engine_options: dict[str, t.Any] | None = None,
add_models_to_shell: bool = True,
):
Expand Down Expand Up @@ -336,7 +336,9 @@ def init_app(self, app: Flask) -> None:

track_modifications._listen(self.session)

def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_session:
def _make_scoped_session(
self, options: dict[str, t.Any]
) -> sa.orm.scoped_session[Session]:
"""Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory
from :meth:`_make_session_factory`. The result is available as :attr:`session`.
Expand All @@ -363,7 +365,7 @@ def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_sessi

def _make_session_factory(
self, options: dict[str, t.Any]
) -> sa.orm.sessionmaker[Session]: # type: ignore[type-var]
) -> sa.orm.sessionmaker[Session]:
"""Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by
:meth:`_make_scoped_session`.
Expand Down Expand Up @@ -438,7 +440,7 @@ def _make_metadata(self, bind_key: str | None) -> sa.MetaData:
self.metadatas[bind_key] = metadata
return metadata

def _make_table_class(self) -> t.Type[sa.Table]:
def _make_table_class(self) -> type[_Table]:
"""Create a SQLAlchemy :class:`sqlalchemy.schema.Table` class that chooses a
metadata automatically based on the ``bind_key``. The result is available as
:attr:`Table`.
Expand All @@ -450,7 +452,7 @@ def _make_table_class(self) -> t.Type[sa.Table]:
.. versionadded:: 3.0
"""

class Table(sa.Table):
class Table(_Table):
def __new__(
cls, *args: t.Any, bind_key: str | None = None, **kwargs: t.Any
) -> Table:
Expand All @@ -475,13 +477,13 @@ def __new__(
bind_key = kwargs["info"].get("bind_key")

metadata = self._make_metadata(bind_key)
return super().__new__(cls, args[0], metadata, *args[1:], **kwargs)
return super().__new__(cls, *[args[0], metadata, *args[1:]], **kwargs)

return Table

def _make_declarative_base(
self, model: t.Type[Model] | sa.orm.DeclarativeMeta
) -> t.Type[t.Any]:
self, model: type[Model] | sa.orm.DeclarativeMeta
) -> type[t.Any]:
"""Create a SQLAlchemy declarative model class. The result is available as
:attr:`Model`.
Expand Down Expand Up @@ -728,7 +730,7 @@ def get_binds(self) -> dict[sa.Table, sa.engine.Engine]:
}

def get_or_404(
self, entity: t.Type[t.Any], ident: t.Any, *, description: str | None = None
self, entity: type[t.Any], ident: t.Any, *, description: str | None = None
) -> t.Any:
"""Like :meth:`session.get() <sqlalchemy.orm.Session.get>` but aborts with a
``404 Not Found`` error instead of returning ``None``.
Expand All @@ -747,7 +749,7 @@ def get_or_404(
return value

def first_or_404(
self, statement: sa.sql.Select, *, description: str | None = None
self, statement: sa.sql.Select[t.Any], *, description: str | None = None
) -> t.Any:
"""Like :meth:`Result.scalar() <sqlalchemy.engine.Result.scalar>`, but aborts
with a ``404 Not Found`` error instead of returning ``None``.
Expand All @@ -765,7 +767,7 @@ def first_or_404(
return value

def one_or_404(
self, statement: sa.sql.Select, *, description: str | None = None
self, statement: sa.sql.Select[t.Any], *, description: str | None = None
) -> t.Any:
"""Like :meth:`Result.scalar_one() <sqlalchemy.engine.Result.scalar_one>`,
but aborts with a ``404 Not Found`` error instead of raising ``NoResultFound``
Expand All @@ -783,7 +785,7 @@ def one_or_404(

def paginate(
self,
select: sa.sql.Select,
select: sa.sql.Select[t.Any],
*,
page: int | None = None,
per_page: int | None = None,
Expand Down Expand Up @@ -971,7 +973,8 @@ def _relation(
"""
# Deprecated, removed in SQLAlchemy 2.0. Accessed through ``__getattr__``.
self._set_rel_query(kwargs)
return sa.orm.relation(*args, **kwargs)
f = sa.orm.relation # type: ignore[attr-defined]
return f(*args, **kwargs) # type: ignore[no-any-return]

def __getattr__(self, name: str) -> t.Any:
if name == "db":
Expand Down
9 changes: 5 additions & 4 deletions src/flask_sqlalchemy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ class _QueryProperty:
"""

@t.overload
def __get__(self, obj: None, cls: t.Type[Model]) -> Query:
def __get__(self, obj: None, cls: type[Model]) -> Query:
...

@t.overload
def __get__(self, obj: Model, cls: t.Type[Model]) -> Query:
def __get__(self, obj: Model, cls: type[Model]) -> Query:
...

def __get__(self, obj: Model | None, cls: t.Type[Model]) -> Query:
def __get__(self, obj: Model | None, cls: type[Model]) -> Query:
return cls.query_class(
cls, session=cls.__fsa__.session() # type: ignore[arg-type]
)
Expand All @@ -47,7 +47,7 @@ class Model:
:meta private:
"""

query_class: t.ClassVar[t.Type[Query]] = Query
query_class: t.ClassVar[type[Query]] = Query
"""Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which
defaults to :class:`.Query`.
"""
Expand All @@ -63,6 +63,7 @@ class Model:

def __repr__(self) -> str:
state = sa.inspect(self)
assert state is not None

if state.transient:
pk = f"(transient {id(self)})"
Expand Down
2 changes: 1 addition & 1 deletion src/flask_sqlalchemy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None:
self._db = db
self._model_changes: dict[object, tuple[t.Any, str]] = {}

def get_bind( # type: ignore[override]
def get_bind(
self,
mapper: t.Any | None = None,
clause: t.Any | None = None,
Expand Down
39 changes: 39 additions & 0 deletions src/flask_sqlalchemy/table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import typing as t

import sqlalchemy as sa
import sqlalchemy.sql.schema as sa_sql_schema


class _Table(sa.Table):
@t.overload
def __init__(
self,
name: str,
*args: sa_sql_schema.SchemaItem,
bind_key: str | None = None,
**kwargs: t.Any,
) -> None:
...

@t.overload
def __init__(
self,
name: str,
metadata: sa.MetaData,
*args: sa_sql_schema.SchemaItem,
**kwargs: t.Any,
) -> None:
...

@t.overload
def __init__(
self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any
) -> None:
...

def __init__(
self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any
) -> None:
super().__init__(name, *args, **kwargs) # type: ignore[arg-type]
2 changes: 1 addition & 1 deletion src/flask_sqlalchemy/track_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"""


def _listen(session: sa.orm.scoped_session) -> None:
def _listen(session: sa.orm.scoped_session[Session]) -> None:
sa.event.listen(session, "before_flush", _record_ops, named=True)
sa.event.listen(session, "before_commit", _record_ops, named=True)
sa.event.listen(session, "before_commit", _before_commit)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_legacy_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class Parent(db.Model):

class Child(db.Model):
id = sa.Column(sa.Integer, primary_key=True)
parent_id = sa.Column(sa.ForeignKey(Parent.id))
parent_id = sa.Column(sa.ForeignKey(Parent.id)) # type: ignore[var-annotated]
parent2 = db.relationship(
Parent,
backref=db.backref("children2", lazy="dynamic", viewonly=True),
Expand All @@ -109,7 +109,7 @@ class Parent(db.Model):

class Child(db.Model):
id = sa.Column(sa.Integer, primary_key=True)
parent_id = sa.Column(sa.ForeignKey(Parent.id))
parent_id = sa.Column(sa.ForeignKey(Parent.id)) # type: ignore[var-annotated]
parent2 = db.relationship(
Parent,
backref=db.backref("children2", lazy="dynamic", viewonly=True),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_custom_metadata() -> None:

def test_metadata_from_custom_model() -> None:
base = sa.orm.declarative_base(cls=Model, metaclass=DefaultMeta)
metadata = base.metadata # type: ignore[attr-defined]
metadata = base.metadata
db = SQLAlchemy(model_class=base)
assert db.Model.metadata is metadata
assert db.Model.metadata is db.metadata
Expand Down
6 changes: 3 additions & 3 deletions tests/test_model_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_mixin_attr(db: SQLAlchemy) -> None:
"""

class Mixin:
@sa.orm.declared_attr
@sa.orm.declared_attr # type: ignore[arg-type]
def __tablename__(cls) -> str: # noqa: B902
return cls.__name__.upper() # type: ignore[attr-defined,no-any-return]

Expand Down Expand Up @@ -210,7 +210,7 @@ class class_property:
def __init__(self, f: t.Callable[..., t.Any]) -> None:
self.f = f

def __get__(self, instance: t.Any, owner: t.Type[t.Any]) -> t.Any:
def __get__(self, instance: t.Any, owner: type[t.Any]) -> t.Any:
return self.f(owner)

class Duck(db.Model):
Expand All @@ -221,7 +221,7 @@ class ns:
floats = False

class Witch(Duck):
@sa.orm.declared_attr
@sa.orm.declared_attr # type: ignore[arg-type]
def is_duck(self) -> None:
# declared attrs will be accessed during mapper configuration,
# but make sure they're not accessed before that
Expand Down
4 changes: 1 addition & 3 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ class User(db.Model):
__mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"}

class Admin(User):
id = sa.Column( # type: ignore[assignment]
sa.ForeignKey(User.id), primary_key=True
)
id = sa.Column(sa.ForeignKey(User.id), primary_key=True)
org = sa.Column(sa.String, nullable=False)

__mapper_args__ = {"polymorphic_identity": "admin"}
Expand Down

0 comments on commit c6e76b8

Please sign in to comment.