Skip to content

Commit

Permalink
Improve types of as_sql() and as_<engine>() methods
Browse files Browse the repository at this point in the history
  • Loading branch information
adamchainz authored and intgr committed Jan 9, 2023
1 parent a7a1518 commit f860731
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 19 deletions.
4 changes: 3 additions & 1 deletion django-stubs/contrib/gis/db/models/aggregates.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Aggregate
from django.db.models.sql.compiler import SQLCompiler, _AsSqlType

class GeoAggregate(Aggregate):
is_extent: bool
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Collect(GeoAggregate):
name: str
Expand Down
38 changes: 23 additions & 15 deletions django-stubs/contrib/gis/db/models/functions.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.mysql.compiler import SQLCompiler
from django.db.models import Func
from django.db.models import Transform as StandardTransform

Expand All @@ -17,17 +19,17 @@ class GeomOutputGeoFunc(GeoFunc):
def output_field(self) -> Any: ...

class SQLiteDecimalToFloatMixin:
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class OracleToleranceMixin:
tolerance: float
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Area(OracleToleranceMixin, GeoFunc):
arity: int
@property
def output_field(self) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Azimuth(GeoFunc):
output_field: Any
Expand All @@ -39,13 +41,13 @@ class AsGeoJSON(GeoFunc):
def __init__(
self, expression: Any, bbox: bool = ..., crs: bool = ..., precision: int = ..., **extra: Any
) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class AsGML(GeoFunc):
geom_param_pos: Any
output_field: Any
def __init__(self, expression: Any, version: int = ..., precision: int = ..., **extra: Any) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class AsKML(GeoFunc):
output_field: Any
Expand All @@ -65,7 +67,7 @@ class AsWKT(GeoFunc):

class BoundingCircle(OracleToleranceMixin, GeomOutputGeoFunc):
def __init__(self, expression: Any, num_seg: int = ..., **extra: Any) -> None: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
arity: int
Expand All @@ -83,8 +85,10 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
geom_param_pos: Any
spheroid: Any
def __init__(self, expr1: Any, expr2: Any, spheroid: Any | None = ..., **extra: Any) -> None: ...
def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Envelope(GeomOutputGeoFunc):
arity: int
Expand All @@ -95,7 +99,7 @@ class ForcePolygonCW(GeomOutputGeoFunc):
class GeoHash(GeoFunc):
output_field: Any
def __init__(self, expression: Any, precision: Any | None = ..., **extra: Any) -> None: ...
def as_mysql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class GeometryDistance(GeoFunc):
output_field: Any
Expand All @@ -111,13 +115,15 @@ class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
class IsValid(OracleToleranceMixin, GeoFuncMixin, StandardTransform):
lookup_name: str
output_field: Any
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
spheroid: Any
def __init__(self, expr1: Any, spheroid: bool = ..., **extra: Any) -> None: ...
def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class LineLocatePoint(GeoFunc):
output_field: Any
Expand All @@ -140,8 +146,10 @@ class NumPoints(GeoFunc):

class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
arity: int
def as_postgresql(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
arity: int
Expand All @@ -163,7 +171,7 @@ class Transform(GeomOutputGeoFunc):
def __init__(self, expression: Any, srid: Any, **extra: Any) -> None: ...

class Translate(Scale):
def as_sqlite(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def as_sqlite(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class Union(OracleToleranceMixin, GeomOutputGeoFunc):
arity: int
Expand Down
4 changes: 2 additions & 2 deletions django-stubs/db/models/functions/text.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ from django.db.models.sql.compiler import SQLCompiler, _AsSqlType

# Typo: `extra_conteNt`, remains in 4.0
class MySQLSHA2Mixin:
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any) -> _AsSqlType: ...
def as_mysql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

class OracleHashMixin:
def as_oracle(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any) -> _AsSqlType: ...

# Typo: `extra_conteNt`, remains in 4.0
class PostgreSQLSHAMixin:
def as_postgresql(
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_content: Any
self, compiler: SQLCompiler, connection: BaseDatabaseWrapper, **extra_context: Any
) -> _AsSqlType: ...

class Chr(Transform):
Expand Down
2 changes: 1 addition & 1 deletion django-stubs/db/models/sql/where.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class WhereNode(tree.Node):
resolved: bool
conditional: bool
def split_having(self, negated: bool = ...) -> tuple[WhereNode | None, WhereNode | None]: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> Any: ...
def as_sql(self, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...
def get_group_by_cols(self, alias: str | None = ...) -> list[Expression]: ...
def relabel_aliases(self, change_map: dict[str | None, str]) -> None: ...
def clone(self) -> WhereNode: ...
Expand Down

0 comments on commit f860731

Please sign in to comment.