diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 6e2fa7fb8..def89f988 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -355,8 +355,28 @@ def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], return prepared_arguments, return_type +def get_method_return_type(semanal_api: SemanticAnalyzer, method_node: FuncDef) -> Optional[MypyType]: + method_type = method_node.type + if not isinstance(method_type, CallableType): + if not semanal_api.final_iteration: + semanal_api.defer() + return None + + bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True) + + assert bound_return_type is not None + + if isinstance(bound_return_type, PlaceholderNode): + return None + return bound_return_type + + def copy_method_to_another_class( - ctx: ClassDefContext, self_type: Instance, new_method_name: str, method_node: FuncDef + ctx: ClassDefContext, + self_type: Instance, + new_method_name: str, + method_node: FuncDef, + override_return_type: Optional[MypyType] = None, ) -> None: semanal_api = get_semanal_api(ctx) if method_node.type is None: @@ -374,12 +394,8 @@ def copy_method_to_another_class( semanal_api.defer() return - arguments = [] - bound_return_type = semanal_api.anal_type(method_type.ret_type, allow_placeholder=True) - - assert bound_return_type is not None - - if isinstance(bound_return_type, PlaceholderNode): + bound_return_type = get_method_return_type(semanal_api, method_node) + if bound_return_type is None: return try: @@ -387,6 +403,7 @@ def copy_method_to_another_class( except AttributeError: original_arguments = [] + arguments = [] for arg_name, arg_type, original_argument in zip( method_type.arg_names[1:], method_type.arg_types[1:], original_arguments ): @@ -412,4 +429,7 @@ def copy_method_to_another_class( argument.set_line(original_argument) arguments.append(argument) + if override_return_type is not None: + bound_return_type = override_return_type + add_method(ctx, new_method_name, args=arguments, return_type=bound_return_type, self_type=self_type) diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 13bb39bc4..a0ee1dbe3 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -1,6 +1,6 @@ from mypy.nodes import GDEF, Decorator, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo from mypy.plugin import ClassDefContext, DynamicClassDefContext -from mypy.types import AnyType, Instance, TypeOfAny +from mypy.types import AnyType, Instance, TypeOfAny, get_proper_type from mypy_django_plugin.lib import fullnames, helpers @@ -32,12 +32,6 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte new_manager_info = semanal_api.basic_new_typeinfo( ctx.name, basetype_or_fallback=Instance(base_manager_info, [AnyType(TypeOfAny.unannotated)]), line=ctx.call.line ) - new_manager_info.line = ctx.call.line - new_manager_info.defn.line = ctx.call.line - new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() - - current_module = semanal_api.cur_mod_node - current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True) sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) assert sym is not None @@ -52,6 +46,14 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte derived_queryset_info = sym.node assert isinstance(derived_queryset_info, TypeInfo) + new_manager_info.line = ctx.call.line + new_manager_info.defn.line = ctx.call.line + # new_manager_info.bases.append(Instance(derived_queryset_info, [AnyType(TypeOfAny.unannotated)])) + new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() + + current_module = semanal_api.cur_mod_node + current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True) + if len(ctx.call.args) > 1: expr = ctx.call.args[1] assert isinstance(expr, StrExpr) @@ -66,9 +68,14 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api) self_type = Instance(new_manager_info, []) + + queryset_method_names = [] + # we need to copy all methods in MRO before django.db.models.query.QuerySet for class_mro_info in derived_queryset_info.mro: if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: + for name, sym in class_mro_info.names.items(): + queryset_method_names.append(name) break for name, sym in class_mro_info.names.items(): if isinstance(sym.node, FuncDef): @@ -80,3 +87,39 @@ def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefConte helpers.copy_method_to_another_class( class_def_context, self_type, new_method_name=name, method_node=func_node ) + + # Copy/alter all methods in common between BaseManager/QuerySet over to the new manager if their return type is + # QuerySet. Alter the return type to be the custom queryset. + for manager_mro_info in new_manager_info.mro: + if manager_mro_info.fullname != fullnames.BASE_MANAGER_CLASS_FULLNAME: + continue + + for name, sym in manager_mro_info.names.items(): + if name not in queryset_method_names: + continue + + if isinstance(sym.node, FuncDef): + func_node = sym.node + elif isinstance(sym.node, Decorator): + func_node = sym.node.func + else: + continue + bound_return_type = helpers.get_method_return_type(semanal_api, func_node) + if bound_return_type is None: + continue + + bound_return_type = get_proper_type(bound_return_type) + if not isinstance(bound_return_type, Instance): + continue + if not bound_return_type.type.has_base(fullnames.QUERYSET_CLASS_FULLNAME): + continue + + return_type = Instance(derived_queryset_info, bound_return_type.args) + + helpers.copy_method_to_another_class( + class_def_context, + self_type, + new_method_name=name, + method_node=func_node, + override_return_type=return_type, + ) diff --git a/tests/typecheck/managers/querysets/test_from_queryset.yml b/tests/typecheck/managers/querysets/test_from_queryset.yml index 53dceac33..0678909d5 100644 --- a/tests/typecheck/managers/querysets/test_from_queryset.yml +++ b/tests/typecheck/managers/querysets/test_from_queryset.yml @@ -4,6 +4,7 @@ reveal_type(MyModel().objects) # N: Revealed type is "myapp.models.MyModel_NewManager[myapp.models.MyModel]" reveal_type(MyModel().objects.get()) # N: Revealed type is "myapp.models.MyModel*" reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is "builtins.str" + MyModel.objects.filter(id=1).queryset_method() installed_apps: - myapp files: