From 625a94a8eddf1505ed4d0aaa0d9baa4a60080664 Mon Sep 17 00:00:00 2001 From: Marti Raudsepp Date: Mon, 9 Jan 2023 13:26:52 +0200 Subject: [PATCH] Improve types of as_sql() and as_() methods --- .../contrib/gis/db/models/aggregates.pyi | 4 +- .../contrib/gis/db/models/functions.pyi | 38 +++++++++++-------- django-stubs/db/models/functions/text.pyi | 4 +- django-stubs/db/models/sql/where.pyi | 2 +- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/django-stubs/contrib/gis/db/models/aggregates.pyi b/django-stubs/contrib/gis/db/models/aggregates.pyi index 0d1ca38b2f..88497cfc68 100644 --- a/django-stubs/contrib/gis/db/models/aggregates.pyi +++ b/django-stubs/contrib/gis/db/models/aggregates.pyi @@ -1,10 +1,12 @@ from typing import Any +from django.db.backends.base.base import BaseDatabaseWrapper from django.db.models import Aggregate +from django.db.models.sql.compiler import SQLCompiler, _AsSqlType class GeoAggregate(Aggregate): is_extent: bool - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Collect(GeoAggregate): name: str diff --git a/django-stubs/contrib/gis/db/models/functions.pyi b/django-stubs/contrib/gis/db/models/functions.pyi index 89efe2d607..93eb34c027 100644 --- a/django-stubs/contrib/gis/db/models/functions.pyi +++ b/django-stubs/contrib/gis/db/models/functions.pyi @@ -1,5 +1,7 @@ from typing import Any +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.backends.mysql.compiler import SQLCompiler from django.db.models import Func from django.db.models import Transform as StandardTransform @@ -17,17 +19,17 @@ class GeomOutputGeoFunc(GeoFunc): def output_field(self) -> Any: ... class SQLiteDecimalToFloatMixin: - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class OracleToleranceMixin: tolerance: float - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Area(OracleToleranceMixin, GeoFunc): arity: int @property def output_field(self) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Azimuth(GeoFunc): output_field: Any @@ -39,13 +41,13 @@ class AsGeoJSON(GeoFunc): def __init__( self, expression: Any, bbox: bool = ..., crs: bool = ..., precision: int = ..., **extra: Any ) -> None: ... - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class AsGML(GeoFunc): geom_param_pos: Any output_field: Any def __init__(self, expression: Any, version: int = ..., precision: int = ..., **extra: Any) -> None: ... - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class AsKML(GeoFunc): output_field: Any @@ -65,7 +67,7 @@ class AsWKT(GeoFunc): class BoundingCircle(OracleToleranceMixin, GeomOutputGeoFunc): def __init__(self, expression: Any, num_seg: int = ..., **extra: Any) -> None: ... - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Centroid(OracleToleranceMixin, GeomOutputGeoFunc): arity: int @@ -83,8 +85,10 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc): geom_param_pos: Any spheroid: Any def __init__(self, expr1: Any, expr2: Any, spheroid: Any | None = ..., **extra: Any) -> None: ... - def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_postgresql( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any + ) -> _AsSqlType: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Envelope(GeomOutputGeoFunc): arity: int @@ -95,7 +99,7 @@ class ForcePolygonCW(GeomOutputGeoFunc): class GeoHash(GeoFunc): output_field: Any def __init__(self, expression: Any, precision: Any | None = ..., **extra: Any) -> None: ... - def as_mysql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class GeometryDistance(GeoFunc): output_field: Any @@ -111,13 +115,15 @@ class Intersection(OracleToleranceMixin, GeomOutputGeoFunc): class IsValid(OracleToleranceMixin, GeoFuncMixin, StandardTransform): lookup_name: str output_field: Any - def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc): spheroid: Any def __init__(self, expr1: Any, spheroid: bool = ..., **extra: Any) -> None: ... - def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_postgresql( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any + ) -> _AsSqlType: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class LineLocatePoint(GeoFunc): output_field: Any @@ -140,8 +146,10 @@ class NumPoints(GeoFunc): class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc): arity: int - def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_postgresql( + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any + ) -> _AsSqlType: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc): arity: int @@ -163,7 +171,7 @@ class Transform(GeomOutputGeoFunc): def __init__(self, expression: Any, srid: Any, **extra: Any) -> None: ... class Translate(Scale): - def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ... + def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class Union(OracleToleranceMixin, GeomOutputGeoFunc): arity: int diff --git a/django-stubs/db/models/functions/text.pyi b/django-stubs/db/models/functions/text.pyi index dedbf69713..3c2a1ff514 100644 --- a/django-stubs/db/models/functions/text.pyi +++ b/django-stubs/db/models/functions/text.pyi @@ -8,7 +8,7 @@ from django.db.models.sql.compiler import SQLCompiler, _AsSqlType # Typo: `extra_conteNt`, remains in 4.0 class MySQLSHA2Mixin: - def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any) -> _AsSqlType: ... + def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... class OracleHashMixin: def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ... @@ -16,7 +16,7 @@ class OracleHashMixin: # Typo: `extra_conteNt`, remains in 4.0 class PostgreSQLSHAMixin: def as_postgresql( - self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any + self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any ) -> _AsSqlType: ... class Chr(Transform): diff --git a/django-stubs/db/models/sql/where.pyi b/django-stubs/db/models/sql/where.pyi index aa959b5183..cc4e07b97f 100644 --- a/django-stubs/db/models/sql/where.pyi +++ b/django-stubs/db/models/sql/where.pyi @@ -17,7 +17,7 @@ class WhereNode(tree.Node): resolved: bool conditional: bool def split_having(self, negated: bool = ...) -> tuple[WhereNode | None, WhereNode | None]: ... - def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ... + def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ... def get_group_by_cols(self, alias: str | None = ...) -> list[Expression]: ... def relabel_aliases(self, change_map: dict[str | None, str]) -> None: ... def clone(self) -> WhereNode: ...