Skip to content

Commit

Permalink
Add TypeVar resolution for as_manager and from_queryset querysets (ty…
Browse files Browse the repository at this point in the history
  • Loading branch information
moranabadie committed Aug 31, 2023
1 parent a8e42cb commit f44ec31
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 4 deletions.
80 changes: 76 additions & 4 deletions mypy_django_plugin/transformers/managers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Final, Optional, Union
from typing import Final, Optional

from mypy.checker import TypeChecker
from mypy.nodes import (
Expand All @@ -8,6 +8,7 @@
FuncBase,
FuncDef,
MemberExpr,
Node,
OverloadedFuncDef,
RefExpr,
StrExpr,
Expand All @@ -18,7 +19,7 @@
from mypy.plugin import AttributeContext, ClassDefContext, DynamicClassDefContext
from mypy.semanal import SemanticAnalyzer
from mypy.semanal_shared import has_placeholder
from mypy.types import AnyType, CallableType, Instance, ProperType, TypeOfAny
from mypy.types import AnyType, CallableType, Instance, ProperType, TypeOfAny, TypeVarType
from mypy.types import Type as MypyType
from mypy.typevars import fill_typevars

Expand Down Expand Up @@ -71,15 +72,19 @@ def get_method_type_from_dynamic_manager(
queryset_info = helpers.lookup_fully_qualified_typeinfo(api, queryset_fullname)
assert queryset_info is not None

def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[ProperType]:
def get_funcdef_type(definition: Optional[Node]) -> Optional[ProperType]:
# TODO: Handle @overload?
if isinstance(definition, FuncBase) and not isinstance(definition, OverloadedFuncDef):
return definition.type
elif isinstance(definition, Decorator):
return definition.func.type
return None

method_type = get_funcdef_type(queryset_info.get_method(method_name))
base_that_has_method = _get_base_containing_method(queryset_info, method_name)
if base_that_has_method is None:
return None

method_type = get_funcdef_type(base_that_has_method.type.names[method_name].node)
if method_type is None:
return None

Expand All @@ -105,6 +110,8 @@ def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[P
ret_type = Instance(queryset_info, [manager_instance.args[0], manager_instance.args[0]])
variables = []

ret_type = _find_type_var_from_mro(ret_type, queryset_info, base_that_has_method)

# Drop any 'self' argument as our manager is already initialized
return method_type.copy_modified(
arg_types=method_type.arg_types[1:],
Expand All @@ -115,6 +122,71 @@ def get_funcdef_type(definition: Union[FuncBase, Decorator, None]) -> Optional[P
)


def _get_base_containing_method(queryset_info: TypeInfo, method_name: str) -> Instance | None:
for mro in queryset_info.mro:
for base in mro.bases:
if method_name in base.type.names:
return base
return None


def _manage_type_var(ret_type: MypyType, base: Instance) -> MypyType:
"""
Attempts to manage the concrete type corresponding to a `TypeVarType` in a class's base types.
If nothing is found, keep it unchanged.
Example:
T = TypeVar("T")
for def example_list_dict(self) -> list[dict[str, T]]
Tries to find the corresponding `TypeVar` of `T` if it has been set.
"""
if isinstance(ret_type, TypeVarType):
return _find_type_var(base, ret_type) or ret_type
elif isinstance(ret_type, Instance):
# Since it is an instance, recursively find the type var for all its args.
ret_type.args = tuple(_manage_type_var(item, base) for item in ret_type.args)
if isinstance(ret_type, ProperType) and hasattr(ret_type, "item"):
# For example TypeType has an item. find the type_var for this item
ret_type.item = _manage_type_var(ret_type.item, base)
if isinstance(ret_type, ProperType) and hasattr(ret_type, "items"):
# For example TypeList has items. find recursively type_var for its items
ret_type.items = [_manage_type_var(item, base) for item in ret_type.items]
return ret_type


def _find_type_var(base: Instance, target: TypeVarType) -> Optional[MypyType]:
"""
Attempts to find the concrete type corresponding to a TypeVarType in a class's base types.
Example:
Suppose queryset_info is based on a generic class `QuerySet[T]` and ret_type corresponds
to `T`, then this function will return the concrete type that `T` was instantiated with.
"""
for i, type_var in enumerate(base.type.defn.type_vars):
if type_var.fullname == target.fullname:
return base.args[i]
return None


def _find_type_var_from_mro(
ret_type: MypyType, queryset_info: TypeInfo, base_that_has_method: Instance
) -> Optional[MypyType]:
"""
Attempts to find the type var by looking at the mro elements (reversed).
"""
new_ret_type = ret_type
for i, mro in enumerate(reversed(queryset_info.mro)):
if mro in base_that_has_method.type.mro:
# No need to look at the mro before base_that_has_method
continue
for base in mro.bases:
new_ret_type = _manage_type_var(new_ret_type, base)
return new_ret_type or ret_type


def get_method_type_from_reverse_manager(
api: TypeChecker, method_name: str, manager_type_info: TypeInfo
) -> Optional[ProperType]:
Expand Down
99 changes: 99 additions & 0 deletions tests/typecheck/managers/querysets/test_as_manager.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,105 @@
class MyModel(models.Model):
objects = ManagerFromModelQuerySet
- case: handles_subclasses_of_subclasses_of_queryset
main: |
from myapp.models import MyModel
reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel"
reveal_type(MyModel.objects.example_2()) # N: Revealed type is "myapp.models.MyModel"
reveal_type(MyModel.objects.override()) # N: Revealed type is "myapp.models.MyModel"
reveal_type(MyModel.objects.override2()) # N: Revealed type is "myapp.models.MyModel"
reveal_type(MyModel.objects.dummy_override()) # N: Revealed type is "myapp.models.MyModel"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from typing import TypeVar
from django.db import models
_CTE = TypeVar("_CTE", bound=models.Model)
_CTE_2 = TypeVar("_CTE_2", bound=models.Model)
class _MyModelQuerySet(models.QuerySet[_CTE]):
def example(self) -> _CTE: ...
def override(self) -> _CTE: ...
def override2(self) -> _CTE: ...
def dummy_override(self) -> int: ...
class _MyModelQuerySet2(_MyModelQuerySet[_CTE_2]):
def example_2(self) -> _CTE_2: ...
def override(self) -> _CTE_2: ...
def override2(self) -> _CTE_2: ...
def dummy_override(self) -> _CTE_2: ... # type: ignore[override]
class MyModelQuerySet(_MyModelQuerySet2["MyModel"]):
def override(self) -> "MyModel": ...
class MyModel(models.Model):
objects = MyModelQuerySet.as_manager()
- case: handles_subclasses_of_queryset
main: |
from myapp.models import MyModel, BaseQuerySet
reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel"
reveal_type(MyModel.objects.example_2()) # N: Revealed type is "myapp.models.MyOtherModel"
reveal_type(MyModel.objects.example_list()) # N: Revealed type is "builtins.list[myapp.models.MyModel]"
reveal_type(MyModel.objects.example_type()) # N: Revealed type is "Type[myapp.models.MyModel]"
reveal_type(MyModel.objects.example_tuple_simple()) # N: Revealed type is "Tuple[myapp.models.MyModel]"
reveal_type(MyModel.objects.example_tuple_list()) # N: Revealed type is "builtins.tuple[myapp.models.MyModel, ...]"
reveal_type(MyModel.objects.example_tuple_double()) # N: Revealed type is "Tuple[builtins.int, myapp.models.MyModel]"
reveal_type(MyModel.objects.example_class()) # N: Revealed type is "myapp.models.Example[myapp.models.MyModel]"
reveal_type(MyModel.objects.example_type_class()) # N: Revealed type is "Type[myapp.models.Example[myapp.models.MyModel]]"
reveal_type(MyModel.objects.example_collection()) # N: Revealed type is "typing.Collection[myapp.models.MyModel]"
reveal_type(MyModel.objects.example_set()) # N: Revealed type is "builtins.set[myapp.models.MyModel]"
reveal_type(MyModel.objects.example_dict()) # N: Revealed type is "builtins.dict[builtins.str, myapp.models.MyModel]"
reveal_type(MyModel.objects.example_list_dict()) # N: Revealed type is "builtins.list[builtins.dict[myapp.models.MyOtherModel, myapp.models.MyModel]]"
class TestQuerySet(BaseQuerySet[str, MyModel]): ... # E: Type argument "str" of "BaseQuerySet" must be a subtype of "Model"
reveal_type(MyModel.objects.example_t(5)) # N: Revealed type is "builtins.int"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from typing import TypeVar, Generic, Collection
from django.db import models
_CTE = TypeVar("_CTE", bound=models.Model)
_CTE_2 = TypeVar("_CTE_2", bound=models.Model)
T = TypeVar("T")
class Example(Generic[_CTE]): ...
class BaseQuerySet(models.QuerySet[_CTE], Generic[_CTE, _CTE_2]):
def example(self) -> _CTE: ...
def example_2(self) -> _CTE_2: ...
def example_list(self) -> list[_CTE]: ...
def example_type(self) -> type[_CTE]: ...
def example_tuple_simple(self) -> tuple[_CTE]: ...
def example_tuple_list(self) -> tuple[_CTE, ...]: ...
def example_tuple_double(self) -> tuple[int, _CTE]: ...
def example_class(self) -> Example[_CTE]: ...
def example_type_class(self) -> type[Example[_CTE]]: ...
def example_collection(self) -> Collection[_CTE]: ...
def example_set(self) -> set[_CTE]: ...
def example_dict(self) -> dict[str, _CTE]: ...
def example_list_dict(self) -> list[dict[_CTE_2, _CTE]]: ...
def example_t(self, a: T) -> T: ...
class MyModelQuerySet(BaseQuerySet["MyModel", "MyOtherModel"]):
...
class MyModel(models.Model):
objects = MyModelQuerySet.as_manager()
class MyOtherModel(models.Model): ...
- case: reuses_generated_type_when_called_identically_for_multiple_managers
main: |
from myapp.models import MyModel
Expand Down
35 changes: 35 additions & 0 deletions tests/typecheck/managers/querysets/test_from_queryset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,41 @@
NewManager = BaseManager.from_queryset(ModelQuerySet)
class MyModel(models.Model):
objects = NewManager()
- case: handles_subclasses_of_queryset
main: |
from myapp.models import MyModel
reveal_type(MyModel.objects.example()) # N: Revealed type is "myapp.models.MyModel"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/querysets.py
content: |
from typing import TypeVar, TYPE_CHECKING
from django.db import models
from django.db.models.manager import BaseManager
if TYPE_CHECKING:
from .models import MyModel
_CTE = TypeVar("_CTE", bound=models.Model)
class _MyModelQuerySet(models.QuerySet[_CTE]):
def example(self) -> _CTE: ...
class MyModelQuerySet(_MyModelQuerySet["MyModel"]):
...
- path: myapp/models.py
content: |
from django.db import models
from django.db.models.manager import BaseManager
from .querysets import MyModelQuerySet
NewManager = BaseManager.from_queryset(MyModelQuerySet)
class MyModel(models.Model):
objects = NewManager()
- case: from_queryset_generated_manager_imported_from_other_module
main: |
Expand Down

0 comments on commit f44ec31

Please sign in to comment.