Skip to content

Commit

Permalink
Support inheriting ManyToManyField from an abstract model (#2260)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaeppe authored Jul 15, 2024
1 parent a16b9ba commit e41ffed
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 65 deletions.
36 changes: 36 additions & 0 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MemberExpr,
MypyFile,
NameExpr,
RefExpr,
StrExpr,
SymbolNode,
SymbolTable,
Expand Down Expand Up @@ -497,3 +498,38 @@ def resolve_lazy_reference(

def is_model_type(info: TypeInfo) -> bool:
return info.metaclass_type is not None and info.metaclass_type.type.has_base(fullnames.MODEL_METACLASS_FULLNAME)


def get_model_from_expression(
expr: Expression,
*,
self_model: TypeInfo,
api: Union[TypeChecker, SemanticAnalyzer],
django_context: "DjangoContext",
) -> Optional[Instance]:
"""
Attempts to resolve an expression to a 'TypeInfo' instance. Any lazy reference
argument(e.g. "<app_label>.<object_name>") to a Django model is also attempted.
"""
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
if is_model_type(expr.node):
return Instance(expr.node, [])

if isinstance(expr, StrExpr) and expr.value == "self":
return Instance(self_model, [])

lazy_reference = None
if isinstance(expr, StrExpr):
lazy_reference = expr.value
elif (
isinstance(expr, MemberExpr)
and isinstance(expr.expr, NameExpr)
and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_FULLNAME
):
lazy_reference = django_context.settings.AUTH_USER_MODEL

if lazy_reference is not None:
model_info = resolve_lazy_reference(lazy_reference, api=api, django_context=django_context, ctx=expr)
if model_info is not None:
return Instance(model_info, [])
return None
54 changes: 9 additions & 45 deletions mypy_django_plugin/transformers/manytomany.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import NamedTuple, Optional, Tuple, Union
from typing import NamedTuple, Optional, Tuple

from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo
from mypy.nodes import AssignmentStmt, NameExpr, Node, TypeInfo
from mypy.plugin import FunctionContext, MethodContext
from mypy.semanal import SemanticAnalyzer
from mypy.types import Instance, ProperType, UninhabitedType
from mypy.types import Type as MypyType

Expand Down Expand Up @@ -72,24 +70,21 @@ def get_m2m_arguments(
) -> Optional[M2MArguments]:
checker = helpers.get_typechecker_api(ctx)
to_arg = ctx.args[0][0]
to_model: Optional[ProperType]
if isinstance(to_arg, StrExpr) and to_arg.value == "self":
to_model = Instance(model_info, [])
to_self = True
else:
to_model = get_model_from_expression(to_arg, api=checker, django_context=django_context)
to_self = False

to_model = helpers.get_model_from_expression(
to_arg, self_model=model_info, api=checker, django_context=django_context
)
if to_model is None:
# 'ManyToManyField()' requires the 'to' argument
return None
to = M2MTo(arg=to_arg, model=to_model, self=to_self)
to = M2MTo(arg=to_arg, model=to_model, self=to_model.type == model_info)

through = None
if len(ctx.args) > 5 and ctx.args[5]:
# 'ManyToManyField(..., through=)' was called
through_arg = ctx.args[5][0]
through_model = get_model_from_expression(through_arg, api=checker, django_context=django_context)
through_model = helpers.get_model_from_expression(
through_arg, self_model=model_info, api=checker, django_context=django_context
)
if through_model is not None:
through = M2MThrough(arg=through_arg, model=through_model)
elif not helpers.is_abstract_model(model_info):
Expand Down Expand Up @@ -119,37 +114,6 @@ def get_m2m_arguments(
return M2MArguments(to=to, through=through)


def get_model_from_expression(
expr: Expression,
*,
api: Union[TypeChecker, SemanticAnalyzer],
django_context: DjangoContext,
) -> Optional[ProperType]:
"""
Attempts to resolve an expression to a 'TypeInfo' instance. Any lazy reference
argument(e.g. "<app_label>.<object_name>") to a Django model is also attempted.
"""
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
if helpers.is_model_type(expr.node):
return Instance(expr.node, [])

lazy_reference = None
if isinstance(expr, StrExpr):
lazy_reference = expr.value
elif (
isinstance(expr, MemberExpr)
and isinstance(expr.expr, NameExpr)
and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_FULLNAME
):
lazy_reference = django_context.settings.AUTH_USER_MODEL

if lazy_reference is not None:
model_info = helpers.resolve_lazy_reference(lazy_reference, api=api, django_context=django_context, ctx=expr)
if model_info is not None:
return Instance(model_info, [])
return None


def get_related_manager_and_model(ctx: MethodContext) -> Optional[Tuple[Instance, Instance, Instance]]:
"""
Returns a 3-tuple consisting of:
Expand Down
61 changes: 42 additions & 19 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
NameExpr,
RefExpr,
Statement,
StrExpr,
SymbolTableNode,
TypeInfo,
Var,
Expand All @@ -41,7 +40,7 @@
MANAGER_METHODS_RETURNING_QUERYSET,
create_manager_info_from_from_queryset_call,
)
from mypy_django_plugin.transformers.manytomany import M2MArguments, M2MThrough, M2MTo, get_model_from_expression
from mypy_django_plugin.transformers.manytomany import M2MArguments, M2MThrough, M2MTo


class ModelClassInitializer:
Expand Down Expand Up @@ -677,12 +676,27 @@ def run(self) -> None:
continue
# Get the names of the implicit through model that will be generated
through_model_name = f"{self.model_classdef.name}_{m2m_field_name}"
self.create_through_table_class(
through_model = self.create_through_table_class(
field_name=m2m_field_name,
model_name=through_model_name,
model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}",
m2m_args=args,
)
container = self.model_classdef.info.get_containing_type_info(m2m_field_name)
if (
through_model is not None
and container is not None
and container.fullname != self.model_classdef.info.fullname
and helpers.is_abstract_model(container)
):
# ManyToManyField is inherited from an abstract parent class, so in
# order to get the to and the through model argument right we
# override the ManyToManyField attribute on the current class
helpers.add_new_sym_for_info(
self.model_classdef.info,
name=m2m_field_name,
sym_type=Instance(self.m2m_field, [args.to.model, Instance(through_model, [])]),
)
# Create a 'ManyRelatedManager' class for the processed model
self.create_many_related_manager(Instance(self.model_classdef.info, []))
if isinstance(args.to.model, Instance):
Expand Down Expand Up @@ -717,6 +731,13 @@ def fk_field(self) -> TypeInfo:
raise helpers.IncompleteDefnException()
return info

@cached_property
def m2m_field(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.MANYTOMANY_FIELD_FULLNAME)
if info is None:
raise helpers.IncompleteDefnException()
return info

@cached_property
def manager_info(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
Expand Down Expand Up @@ -746,18 +767,17 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance:

def create_through_table_class(
self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments
) -> None:
if (
not isinstance(m2m_args.to.model, Instance)
) -> Optional[TypeInfo]:
if not isinstance(m2m_args.to.model, Instance):
return None
elif m2m_args.through is not None:
# Call has explicit 'through=', no need to create any implicit through table
or m2m_args.through is not None
):
return
return m2m_args.through.model.type if isinstance(m2m_args.through.model, Instance) else None

# If through model is already declared there's nothing more we should do
through_model = self.lookup_typeinfo(model_fullname)
if through_model is not None:
return
return through_model
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])])
# We attempt to be a bit clever here and store the generated through model's fullname in
Expand Down Expand Up @@ -823,6 +843,7 @@ def create_through_table_class(
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
is_classvar=True,
)
return through_model

def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]:
"""
Expand All @@ -848,22 +869,24 @@ def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) ->
return None

# Resolve the type of the 'to' argument expression
to_model: Optional[ProperType]
if isinstance(to_arg, StrExpr) and to_arg.value == "self":
to_model = Instance(self.model_classdef.info, [])
to_self = True
else:
to_model = get_model_from_expression(to_arg, api=self.api, django_context=self.django_context)
to_self = False
to_model = helpers.get_model_from_expression(
to_arg, self_model=self.model_classdef.info, api=self.api, django_context=self.django_context
)
if to_model is None:
return None
to = M2MTo(arg=to_arg, model=to_model, self=to_self)
to = M2MTo(
arg=to_arg,
model=to_model,
self=to_model.type == self.model_classdef.info,
)

# Resolve the type of the 'through' argument expression
through_arg = look_for["through"]
through = None
if through_arg is not None:
through_model = get_model_from_expression(through_arg, api=self.api, django_context=self.django_context)
through_model = helpers.get_model_from_expression(
through_arg, self_model=self.model_classdef.info, api=self.api, django_context=self.django_context
)
if through_model is not None:
through = M2MThrough(arg=through_arg, model=through_model)

Expand Down
69 changes: 69 additions & 0 deletions tests/typecheck/fields/test_related.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1403,3 +1403,72 @@
class MyModel(models.Model):
m2m_1 = models.ManyToManyField(other_models.Other, related_name="auto_through")
m2m_2 = models.ManyToManyField(other_models.Other, related_name="custom_through", through=Through)
- case: test_m2m_from_abstract_model
main: |
from myapp.models import First, Second
reveal_type(First().others) # N: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.First_others]"
reveal_type(First().others.get()) # N: Revealed type is "myapp.models.Other"
reveal_type(First.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Other, myapp.models.First_others]"
reveal_type(First.others.through) # N: Revealed type is "Type[myapp.models.First_others]"
reveal_type(First.others.through.objects.get()) # N: Revealed type is "myapp.models.First_others"
reveal_type(Second().others) # N: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.Second_others]"
reveal_type(Second().others.get()) # N: Revealed type is "myapp.models.Other"
reveal_type(Second.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Other, myapp.models.Second_others]"
reveal_type(Second.others.through) # N: Revealed type is "Type[myapp.models.Second_others]"
reveal_type(Second.others.through.objects.get()) # N: Revealed type is "myapp.models.Second_others"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class Other(models.Model):
...
class Parent(models.Model):
others = models.ManyToManyField(Other)
class Meta:
abstract = True
class First(Parent):
...
class Second(Parent):
...
- case: test_m2m_self_on_abstract_model
main: |
from myapp.models import First, Second
reveal_type(First().others) # N: Revealed type is "myapp.models.First_ManyRelatedManager[myapp.models.First_others]"
reveal_type(First().others.get()) # N: Revealed type is "myapp.models.First"
reveal_type(First.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.First, myapp.models.First_others]"
reveal_type(First.others.through) # N: Revealed type is "Type[myapp.models.First_others]"
reveal_type(First.others.through.objects.get()) # N: Revealed type is "myapp.models.First_others"
reveal_type(Second().others) # N: Revealed type is "myapp.models.Second_ManyRelatedManager[myapp.models.Second_others]"
reveal_type(Second().others.get()) # N: Revealed type is "myapp.models.Second"
reveal_type(Second.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Second, myapp.models.Second_others]"
reveal_type(Second.others.through) # N: Revealed type is "Type[myapp.models.Second_others]"
reveal_type(Second.others.through.objects.get()) # N: Revealed type is "myapp.models.Second_others"
installed_apps:
- myapp
files:
- path: myapp/__init__.py
- path: myapp/models.py
content: |
from django.db import models
class Parent(models.Model):
others = models.ManyToManyField("self")
class Meta:
abstract = True
class First(Parent):
...
class Second(Parent):
...
2 changes: 1 addition & 1 deletion tests/typecheck/models/test_contrib_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
reveal_type(User().is_anonymous) # N: Revealed type is "Literal[False]"
reveal_type(User().groups.get()) # N: Revealed type is "django.contrib.auth.models.Group"
reveal_type(User().user_permissions.get()) # N: Revealed type is "django.contrib.auth.models.Permission"
reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, django.db.models.base.Model]"
reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, django.contrib.auth.models.User_groups]"
reveal_type(User.user_permissions) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Permission, django.db.models.base.Model]"
from django.contrib.auth.models import AnonymousUser
Expand Down

0 comments on commit e41ffed

Please sign in to comment.