From e41ffedd0fcfdd04a59ebe3a69b755927e87da62 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Mon, 15 Jul 2024 15:23:11 +0200 Subject: [PATCH] Support inheriting ManyToManyField from an abstract model (#2260) --- mypy_django_plugin/lib/helpers.py | 36 ++++++++++ mypy_django_plugin/transformers/manytomany.py | 54 +++------------ mypy_django_plugin/transformers/models.py | 61 +++++++++++----- tests/typecheck/fields/test_related.yml | 69 +++++++++++++++++++ .../typecheck/models/test_contrib_models.yml | 2 +- 5 files changed, 157 insertions(+), 65 deletions(-) diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index e69062a44..de8aeb1b4 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -18,6 +18,7 @@ MemberExpr, MypyFile, NameExpr, + RefExpr, StrExpr, SymbolNode, SymbolTable, @@ -497,3 +498,38 @@ def resolve_lazy_reference( def is_model_type(info: TypeInfo) -> bool: return info.metaclass_type is not None and info.metaclass_type.type.has_base(fullnames.MODEL_METACLASS_FULLNAME) + + +def get_model_from_expression( + expr: Expression, + *, + self_model: TypeInfo, + api: Union[TypeChecker, SemanticAnalyzer], + django_context: "DjangoContext", +) -> Optional[Instance]: + """ + Attempts to resolve an expression to a 'TypeInfo' instance. Any lazy reference + argument(e.g. ".") to a Django model is also attempted. + """ + if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo): + if is_model_type(expr.node): + return Instance(expr.node, []) + + if isinstance(expr, StrExpr) and expr.value == "self": + return Instance(self_model, []) + + lazy_reference = None + if isinstance(expr, StrExpr): + lazy_reference = expr.value + elif ( + isinstance(expr, MemberExpr) + and isinstance(expr.expr, NameExpr) + and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_FULLNAME + ): + lazy_reference = django_context.settings.AUTH_USER_MODEL + + if lazy_reference is not None: + model_info = resolve_lazy_reference(lazy_reference, api=api, django_context=django_context, ctx=expr) + if model_info is not None: + return Instance(model_info, []) + return None diff --git a/mypy_django_plugin/transformers/manytomany.py b/mypy_django_plugin/transformers/manytomany.py index c9fa3fc75..225d8f2c1 100644 --- a/mypy_django_plugin/transformers/manytomany.py +++ b/mypy_django_plugin/transformers/manytomany.py @@ -1,9 +1,7 @@ -from typing import NamedTuple, Optional, Tuple, Union +from typing import NamedTuple, Optional, Tuple -from mypy.checker import TypeChecker -from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo +from mypy.nodes import AssignmentStmt, NameExpr, Node, TypeInfo from mypy.plugin import FunctionContext, MethodContext -from mypy.semanal import SemanticAnalyzer from mypy.types import Instance, ProperType, UninhabitedType from mypy.types import Type as MypyType @@ -72,24 +70,21 @@ def get_m2m_arguments( ) -> Optional[M2MArguments]: checker = helpers.get_typechecker_api(ctx) to_arg = ctx.args[0][0] - to_model: Optional[ProperType] - if isinstance(to_arg, StrExpr) and to_arg.value == "self": - to_model = Instance(model_info, []) - to_self = True - else: - to_model = get_model_from_expression(to_arg, api=checker, django_context=django_context) - to_self = False - + to_model = helpers.get_model_from_expression( + to_arg, self_model=model_info, api=checker, django_context=django_context + ) if to_model is None: # 'ManyToManyField()' requires the 'to' argument return None - to = M2MTo(arg=to_arg, model=to_model, self=to_self) + to = M2MTo(arg=to_arg, model=to_model, self=to_model.type == model_info) through = None if len(ctx.args) > 5 and ctx.args[5]: # 'ManyToManyField(..., through=)' was called through_arg = ctx.args[5][0] - through_model = get_model_from_expression(through_arg, api=checker, django_context=django_context) + through_model = helpers.get_model_from_expression( + through_arg, self_model=model_info, api=checker, django_context=django_context + ) if through_model is not None: through = M2MThrough(arg=through_arg, model=through_model) elif not helpers.is_abstract_model(model_info): @@ -119,37 +114,6 @@ def get_m2m_arguments( return M2MArguments(to=to, through=through) -def get_model_from_expression( - expr: Expression, - *, - api: Union[TypeChecker, SemanticAnalyzer], - django_context: DjangoContext, -) -> Optional[ProperType]: - """ - Attempts to resolve an expression to a 'TypeInfo' instance. Any lazy reference - argument(e.g. ".") to a Django model is also attempted. - """ - if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo): - if helpers.is_model_type(expr.node): - return Instance(expr.node, []) - - lazy_reference = None - if isinstance(expr, StrExpr): - lazy_reference = expr.value - elif ( - isinstance(expr, MemberExpr) - and isinstance(expr.expr, NameExpr) - and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_FULLNAME - ): - lazy_reference = django_context.settings.AUTH_USER_MODEL - - if lazy_reference is not None: - model_info = helpers.resolve_lazy_reference(lazy_reference, api=api, django_context=django_context, ctx=expr) - if model_info is not None: - return Instance(model_info, []) - return None - - def get_related_manager_and_model(ctx: MethodContext) -> Optional[Tuple[Instance, Instance, Instance]]: """ Returns a 3-tuple consisting of: diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 0888fff58..c203bc8c9 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -18,7 +18,6 @@ NameExpr, RefExpr, Statement, - StrExpr, SymbolTableNode, TypeInfo, Var, @@ -41,7 +40,7 @@ MANAGER_METHODS_RETURNING_QUERYSET, create_manager_info_from_from_queryset_call, ) -from mypy_django_plugin.transformers.manytomany import M2MArguments, M2MThrough, M2MTo, get_model_from_expression +from mypy_django_plugin.transformers.manytomany import M2MArguments, M2MThrough, M2MTo class ModelClassInitializer: @@ -677,12 +676,27 @@ def run(self) -> None: continue # Get the names of the implicit through model that will be generated through_model_name = f"{self.model_classdef.name}_{m2m_field_name}" - self.create_through_table_class( + through_model = self.create_through_table_class( field_name=m2m_field_name, model_name=through_model_name, model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}", m2m_args=args, ) + container = self.model_classdef.info.get_containing_type_info(m2m_field_name) + if ( + through_model is not None + and container is not None + and container.fullname != self.model_classdef.info.fullname + and helpers.is_abstract_model(container) + ): + # ManyToManyField is inherited from an abstract parent class, so in + # order to get the to and the through model argument right we + # override the ManyToManyField attribute on the current class + helpers.add_new_sym_for_info( + self.model_classdef.info, + name=m2m_field_name, + sym_type=Instance(self.m2m_field, [args.to.model, Instance(through_model, [])]), + ) # Create a 'ManyRelatedManager' class for the processed model self.create_many_related_manager(Instance(self.model_classdef.info, [])) if isinstance(args.to.model, Instance): @@ -717,6 +731,13 @@ def fk_field(self) -> TypeInfo: raise helpers.IncompleteDefnException() return info + @cached_property + def m2m_field(self) -> TypeInfo: + info = self.lookup_typeinfo(fullnames.MANYTOMANY_FIELD_FULLNAME) + if info is None: + raise helpers.IncompleteDefnException() + return info + @cached_property def manager_info(self) -> TypeInfo: info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME) @@ -746,18 +767,17 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance: def create_through_table_class( self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments - ) -> None: - if ( - not isinstance(m2m_args.to.model, Instance) + ) -> Optional[TypeInfo]: + if not isinstance(m2m_args.to.model, Instance): + return None + elif m2m_args.through is not None: # Call has explicit 'through=', no need to create any implicit through table - or m2m_args.through is not None - ): - return + return m2m_args.through.model.type if isinstance(m2m_args.through.model, Instance) else None # If through model is already declared there's nothing more we should do through_model = self.lookup_typeinfo(model_fullname) if through_model is not None: - return + return through_model # Declare a new, empty, implicitly generated through model class named: '_' through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])]) # We attempt to be a bit clever here and store the generated through model's fullname in @@ -823,6 +843,7 @@ def create_through_table_class( sym_type=Instance(self.manager_info, [Instance(through_model, [])]), is_classvar=True, ) + return through_model def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]: """ @@ -848,22 +869,24 @@ def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> return None # Resolve the type of the 'to' argument expression - to_model: Optional[ProperType] - if isinstance(to_arg, StrExpr) and to_arg.value == "self": - to_model = Instance(self.model_classdef.info, []) - to_self = True - else: - to_model = get_model_from_expression(to_arg, api=self.api, django_context=self.django_context) - to_self = False + to_model = helpers.get_model_from_expression( + to_arg, self_model=self.model_classdef.info, api=self.api, django_context=self.django_context + ) if to_model is None: return None - to = M2MTo(arg=to_arg, model=to_model, self=to_self) + to = M2MTo( + arg=to_arg, + model=to_model, + self=to_model.type == self.model_classdef.info, + ) # Resolve the type of the 'through' argument expression through_arg = look_for["through"] through = None if through_arg is not None: - through_model = get_model_from_expression(through_arg, api=self.api, django_context=self.django_context) + through_model = helpers.get_model_from_expression( + through_arg, self_model=self.model_classdef.info, api=self.api, django_context=self.django_context + ) if through_model is not None: through = M2MThrough(arg=through_arg, model=through_model) diff --git a/tests/typecheck/fields/test_related.yml b/tests/typecheck/fields/test_related.yml index 843bf3a93..9a19e796d 100644 --- a/tests/typecheck/fields/test_related.yml +++ b/tests/typecheck/fields/test_related.yml @@ -1403,3 +1403,72 @@ class MyModel(models.Model): m2m_1 = models.ManyToManyField(other_models.Other, related_name="auto_through") m2m_2 = models.ManyToManyField(other_models.Other, related_name="custom_through", through=Through) + +- case: test_m2m_from_abstract_model + main: | + from myapp.models import First, Second + reveal_type(First().others) # N: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.First_others]" + reveal_type(First().others.get()) # N: Revealed type is "myapp.models.Other" + reveal_type(First.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Other, myapp.models.First_others]" + reveal_type(First.others.through) # N: Revealed type is "Type[myapp.models.First_others]" + reveal_type(First.others.through.objects.get()) # N: Revealed type is "myapp.models.First_others" + + reveal_type(Second().others) # N: Revealed type is "myapp.models.Other_ManyRelatedManager[myapp.models.Second_others]" + reveal_type(Second().others.get()) # N: Revealed type is "myapp.models.Other" + reveal_type(Second.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Other, myapp.models.Second_others]" + reveal_type(Second.others.through) # N: Revealed type is "Type[myapp.models.Second_others]" + reveal_type(Second.others.through.objects.get()) # N: Revealed type is "myapp.models.Second_others" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Other(models.Model): + ... + + class Parent(models.Model): + others = models.ManyToManyField(Other) + + class Meta: + abstract = True + + class First(Parent): + ... + + class Second(Parent): + ... + +- case: test_m2m_self_on_abstract_model + main: | + from myapp.models import First, Second + reveal_type(First().others) # N: Revealed type is "myapp.models.First_ManyRelatedManager[myapp.models.First_others]" + reveal_type(First().others.get()) # N: Revealed type is "myapp.models.First" + reveal_type(First.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.First, myapp.models.First_others]" + reveal_type(First.others.through) # N: Revealed type is "Type[myapp.models.First_others]" + reveal_type(First.others.through.objects.get()) # N: Revealed type is "myapp.models.First_others" + + reveal_type(Second().others) # N: Revealed type is "myapp.models.Second_ManyRelatedManager[myapp.models.Second_others]" + reveal_type(Second().others.get()) # N: Revealed type is "myapp.models.Second" + reveal_type(Second.others) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[myapp.models.Second, myapp.models.Second_others]" + reveal_type(Second.others.through) # N: Revealed type is "Type[myapp.models.Second_others]" + reveal_type(Second.others.through.objects.get()) # N: Revealed type is "myapp.models.Second_others" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class Parent(models.Model): + others = models.ManyToManyField("self") + + class Meta: + abstract = True + + class First(Parent): + ... + + class Second(Parent): + ... diff --git a/tests/typecheck/models/test_contrib_models.yml b/tests/typecheck/models/test_contrib_models.yml index 3223059c7..ca7b14be0 100644 --- a/tests/typecheck/models/test_contrib_models.yml +++ b/tests/typecheck/models/test_contrib_models.yml @@ -15,7 +15,7 @@ reveal_type(User().is_anonymous) # N: Revealed type is "Literal[False]" reveal_type(User().groups.get()) # N: Revealed type is "django.contrib.auth.models.Group" reveal_type(User().user_permissions.get()) # N: Revealed type is "django.contrib.auth.models.Permission" - reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, django.db.models.base.Model]" + reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, django.contrib.auth.models.User_groups]" reveal_type(User.user_permissions) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Permission, django.db.models.base.Model]" from django.contrib.auth.models import AnonymousUser