Skip to content

Commit

Permalink
Add type checking plugin support for functions (python#3299)
Browse files Browse the repository at this point in the history
* Add type checking plugin support for functions

The plugins allow implementing special-case logic for
inferring the return type of certain functions with
tricky signatures such as `open` in Python 3.

Include plugins for `open` and `contextlib.contextmanager`.

Some design considerations:

- The plugins have direct access to mypy internals. The
  idea is that most plugins will be included with mypy
  so mypy maintainers can update the plugins as needed.

- User-maintained plugins are currently not supported but
  could be added in the future. However, the intention is
  to not have a stable plugin API, at least initially.
  User-maintained plugins would have to track mypy internal
  API changes. Later on, we may decide to provide a more
  stable API if there seems to be a significant need. The
  preferred way would still be to keep plugins in the
  mypy repo.

* Add test case for additional special cases

* Fix handling of arguments other than simple positional ones

Also add comments and some defensive checks.
  • Loading branch information
JukkaL authored and ilevkivskyi committed May 25, 2017
1 parent a494197 commit 53879ef
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 6 deletions.
6 changes: 5 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,8 +2223,12 @@ def visit_decorator(self, e: Decorator) -> None:
continue
dec = self.expr_checker.accept(d)
temp = self.temp_node(sig)
fullname = None
if isinstance(d, RefExpr):
fullname = d.fullname
sig, t2 = self.expr_checker.check_call(dec, [temp],
[nodes.ARG_POS], e)
[nodes.ARG_POS], e,
callable_name=fullname)
sig = cast(FunctionLike, sig)
sig = set_callable_name(sig, e.func)
e.var.type = sig
Expand Down
43 changes: 39 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from mypy.util import split_module_names
from mypy.typevars import fill_typevars
from mypy.visitor import ExpressionVisitor
from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback

from mypy import experiments

Expand Down Expand Up @@ -103,6 +104,7 @@ class ExpressionChecker(ExpressionVisitor[Type]):
type_context = None # type: List[Optional[Type]]

strfrm_checker = None # type: StringFormatterChecker
function_plugins = None # type: Dict[str, PluginCallback]

def __init__(self,
chk: 'mypy.checker.TypeChecker',
Expand All @@ -112,6 +114,7 @@ def __init__(self,
self.msg = msg
self.type_context = [None]
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version)

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.
Expand Down Expand Up @@ -198,7 +201,11 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
isinstance(callee_type, CallableType)
and callee_type.implicit):
return self.msg.untyped_function_call(callee_type, e)
ret_type = self.check_call_expr_with_callee_type(callee_type, e)
if not isinstance(e.callee, RefExpr):
fullname = None
else:
fullname = e.callee.fullname
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname)
if isinstance(ret_type, UninhabitedType):
self.chk.binder.unreachable()
if not allow_none_return and isinstance(ret_type, NoneTyp):
Expand Down Expand Up @@ -330,21 +337,44 @@ def try_infer_partial_type(self, e: CallExpr) -> None:
list(full_item_types))
del partial_types[var]

def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_kinds: List[int],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: Optional[str]) -> Type:
"""Use special case logic to infer the return type for of a particular named function.
Return the inferred return type.
"""
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
return self.function_plugins[fullname](
formal_arg_types, formal_arg_exprs, inferred_ret_type, self.chk.named_generic_type)

def check_call_expr_with_callee_type(self, callee_type: Type,
e: CallExpr) -> Type:
e: CallExpr, callable_name: Optional[str]) -> Type:
"""Type check call expression.
The given callee type overrides the type of the callee
expression.
"""
return self.check_call(callee_type, e.args, e.arg_kinds, e,
e.arg_names, callable_node=e.callee)[0]
e.arg_names, callable_node=e.callee,
callable_name=callable_name)[0]

def check_call(self, callee: Type, args: List[Expression],
arg_kinds: List[int], context: Context,
arg_names: List[str] = None,
callable_node: Expression = None,
arg_messages: MessageBuilder = None) -> Tuple[Type, Type]:
arg_messages: MessageBuilder = None,
callable_name: Optional[str] = None) -> Tuple[Type, Type]:
"""Type check a call.
Also infer type arguments if the callee is a generic function.
Expand Down Expand Up @@ -406,6 +436,11 @@ def check_call(self, callee: Type, args: List[Expression],
if callable_node:
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)
if callable_name in self.function_plugins:
ret_type = self.apply_function_plugin(
arg_types, callee.ret_type, arg_kinds, formal_to_actual,
args, len(callee.arg_types), callable_name)
callee = callee.copy_modified(ret_type=ret_type)
return callee.ret_type, callee
elif isinstance(callee, Overloaded):
# Type check arguments in empty context. They will be checked again
Expand Down
81 changes: 81 additions & 0 deletions mypy/funcplugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Plugins that implement special type checking rules for individual functions.
The plugins infer better types for tricky functions such as "open".
"""

from typing import Tuple, Dict, Callable, List

from mypy.nodes import Expression, StrExpr
from mypy.types import Type, Instance, CallableType


# A callback that infers the return type of a function with a special signature.
#
# A no-op callback would just return the inferred return type, but a useful callback
# at least sometimes can infer a more precise type.
PluginCallback = Callable[
[
List[List[Type]], # List of types caller provides for each formal argument
List[List[Expression]], # Actual argument expressions for each formal argument
Type, # Return type for call inferred using the regular signature
Callable[[str, List[Type]], Type] # Callable for constructing a named instance type
],
Type # Return type inferred by the callback
]


def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, PluginCallback]:
"""Return all available function plugins for a given Python version."""
if python_version[0] == 3:
return {
'builtins.open': open_callback,
'contextlib.contextmanager': contextmanager_callback,
}
else:
return {
'contextlib.contextmanager': contextmanager_callback,
}


def open_callback(
arg_types: List[List[Type]],
args: List[List[Expression]],
inferred_return_type: Type,
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
"""Infer a better return type for 'open'.
Infer IO[str] or IO[bytes] as the return value if the mode argument is not
given or is a literal.
"""
mode = None
if not arg_types or len(arg_types[1]) != 1:
mode = 'r'
elif isinstance(args[1][0], StrExpr):
mode = args[1][0].value
if mode is not None:
assert isinstance(inferred_return_type, Instance)
if 'b' in mode:
arg = named_generic_type('builtins.bytes', [])
else:
arg = named_generic_type('builtins.str', [])
return Instance(inferred_return_type.type, [arg])
return inferred_return_type


def contextmanager_callback(
arg_types: List[List[Type]],
args: List[List[Expression]],
inferred_return_type: Type,
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
"""Infer a better return type for 'contextlib.contextmanager'."""
# Be defensive, just in case.
if arg_types and len(arg_types[0]) == 1:
arg_type = arg_types[0][0]
if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType):
# The stub signature doesn't preserve information about arguments so
# add them back here.
return inferred_return_type.copy_modified(
arg_types=arg_type.arg_types,
arg_kinds=arg_type.arg_kinds,
arg_names=arg_type.arg_names)
return inferred_return_type
52 changes: 51 additions & 1 deletion test-data/unit/pythoneval.test
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,33 @@ f.write('x')
f.write(b'x')
f.foobar()
[out]
_program.py:4: error: IO[Any] has no attribute "foobar"
_program.py:3: error: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str"
_program.py:4: error: IO[str] has no attribute "foobar"

[case testOpenReturnTypeInference]
reveal_type(open('x'))
reveal_type(open('x', 'r'))
reveal_type(open('x', 'rb'))
mode = 'rb'
reveal_type(open('x', mode))
[out]
_program.py:1: error: Revealed type is 'typing.IO[builtins.str]'
_program.py:2: error: Revealed type is 'typing.IO[builtins.str]'
_program.py:3: error: Revealed type is 'typing.IO[builtins.bytes]'
_program.py:5: error: Revealed type is 'typing.IO[Any]'

[case testOpenReturnTypeInferenceSpecialCases]
reveal_type(open())
reveal_type(open(mode='rb', file='x'))
reveal_type(open(file='x', mode='rb'))
mode = 'rb'
reveal_type(open(mode=mode, file='r'))
[out]
_testOpenReturnTypeInferenceSpecialCases.py:1: error: Revealed type is 'typing.IO[builtins.str]'
_testOpenReturnTypeInferenceSpecialCases.py:1: error: Too few arguments for "open"
_testOpenReturnTypeInferenceSpecialCases.py:2: error: Revealed type is 'typing.IO[builtins.bytes]'
_testOpenReturnTypeInferenceSpecialCases.py:3: error: Revealed type is 'typing.IO[builtins.bytes]'
_testOpenReturnTypeInferenceSpecialCases.py:5: error: Revealed type is 'typing.IO[Any]'

[case testGenericPatterns]
from typing import Pattern
Expand Down Expand Up @@ -1286,3 +1312,27 @@ a[1] = 2, 'y'
a[:] = [('z', 3)]
[out]
_program.py:4: error: Incompatible types in assignment (expression has type "Tuple[int, str]", target has type "Tuple[str, int]")

[case testContextManager]
import contextlib
from contextlib import contextmanager
from typing import Iterator

@contextmanager
def f(x: int) -> Iterator[str]:
yield 'foo'

@contextlib.contextmanager
def g(*x: str) -> Iterator[int]:
yield 1

reveal_type(f)
reveal_type(g)

with f('') as s:
reveal_type(s)
[out]
_program.py:13: error: Revealed type is 'def (x: builtins.int) -> contextlib.GeneratorContextManager[builtins.str*]'
_program.py:14: error: Revealed type is 'def (*x: builtins.str) -> contextlib.GeneratorContextManager[builtins.int*]'
_program.py:16: error: Argument 1 to "f" has incompatible type "str"; expected "int"
_program.py:17: error: Revealed type is 'builtins.str*'

0 comments on commit 53879ef

Please sign in to comment.