Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plugin to infer more precise regex match types #7803

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mypy.subtypes import is_subtype
from mypy.join import join_simple
from mypy.sametypes import is_same_type
from mypy.erasetype import remove_instance_last_known_values
from mypy.erasetype import remove_instance_transient_info
from mypy.nodes import Expression, Var, RefExpr
from mypy.literals import Key, literal, literal_hash, subkeys
from mypy.nodes import IndexExpr, MemberExpr, AssignmentExpr, NameExpr
Expand Down Expand Up @@ -259,7 +259,7 @@ def assign_type(self, expr: Expression,
restrict_any: bool = False) -> None:
# We should erase last known value in binder, because if we are using it,
# it means that the target is not final, and therefore can't hold a literal.
type = remove_instance_last_known_values(type)
type = remove_instance_transient_info(type)

type = get_proper_type(type)
declared_type = get_proper_type(declared_type)
Expand Down
8 changes: 4 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any
from mypy.semanal import set_callable_name, refers_to_fullname
from mypy.mro import calculate_mro, MroError
from mypy.erasetype import erase_typevars, remove_instance_last_known_values, erase_type
from mypy.erasetype import erase_typevars, remove_instance_transient_info, erase_type
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.visitor import NodeVisitor
from mypy.join import join_types
Expand Down Expand Up @@ -2188,7 +2188,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
if partial_types is not None:
if not self.current_node_deferred:
# Partial type can't be final, so strip any literal values.
rvalue_type = remove_instance_last_known_values(rvalue_type)
rvalue_type = remove_instance_transient_info(rvalue_type)
inferred_type = make_simplified_union(
[rvalue_type, NoneType()])
self.set_inferred_type(var, lvalue, inferred_type)
Expand Down Expand Up @@ -2270,7 +2270,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
if inferred:
rvalue_type = self.expr_checker.accept(rvalue)
if not inferred.is_final:
rvalue_type = remove_instance_last_known_values(rvalue_type)
rvalue_type = remove_instance_transient_info(rvalue_type)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
self.check_assignment_to_slots(lvalue)

Expand Down Expand Up @@ -4988,7 +4988,7 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance:
the name refers to a compatible generic type.
"""
info = self.lookup_typeinfo(name)
args = [remove_instance_last_known_values(arg) for arg in args]
args = [remove_instance_transient_info(arg) for arg in args]
# TODO: assert len(args) == len(info.defn.type_vars)
return Instance(info, args)

Expand Down
4 changes: 2 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import mypy.checker
from mypy import types
from mypy.sametypes import is_same_type
from mypy.erasetype import replace_meta_vars, erase_type, remove_instance_last_known_values
from mypy.erasetype import replace_meta_vars, erase_type, remove_instance_transient_info
from mypy.maptype import map_instance_to_supertype
from mypy.messages import MessageBuilder
from mypy import message_registry
Expand Down Expand Up @@ -3334,7 +3334,7 @@ def check_lst_expr(self, items: List[Expression], fullname: str,
[(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS)
for i in items],
context)[0]
return remove_instance_last_known_values(out)
return remove_instance_transient_info(out)

def visit_tuple_expr(self, e: TupleExpr) -> Type:
"""Type check a tuple expression."""
Expand Down
16 changes: 9 additions & 7 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,22 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
return t.copy_modified(args=[a.accept(self) for a in t.args])


def remove_instance_last_known_values(t: Type) -> Type:
return t.accept(LastKnownValueEraser())

def remove_instance_transient_info(t: Type) -> Type:
"""Recursively removes any info from Instances that exist
on a per-instance basis. Currently, this means erasing the
last-known literal type and any plugin metadata.
"""
return t.accept(TransientInstanceInfoEraser())

class LastKnownValueEraser(TypeTranslator):
"""Removes the Literal[...] type that may be associated with any
Instance types."""

class TransientInstanceInfoEraser(TypeTranslator):
def visit_instance(self, t: Instance) -> Type:
if not t.last_known_value and not t.args:
if not t.last_known_value and not t.args and not t.metadata:
return t
new_t = t.copy_modified(
args=[a.accept(self) for a in t.args],
last_known_value=None,
metadata={},
)
new_t.can_be_true = t.can_be_true
new_t.can_be_false = t.can_be_false
Expand Down
14 changes: 12 additions & 2 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ class DefaultPlugin(Plugin):

def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
from mypy.plugins import ctypes, singledispatch
from mypy.plugins import ctypes, regex, singledispatch

if fullname in ('contextlib.contextmanager', 'contextlib.asynccontextmanager'):
return contextmanager_callback
elif fullname == 'builtins.open' and self.python_version[0] == 3:
return open_callback
elif fullname == 'ctypes.Array':
return ctypes.array_constructor_callback
elif fullname == 're.compile':
return regex.re_compile_callback
elif fullname in regex.FUNCTIONS_PRODUCING_MATCH_OBJECT:
return regex.re_direct_match_callback
elif fullname == 'functools.singledispatch':
return singledispatch.create_singledispatch_function_callback
return None
Expand All @@ -55,7 +59,7 @@ def get_method_signature_hook(self, fullname: str

def get_method_hook(self, fullname: str
) -> Optional[Callable[[MethodContext], Type]]:
from mypy.plugins import ctypes, singledispatch
from mypy.plugins import ctypes, regex, singledispatch

if fullname == 'typing.Mapping.get':
return typed_dict_get_callback
Expand All @@ -77,6 +81,12 @@ def get_method_hook(self, fullname: str
return ctypes.array_iter_callback
elif fullname == 'pathlib.Path.open':
return path_open_callback
elif fullname in regex.METHODS_PRODUCING_MATCH_OBJECT:
return regex.re_get_match_callback
elif fullname == 'typing.Match.groups':
return regex.re_match_groups_callback
elif fullname in regex.METHODS_PRODUCING_GROUP:
return regex.re_match_group_callback
elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
return singledispatch.singledispatch_register_callback
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
Expand Down
Loading