From dad022c43b273b62fa6820b7f896ef265e3a6498 Mon Sep 17 00:00:00 2001 From: Marti Raudsepp Date: Mon, 9 Jan 2023 13:26:52 +0200 Subject: [PATCH 1/4] Improve types of as_sql() and as_() methods --- .../contrib/gis/db/models/aggregates.pyi | 4 +- .../contrib/gis/db/models/functions.pyi | 38 +++++++++++-------- django-stubs/db/models/functions/text.pyi | 4 +- django-stubs/db/models/sql/where.pyi | 2 +- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/django-stubs/contrib/gis/db/models/aggregates.pyi b/django-stubs/contrib/gis/db/models/aggregates.pyi index 0d1ca38b2..88497cfc6 100644 --- a/django-stubs/contrib/gis/db/models/aggregates.pyi +++ b/django-stubs/contrib/gis/db/models/aggregates.pyi @@ -1,10 +1,12 @@ from typing import Any +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Aggregate +from django.db.models.sql.compiler import SQLCompiler, _AsSqlType class GeoAggregate(Aggregate): is_extent: bool - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Collect(GeoAggregate): name: str diff --git a/django-stubs/contrib/gis/db/models/functions.pyi b/django-stubs/contrib/gis/db/models/functions.pyi index 89efe2d60..93eb34c02 100644 --- a/django-stubs/contrib/gis/db/models/functions.pyi +++ b/django-stubs/contrib/gis/db/models/functions.pyi @@ -1,5 +1,7 @@ from typing import Any +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.backends.mysql.compiler import SQLCompiler from django.db.models import Func from django.db.models import Transform as StandardTransform @@ -17,17 +19,17 @@ class GeomOutputGeoFunc(GeoFunc): def output_field(self) -> Any: ... class SQLiteDecimalToFloatMixin: - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class OracleToleranceMixin: tolerance: float - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Area(OracleToleranceMixin, GeoFunc): arity: int @property def output_field(self) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Azimuth(GeoFunc): output_field: Any @@ -39,13 +41,13 @@ class AsGeoJSON(GeoFunc): def __init__( self, expression: Any, bbox: bool = ..., crs: bool = ..., precision: int = ..., **extra: Any ) -> None: ... - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class AsGML(GeoFunc): geom_param_pos: Any output_field: Any def __init__(self, expression: Any, version: int = ..., precision: int = ..., **extra: Any) -> None: ... - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class AsKML(GeoFunc): output_field: Any @@ -65,7 +67,7 @@ class AsWKT(GeoFunc): class BoundingCircle(OracleToleranceMixin, GeomOutputGeoFunc): def __init__(self, expression: Any, num_seg: int = ..., **extra: Any) -> None: ... - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): arity: int @@ -83,8 +85,10 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): geom_param_pos: Any spheroid: Any def __init__(self, expr1: Any, expr2: Any, spheroid: Any | None = ..., **extra: Any) -> None: ... - def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_postgresql( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any + ) -> _AsSqlType: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Envelope(GeomOutputGeoFunc): arity: int @@ -95,7 +99,7 @@ class ForcePolygonCW(GeomOutputGeoFunc): class GeoHash(GeoFunc): output_field: Any def __init__(self, expression: Any, precision: Any | None = ..., **extra: Any) -> None: ... - def as_mysql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class GeometryDistance(GeoFunc): output_field: Any @@ -111,13 +115,15 @@ class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): class IsValid(OracleToleranceMixin, GeoFuncMixin, StandardTransform): lookup_name: str output_field: Any - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): spheroid: Any def __init__(self, expr1: Any, spheroid: bool = ..., **extra: Any) -> None: ... - def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_postgresql( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any + ) -> _AsSqlType: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class LineLocatePoint(GeoFunc): output_field: Any @@ -140,8 +146,10 @@ class NumPoints(GeoFunc): class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): arity: int - def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_postgresql( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any + ) -> _AsSqlType: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): arity: int @@ -163,7 +171,7 @@ class Transform(GeomOutputGeoFunc): def __init__(self, expression: Any, srid: Any, **extra: Any) -> None: ... class Translate(Scale): - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Union(OracleToleranceMixin, GeomOutputGeoFunc): arity: int diff --git a/django-stubs/db/models/functions/text.pyi b/django-stubs/db/models/functions/text.pyi index dedbf6971..3c2a1ff51 100644 --- a/django-stubs/db/models/functions/text.pyi +++ b/django-stubs/db/models/functions/text.pyi @@ -8,7 +8,7 @@ from django.db.models.sql.compiler import SQLCompiler, _AsSqlType # Typo: `extra_conteNt`, remains in 4.0 class MySQLSHA2Mixin: - def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any) -> _AsSqlType: ... + def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class OracleHashMixin: def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... @@ -16,7 +16,7 @@ class OracleHashMixin: # Typo: `extra_conteNt`, remains in 4.0 class PostgreSQLSHAMixin: def as_postgresql( - self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any ) -> _AsSqlType: ... class Chr(Transform): diff --git a/django-stubs/db/models/sql/where.pyi b/django-stubs/db/models/sql/where.pyi index aa959b518..cc4e07b97 100644 --- a/django-stubs/db/models/sql/where.pyi +++ b/django-stubs/db/models/sql/where.pyi @@ -17,7 +17,7 @@ class WhereNode(tree.Node): resolved: bool conditional: bool def split_having(self, negated: bool = ...) -> tuple[WhereNode | None, WhereNode | None]: ... - def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ... def get_group_by_cols(self, alias: str | None = ...) -> list[Expression]: ... def relabel_aliases(self, change_map: dict[str | None, str]) -> None: ... def clone(self) -> WhereNode: ... From adb5fcfbbb4519637a092558757e0ac2adda9df6 Mon Sep 17 00:00:00 2001 From: Marti Raudsepp Date: Mon, 9 Jan 2023 13:27:53 +0200 Subject: [PATCH 2/4] Remove obsolete comments --- django-stubs/db/models/functions/text.pyi | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/django-stubs/db/models/functions/text.pyi b/django-stubs/db/models/functions/text.pyi index 3c2a1ff51..e71cbf571 100644 --- a/django-stubs/db/models/functions/text.pyi +++ b/django-stubs/db/models/functions/text.pyi @@ -3,17 +3,15 @@ from typing import Any from django.db import models from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Func, Transform -from django.db.models.expressions import Combinable, Expression, F, Value +from django.db.models.expressions import Combinable, Expression, Value from django.db.models.sql.compiler import SQLCompiler, _AsSqlType -# Typo: `extra_conteNt`, remains in 4.0 class MySQLSHA2Mixin: def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class OracleHashMixin: def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... -# Typo: `extra_conteNt`, remains in 4.0 class PostgreSQLSHAMixin: def as_postgresql( self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any From 1b320cc7efdf4ae133acd904ba0c06880a6b540c Mon Sep 17 00:00:00 2001 From: Marti Raudsepp Date: Wed, 25 Jan 2023 15:20:14 +0200 Subject: [PATCH 3/4] Typehint SpatialOperator.as_sql as well --- django-stubs/contrib/gis/db/backends/utils.pyi | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/django-stubs/contrib/gis/db/backends/utils.pyi b/django-stubs/contrib/gis/db/backends/utils.pyi index 23a977074..54e15f709 100644 --- a/django-stubs/contrib/gis/db/backends/utils.pyi +++ b/django-stubs/contrib/gis/db/backends/utils.pyi @@ -1,5 +1,8 @@ +from collections.abc import Mapping, Sequence from typing import Any +from django.contrib.gis.db.models.lookups import GISLookup +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models.sql.compiler import _AsSqlType class SpatialOperator: @@ -9,4 +12,10 @@ class SpatialOperator: def __init__(self, op: Any | None = ..., func: Any | None = ...) -> None: ... @property def default_template(self) -> Any: ... - def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any) -> _AsSqlType: ... + def as_sql( + self, + connection: BaseDatabaseWrapper, + lookup: GISLookup, + template_params: Mapping[str, Any], + sql_params: Sequence[Any], + ) -> _AsSqlType: ... From 1f04ad2c4d6cc5e737ac6f5e6c3d3b1c914b2596 Mon Sep 17 00:00:00 2001 From: Marti Raudsepp Date: Wed, 25 Jan 2023 15:25:40 +0200 Subject: [PATCH 4/4] Fix mypy-self-check, facepalm! --- django-stubs/contrib/gis/db/models/functions.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django-stubs/contrib/gis/db/models/functions.pyi b/django-stubs/contrib/gis/db/models/functions.pyi index 93eb34c02..f2a7fc599 100644 --- a/django-stubs/contrib/gis/db/models/functions.pyi +++ b/django-stubs/contrib/gis/db/models/functions.pyi @@ -1,9 +1,9 @@ from typing import Any from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.mysql.compiler import SQLCompiler from django.db.models import Func from django.db.models import Transform as StandardTransform +from django.db.models.sql.compiler import SQLCompiler, _AsSqlType NUMERIC_TYPES: Any