Skip to content

Commit

Permalink
Improve some DB expression types (#1243)
Browse files Browse the repository at this point in the history
* Remove several things from `GeoAggregate` and `GeoFuncMixin` where the inherited types are correct
* Add missing kwargs `**extra_context` to `Func.as_sql`
* Remove `Agggregate.as_sql` since it’s a pass-through to `Func.as_sql`
* Reorganize `db/models/expressions.pyi` to follow the same order as upstream
* Add missing classes `DurationExpression`, `TemporalSubtraction`, `Star`, and `OrderByList`

Spotted the problem with `Aggregate` when adding Mypy to Django-MySQL, which has a custom aggregate class.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
adamchainz and pre-commit-ci[bot] authored Jan 9, 2023
1 parent a868e74 commit a7a1518
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 65 deletions.
15 changes: 0 additions & 15 deletions django-stubs/contrib/gis/db/models/aggregates.pyi
Original file line number Diff line number Diff line change
@@ -1,25 +1,10 @@
from typing import Any

from django.db.models import Aggregate
from django.db.models.sql.compiler import _AsSqlType

class GeoAggregate(Aggregate):
function: Any
is_extent: bool
@property
def output_field(self) -> Any: ...
def as_sql(
self, compiler: Any, connection: Any, function: Any | None = ..., **extra_context: Any
) -> _AsSqlType: ...
def as_oracle(self, compiler: Any, connection: Any, **extra_context: Any) -> Any: ...
def resolve_expression(
self,
query: Any | None = ...,
allow_joins: bool = ...,
reuse: Any | None = ...,
summarize: bool = ...,
for_save: bool = ...,
) -> Any: ...

class Collect(GeoAggregate):
name: str
Expand Down
7 changes: 0 additions & 7 deletions django-stubs/contrib/gis/db/models/functions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,13 @@ from typing import Any

from django.db.models import Func
from django.db.models import Transform as StandardTransform
from django.db.models.sql.compiler import _AsSqlType

NUMERIC_TYPES: Any

class GeoFuncMixin:
function: Any
geom_param_pos: Any
def __init__(self, *expressions: Any, **extra: Any) -> None: ...
@property
def geo_field(self) -> Any: ...
def as_sql(
self, compiler: Any, connection: Any, function: Any | None = ..., **extra_context: Any
) -> _AsSqlType: ...
def resolve_expression(self, *args: Any, **kwargs: Any) -> Any: ...

class GeoFunc(GeoFuncMixin, Func): ...

Expand Down
105 changes: 62 additions & 43 deletions django-stubs/db/models/expressions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
rhs: Combinable
def __init__(self, lhs: Combinable, connector: str, rhs: Combinable, output_field: Field | None = ...) -> None: ...

class DurationExpression(CombinedExpression):
def compile(self, side: Combinable, compiler: SQLCompiler, connection: BaseDatabaseWrapper) -> _AsSqlType: ...

class TemporalSubtraction(CombinedExpression):
def __init__(self, lhs: Combinable, rhs: Combinable) -> None: ...

class F(Combinable):
name: str
def __init__(self, name: str) -> None: ...
Expand Down Expand Up @@ -148,30 +154,24 @@ class OuterRef(F):
contains_aggregate: bool
def relabeled_clone(self: Self, relabels: Any) -> Self: ...

class Subquery(BaseExpression, Combinable):
class Func(SQLiteNumericMixin, Expression):
function: str
name: str
template: str
query: Query
arg_joiner: str
arity: int | None
source_expressions: list[Expression]
extra: dict[Any, Any]
def __init__(self, queryset: Query | QuerySet, output_field: Field | None = ..., **extra: Any) -> None: ...

class Exists(Subquery):
negated: bool
def __init__(self, queryset: Query | QuerySet, negated: bool = ..., **kwargs: Any) -> None: ...
def __invert__(self) -> Exists: ...

class OrderBy(Expression):
template: str
nulls_first: bool
nulls_last: bool
descending: bool
expression: Expression | F | Subquery
def __init__(
def __init__(self, *expressions: Any, output_field: Field | None = ..., **extra: Any) -> None: ...
def as_sql(
self,
expression: Expression | F | Subquery,
descending: bool = ...,
nulls_first: bool = ...,
nulls_last: bool = ...,
) -> None: ...
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
function: str | None = ...,
template: str | None = ...,
arg_joiner: str | None = ...,
**extra_context: Any,
) -> _AsSqlType: ...

class Value(Expression):
value: Any
Expand All @@ -182,15 +182,25 @@ class RawSQL(Expression):
sql: str
def __init__(self, sql: str, params: Sequence[Any], output_field: Field | None = ...) -> None: ...

class Func(SQLiteNumericMixin, Expression):
function: str
name: str
template: str
arg_joiner: str
arity: int | None
source_expressions: list[Expression]
extra: dict[Any, Any]
def __init__(self, *expressions: Any, output_field: Field | None = ..., **extra: Any) -> None: ...
class Star(Expression): ...

class Col(Expression):
target: Field
alias: str
contains_column_references: Literal[True]
possibly_multivalued: Literal[False]
def __init__(self, alias: str, target: Field, output_field: Field | None = ...) -> None: ...

class Ref(Expression):
def __init__(self, refs: str, source: Expression) -> None: ...

class ExpressionList(Func):
def __init__(self, *expressions: BaseExpression | Combinable, **extra: Any) -> None: ...

class OrderByList(Func): ...

class ExpressionWrapper(Expression):
def __init__(self, expression: Q | Combinable, output_field: Field) -> None: ...

class When(Expression):
template: str
Expand All @@ -208,21 +218,30 @@ class Case(Expression):
self, *cases: Any, default: Any | None = ..., output_field: Field | None = ..., **extra: Any
) -> None: ...

class ExpressionWrapper(Expression):
def __init__(self, expression: Q | Combinable, output_field: Field) -> None: ...

class Col(Expression):
target: Field
alias: str
contains_column_references: Literal[True]
possibly_multivalued: Literal[False]
def __init__(self, alias: str, target: Field, output_field: Field | None = ...) -> None: ...
class Subquery(BaseExpression, Combinable):
template: str
query: Query
extra: dict[Any, Any]
def __init__(self, queryset: Query | QuerySet, output_field: Field | None = ..., **extra: Any) -> None: ...

class Ref(Expression):
def __init__(self, refs: str, source: Expression) -> None: ...
class Exists(Subquery):
negated: bool
def __init__(self, queryset: Query | QuerySet, negated: bool = ..., **kwargs: Any) -> None: ...
def __invert__(self) -> Exists: ...

class ExpressionList(Func):
def __init__(self, *expressions: BaseExpression | Combinable, **extra: Any) -> None: ...
class OrderBy(Expression):
template: str
nulls_first: bool
nulls_last: bool
descending: bool
expression: Expression | F | Subquery
def __init__(
self,
expression: Expression | F | Subquery,
descending: bool = ...,
nulls_first: bool = ...,
nulls_last: bool = ...,
) -> None: ...

class Window(SQLiteNumericMixin, Expression):
template: str
Expand Down

0 comments on commit a7a1518

Please sign in to comment.