Skip to content

Commit

Permalink
Formatting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
syastrov committed May 10, 2021
1 parent db604d6 commit ebf9500
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 63 deletions.
2 changes: 1 addition & 1 deletion mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]:
# Strip suffix which is present if the model came from QuerySet.annotate
if is_annotated_model_fullname(fullname):
fullname, _, _ = fullname.rpartition(ANNOTATED_SUFFIX)
module, _, model_cls_name = fullname.rpartition('.')
module, _, model_cls_name = fullname.rpartition(".")
for model_cls in self.model_modules.get(module, set()):
if model_cls.__name__ == model_cls_name:
return model_cls
Expand Down
6 changes: 3 additions & 3 deletions mypy_django_plugin/lib/fullnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
DUMMY_SETTINGS_BASE_CLASS = "django.conf._DjangoConfLazyObject"

QUERYSET_CLASS_FULLNAME = "django.db.models.query.QuerySet"
BASE_QUERYSET_CLASS_FULLNAME = 'django.db.models.query._BaseQuerySet'
VALUES_QUERYSET_CLASS_FULLNAME = 'django.db.models.query.ValuesQuerySet'
BASE_QUERYSET_CLASS_FULLNAME = "django.db.models.query._BaseQuerySet"
VALUES_QUERYSET_CLASS_FULLNAME = "django.db.models.query.ValuesQuerySet"
BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager"
MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager"
RELATED_MANAGER_CLASS = "django.db.models.manager.RelatedManager"
Expand All @@ -37,4 +37,4 @@

F_EXPRESSION_FULLNAME = "django.db.models.expressions.F"

ANY_ATTR_ALLOWED_CLASS_FULLNAME = 'django._AnyAttrAllowed'
ANY_ATTR_ALLOWED_CLASS_FULLNAME = "django._AnyAttrAllowed"
11 changes: 6 additions & 5 deletions mypy_django_plugin/lib/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,15 @@ def get_current_module(api: TypeChecker) -> MypyFile:
return current_module


def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]',
extra_bases: Optional[List[Instance]] = None) -> TupleType:
def make_oneoff_named_tuple(
api: TypeChecker, name: str, fields: "OrderedDict[str, MypyType]", extra_bases: Optional[List[Instance]] = None
) -> TupleType:
current_module = get_current_module(api)
if extra_bases is None:
extra_bases = []
namedtuple_info = add_new_class_for_module(current_module, name,
bases=[api.named_generic_type('typing.NamedTuple', [])] + extra_bases,
fields=fields)
namedtuple_info = add_new_class_for_module(
current_module, name, bases=[api.named_generic_type("typing.NamedTuple", [])] + extra_bases, fields=fields
)
return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, []))


Expand Down
2 changes: 1 addition & 1 deletion mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)

if method_name == 'annotate':
if method_name == "annotate":
info = self._get_typeinfo_or_none(class_fullname)
if info and info.has_base(fullnames.BASE_QUERYSET_CLASS_FULLNAME):
return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context)
Expand Down
126 changes: 73 additions & 53 deletions mypy_django_plugin/transformers/querysets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@
from mypy.types import Type as MypyType
from mypy.types import TypedDictType, TypeOfAny

from mypy_django_plugin.django.context import (
DjangoContext, LookupsAreUnsupported,
)
from mypy_django_plugin.django.context import DjangoContext, LookupsAreUnsupported
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.constants import ANNOTATED_SUFFIX
from mypy_django_plugin.lib.fullnames import (
ANY_ATTR_ALLOWED_CLASS_FULLNAME, VALUES_QUERYSET_CLASS_FULLNAME,
)
from mypy_django_plugin.lib.helpers import (
add_new_class_for_module, is_annotated_model_fullname,
)
from mypy_django_plugin.lib.fullnames import ANY_ATTR_ALLOWED_CLASS_FULLNAME, VALUES_QUERYSET_CLASS_FULLNAME
from mypy_django_plugin.lib.helpers import add_new_class_for_module, is_annotated_model_fullname


def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
for base_type in [queryset_type, *queryset_type.type.bases]:
if (len(base_type.args)
and isinstance(base_type.args[0], Instance)
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)):
if (
len(base_type.args)
and isinstance(base_type.args[0], Instance)
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)
):
return base_type.args[0]
return None

Expand All @@ -39,15 +35,21 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
assert isinstance(default_return_type, Instance)

outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class()
if (outer_model_info is None
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
if outer_model_info is None or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME):
return default_return_type

return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])])


def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
*, method: str, lookup: str, silent_on_error: bool = False) -> Optional[MypyType]:
def get_field_type_from_lookup(
ctx: MethodContext,
django_context: DjangoContext,
model_cls: Type[Model],
*,
method: str,
lookup: str,
silent_on_error: bool = False
) -> Optional[MypyType]:
try:
lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup)
except FieldError as exc:
Expand All @@ -57,20 +59,26 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
except LookupsAreUnsupported:
return AnyType(TypeOfAny.explicit)

if ((isinstance(lookup_field, RelatedField) and lookup_field.column == lookup)
or isinstance(lookup_field, ForeignObjectRel)):
if (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance(
lookup_field, ForeignObjectRel
):
related_model_cls = django_context.get_field_related_model_cls(lookup_field)
if related_model_cls is None:
return AnyType(TypeOfAny.from_error)
lookup_field = django_context.get_primary_key_field(related_model_cls)

field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx),
lookup_field, method=method)
field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method)
return field_get_type


def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
is_annotated: bool, flat: bool, named: bool) -> MypyType:
def get_values_list_row_type(
ctx: MethodContext,
django_context: DjangoContext,
model_cls: Type[Model],
is_annotated: bool,
flat: bool,
named: bool,
) -> MypyType:
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
if field_lookups is None:
return AnyType(TypeOfAny.from_error)
Expand All @@ -79,27 +87,30 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
if len(field_lookups) == 0:
if flat:
primary_key_field = django_context.get_primary_key_field(model_cls)
lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
lookup=primary_key_field.attname, method='values_list')
lookup_type = get_field_type_from_lookup(
ctx, django_context, model_cls, lookup=primary_key_field.attname, method="values_list"
)
assert lookup_type is not None
return lookup_type
elif named:
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
column_types: "OrderedDict[str, MypyType]" = OrderedDict()
for field in django_context.get_model_fields(model_cls):
column_type = django_context.get_field_get_type(typechecker_api, field,
method='values_list')
column_type = django_context.get_field_get_type(typechecker_api, field, method="values_list")
column_types[field.attname] = column_type
if is_annotated:
# Return a NamedTuple with a fallback so that it's possible to access any field
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types, extra_bases=[
typechecker_api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])
])
return helpers.make_oneoff_named_tuple(
typechecker_api,
"Row",
column_types,
extra_bases=[typechecker_api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])],
)
else:
return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
else:
# flat=False, named=False, all fields
if is_annotated:
return typechecker_api.named_generic_type('builtins.tuple', [AnyType(TypeOfAny.special_form)])
return typechecker_api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)])
field_lookups = []
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)
Expand All @@ -110,9 +121,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,

column_types = OrderedDict()
for field_lookup in field_lookups:
lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls,
lookup=field_lookup, method='values_list',
silent_on_error=is_annotated)
lookup_field_type = get_field_type_from_lookup(
ctx, django_context, model_cls, lookup=field_lookup, method="values_list", silent_on_error=is_annotated
)
if lookup_field_type is None:
if is_annotated:
lookup_field_type = AnyType(TypeOfAny.from_omitted_generics)
Expand All @@ -124,7 +135,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
assert len(column_types) == 1
row_type = next(iter(column_types.values()))
elif named:
row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types)
row_type = helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types)
else:
row_type = helpers.make_tuple(typechecker_api, list(column_types.values()))

Expand All @@ -144,13 +155,13 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
if model_cls is None:
return ctx.default_return_type

flat_expr = helpers.get_call_argument_by_name(ctx, 'flat')
flat_expr = helpers.get_call_argument_by_name(ctx, "flat")
if flat_expr is not None and isinstance(flat_expr, NameExpr):
flat = helpers.parse_bool(flat_expr)
else:
flat = False

named_expr = helpers.get_call_argument_by_name(ctx, 'named')
named_expr = helpers.get_call_argument_by_name(ctx, "named")
if named_expr is not None and isinstance(named_expr, NameExpr):
named = helpers.parse_bool(named_expr)
else:
Expand All @@ -165,8 +176,9 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context:
named = named or False

is_annotated = is_annotated_model_fullname(model_type.type.fullname)
row_type = get_values_list_row_type(ctx, django_context, model_cls, is_annotated=is_annotated,
flat=flat, named=named)
row_type = get_values_list_row_type(
ctx, django_context, model_cls, is_annotated=is_annotated, flat=flat, named=named
)
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])


Expand All @@ -187,34 +199,41 @@ def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: Dj
if model_type.type.has_base(ANY_ATTR_ALLOWED_CLASS_FULLNAME):
annotated_type = model_type
else:
annotated_typeinfo = helpers.lookup_fully_qualified_typeinfo(cast(TypeChecker, api),
model_module_name + "." + type_name)
annotated_typeinfo = helpers.lookup_fully_qualified_typeinfo(
cast(TypeChecker, api), model_module_name + "." + type_name
)
if annotated_typeinfo is None:
model_module_file = api.modules[model_module_name] # type: ignore
annotated_model_type = api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])

# Create a new class in the same module as the model, with the same name as the model but with a suffix
# The class inherits from the model and an internal class which allows get/set of any attribute.
# Essentially, this is a way of making an "intersection" type between the two types.
annotated_typeinfo = add_new_class_for_module(model_module_file, type_name,
bases=[model_type, annotated_model_type, ], )
annotated_typeinfo = add_new_class_for_module(
model_module_file,
type_name,
bases=[
model_type,
annotated_model_type,
],
)
annotated_type = Instance(annotated_typeinfo, [])
if ctx.type.type.has_base(VALUES_QUERYSET_CLASS_FULLNAME):
original_row_type: MypyType = ctx.default_return_type.args[1]
row_type: MypyType = original_row_type
if isinstance(original_row_type, TypedDictType):
row_type = api.named_generic_type('builtins.dict', [api.named_generic_type('builtins.str', []),
AnyType(TypeOfAny.from_omitted_generics)])
row_type = api.named_generic_type(
"builtins.dict", [api.named_generic_type("builtins.str", []), AnyType(TypeOfAny.from_omitted_generics)]
)
elif isinstance(original_row_type, TupleType):
fallback: Instance = original_row_type.partial_fallback
if fallback is not None and fallback.type.has_base('typing.NamedTuple'):
if fallback is not None and fallback.type.has_base("typing.NamedTuple"):
# TODO: Use a NamedTuple which contains the known fields, but also
# falls back to allowing any attribute access.
row_type = AnyType(TypeOfAny.implementation_artifact)
else:
row_type = api.named_generic_type('builtins.tuple', [AnyType(TypeOfAny.from_omitted_generics)])
return helpers.reparametrize_instance(ctx.default_return_type,
[annotated_type, row_type])
row_type = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.from_omitted_generics)])
return helpers.reparametrize_instance(ctx.default_return_type, [annotated_type, row_type])
else:
return helpers.reparametrize_instance(ctx.default_return_type, [annotated_type])

Expand Down Expand Up @@ -253,10 +272,11 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
for field in django_context.get_model_fields(model_cls):
field_lookups.append(field.attname)

column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
column_types: "OrderedDict[str, MypyType]" = OrderedDict()
for field_lookup in field_lookups:
field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
lookup=field_lookup, method='values')
field_lookup_type = get_field_type_from_lookup(
ctx, django_context, model_cls, lookup=field_lookup, method="values"
)
if field_lookup_type is None:
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])

Expand Down

0 comments on commit ebf9500

Please sign in to comment.