Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support additional plugin hooks #3534

Merged
merged 11 commits into from
Jun 21, 2017
38 changes: 19 additions & 19 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ def build(sources: List[BuildSource],
lib_path.insert(0, alt_lib_path)

reports = Reports(data_dir, options.report_dirs)

source_set = BuildSourceSet(sources)
errors = Errors(options.show_error_context, options.show_column_numbers)
plugin = load_plugins(options, errors)

# Construct a build manager object to hold state during the build.
#
Expand All @@ -184,9 +185,8 @@ def build(sources: List[BuildSource],
reports=reports,
options=options,
version_id=__version__,
plugin=DefaultPlugin(options.python_version))

manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors)
plugin=plugin,
errors=errors)

try:
graph = dispatch(sources, manager)
Expand Down Expand Up @@ -337,13 +337,14 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int:
return toplevel_priority


def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin:
"""Load custom plugins if any are configured.
def load_plugins(options: Options, errors: Errors) -> Plugin:
"""Load all configured plugins.

Return a plugin that chains all custom plugins (if any) and falls
back to default_plugin.
Return a plugin that encapsulates all plugins chained together. Always
at least include the default plugin (it's last in the chain).
"""

default_plugin = DefaultPlugin(options) # type: Plugin
if not options.config_file:
return default_plugin

Expand All @@ -355,8 +356,8 @@ def plugin_error(message: str) -> None:
errors.report(line, 0, message)
errors.raise_error()

custom_plugins = [] # type: List[Plugin]
errors.set_file(options.config_file, None)
custom_plugins = []
for plugin_path in options.plugins:
# Plugin paths are relative to the config file location.
plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path)
Expand Down Expand Up @@ -395,15 +396,12 @@ def plugin_error(message: str) -> None:
'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin" '
'(in {})'.format(plugin_path))
try:
custom_plugins.append(plugin_type(options.python_version))
custom_plugins.append(plugin_type(options))
except Exception:
print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__))
raise # Propagate to display traceback
if not custom_plugins:
return default_plugin
else:
# Custom plugins take precendence over built-in plugins.
return ChainedPlugin(options.python_version, custom_plugins + [default_plugin])
# Custom plugins take precedence over the default plugin.
return ChainedPlugin(options, custom_plugins + [default_plugin])


def find_config_file_line_number(path: str, section: str, setting_name: str) -> int:
Expand Down Expand Up @@ -447,12 +445,12 @@ class BuildManager:
semantic_analyzer_pass3:
Semantic analyzer, pass 3
all_types: Map {Expression: Type} collected from all modules
errors: Used for reporting all errors
options: Build options
missing_modules: Set of modules that could not be imported encountered so far
stale_modules: Set of modules that needed to be rechecked
version_id: The current mypy version (based on commit id when possible)
plugin: Active mypy plugin(s)
errors: Used for reporting all errors
"""

def __init__(self, data_dir: str,
Expand All @@ -462,10 +460,11 @@ def __init__(self, data_dir: str,
reports: Reports,
options: Options,
version_id: str,
plugin: Plugin) -> None:
plugin: Plugin,
errors: Errors) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this to the docstring too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was mentioned already, but moved to another place where it might be easier to find.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some future PR maybe we can remove the attributes from the docstring and instead use class-level annotations + comments for them.

self.start_time = time.time()
self.data_dir = data_dir
self.errors = Errors(options.show_error_context, options.show_column_numbers)
self.errors = errors
self.errors.set_ignore_prefix(ignore_prefix)
self.lib_path = tuple(lib_path)
self.source_set = source_set
Expand All @@ -474,8 +473,9 @@ def __init__(self, data_dir: str,
self.version_id = version_id
self.modules = {} # type: Dict[str, MypyFile]
self.missing_modules = set() # type: Set[str]
self.plugin = plugin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JukkaL you have self.plugin = plugin twice in the init (on 476 and 485 at the end of the init). Is that on purpose?

self.semantic_analyzer = SemanticAnalyzer(self.modules, self.missing_modules,
lib_path, self.errors)
lib_path, self.errors, self.plugin)
self.modules = self.semantic_analyzer.modules
self.semantic_analyzer_pass3 = ThirdPass(self.modules, self.errors)
self.all_types = {} # type: Dict[Expression, Type]
Expand Down
4 changes: 2 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from mypy.binder import ConditionalTypeBinder, get_declaration
from mypy.meet import is_overlapping_types
from mypy.options import Options
from mypy.plugin import Plugin
from mypy.plugin import Plugin, CheckerPluginInterface

from mypy import experiments

Expand All @@ -80,7 +80,7 @@
])


class TypeChecker(NodeVisitor[None]):
class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
"""Mypy type checker.

Type check mypy source files that have been semantically analyzed.
Expand Down
36 changes: 23 additions & 13 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from mypy.util import split_module_names
from mypy.typevars import fill_typevars
from mypy.visitor import ExpressionVisitor
from mypy.plugin import Plugin, PluginContext, MethodSignatureHook
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
from mypy.typeanal import make_optional_type

from mypy import experiments
Expand Down Expand Up @@ -380,6 +380,13 @@ def apply_function_plugin(self,
context: Context) -> Type:
"""Use special case logic to infer the return type of a specific named function/method.

Caller must ensure that a plugin hook exists. There are two different cases:

- If object_type is None, the caller must ensure that a function hook exists
for fullname.
- If object_type is not None, the caller must ensure that a method hook exists
for fullname.

Return the inferred return type.
"""
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
Expand All @@ -392,17 +399,21 @@ def apply_function_plugin(self,
# Apply function plugin
callback = self.plugin.get_function_hook(fullname)
assert callback is not None # Assume that caller ensures this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still feels a bit awkward to me, but I agree that other ways to factor this out aren't much better. Maybe the docstring should just call out that there are two calling cases and that the caller should ensure the relevant callback isn't None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated docstring

return callback(formal_arg_types, formal_arg_exprs, inferred_ret_type,
self.chk.named_generic_type)
return callback(
FunctionContext(formal_arg_types, inferred_ret_type, formal_arg_exprs,
context, self.chk))
else:
# Apply method plugin
method_callback = self.plugin.get_method_hook(fullname)
assert method_callback is not None # Assume that caller ensures this
return method_callback(object_type, formal_arg_types, formal_arg_exprs,
inferred_ret_type, self.create_plugin_context(context))

def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_type: Type,
signature_hook: MethodSignatureHook) -> FunctionLike:
return method_callback(
MethodContext(object_type, formal_arg_types,
inferred_ret_type, formal_arg_exprs,
context, self.chk))

def apply_method_signature_hook(
self, e: CallExpr, callee: FunctionLike, object_type: Type,
signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike:
"""Apply a plugin hook that may infer a more precise signature for a method."""
if isinstance(callee, CallableType):
arg_kinds = e.arg_kinds
Expand All @@ -417,8 +428,8 @@ def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_exprs[formal].append(args[actual])
return signature_hook(object_type, formal_arg_exprs, callee,
self.chk.named_generic_type)
return signature_hook(
MethodSigContext(object_type, formal_arg_exprs, callee, e, self.chk))
else:
assert isinstance(callee, Overloaded)
items = []
Expand All @@ -428,9 +439,6 @@ def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_
items.append(adjusted)
return Overloaded(items)

def create_plugin_context(self, context: Context) -> PluginContext:
return PluginContext(self.chk.named_generic_type, self.msg, context)

def check_call_expr_with_callee_type(self,
callee_type: Type,
e: CallExpr,
Expand Down Expand Up @@ -475,6 +483,8 @@ def check_call(self, callee: Type, args: List[Expression],
"""
arg_messages = arg_messages or self.msg
if isinstance(callee, CallableType):
if callable_name is None and callee.name:
callable_name = callee.name
if (isinstance(callable_node, RefExpr)
and callable_node.fullname in ('enum.Enum', 'enum.IntEnum',
'enum.Flag', 'enum.IntFlag')):
Expand Down
29 changes: 18 additions & 11 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mypy.expandtype import expand_type_by_instance, expand_type, freshen_function_type_vars
from mypy.infer import infer_type_arguments
from mypy.typevars import fill_typevars
from mypy.plugin import Plugin, AttributeContext
from mypy import messages
from mypy import subtypes
MYPY = False
Expand All @@ -36,8 +37,8 @@ def analyze_member_access(name: str,
not_ready_callback: Callable[[str, Context], None],
msg: MessageBuilder, *,
original_type: Type,
override_info: TypeInfo = None,
chk: 'mypy.checker.TypeChecker' = None) -> Type:
chk: 'mypy.checker.TypeChecker',
override_info: TypeInfo = None) -> Type:
"""Return the type of attribute `name` of typ.

This is a general operation that supports various different variations:
Expand Down Expand Up @@ -77,7 +78,7 @@ def analyze_member_access(name: str,
assert isinstance(method, OverloadedFuncDef)
first_item = cast(Decorator, method.items[0])
return analyze_var(name, first_item.var, typ, info, node, is_lvalue, msg,
original_type, not_ready_callback)
original_type, not_ready_callback, chk=chk)
if is_lvalue:
msg.cant_assign_to_method(node)
signature = function_type(method, builtin_type('builtins.function'))
Expand All @@ -102,7 +103,7 @@ def analyze_member_access(name: str,
# The base object has dynamic type.
return AnyType()
elif isinstance(typ, NoneTyp):
if chk and chk.should_suppress_optional_error([typ]):
if chk.should_suppress_optional_error([typ]):
return AnyType()
# The only attribute NoneType has are those it inherits from object
return analyze_member_access(name, builtin_type('builtins.object'), node, is_lvalue,
Expand Down Expand Up @@ -200,7 +201,7 @@ def analyze_member_access(name: str,
is_operator, builtin_type, not_ready_callback, msg,
original_type=original_type, chk=chk)

if chk and chk.should_suppress_optional_error([typ]):
if chk.should_suppress_optional_error([typ]):
return AnyType()
return msg.has_no_attr(original_type, typ, name, node)

Expand Down Expand Up @@ -228,7 +229,7 @@ def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo,

if isinstance(v, Var):
return analyze_var(name, v, itype, info, node, is_lvalue, msg,
original_type, not_ready_callback)
original_type, not_ready_callback, chk=chk)
elif isinstance(v, FuncDef):
assert False, "Did not expect a function"
elif not v and name not in ['__getattr__', '__setattr__', '__getattribute__']:
Expand Down Expand Up @@ -270,7 +271,8 @@ def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo,

def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Context,
is_lvalue: bool, msg: MessageBuilder, original_type: Type,
not_ready_callback: Callable[[str, Context], None]) -> Type:
not_ready_callback: Callable[[str, Context], None], *,
chk: 'mypy.checker.TypeChecker') -> Type:
"""Analyze access to an attribute via a Var node.

This is conceptually part of analyze_member_access and the arguments are similar.
Expand All @@ -289,6 +291,7 @@ def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Cont
msg.read_only_property(name, info, node)
if is_lvalue and var.is_classvar:
msg.cant_assign_to_classvar(name, node)
result = t
if var.is_initialized_in_class and isinstance(t, FunctionLike) and not t.is_type_obj():
if is_lvalue:
if var.is_property:
Expand All @@ -308,15 +311,19 @@ def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Cont
# A property cannot have an overloaded type => the cast
# is fine.
assert isinstance(signature, CallableType)
return signature.ret_type
result = signature.ret_type
else:
return signature
return t
result = signature
else:
if not var.is_ready:
not_ready_callback(var.name(), node)
# Implicit 'Any' type.
return AnyType()
result = AnyType()
fullname = '{}.{}'.format(var.info.fullname(), name)
hook = chk.plugin.get_attribute_hook(fullname)
if hook:
result = hook(AttributeContext(original_type, result, node, chk))
return result


def freeze_type_vars(member_type: Type) -> None:
Expand Down
Loading