From c6e76b8b92556525243ca58c5027c707c1f6489a Mon Sep 17 00:00:00 2001 From: David Lord Date: Sun, 29 Jan 2023 11:21:58 -0800 Subject: [PATCH] update typing for sqlalchemy 2 --- pdm.lock | 13 +------ pyproject.toml | 2 +- src/flask_sqlalchemy/extension.py | 33 +++++++++-------- src/flask_sqlalchemy/model.py | 9 ++--- src/flask_sqlalchemy/session.py | 2 +- src/flask_sqlalchemy/table.py | 39 +++++++++++++++++++++ src/flask_sqlalchemy/track_modifications.py | 2 +- tests/test_legacy_query.py | 4 +-- tests/test_metadata.py | 2 +- tests/test_model_name.py | 6 ++-- tests/test_session.py | 4 +-- 11 files changed, 73 insertions(+), 43 deletions(-) create mode 100644 src/flask_sqlalchemy/table.py diff --git a/pdm.lock b/pdm.lock index eab670a2..59f0a596 100644 --- a/pdm.lock +++ b/pdm.lock @@ -427,17 +427,6 @@ dependencies = [ "typing-extensions>=4.2.0", ] -[[package]] -name = "sqlalchemy" -version = "2.0.0" -extras = ["mypy"] -requires_python = ">=3.7" -summary = "Database Abstraction Library" -dependencies = [ - "mypy>=0.910", - "sqlalchemy==2.0.0", -] - [[package]] name = "tomli" version = "2.0.1" @@ -521,7 +510,7 @@ summary = "Backport of pathlib-compatible object wrapper for zip files" [metadata] lock_version = "4.1" -content_hash = "sha256:44bad6aae01e17b37f8234b6f2fd4d61e8e43b786473dc29c9d610f3b34e55d6" +content_hash = "sha256:2decebce9f60e5e1669dbe8b3b6ff96ad20f63b899810677caf6306c0a868e3e" [metadata.files] "alabaster 0.7.13" = [ diff --git a/pyproject.toml b/pyproject.toml index 719f693e..7476ae9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ coverage = [ mypy = [ "mypy", "pytest", - "sqlalchemy[mypy]", + "sqlalchemy", ] docs = [ "sphinx", diff --git a/src/flask_sqlalchemy/extension.py b/src/flask_sqlalchemy/extension.py index 930173dd..a64ced0a 100644 --- a/src/flask_sqlalchemy/extension.py +++ b/src/flask_sqlalchemy/extension.py @@ -8,7 +8,6 @@ import sqlalchemy.event import sqlalchemy.exc import sqlalchemy.orm -import sqlalchemy.pool from flask import abort from flask import current_app from flask import Flask @@ -22,6 +21,7 @@ from .query import Query from .session import _app_ctx_id from .session import Session +from .table import _Table class SQLAlchemy: @@ -126,8 +126,8 @@ def __init__( *, metadata: sa.MetaData | None = None, session_options: dict[str, t.Any] | None = None, - query_class: t.Type[Query] = Query, - model_class: t.Type[Model] | sa.orm.DeclarativeMeta = Model, + query_class: type[Query] = Query, + model_class: type[Model] | sa.orm.DeclarativeMeta = Model, engine_options: dict[str, t.Any] | None = None, add_models_to_shell: bool = True, ): @@ -336,7 +336,9 @@ def init_app(self, app: Flask) -> None: track_modifications._listen(self.session) - def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_session: + def _make_scoped_session( + self, options: dict[str, t.Any] + ) -> sa.orm.scoped_session[Session]: """Create a :class:`sqlalchemy.orm.scoping.scoped_session` around the factory from :meth:`_make_session_factory`. The result is available as :attr:`session`. @@ -363,7 +365,7 @@ def _make_scoped_session(self, options: dict[str, t.Any]) -> sa.orm.scoped_sessi def _make_session_factory( self, options: dict[str, t.Any] - ) -> sa.orm.sessionmaker[Session]: # type: ignore[type-var] + ) -> sa.orm.sessionmaker[Session]: """Create the SQLAlchemy :class:`sqlalchemy.orm.sessionmaker` used by :meth:`_make_scoped_session`. @@ -438,7 +440,7 @@ def _make_metadata(self, bind_key: str | None) -> sa.MetaData: self.metadatas[bind_key] = metadata return metadata - def _make_table_class(self) -> t.Type[sa.Table]: + def _make_table_class(self) -> type[_Table]: """Create a SQLAlchemy :class:`sqlalchemy.schema.Table` class that chooses a metadata automatically based on the ``bind_key``. The result is available as :attr:`Table`. @@ -450,7 +452,7 @@ def _make_table_class(self) -> t.Type[sa.Table]: .. versionadded:: 3.0 """ - class Table(sa.Table): + class Table(_Table): def __new__( cls, *args: t.Any, bind_key: str | None = None, **kwargs: t.Any ) -> Table: @@ -475,13 +477,13 @@ def __new__( bind_key = kwargs["info"].get("bind_key") metadata = self._make_metadata(bind_key) - return super().__new__(cls, args[0], metadata, *args[1:], **kwargs) + return super().__new__(cls, *[args[0], metadata, *args[1:]], **kwargs) return Table def _make_declarative_base( - self, model: t.Type[Model] | sa.orm.DeclarativeMeta - ) -> t.Type[t.Any]: + self, model: type[Model] | sa.orm.DeclarativeMeta + ) -> type[t.Any]: """Create a SQLAlchemy declarative model class. The result is available as :attr:`Model`. @@ -728,7 +730,7 @@ def get_binds(self) -> dict[sa.Table, sa.engine.Engine]: } def get_or_404( - self, entity: t.Type[t.Any], ident: t.Any, *, description: str | None = None + self, entity: type[t.Any], ident: t.Any, *, description: str | None = None ) -> t.Any: """Like :meth:`session.get() ` but aborts with a ``404 Not Found`` error instead of returning ``None``. @@ -747,7 +749,7 @@ def get_or_404( return value def first_or_404( - self, statement: sa.sql.Select, *, description: str | None = None + self, statement: sa.sql.Select[t.Any], *, description: str | None = None ) -> t.Any: """Like :meth:`Result.scalar() `, but aborts with a ``404 Not Found`` error instead of returning ``None``. @@ -765,7 +767,7 @@ def first_or_404( return value def one_or_404( - self, statement: sa.sql.Select, *, description: str | None = None + self, statement: sa.sql.Select[t.Any], *, description: str | None = None ) -> t.Any: """Like :meth:`Result.scalar_one() `, but aborts with a ``404 Not Found`` error instead of raising ``NoResultFound`` @@ -783,7 +785,7 @@ def one_or_404( def paginate( self, - select: sa.sql.Select, + select: sa.sql.Select[t.Any], *, page: int | None = None, per_page: int | None = None, @@ -971,7 +973,8 @@ def _relation( """ # Deprecated, removed in SQLAlchemy 2.0. Accessed through ``__getattr__``. self._set_rel_query(kwargs) - return sa.orm.relation(*args, **kwargs) + f = sa.orm.relation # type: ignore[attr-defined] + return f(*args, **kwargs) # type: ignore[no-any-return] def __getattr__(self, name: str) -> t.Any: if name == "db": diff --git a/src/flask_sqlalchemy/model.py b/src/flask_sqlalchemy/model.py index 058338c9..e49bda05 100644 --- a/src/flask_sqlalchemy/model.py +++ b/src/flask_sqlalchemy/model.py @@ -19,14 +19,14 @@ class _QueryProperty: """ @t.overload - def __get__(self, obj: None, cls: t.Type[Model]) -> Query: + def __get__(self, obj: None, cls: type[Model]) -> Query: ... @t.overload - def __get__(self, obj: Model, cls: t.Type[Model]) -> Query: + def __get__(self, obj: Model, cls: type[Model]) -> Query: ... - def __get__(self, obj: Model | None, cls: t.Type[Model]) -> Query: + def __get__(self, obj: Model | None, cls: type[Model]) -> Query: return cls.query_class( cls, session=cls.__fsa__.session() # type: ignore[arg-type] ) @@ -47,7 +47,7 @@ class Model: :meta private: """ - query_class: t.ClassVar[t.Type[Query]] = Query + query_class: t.ClassVar[type[Query]] = Query """Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which defaults to :class:`.Query`. """ @@ -63,6 +63,7 @@ class Model: def __repr__(self) -> str: state = sa.inspect(self) + assert state is not None if state.transient: pk = f"(transient {id(self)})" diff --git a/src/flask_sqlalchemy/session.py b/src/flask_sqlalchemy/session.py index f8035e5e..e50c2926 100644 --- a/src/flask_sqlalchemy/session.py +++ b/src/flask_sqlalchemy/session.py @@ -28,7 +28,7 @@ def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None: self._db = db self._model_changes: dict[object, tuple[t.Any, str]] = {} - def get_bind( # type: ignore[override] + def get_bind( self, mapper: t.Any | None = None, clause: t.Any | None = None, diff --git a/src/flask_sqlalchemy/table.py b/src/flask_sqlalchemy/table.py new file mode 100644 index 00000000..ab08a692 --- /dev/null +++ b/src/flask_sqlalchemy/table.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import typing as t + +import sqlalchemy as sa +import sqlalchemy.sql.schema as sa_sql_schema + + +class _Table(sa.Table): + @t.overload + def __init__( + self, + name: str, + *args: sa_sql_schema.SchemaItem, + bind_key: str | None = None, + **kwargs: t.Any, + ) -> None: + ... + + @t.overload + def __init__( + self, + name: str, + metadata: sa.MetaData, + *args: sa_sql_schema.SchemaItem, + **kwargs: t.Any, + ) -> None: + ... + + @t.overload + def __init__( + self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any + ) -> None: + ... + + def __init__( + self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any + ) -> None: + super().__init__(name, *args, **kwargs) # type: ignore[arg-type] diff --git a/src/flask_sqlalchemy/track_modifications.py b/src/flask_sqlalchemy/track_modifications.py index fac5e411..c40c1aec 100644 --- a/src/flask_sqlalchemy/track_modifications.py +++ b/src/flask_sqlalchemy/track_modifications.py @@ -29,7 +29,7 @@ """ -def _listen(session: sa.orm.scoped_session) -> None: +def _listen(session: sa.orm.scoped_session[Session]) -> None: sa.event.listen(session, "before_flush", _record_ops, named=True) sa.event.listen(session, "before_commit", _record_ops, named=True) sa.event.listen(session, "before_commit", _before_commit) diff --git a/tests/test_legacy_query.py b/tests/test_legacy_query.py index fd71e3fd..9e7ea26e 100644 --- a/tests/test_legacy_query.py +++ b/tests/test_legacy_query.py @@ -82,7 +82,7 @@ class Parent(db.Model): class Child(db.Model): id = sa.Column(sa.Integer, primary_key=True) - parent_id = sa.Column(sa.ForeignKey(Parent.id)) + parent_id = sa.Column(sa.ForeignKey(Parent.id)) # type: ignore[var-annotated] parent2 = db.relationship( Parent, backref=db.backref("children2", lazy="dynamic", viewonly=True), @@ -109,7 +109,7 @@ class Parent(db.Model): class Child(db.Model): id = sa.Column(sa.Integer, primary_key=True) - parent_id = sa.Column(sa.ForeignKey(Parent.id)) + parent_id = sa.Column(sa.ForeignKey(Parent.id)) # type: ignore[var-annotated] parent2 = db.relationship( Parent, backref=db.backref("children2", lazy="dynamic", viewonly=True), diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 3cff84fa..c426918e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -27,7 +27,7 @@ def test_custom_metadata() -> None: def test_metadata_from_custom_model() -> None: base = sa.orm.declarative_base(cls=Model, metaclass=DefaultMeta) - metadata = base.metadata # type: ignore[attr-defined] + metadata = base.metadata db = SQLAlchemy(model_class=base) assert db.Model.metadata is metadata assert db.Model.metadata is db.metadata diff --git a/tests/test_model_name.py b/tests/test_model_name.py index 2c37572d..35ac7c3c 100644 --- a/tests/test_model_name.py +++ b/tests/test_model_name.py @@ -110,7 +110,7 @@ def test_mixin_attr(db: SQLAlchemy) -> None: """ class Mixin: - @sa.orm.declared_attr + @sa.orm.declared_attr # type: ignore[arg-type] def __tablename__(cls) -> str: # noqa: B902 return cls.__name__.upper() # type: ignore[attr-defined,no-any-return] @@ -210,7 +210,7 @@ class class_property: def __init__(self, f: t.Callable[..., t.Any]) -> None: self.f = f - def __get__(self, instance: t.Any, owner: t.Type[t.Any]) -> t.Any: + def __get__(self, instance: t.Any, owner: type[t.Any]) -> t.Any: return self.f(owner) class Duck(db.Model): @@ -221,7 +221,7 @@ class ns: floats = False class Witch(Duck): - @sa.orm.declared_attr + @sa.orm.declared_attr # type: ignore[arg-type] def is_duck(self) -> None: # declared attrs will be accessed during mapper configuration, # but make sure they're not accessed before that diff --git a/tests/test_session.py b/tests/test_session.py index f9378f15..40dd0233 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -79,9 +79,7 @@ class User(db.Model): __mapper_args__ = {"polymorphic_on": type, "polymorphic_identity": "user"} class Admin(User): - id = sa.Column( # type: ignore[assignment] - sa.ForeignKey(User.id), primary_key=True - ) + id = sa.Column(sa.ForeignKey(User.id), primary_key=True) org = sa.Column(sa.String, nullable=False) __mapper_args__ = {"polymorphic_identity": "admin"}