Skip to content

Commit

Permalink
Allow modifying method return types as well as function return types
Browse files Browse the repository at this point in the history
  • Loading branch information
delfick committed May 28, 2024
1 parent 58acb62 commit db1ef71
Showing 1 changed file with 51 additions and 6 deletions.
57 changes: 51 additions & 6 deletions extended_mypy_django_plugin/plugin/_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ClassDefContext,
DynamicClassDefContext,
FunctionContext,
MethodContext,
)
from mypy.semanal import SemanticAnalyzer
from mypy.typeanal import TypeAnalyser
Expand Down Expand Up @@ -294,13 +295,24 @@ def run(self, ctx: AttributeContext) -> MypyType:
ctx, resolve_manager_method_from_instance=resolve_manager_method_from_instance
)

@_hook.hook
class get_function_hook(Hook[FunctionContext, MypyType]):
class SharedCallableHookLogic:
"""
Find functions that return a concrete annotation with a type var and resolve that annotation
Shared logic for modifying the return type of methods and functions that use a concrete
annotation with a type variable.
"""

def __init__(self, fullname: str, plugin: "ExtendedMypyStubs") -> None:
self.plugin = plugin
self.store = plugin.store
self.fullname = fullname

def choose(self) -> bool:
"""
Choose methods and functions either returning a type guard or have a generic
return type.
We determine whether the return type is a concrete annotation or not in the run method.
"""
if self.fullname.startswith("builtins."):
return False

Expand All @@ -312,9 +324,9 @@ def choose(self) -> bool:
if not isinstance(call, CallableType):
return False

return call.is_generic()
return bool(call.type_guard or call.is_generic())

def run(self, ctx: FunctionContext) -> MypyType:
def run(self, ctx: MethodContext | FunctionContext) -> MypyType | None:
assert isinstance(ctx.api, TypeChecker)

type_checking = actions.TypeChecking(
Expand All @@ -323,7 +335,40 @@ def run(self, ctx: FunctionContext) -> MypyType:
lookup_info=self.plugin._lookup_info,
)

result = type_checking.modify_return_type(ctx)
return type_checking.modify_return_type(ctx)

@_hook.hook
class get_method_hook(Hook[MethodContext, MypyType]):
def extra_init(self) -> None:
self.shared_logic = self.plugin.SharedCallableHookLogic(
fullname=self.fullname, plugin=self.plugin
)

def choose(self) -> bool:
return self.shared_logic.choose()

def run(self, ctx: MethodContext) -> MypyType:
result = self.shared_logic.run(ctx)
if result is not None:
return result

if self.super_hook is not None:
return self.super_hook(ctx)

return ctx.default_return_type

@_hook.hook
class get_function_hook(Hook[FunctionContext, MypyType]):
def extra_init(self) -> None:
self.shared_logic = self.plugin.SharedCallableHookLogic(
fullname=self.fullname, plugin=self.plugin
)

def choose(self) -> bool:
return self.shared_logic.choose()

def run(self, ctx: FunctionContext) -> MypyType:
result = self.shared_logic.run(ctx)

if result is not None:
return result
Expand Down

0 comments on commit db1ef71

Please sign in to comment.