Skip to content

Commit

Permalink
Add get_expression_type to CheckerPluginInterface (#15369)
Browse files Browse the repository at this point in the history
Fixes #14845.

p.s. In the issue above, I was concerned that adding this method would
create an avenue for infinite recursions (if called carelessly), but in
fact I haven't managed to induce it, e.g. FunctionSigContext has `args`
but not the call expression itself.
  • Loading branch information
ikonst authored Jun 11, 2023
1 parent e7b917e commit a108c67
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 28 deletions.
3 changes: 3 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6793,6 +6793,9 @@ def has_valid_attribute(self, typ: Type, name: str) -> bool:
)
return not watcher.has_new_errors()

def get_expression_type(self, node: Expression, type_context: Type | None = None) -> Type:
return self.expr_checker.accept(node, type_context=type_context)


class CollectArgTypeVarTypes(TypeTraverserVisitor):
"""Collects the non-nested argument types in a set."""
Expand Down
5 changes: 5 additions & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ def named_generic_type(self, name: str, args: list[Type]) -> Instance:
"""Construct an instance of a builtin type with given type arguments."""
raise NotImplementedError

@abstractmethod
def get_expression_type(self, node: Expression, type_context: Type | None = None) -> Type:
"""Checks the type of the given expression."""
raise NotImplementedError


@trait
class SemanticAnalyzerPluginInterface:
Expand Down
19 changes: 4 additions & 15 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import mypy.plugin # To avoid circular imports.
from mypy.applytype import apply_generic_arguments
from mypy.checker import TypeChecker
from mypy.errorcodes import LITERAL_REQ
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
Expand Down Expand Up @@ -1048,13 +1047,7 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
return ctx.default_signature # leave it to the type checker to complain

inst_arg = ctx.args[0][0]

# <hack>
assert isinstance(ctx.api, TypeChecker)
inst_type = ctx.api.expr_checker.accept(inst_arg)
# </hack>

inst_type = get_proper_type(inst_type)
inst_type = get_proper_type(ctx.api.get_expression_type(inst_arg))
inst_type_str = format_type_bare(inst_type, ctx.api.options)

attr_types = _get_expanded_attr_types(ctx, inst_type, inst_type, inst_type)
Expand All @@ -1074,14 +1067,10 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl

def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
"""Provide the signature for `attrs.fields`."""
if not ctx.args or len(ctx.args) != 1 or not ctx.args[0] or not ctx.args[0][0]:
if len(ctx.args) != 1 or len(ctx.args[0]) != 1:
return ctx.default_signature

# <hack>
assert isinstance(ctx.api, TypeChecker)
inst_type = ctx.api.expr_checker.accept(ctx.args[0][0])
# </hack>
proper_type = get_proper_type(inst_type)
proper_type = get_proper_type(ctx.api.get_expression_type(ctx.args[0][0]))

# fields(Any) -> Any, fields(type[Any]) -> Any
if (
Expand All @@ -1098,7 +1087,7 @@ def fields_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
inner = get_proper_type(proper_type.upper_bound)
if isinstance(inner, Instance):
# We need to work arg_types to compensate for the attrs stubs.
arg_types = [inst_type]
arg_types = [proper_type]
cls = inner.type
elif isinstance(proper_type, CallableType):
cls = proper_type.type_object()
Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/check-custom-plugin.test
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,10 @@ plugins=<ROOT>/test-data/unit/plugins/descriptor.py
# flags: --config-file tmp/mypy.ini

def dynamic_signature(arg1: str) -> str: ...
reveal_type(dynamic_signature(1)) # N: Revealed type is "builtins.int"
a: int = 1
reveal_type(dynamic_signature(a)) # N: Revealed type is "builtins.int"
b: bytes = b'foo'
reveal_type(dynamic_signature(b)) # N: Revealed type is "builtins.bytes"
[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/function_sig_hook.py
Expand Down
21 changes: 9 additions & 12 deletions test-data/unit/plugins/function_sig_hook.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
from mypy.plugin import CallableType, CheckerPluginInterface, FunctionSigContext, Plugin
from mypy.types import Instance, Type
from mypy.plugin import CallableType, FunctionSigContext, Plugin


class FunctionSigPlugin(Plugin):
def get_function_signature_hook(self, fullname):
if fullname == '__main__.dynamic_signature':
return my_hook
return None

def _str_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
if isinstance(typ, Instance):
if typ.type.fullname == 'builtins.str':
return api.named_generic_type('builtins.int', [])
elif typ.args:
return typ.copy_modified(args=[_str_to_int(api, t) for t in typ.args])

return typ

def my_hook(ctx: FunctionSigContext) -> CallableType:
arg1_args = ctx.args[0]
if len(arg1_args) != 1:
return ctx.default_signature
arg1_type = ctx.api.get_expression_type(arg1_args[0])
return ctx.default_signature.copy_modified(
arg_types=[_str_to_int(ctx.api, t) for t in ctx.default_signature.arg_types],
ret_type=_str_to_int(ctx.api, ctx.default_signature.ret_type),
arg_types=[arg1_type],
ret_type=arg1_type,
)


def plugin(version):
return FunctionSigPlugin

0 comments on commit a108c67

Please sign in to comment.