diff --git a/extended_mypy_django_plugin/plugin/_plugin.py b/extended_mypy_django_plugin/plugin/_plugin.py index 1c82699..e1268a1 100644 --- a/extended_mypy_django_plugin/plugin/_plugin.py +++ b/extended_mypy_django_plugin/plugin/_plugin.py @@ -12,6 +12,7 @@ ClassDefContext, DynamicClassDefContext, FunctionContext, + MethodContext, ) from mypy.semanal import SemanticAnalyzer from mypy.typeanal import TypeAnalyser @@ -294,13 +295,24 @@ def run(self, ctx: AttributeContext) -> MypyType: ctx, resolve_manager_method_from_instance=resolve_manager_method_from_instance ) - @_hook.hook - class get_function_hook(Hook[FunctionContext, MypyType]): + class SharedCallableHookLogic: """ - Find functions that return a concrete annotation with a type var and resolve that annotation + Shared logic for modifying the return type of methods and functions that use a concrete + annotation with a type variable. """ + def __init__(self, fullname: str, plugin: "ExtendedMypyStubs") -> None: + self.plugin = plugin + self.store = plugin.store + self.fullname = fullname + def choose(self) -> bool: + """ + Choose methods and functions either returning a type guard or have a generic + return type. + + We determine whether the return type is a concrete annotation or not in the run method. + """ if self.fullname.startswith("builtins."): return False @@ -312,9 +324,9 @@ def choose(self) -> bool: if not isinstance(call, CallableType): return False - return call.is_generic() + return bool(call.type_guard or call.is_generic()) - def run(self, ctx: FunctionContext) -> MypyType: + def run(self, ctx: MethodContext | FunctionContext) -> MypyType | None: assert isinstance(ctx.api, TypeChecker) type_checking = actions.TypeChecking( @@ -323,7 +335,40 @@ def run(self, ctx: FunctionContext) -> MypyType: lookup_info=self.plugin._lookup_info, ) - result = type_checking.modify_return_type(ctx) + return type_checking.modify_return_type(ctx) + + @_hook.hook + class get_method_hook(Hook[MethodContext, MypyType]): + def extra_init(self) -> None: + self.shared_logic = self.plugin.SharedCallableHookLogic( + fullname=self.fullname, plugin=self.plugin + ) + + def choose(self) -> bool: + return self.shared_logic.choose() + + def run(self, ctx: MethodContext) -> MypyType: + result = self.shared_logic.run(ctx) + if result is not None: + return result + + if self.super_hook is not None: + return self.super_hook(ctx) + + return ctx.default_return_type + + @_hook.hook + class get_function_hook(Hook[FunctionContext, MypyType]): + def extra_init(self) -> None: + self.shared_logic = self.plugin.SharedCallableHookLogic( + fullname=self.fullname, plugin=self.plugin + ) + + def choose(self) -> bool: + return self.shared_logic.choose() + + def run(self, ctx: FunctionContext) -> MypyType: + result = self.shared_logic.run(ctx) if result is not None: return result