Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract through table creation to separate method #2229

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast

from django.core.exceptions import FieldDoesNotExist
from django.db.models.fields import AutoField, Field
Expand Down Expand Up @@ -114,12 +114,17 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
)


class FieldDescriptorTypes(NamedTuple):
set: MypyType
get: MypyType


def get_field_descriptor_types(
field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool
) -> Tuple[MypyType, MypyType]:
) -> FieldDescriptorTypes:
set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable)
get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable)
return set_type, get_type
return FieldDescriptorTypes(set=set_type, get=get_type)


def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
Expand Down
6 changes: 3 additions & 3 deletions mypy_django_plugin/transformers/manytomany.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import NamedTuple, Optional, Tuple, Union

from mypy.checker import TypeChecker
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, RefExpr, StrExpr, TypeInfo
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo
from mypy.plugin import FunctionContext, MethodContext
from mypy.semanal import SemanticAnalyzer
from mypy.types import Instance, ProperType, TypeVarType, UninhabitedType
Expand All @@ -12,12 +12,12 @@


class M2MThrough(NamedTuple):
arg: Optional[Expression]
arg: Optional[Node]
model: ProperType


class M2MTo(NamedTuple):
arg: Expression
arg: Node
model: ProperType
self: bool # ManyToManyField('self', ...)

Expand Down
210 changes: 117 additions & 93 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
from mypy_django_plugin.transformers.fields import FieldDescriptorTypes, get_field_descriptor_types
from mypy_django_plugin.transformers.managers import (
MANAGER_METHODS_RETURNING_QUERYSET,
create_manager_info_from_from_queryset_call,
Expand Down Expand Up @@ -644,17 +644,6 @@ def run(self) -> None:
# TODO: Create abstract through models?
return

# Start out by prefetching a couple of dependencies needed to be able to declare any
# new, implicit, through model class.
model_base = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME)
fk_field = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME)
manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
if model_base is None or fk_field is None or manager_info is None:
raise helpers.IncompleteDefnException()

from_pk = self.get_pk_instance(self.model_classdef.info)
fk_set_type, fk_get_type = get_field_descriptor_types(fk_field, is_set_nullable=False, is_get_nullable=False)

for statement in self.statements():
# Check if this part of the class body is an assignment from a 'ManyToManyField' call
# <field> = ManyToManyField(...)
Expand All @@ -675,90 +664,16 @@ def run(self) -> None:
continue
# Resolve argument information of the 'ManyToManyField(...)' call
args = self.resolve_many_to_many_arguments(statement.rvalue, context=statement)
if (
# Ignore calls without required 'to' argument, mypy will complain
args is None
or not isinstance(args.to.model, Instance)
# Call has explicit 'through=', no need to create any implicit through table
or args.through is not None
):
# Ignore calls without required 'to' argument, mypy will complain
if args is None:
continue

# Get the names of the implicit through model that will be generated
through_model_name = f"{self.model_classdef.name}_{m2m_field_name}"
through_model_fullname = f"{self.model_classdef.info.module_name}.{through_model_name}"
# If implicit through model is already declared there's nothing more we should do
through_model = self.lookup_typeinfo(through_model_fullname)
if through_model is not None:
continue
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
through_model = self.add_new_class_for_current_module(
through_model_name, bases=[Instance(model_base, [])]
)
# We attempt to be a bit clever here and store the generated through model's fullname in
# the metadata of the class containing the 'ManyToManyField' call expression, where its
# identifier is the field name of the 'ManyToManyField'. This would allow the containing
# model to always find the implicit through model, so that it doesn't get lost.
model_metadata = helpers.get_django_metadata(self.model_classdef.info)
model_metadata.setdefault("m2m_throughs", {})
model_metadata["m2m_throughs"][m2m_field_name] = through_model.fullname
# Add a 'pk' symbol to the model class
helpers.add_new_sym_for_info(
through_model, name="pk", sym_type=self.default_pk_instance.copy_modified()
)
# Add an 'id' symbol to the model class
helpers.add_new_sym_for_info(
through_model, name="id", sym_type=self.default_pk_instance.copy_modified()
)
# Add the foreign key to the model containing the 'ManyToManyField' call:
# <containing_model> or from_<model>
from_name = (
f"from_{self.model_classdef.name.lower()}" if args.to.self else self.model_classdef.name.lower()
)
helpers.add_new_sym_for_info(
through_model,
name=from_name,
sym_type=Instance(
fk_field,
[
helpers.convert_any_to_type(fk_set_type, Instance(self.model_classdef.info, [])),
helpers.convert_any_to_type(fk_get_type, Instance(self.model_classdef.info, [])),
],
),
)
# Add the foreign key's '_id' field: <containing_model>_id or from_<model>_id
helpers.add_new_sym_for_info(through_model, name=f"{from_name}_id", sym_type=from_pk.copy_modified())
# Add the foreign key to the model on the opposite side of the relation
# i.e. the model given as 'to' argument to the 'ManyToManyField' call:
# <other_model> or to_<model>
to_name = f"to_{args.to.model.type.name.lower()}" if args.to.self else args.to.model.type.name.lower()
helpers.add_new_sym_for_info(
through_model,
name=to_name,
sym_type=Instance(
fk_field,
[
helpers.convert_any_to_type(fk_set_type, args.to.model),
helpers.convert_any_to_type(fk_get_type, args.to.model),
],
),
)
# Add the foreign key's '_id' field: <other_model>_id or to_<model>_id
other_pk = self.get_pk_instance(args.to.model.type)
helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified())
# Add a manager named 'objects'
helpers.add_new_sym_for_info(
through_model,
name="objects",
sym_type=Instance(manager_info, [Instance(through_model, [])]),
is_classvar=True,
)
# Also add manager as '_default_manager' attribute
helpers.add_new_sym_for_info(
through_model,
name="_default_manager",
sym_type=Instance(manager_info, [Instance(through_model, [])]),
is_classvar=True,
self.create_through_table_class(
field_name=m2m_field_name,
model_name=through_model_name,
model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}",
m2m_args=args,
)

@cached_property
Expand All @@ -771,6 +686,35 @@ def default_pk_instance(self) -> Instance:
list(get_field_descriptor_types(default_pk_field, is_set_nullable=True, is_get_nullable=False)),
)

@cached_property
def model_pk_instance(self) -> Instance:
return self.get_pk_instance(self.model_classdef.info)

@cached_property
def model_base(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME)
if info is None:
raise helpers.IncompleteDefnException()
return info

@cached_property
def fk_field(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME)
if info is None:
raise helpers.IncompleteDefnException()
return info

@cached_property
def manager_info(self) -> TypeInfo:
info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
if info is None:
raise helpers.IncompleteDefnException()
return info

@cached_property
def fk_field_types(self) -> FieldDescriptorTypes:
return get_field_descriptor_types(self.fk_field, is_set_nullable=False, is_get_nullable=False)

def get_pk_instance(self, model: TypeInfo, /) -> Instance:
"""
Get a primary key instance of provided model's type info. If primary key can't be resolved,
Expand All @@ -783,6 +727,86 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance:
return pk.type
return self.default_pk_instance

def create_through_table_class(
self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments
) -> None:
if (
not isinstance(m2m_args.to.model, Instance)
# Call has explicit 'through=', no need to create any implicit through table
or m2m_args.through is not None
):
return

# If through model is already declared there's nothing more we should do
through_model = self.lookup_typeinfo(model_fullname)
if through_model is not None:
return
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])])
# We attempt to be a bit clever here and store the generated through model's fullname in
# the metadata of the class containing the 'ManyToManyField' call expression, where its
# identifier is the field name of the 'ManyToManyField'. This would allow the containing
# model to always find the implicit through model, so that it doesn't get lost.
model_metadata = helpers.get_django_metadata(self.model_classdef.info)
model_metadata.setdefault("m2m_throughs", {})
model_metadata["m2m_throughs"][field_name] = through_model.fullname
# Add a 'pk' symbol to the model class
helpers.add_new_sym_for_info(through_model, name="pk", sym_type=self.default_pk_instance.copy_modified())
# Add an 'id' symbol to the model class
helpers.add_new_sym_for_info(through_model, name="id", sym_type=self.default_pk_instance.copy_modified())
# Add the foreign key to the model containing the 'ManyToManyField' call:
# <containing_model> or from_<model>
from_name = f"from_{self.model_classdef.name.lower()}" if m2m_args.to.self else self.model_classdef.name.lower()
helpers.add_new_sym_for_info(
through_model,
name=from_name,
sym_type=Instance(
self.fk_field,
[
helpers.convert_any_to_type(self.fk_field_types.set, Instance(self.model_classdef.info, [])),
helpers.convert_any_to_type(self.fk_field_types.get, Instance(self.model_classdef.info, [])),
],
),
)
# Add the foreign key's '_id' field: <containing_model>_id or from_<model>_id
helpers.add_new_sym_for_info(
through_model, name=f"{from_name}_id", sym_type=self.model_pk_instance.copy_modified()
)
# Add the foreign key to the model on the opposite side of the relation
# i.e. the model given as 'to' argument to the 'ManyToManyField' call:
# <other_model> or to_<model>
to_name = (
f"to_{m2m_args.to.model.type.name.lower()}" if m2m_args.to.self else m2m_args.to.model.type.name.lower()
)
helpers.add_new_sym_for_info(
through_model,
name=to_name,
sym_type=Instance(
self.fk_field,
[
helpers.convert_any_to_type(self.fk_field_types.set, m2m_args.to.model),
helpers.convert_any_to_type(self.fk_field_types.get, m2m_args.to.model),
],
),
)
# Add the foreign key's '_id' field: <other_model>_id or to_<model>_id
other_pk = self.get_pk_instance(m2m_args.to.model.type)
helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified())
# Add a manager named 'objects'
helpers.add_new_sym_for_info(
through_model,
name="objects",
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
is_classvar=True,
)
# Also add manager as '_default_manager' attribute
helpers.add_new_sym_for_info(
through_model,
name="_default_manager",
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
is_classvar=True,
)

def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]:
"""
Inspect a 'ManyToManyField(...)' call to collect argument data on any 'to' and
Expand Down
Loading