From ebf950010af43cf8c45036d32bedb5a76571cda4 Mon Sep 17 00:00:00 2001 From: Seth Yastrov Date: Mon, 10 May 2021 08:57:42 +0200 Subject: [PATCH] Formatting fixes --- mypy_django_plugin/django/context.py | 2 +- mypy_django_plugin/lib/fullnames.py | 6 +- mypy_django_plugin/lib/helpers.py | 11 +- mypy_django_plugin/main.py | 2 +- mypy_django_plugin/transformers/querysets.py | 126 +++++++++++-------- 5 files changed, 84 insertions(+), 63 deletions(-) diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index e938f9ea6..3485e3376 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -111,7 +111,7 @@ def get_model_class_by_fullname(self, fullname: str) -> Optional[Type[Model]]: # Strip suffix which is present if the model came from QuerySet.annotate if is_annotated_model_fullname(fullname): fullname, _, _ = fullname.rpartition(ANNOTATED_SUFFIX) - module, _, model_cls_name = fullname.rpartition('.') + module, _, model_cls_name = fullname.rpartition(".") for model_cls in self.model_modules.get(module, set()): if model_cls.__name__ == model_cls_name: return model_cls diff --git a/mypy_django_plugin/lib/fullnames.py b/mypy_django_plugin/lib/fullnames.py index 883fa70c8..aadc52411 100644 --- a/mypy_django_plugin/lib/fullnames.py +++ b/mypy_django_plugin/lib/fullnames.py @@ -12,8 +12,8 @@ DUMMY_SETTINGS_BASE_CLASS = "django.conf._DjangoConfLazyObject" QUERYSET_CLASS_FULLNAME = "django.db.models.query.QuerySet" -BASE_QUERYSET_CLASS_FULLNAME = 'django.db.models.query._BaseQuerySet' -VALUES_QUERYSET_CLASS_FULLNAME = 'django.db.models.query.ValuesQuerySet' +BASE_QUERYSET_CLASS_FULLNAME = "django.db.models.query._BaseQuerySet" +VALUES_QUERYSET_CLASS_FULLNAME = "django.db.models.query.ValuesQuerySet" BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager" MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager" RELATED_MANAGER_CLASS = "django.db.models.manager.RelatedManager" @@ -37,4 +37,4 @@ F_EXPRESSION_FULLNAME = "django.db.models.expressions.F" -ANY_ATTR_ALLOWED_CLASS_FULLNAME = 'django._AnyAttrAllowed' +ANY_ATTR_ALLOWED_CLASS_FULLNAME = "django._AnyAttrAllowed" diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index f29f5f7ff..687cb9cf7 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -234,14 +234,15 @@ def get_current_module(api: TypeChecker) -> MypyFile: return current_module -def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]', - extra_bases: Optional[List[Instance]] = None) -> TupleType: +def make_oneoff_named_tuple( + api: TypeChecker, name: str, fields: "OrderedDict[str, MypyType]", extra_bases: Optional[List[Instance]] = None +) -> TupleType: current_module = get_current_module(api) if extra_bases is None: extra_bases = [] - namedtuple_info = add_new_class_for_module(current_module, name, - bases=[api.named_generic_type('typing.NamedTuple', [])] + extra_bases, - fields=fields) + namedtuple_info = add_new_class_for_module( + current_module, name, bases=[api.named_generic_type("typing.NamedTuple", [])] + extra_bases, fields=fields + ) return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, [])) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index b7a8d8d10..f953e6064 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -235,7 +235,7 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context) - if method_name == 'annotate': + if method_name == "annotate": info = self._get_typeinfo_or_none(class_fullname) if info and info.has_base(fullnames.BASE_QUERYSET_CLASS_FULLNAME): return partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context) diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 396e1adf0..410b7da9a 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -12,24 +12,20 @@ from mypy.types import Type as MypyType from mypy.types import TypedDictType, TypeOfAny -from mypy_django_plugin.django.context import ( - DjangoContext, LookupsAreUnsupported, -) +from mypy_django_plugin.django.context import DjangoContext, LookupsAreUnsupported from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib.constants import ANNOTATED_SUFFIX -from mypy_django_plugin.lib.fullnames import ( - ANY_ATTR_ALLOWED_CLASS_FULLNAME, VALUES_QUERYSET_CLASS_FULLNAME, -) -from mypy_django_plugin.lib.helpers import ( - add_new_class_for_module, is_annotated_model_fullname, -) +from mypy_django_plugin.lib.fullnames import ANY_ATTR_ALLOWED_CLASS_FULLNAME, VALUES_QUERYSET_CLASS_FULLNAME +from mypy_django_plugin.lib.helpers import add_new_class_for_module, is_annotated_model_fullname def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]: for base_type in [queryset_type, *queryset_type.type.bases]: - if (len(base_type.args) - and isinstance(base_type.args[0], Instance) - and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)): + if ( + len(base_type.args) + and isinstance(base_type.args[0], Instance) + and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME) + ): return base_type.args[0] return None @@ -39,15 +35,21 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: assert isinstance(default_return_type, Instance) outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() - if (outer_model_info is None - or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): + if outer_model_info is None or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME): return default_return_type return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])]) -def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], - *, method: str, lookup: str, silent_on_error: bool = False) -> Optional[MypyType]: +def get_field_type_from_lookup( + ctx: MethodContext, + django_context: DjangoContext, + model_cls: Type[Model], + *, + method: str, + lookup: str, + silent_on_error: bool = False +) -> Optional[MypyType]: try: lookup_field = django_context.resolve_lookup_into_field(model_cls, lookup) except FieldError as exc: @@ -57,20 +59,26 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext except LookupsAreUnsupported: return AnyType(TypeOfAny.explicit) - if ((isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) - or isinstance(lookup_field, ForeignObjectRel)): + if (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance( + lookup_field, ForeignObjectRel + ): related_model_cls = django_context.get_field_related_model_cls(lookup_field) if related_model_cls is None: return AnyType(TypeOfAny.from_error) lookup_field = django_context.get_primary_key_field(related_model_cls) - field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), - lookup_field, method=method) + field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method) return field_get_type -def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], - is_annotated: bool, flat: bool, named: bool) -> MypyType: +def get_values_list_row_type( + ctx: MethodContext, + django_context: DjangoContext, + model_cls: Type[Model], + is_annotated: bool, + flat: bool, + named: bool, +) -> MypyType: field_lookups = resolve_field_lookups(ctx.args[0], django_context) if field_lookups is None: return AnyType(TypeOfAny.from_error) @@ -79,27 +87,30 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, if len(field_lookups) == 0: if flat: primary_key_field = django_context.get_primary_key_field(model_cls) - lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls, - lookup=primary_key_field.attname, method='values_list') + lookup_type = get_field_type_from_lookup( + ctx, django_context, model_cls, lookup=primary_key_field.attname, method="values_list" + ) assert lookup_type is not None return lookup_type elif named: - column_types: 'OrderedDict[str, MypyType]' = OrderedDict() + column_types: "OrderedDict[str, MypyType]" = OrderedDict() for field in django_context.get_model_fields(model_cls): - column_type = django_context.get_field_get_type(typechecker_api, field, - method='values_list') + column_type = django_context.get_field_get_type(typechecker_api, field, method="values_list") column_types[field.attname] = column_type if is_annotated: # Return a NamedTuple with a fallback so that it's possible to access any field - return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types, extra_bases=[ - typechecker_api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, []) - ]) + return helpers.make_oneoff_named_tuple( + typechecker_api, + "Row", + column_types, + extra_bases=[typechecker_api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, [])], + ) else: - return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) + return helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types) else: # flat=False, named=False, all fields if is_annotated: - return typechecker_api.named_generic_type('builtins.tuple', [AnyType(TypeOfAny.special_form)]) + return typechecker_api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.special_form)]) field_lookups = [] for field in django_context.get_model_fields(model_cls): field_lookups.append(field.attname) @@ -110,9 +121,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, column_types = OrderedDict() for field_lookup in field_lookups: - lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls, - lookup=field_lookup, method='values_list', - silent_on_error=is_annotated) + lookup_field_type = get_field_type_from_lookup( + ctx, django_context, model_cls, lookup=field_lookup, method="values_list", silent_on_error=is_annotated + ) if lookup_field_type is None: if is_annotated: lookup_field_type = AnyType(TypeOfAny.from_omitted_generics) @@ -124,7 +135,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, assert len(column_types) == 1 row_type = next(iter(column_types.values())) elif named: - row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) + row_type = helpers.make_oneoff_named_tuple(typechecker_api, "Row", column_types) else: row_type = helpers.make_tuple(typechecker_api, list(column_types.values())) @@ -144,13 +155,13 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: if model_cls is None: return ctx.default_return_type - flat_expr = helpers.get_call_argument_by_name(ctx, 'flat') + flat_expr = helpers.get_call_argument_by_name(ctx, "flat") if flat_expr is not None and isinstance(flat_expr, NameExpr): flat = helpers.parse_bool(flat_expr) else: flat = False - named_expr = helpers.get_call_argument_by_name(ctx, 'named') + named_expr = helpers.get_call_argument_by_name(ctx, "named") if named_expr is not None and isinstance(named_expr, NameExpr): named = helpers.parse_bool(named_expr) else: @@ -165,8 +176,9 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: named = named or False is_annotated = is_annotated_model_fullname(model_type.type.fullname) - row_type = get_values_list_row_type(ctx, django_context, model_cls, is_annotated=is_annotated, - flat=flat, named=named) + row_type = get_values_list_row_type( + ctx, django_context, model_cls, is_annotated=is_annotated, flat=flat, named=named + ) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) @@ -187,8 +199,9 @@ def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: Dj if model_type.type.has_base(ANY_ATTR_ALLOWED_CLASS_FULLNAME): annotated_type = model_type else: - annotated_typeinfo = helpers.lookup_fully_qualified_typeinfo(cast(TypeChecker, api), - model_module_name + "." + type_name) + annotated_typeinfo = helpers.lookup_fully_qualified_typeinfo( + cast(TypeChecker, api), model_module_name + "." + type_name + ) if annotated_typeinfo is None: model_module_file = api.modules[model_module_name] # type: ignore annotated_model_type = api.named_generic_type(ANY_ATTR_ALLOWED_CLASS_FULLNAME, []) @@ -196,25 +209,31 @@ def extract_proper_type_queryset_annotate(ctx: MethodContext, django_context: Dj # Create a new class in the same module as the model, with the same name as the model but with a suffix # The class inherits from the model and an internal class which allows get/set of any attribute. # Essentially, this is a way of making an "intersection" type between the two types. - annotated_typeinfo = add_new_class_for_module(model_module_file, type_name, - bases=[model_type, annotated_model_type, ], ) + annotated_typeinfo = add_new_class_for_module( + model_module_file, + type_name, + bases=[ + model_type, + annotated_model_type, + ], + ) annotated_type = Instance(annotated_typeinfo, []) if ctx.type.type.has_base(VALUES_QUERYSET_CLASS_FULLNAME): original_row_type: MypyType = ctx.default_return_type.args[1] row_type: MypyType = original_row_type if isinstance(original_row_type, TypedDictType): - row_type = api.named_generic_type('builtins.dict', [api.named_generic_type('builtins.str', []), - AnyType(TypeOfAny.from_omitted_generics)]) + row_type = api.named_generic_type( + "builtins.dict", [api.named_generic_type("builtins.str", []), AnyType(TypeOfAny.from_omitted_generics)] + ) elif isinstance(original_row_type, TupleType): fallback: Instance = original_row_type.partial_fallback - if fallback is not None and fallback.type.has_base('typing.NamedTuple'): + if fallback is not None and fallback.type.has_base("typing.NamedTuple"): # TODO: Use a NamedTuple which contains the known fields, but also # falls back to allowing any attribute access. row_type = AnyType(TypeOfAny.implementation_artifact) else: - row_type = api.named_generic_type('builtins.tuple', [AnyType(TypeOfAny.from_omitted_generics)]) - return helpers.reparametrize_instance(ctx.default_return_type, - [annotated_type, row_type]) + row_type = api.named_generic_type("builtins.tuple", [AnyType(TypeOfAny.from_omitted_generics)]) + return helpers.reparametrize_instance(ctx.default_return_type, [annotated_type, row_type]) else: return helpers.reparametrize_instance(ctx.default_return_type, [annotated_type]) @@ -253,10 +272,11 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan for field in django_context.get_model_fields(model_cls): field_lookups.append(field.attname) - column_types: 'OrderedDict[str, MypyType]' = OrderedDict() + column_types: "OrderedDict[str, MypyType]" = OrderedDict() for field_lookup in field_lookups: - field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls, - lookup=field_lookup, method='values') + field_lookup_type = get_field_type_from_lookup( + ctx, django_context, model_cls, lookup=field_lookup, method="values" + ) if field_lookup_type is None: return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])