Skip to content

Commit

Permalink
Improve types of as_sql() and as_<engine>() methods (#1315)
Browse files Browse the repository at this point in the history
Co-authored-by: Nick Pope <[email protected]>
  • Loading branch information
intgr and ngnpope authored Jan 25, 2023
1 parent e6d65d2 commit 32e5ebc
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 23 deletions.
11 changes: 10 additions & 1 deletion django-stubs/contrib/gis/db/backends/utils.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Any

from django.contrib.gis.db.models.lookups import GISLookup
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.sql.compiler import _AsSqlType

class SpatialOperator:
Expand All @@ -9,4 +12,10 @@ class SpatialOperator:
def __init__(self, op: Any | None = ..., func: Any | None = ...) -> None: ...
@property
def default_template(self) -> Any: ...
def as_sql(self, connection: Any, lookup: Any, template_params: Any, sql_params: Any) -> _AsSqlType: ...
def as_sql(
self,
connection: BaseDatabaseWrapper,
lookup: GISLookup,
template_params: Mapping[str, Any],
sql_params: Sequence[Any],
) -> _AsSqlType: ...
4 changes: 3 additions & 1 deletion django-stubs/contrib/gis/db/models/aggregates.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Aggregate
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType

class GeoAggregate(Aggregate):
is_extent: bool
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Collect(GeoAggregate):
name: str
Expand Down
38 changes: 23 additions & 15 deletions django-stubs/contrib/gis/db/models/functions.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Func
from django.db.models import Transform as StandardTransform
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType

NUMERIC_TYPES: Any

Expand All @@ -17,17 +19,17 @@ class GeomOutputGeoFunc(GeoFunc):
def output_field(self) -> Any: ...

class SQLiteDecimalToFloatMixin:
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class OracleToleranceMixin:
tolerance: float
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Area(OracleToleranceMixin, GeoFunc):
arity: int
@property
def output_field(self) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Azimuth(GeoFunc):
output_field: Any
Expand All @@ -39,13 +41,13 @@ class AsGeoJSON(GeoFunc):
def __init__(
self, expression: Any, bbox: bool = ..., crs: bool = ..., precision: int = ..., **extra: Any
) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class AsGML(GeoFunc):
geom_param_pos: Any
output_field: Any
def __init__(self, expression: Any, version: int = ..., precision: int = ..., **extra: Any) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class AsKML(GeoFunc):
output_field: Any
Expand All @@ -65,7 +67,7 @@ class AsWKT(GeoFunc):

class BoundingCircle(OracleToleranceMixin, GeomOutputGeoFunc):
def __init__(self, expression: Any, num_seg: int = ..., **extra: Any) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
arity: int
Expand All @@ -83,8 +85,10 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
geom_param_pos: Any
spheroid: Any
def __init__(self, expr1: Any, expr2: Any, spheroid: Any | None = ..., **extra: Any) -> None: ...
def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Envelope(GeomOutputGeoFunc):
arity: int
Expand All @@ -95,7 +99,7 @@ class ForcePolygonCW(GeomOutputGeoFunc):
class GeoHash(GeoFunc):
output_field: Any
def __init__(self, expression: Any, precision: Any | None = ..., **extra: Any) -> None: ...
def as_mysql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class GeometryDistance(GeoFunc):
output_field: Any
Expand All @@ -111,13 +115,15 @@ class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
class IsValid(OracleToleranceMixin, GeoFuncMixin, StandardTransform):
lookup_name: str
output_field: Any
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
spheroid: Any
def __init__(self, expr1: Any, spheroid: bool = ..., **extra: Any) -> None: ...
def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class LineLocatePoint(GeoFunc):
output_field: Any
Expand All @@ -140,8 +146,10 @@ class NumPoints(GeoFunc):

class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
arity: int
def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
arity: int
Expand All @@ -163,7 +171,7 @@ class Transform(GeomOutputGeoFunc):
def __init__(self, expression: Any, srid: Any, **extra: Any) -> None: ...

class Translate(Scale):
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Union(OracleToleranceMixin, GeomOutputGeoFunc):
arity: int
Expand Down
8 changes: 3 additions & 5 deletions django-stubs/db/models/functions/text.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@ from typing import Any
from django.db import models
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Func, Transform
from django.db.models.expressions import Combinable, Expression, F, Value
from django.db.models.expressions import Combinable, Expression, Value
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType

# Typo: `extra_conteNt`, remains in 4.0
class MySQLSHA2Mixin:
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any) -> _AsSqlType: ...
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class OracleHashMixin:
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

# Typo: `extra_conteNt`, remains in 4.0
class PostgreSQLSHAMixin:
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...

class Chr(Transform):
Expand Down
2 changes: 1 addition & 1 deletion django-stubs/db/models/sql/where.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class WhereNode(tree.Node):
resolved: bool
conditional: bool
def split_having(self, negated: bool = ...) -> tuple[WhereNode | None, WhereNode | None]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def get_group_by_cols(self, alias: str | None = ...) -> list[Expression]: ...
def relabel_aliases(self, change_map: dict[str | None, str]) -> None: ...
def clone(self) -> WhereNode: ...
Expand Down

0 comments on commit 32e5ebc

Please sign in to comment.