Skip to content

Commit

Permalink
Extract through table creation to separate method (#2229)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaeppe authored Jun 22, 2024
1 parent c693e2a commit ec37d06
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 99 deletions.
11 changes: 8 additions & 3 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast

from django.core.exceptions import FieldDoesNotExist
from django.db.models.fields import AutoField, Field
Expand Down Expand Up @@ -114,12 +114,17 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
)


class FieldDescriptorTypes(NamedTuple):
set: MypyType
get: MypyType


def get_field_descriptor_types(
field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool
) -> Tuple[MypyType, MypyType]:
) -> FieldDescriptorTypes:
set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable)
get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable)
return set_type, get_type
return FieldDescriptorTypes(set=set_type, get=get_type)


def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
Expand Down
6 changes: 3 additions & 3 deletions mypy_django_plugin/transformers/manytomany.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import NamedTuple, Optional, Tuple, Union

from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, RefExpr, StrExpr, TypeInfo
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo
from mypy.plugin import FunctionContext, MethodContext
from mypy.semanal import SemanticAnalyzer
from mypy.types import Instance, ProperType, TypeVarType, UninhabitedType
Expand All @@ -12,12 +12,12 @@


class M2MThrough(NamedTuple):
arg: Optional[Expression]
arg: Optional[Node]
model: ProperType


class M2MTo(NamedTuple):
arg: Expression
arg: Node
model: ProperType
self: bool # ManyToManyField('self', ...)

Expand Down
210 changes: 117 additions & 93 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
from mypy_django_plugin.transformers.fields import FieldDescriptorTypes, get_field_descriptor_types
from mypy_django_plugin.transformers.managers import (
MANAGER_METHODS_RETURNING_QUERYSET,
create_manager_info_from_from_queryset_call,
Expand Down Expand Up @@ -644,17 +644,6 @@ def run(self) -> None:
# TODO: Create abstract through models?
return

# Start out by prefetching a couple of dependencies needed to be able to declare any
# new, implicit, through model class.
model_base = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME)
fk_field = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME)
manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
if model_base is None or fk_field is None or manager_info is None:
raise helpers.IncompleteDefnException()

from_pk = self.get_pk_instance(self.model_classdef.info)
fk_set_type, fk_get_type = get_field_descriptor_types(fk_field, is_set_nullable=False, is_get_nullable=False)

for statement in self.statements():
# Check if this part of the class body is an assignment from a 'ManyToManyField' call
# <field> = ManyToManyField(...)
Expand All @@ -675,90 +664,16 @@ def run(self) -> None:
continue
# Resolve argument information of the 'ManyToManyField(...)' call
args = self.resolve_many_to_many_arguments(statement.rvalue, context=statement)
if (
# Ignore calls without required 'to' argument, mypy will complain
args is None
or not isinstance(args.to.model, Instance)
# Call has explicit 'through=', no need to create any implicit through table
or args.through is not None
):
# Ignore calls without required 'to' argument, mypy will complain
if args is None:
continue

# Get the names of the implicit through model that will be generated
through_model_name = f"{self.model_classdef.name}_{m2m_field_name}"
through_model_fullname = f"{self.model_classdef.info.module_name}.{through_model_name}"
# If implicit through model is already declared there's nothing more we should do
through_model = self.lookup_typeinfo(through_model_fullname)
if through_model is not None:
continue
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
through_model = self.add_new_class_for_current_module(
through_model_name, bases=[Instance(model_base, [])]
)
# We attempt to be a bit clever here and store the generated through model's fullname in
# the metadata of the class containing the 'ManyToManyField' call expression, where its
# identifier is the field name of the 'ManyToManyField'. This would allow the containing
# model to always find the implicit through model, so that it doesn't get lost.
model_metadata = helpers.get_django_metadata(self.model_classdef.info)
model_metadata.setdefault("m2m_throughs", {})
model_metadata["m2m_throughs"][m2m_field_name] = through_model.fullname
# Add a 'pk' symbol to the model class
helpers.add_new_sym_for_info(
through_model, name="pk", sym_type=self.default_pk_instance.copy_modified()
)
# Add an 'id' symbol to the model class
helpers.add_new_sym_for_info(
through_model, name="id", sym_type=self.default_pk_instance.copy_modified()
)
# Add the foreign key to the model containing the 'ManyToManyField' call:
# <containing_model> or from_<model>
from_name = (
f"from_{self.model_classdef.name.lower()}" if args.to.self else self.model_classdef.name.lower()
)
helpers.add_new_sym_for_info(
through_model,
name=from_name,
sym_type=Instance(
fk_field,
[
helpers.convert_any_to_type(fk_set_type, Instance(self.model_classdef.info, [])),
helpers.convert_any_to_type(fk_get_type, Instance(self.model_classdef.info, [])),
],
),
)
# Add the foreign key's '_id' field: <containing_model>_id or from_<model>_id
helpers.add_new_sym_for_info(through_model, name=f"{from_name}_id", sym_type=from_pk.copy_modified())
# Add the foreign key to the model on the opposite side of the relation
# i.e. the model given as 'to' argument to the 'ManyToManyField' call:
# <other_model> or to_<model>
to_name = f"to_{args.to.model.type.name.lower()}" if args.to.self else args.to.model.type.name.lower()
helpers.add_new_sym_for_info(
through_model,
name=to_name,
sym_type=Instance(
fk_field,
[
helpers.convert_any_to_type(fk_set_type, args.to.model),
helpers.convert_any_to_type(fk_get_type, args.to.model),
],
),
)
# Add the foreign key's '_id' field: <other_model>_id or to_<model>_id
other_pk = self.get_pk_instance(args.to.model.type)
helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified())
# Add a manager named 'objects'
helpers.add_new_sym_for_info(
through_model,
name="objects",
sym_type=Instance(manager_info, [Instance(through_model, [])]),
is_classvar=True,
)
# Also add manager as '_default_manager' attribute
helpers.add_new_sym_for_info(
through_model,
name="_default_manager",
sym_type=Instance(manager_info, [Instance(through_model, [])]),
is_classvar=True,
self.create_through_table_class(
field_name=m2m_field_name,
model_name=through_model_name,
model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}",
m2m_args=args,
)

@cached_property
Expand All @@ -771,6 +686,35 @@ def default_pk_instance(self) -> Instance:
list(get_field_descriptor_types(default_pk_field, is_set_nullable=True, is_get_nullable=False)),
)

@cached_property
def model_pk_instance(self) -> Instance:
return self.get_pk_instance(self.model_classdef.info)

@cached_property
def model_base(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME)
if info is None:
raise helpers.IncompleteDefnException()
return info

@cached_property
def fk_field(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME)
if info is None:
raise helpers.IncompleteDefnException()
return info

@cached_property
def manager_info(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
if info is None:
raise helpers.IncompleteDefnException()
return info

@cached_property
def fk_field_types(self) -> FieldDescriptorTypes:
return get_field_descriptor_types(self.fk_field, is_set_nullable=False, is_get_nullable=False)

def get_pk_instance(self, model: TypeInfo, /) -> Instance:
"""
Get a primary key instance of provided model's type info. If primary key can't be resolved,
Expand All @@ -783,6 +727,86 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance:
return pk.type
return self.default_pk_instance

def create_through_table_class(
self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments
) -> None:
if (
not isinstance(m2m_args.to.model, Instance)
# Call has explicit 'through=', no need to create any implicit through table
or m2m_args.through is not None
):
return

# If through model is already declared there's nothing more we should do
through_model = self.lookup_typeinfo(model_fullname)
if through_model is not None:
return
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])])
# We attempt to be a bit clever here and store the generated through model's fullname in
# the metadata of the class containing the 'ManyToManyField' call expression, where its
# identifier is the field name of the 'ManyToManyField'. This would allow the containing
# model to always find the implicit through model, so that it doesn't get lost.
model_metadata = helpers.get_django_metadata(self.model_classdef.info)
model_metadata.setdefault("m2m_throughs", {})
model_metadata["m2m_throughs"][field_name] = through_model.fullname
# Add a 'pk' symbol to the model class
helpers.add_new_sym_for_info(through_model, name="pk", sym_type=self.default_pk_instance.copy_modified())
# Add an 'id' symbol to the model class
helpers.add_new_sym_for_info(through_model, name="id", sym_type=self.default_pk_instance.copy_modified())
# Add the foreign key to the model containing the 'ManyToManyField' call:
# <containing_model> or from_<model>
from_name = f"from_{self.model_classdef.name.lower()}" if m2m_args.to.self else self.model_classdef.name.lower()
helpers.add_new_sym_for_info(
through_model,
name=from_name,
sym_type=Instance(
self.fk_field,
[
helpers.convert_any_to_type(self.fk_field_types.set, Instance(self.model_classdef.info, [])),
helpers.convert_any_to_type(self.fk_field_types.get, Instance(self.model_classdef.info, [])),
],
),
)
# Add the foreign key's '_id' field: <containing_model>_id or from_<model>_id
helpers.add_new_sym_for_info(
through_model, name=f"{from_name}_id", sym_type=self.model_pk_instance.copy_modified()
)
# Add the foreign key to the model on the opposite side of the relation
# i.e. the model given as 'to' argument to the 'ManyToManyField' call:
# <other_model> or to_<model>
to_name = (
f"to_{m2m_args.to.model.type.name.lower()}" if m2m_args.to.self else m2m_args.to.model.type.name.lower()
)
helpers.add_new_sym_for_info(
through_model,
name=to_name,
sym_type=Instance(
self.fk_field,
[
helpers.convert_any_to_type(self.fk_field_types.set, m2m_args.to.model),
helpers.convert_any_to_type(self.fk_field_types.get, m2m_args.to.model),
],
),
)
# Add the foreign key's '_id' field: <other_model>_id or to_<model>_id
other_pk = self.get_pk_instance(m2m_args.to.model.type)
helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified())
# Add a manager named 'objects'
helpers.add_new_sym_for_info(
through_model,
name="objects",
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
is_classvar=True,
)
# Also add manager as '_default_manager' attribute
helpers.add_new_sym_for_info(
through_model,
name="_default_manager",
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
is_classvar=True,
)

def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]:
"""
Inspect a 'ManyToManyField(...)' call to collect argument data on any 'to' and
Expand Down

0 comments on commit ec37d06

Please sign in to comment.