From 836434a7105a5664a9ebeb9d5139ae06cd307c59 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Sat, 22 Jun 2024 09:36:04 +0200 Subject: [PATCH] Extract through table creation to separate method --- mypy_django_plugin/transformers/fields.py | 11 +- mypy_django_plugin/transformers/manytomany.py | 6 +- mypy_django_plugin/transformers/models.py | 210 ++++++++++-------- 3 files changed, 128 insertions(+), 99 deletions(-) diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 731a102b5..cfd280aa4 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast from django.core.exceptions import FieldDoesNotExist from django.db.models.fields import AutoField, Field @@ -114,12 +114,17 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context ) +class FieldDescriptorTypes(NamedTuple): + set: MypyType + get: MypyType + + def get_field_descriptor_types( field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool -) -> Tuple[MypyType, MypyType]: +) -> FieldDescriptorTypes: set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable) get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable) - return set_type, get_type + return FieldDescriptorTypes(set=set_type, get=get_type) def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: diff --git a/mypy_django_plugin/transformers/manytomany.py b/mypy_django_plugin/transformers/manytomany.py index 9d0f7d1a0..4e0617dd3 100644 --- a/mypy_django_plugin/transformers/manytomany.py +++ b/mypy_django_plugin/transformers/manytomany.py @@ -1,7 +1,7 @@ from typing import NamedTuple, Optional, Tuple, Union from mypy.checker import TypeChecker -from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, RefExpr, StrExpr, TypeInfo +from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo from mypy.plugin import FunctionContext, MethodContext from mypy.semanal import SemanticAnalyzer from mypy.types import Instance, ProperType, TypeVarType, UninhabitedType @@ -12,12 +12,12 @@ class M2MThrough(NamedTuple): - arg: Optional[Expression] + arg: Optional[Node] model: ProperType class M2MTo(NamedTuple): - arg: Expression + arg: Node model: ProperType self: bool # ManyToManyField('self', ...) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index f99923c40..5c17b930a 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -36,7 +36,7 @@ from mypy_django_plugin.exceptions import UnregisteredModelError from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME -from mypy_django_plugin.transformers.fields import get_field_descriptor_types +from mypy_django_plugin.transformers.fields import FieldDescriptorTypes, get_field_descriptor_types from mypy_django_plugin.transformers.managers import ( MANAGER_METHODS_RETURNING_QUERYSET, create_manager_info_from_from_queryset_call, @@ -644,17 +644,6 @@ def run(self) -> None: # TODO: Create abstract through models? return - # Start out by prefetching a couple of dependencies needed to be able to declare any - # new, implicit, through model class. - model_base = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME) - fk_field = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME) - manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME) - if model_base is None or fk_field is None or manager_info is None: - raise helpers.IncompleteDefnException() - - from_pk = self.get_pk_instance(self.model_classdef.info) - fk_set_type, fk_get_type = get_field_descriptor_types(fk_field, is_set_nullable=False, is_get_nullable=False) - for statement in self.statements(): # Check if this part of the class body is an assignment from a 'ManyToManyField' call # = ManyToManyField(...) @@ -675,90 +664,16 @@ def run(self) -> None: continue # Resolve argument information of the 'ManyToManyField(...)' call args = self.resolve_many_to_many_arguments(statement.rvalue, context=statement) - if ( - # Ignore calls without required 'to' argument, mypy will complain - args is None - or not isinstance(args.to.model, Instance) - # Call has explicit 'through=', no need to create any implicit through table - or args.through is not None - ): + # Ignore calls without required 'to' argument, mypy will complain + if args is None: continue - # Get the names of the implicit through model that will be generated through_model_name = f"{self.model_classdef.name}_{m2m_field_name}" - through_model_fullname = f"{self.model_classdef.info.module_name}.{through_model_name}" - # If implicit through model is already declared there's nothing more we should do - through_model = self.lookup_typeinfo(through_model_fullname) - if through_model is not None: - continue - # Declare a new, empty, implicitly generated through model class named: '_' - through_model = self.add_new_class_for_current_module( - through_model_name, bases=[Instance(model_base, [])] - ) - # We attempt to be a bit clever here and store the generated through model's fullname in - # the metadata of the class containing the 'ManyToManyField' call expression, where its - # identifier is the field name of the 'ManyToManyField'. This would allow the containing - # model to always find the implicit through model, so that it doesn't get lost. - model_metadata = helpers.get_django_metadata(self.model_classdef.info) - model_metadata.setdefault("m2m_throughs", {}) - model_metadata["m2m_throughs"][m2m_field_name] = through_model.fullname - # Add a 'pk' symbol to the model class - helpers.add_new_sym_for_info( - through_model, name="pk", sym_type=self.default_pk_instance.copy_modified() - ) - # Add an 'id' symbol to the model class - helpers.add_new_sym_for_info( - through_model, name="id", sym_type=self.default_pk_instance.copy_modified() - ) - # Add the foreign key to the model containing the 'ManyToManyField' call: - # or from_ - from_name = ( - f"from_{self.model_classdef.name.lower()}" if args.to.self else self.model_classdef.name.lower() - ) - helpers.add_new_sym_for_info( - through_model, - name=from_name, - sym_type=Instance( - fk_field, - [ - helpers.convert_any_to_type(fk_set_type, Instance(self.model_classdef.info, [])), - helpers.convert_any_to_type(fk_get_type, Instance(self.model_classdef.info, [])), - ], - ), - ) - # Add the foreign key's '_id' field: _id or from__id - helpers.add_new_sym_for_info(through_model, name=f"{from_name}_id", sym_type=from_pk.copy_modified()) - # Add the foreign key to the model on the opposite side of the relation - # i.e. the model given as 'to' argument to the 'ManyToManyField' call: - # or to_ - to_name = f"to_{args.to.model.type.name.lower()}" if args.to.self else args.to.model.type.name.lower() - helpers.add_new_sym_for_info( - through_model, - name=to_name, - sym_type=Instance( - fk_field, - [ - helpers.convert_any_to_type(fk_set_type, args.to.model), - helpers.convert_any_to_type(fk_get_type, args.to.model), - ], - ), - ) - # Add the foreign key's '_id' field: _id or to__id - other_pk = self.get_pk_instance(args.to.model.type) - helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified()) - # Add a manager named 'objects' - helpers.add_new_sym_for_info( - through_model, - name="objects", - sym_type=Instance(manager_info, [Instance(through_model, [])]), - is_classvar=True, - ) - # Also add manager as '_default_manager' attribute - helpers.add_new_sym_for_info( - through_model, - name="_default_manager", - sym_type=Instance(manager_info, [Instance(through_model, [])]), - is_classvar=True, + self.create_through_table_class( + field_name=m2m_field_name, + model_name=through_model_name, + model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}", + m2m_args=args, ) @cached_property @@ -771,6 +686,35 @@ def default_pk_instance(self) -> Instance: list(get_field_descriptor_types(default_pk_field, is_set_nullable=True, is_get_nullable=False)), ) + @cached_property + def model_pk_instance(self) -> Instance: + return self.get_pk_instance(self.model_classdef.info) + + @cached_property + def model_base(self) -> TypeInfo: + info = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME) + if info is None: + raise helpers.IncompleteDefnException() + return info + + @cached_property + def fk_field(self) -> TypeInfo: + info = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME) + if info is None: + raise helpers.IncompleteDefnException() + return info + + @cached_property + def manager_info(self) -> TypeInfo: + info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME) + if info is None: + raise helpers.IncompleteDefnException() + return info + + @cached_property + def fk_field_types(self) -> FieldDescriptorTypes: + return get_field_descriptor_types(self.fk_field, is_set_nullable=False, is_get_nullable=False) + def get_pk_instance(self, model: TypeInfo, /) -> Instance: """ Get a primary key instance of provided model's type info. If primary key can't be resolved, @@ -783,6 +727,86 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance: return pk.type return self.default_pk_instance + def create_through_table_class( + self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments + ) -> None: + if ( + not isinstance(m2m_args.to.model, Instance) + # Call has explicit 'through=', no need to create any implicit through table + or m2m_args.through is not None + ): + return + + # If through model is already declared there's nothing more we should do + through_model = self.lookup_typeinfo(model_fullname) + if through_model is not None: + return + # Declare a new, empty, implicitly generated through model class named: '_' + through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])]) + # We attempt to be a bit clever here and store the generated through model's fullname in + # the metadata of the class containing the 'ManyToManyField' call expression, where its + # identifier is the field name of the 'ManyToManyField'. This would allow the containing + # model to always find the implicit through model, so that it doesn't get lost. + model_metadata = helpers.get_django_metadata(self.model_classdef.info) + model_metadata.setdefault("m2m_throughs", {}) + model_metadata["m2m_throughs"][field_name] = through_model.fullname + # Add a 'pk' symbol to the model class + helpers.add_new_sym_for_info(through_model, name="pk", sym_type=self.default_pk_instance.copy_modified()) + # Add an 'id' symbol to the model class + helpers.add_new_sym_for_info(through_model, name="id", sym_type=self.default_pk_instance.copy_modified()) + # Add the foreign key to the model containing the 'ManyToManyField' call: + # or from_ + from_name = f"from_{self.model_classdef.name.lower()}" if m2m_args.to.self else self.model_classdef.name.lower() + helpers.add_new_sym_for_info( + through_model, + name=from_name, + sym_type=Instance( + self.fk_field, + [ + helpers.convert_any_to_type(self.fk_field_types.set, Instance(self.model_classdef.info, [])), + helpers.convert_any_to_type(self.fk_field_types.get, Instance(self.model_classdef.info, [])), + ], + ), + ) + # Add the foreign key's '_id' field: _id or from__id + helpers.add_new_sym_for_info( + through_model, name=f"{from_name}_id", sym_type=self.model_pk_instance.copy_modified() + ) + # Add the foreign key to the model on the opposite side of the relation + # i.e. the model given as 'to' argument to the 'ManyToManyField' call: + # or to_ + to_name = ( + f"to_{m2m_args.to.model.type.name.lower()}" if m2m_args.to.self else m2m_args.to.model.type.name.lower() + ) + helpers.add_new_sym_for_info( + through_model, + name=to_name, + sym_type=Instance( + self.fk_field, + [ + helpers.convert_any_to_type(self.fk_field_types.set, m2m_args.to.model), + helpers.convert_any_to_type(self.fk_field_types.get, m2m_args.to.model), + ], + ), + ) + # Add the foreign key's '_id' field: _id or to__id + other_pk = self.get_pk_instance(m2m_args.to.model.type) + helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified()) + # Add a manager named 'objects' + helpers.add_new_sym_for_info( + through_model, + name="objects", + sym_type=Instance(self.manager_info, [Instance(through_model, [])]), + is_classvar=True, + ) + # Also add manager as '_default_manager' attribute + helpers.add_new_sym_for_info( + through_model, + name="_default_manager", + sym_type=Instance(self.manager_info, [Instance(through_model, [])]), + is_classvar=True, + ) + def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]: """ Inspect a 'ManyToManyField(...)' call to collect argument data on any 'to' and