From 54dfad8ab575f7db663ca610bdd94461750a03e2 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Sat, 24 Aug 2019 15:35:05 -0700 Subject: [PATCH 1/7] Add plugin to infer more precise regex match types This pull request adds a plugin to make mypy infer more precise types when grabbing regex groups: the plugin will when possible analyze original regex to deduce whether a given group is required or not. ``` from typing_extensions import Final, Literal import re pattern: Final = re.compile("(a)(b)*") match: Final = pattern.match("") if match: reveal_type(match.groups()) # Revealed type is Tuple[str, Optional[str]] reveal_type(match.group(0)) # Revealed type is str reveal_type(match.group(1)) # Revealed type is str reveal_type(match.group(2)) # Revealed type is Optional[str] index: int reveal_type(match.group(index)) # Revealed type is Optional[str] # Error: Regex has 3 total groups, given group number 5 is too big match.group(5) ``` To track this information, I added in an optional 'metadata' dict field to the Instance class, similar to the metadata dict for plugins in TypeInfos. We skip serializing this dict if it does not contain any data. A limitation of this plugin is that both the pattern and the match variables must be declared to be final. Otherwise, we just default to using whatever types are defined in typeshed. This is because we set and erase the metadata field in exactly the same way we set and erase the `last_known_value` field in Instances: both kinds of info are "transient" and are unsafe to keep around if the variable reference is mutable. This limitation *does* end up limiting the usefulness of this plugin to some degree: it won't support common patterns like the below, since variables aren't allowed to be declared final inside loops: ``` for line in file: match = pattern.match(line) if match: ... ``` Possibly we can remove this limitation by making mypy less aggressive about removing this transient info by tracking the "lifetime" of this sort of data in some way? This pull request should mostly address https://github.com/python/mypy/issues/7363, though it's unclear if it really fully resolves it: we might want to do something about the limitation described above and re-tune typeshed first. The other mostly unrelated change this PR makes is to refactor some of the helper functions in checker.py into typeops.py so I could use them more cleanly in the plugin. --- mypy/binder.py | 4 +- mypy/checker.py | 101 +------- mypy/erasetype.py | 17 +- mypy/plugins/default.py | 12 + mypy/plugins/regex.py | 337 ++++++++++++++++++++++++++ mypy/test/testplugin.py | 93 +++++++ mypy/typeops.py | 95 +++++++- mypy/types.py | 16 +- test-data/unit/fixtures/floatdict.pyi | 2 +- test-data/unit/pythoneval.test | 130 ++++++++++ 10 files changed, 698 insertions(+), 109 deletions(-) create mode 100644 mypy/plugins/regex.py create mode 100644 mypy/test/testplugin.py diff --git a/mypy/binder.py b/mypy/binder.py index 109fef25ce6a..309a6c658840 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -10,7 +10,7 @@ from mypy.subtypes import is_subtype from mypy.join import join_simple from mypy.sametypes import is_same_type -from mypy.erasetype import remove_instance_last_known_values +from mypy.erasetype import remove_instance_transient_info from mypy.nodes import Expression, Var, RefExpr from mypy.literals import Key, literal, literal_hash, subkeys from mypy.nodes import IndexExpr, MemberExpr, NameExpr @@ -251,7 +251,7 @@ def assign_type(self, expr: Expression, restrict_any: bool = False) -> None: # We should erase last known value in binder, because if we are using it, # it means that the target is not final, and therefore can't hold a literal. - type = remove_instance_last_known_values(type) + type = remove_instance_transient_info(type) type = get_proper_type(type) declared_type = get_proper_type(declared_type) diff --git a/mypy/checker.py b/mypy/checker.py index 7cc1b04b5d91..88467c1ca3a3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,7 +2,6 @@ import itertools import fnmatch -import sys from contextlib import contextmanager from typing import ( @@ -50,7 +49,8 @@ from mypy.typeops import ( map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, erase_def_to_union_or_bound, erase_to_union_or_bound, - true_only, false_only, function_type, + true_only, false_only, function_type, is_singleton_type, + try_expanding_enum_to_union, coerce_to_literal, ) from mypy import message_registry from mypy.subtypes import ( @@ -63,7 +63,7 @@ from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any from mypy.semanal import set_callable_name, refers_to_fullname from mypy.mro import calculate_mro -from mypy.erasetype import erase_typevars, remove_instance_last_known_values +from mypy.erasetype import erase_typevars, remove_instance_transient_info from mypy.expandtype import expand_type, expand_type_by_instance from mypy.visitor import NodeVisitor from mypy.join import join_types @@ -2069,7 +2069,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type if partial_types is not None: if not self.current_node_deferred: # Partial type can't be final, so strip any literal values. - rvalue_type = remove_instance_last_known_values(rvalue_type) + rvalue_type = remove_instance_transient_info(rvalue_type) inferred_type = make_simplified_union( [rvalue_type, NoneType()]) self.set_inferred_type(var, lvalue, inferred_type) @@ -2126,7 +2126,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type if inferred: rvalue_type = self.expr_checker.accept(rvalue) if not inferred.is_final: - rvalue_type = remove_instance_last_known_values(rvalue_type) + rvalue_type = remove_instance_transient_info(rvalue_type) self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue) def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type], @@ -4753,97 +4753,6 @@ def is_private(node_name: str) -> bool: return node_name.startswith('__') and not node_name.endswith('__') -def get_enum_values(typ: Instance) -> List[str]: - """Return the list of values for an Enum.""" - return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)] - - -def is_singleton_type(typ: Type) -> bool: - """Returns 'true' if this type is a "singleton type" -- if there exists - exactly only one runtime value associated with this type. - - That is, given two values 'a' and 'b' that have the same type 't', - 'is_singleton_type(t)' returns True if and only if the expression 'a is b' is - always true. - - Currently, this returns True when given NoneTypes, enum LiteralTypes and - enum types with a single value. - - Note that other kinds of LiteralTypes cannot count as singleton types. For - example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed - that 'a is b' will always be true -- some implementations of Python will end up - constructing two distinct instances of 100001. - """ - typ = get_proper_type(typ) - # TODO: Also make this return True if the type is a bool LiteralType. - # Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented? - return ( - isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal()) - or (isinstance(typ, Instance) and typ.type.is_enum and len(get_enum_values(typ)) == 1) - ) - - -def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType: - """Attempts to recursively expand any enum Instances with the given target_fullname - into a Union of all of its component LiteralTypes. - - For example, if we have: - - class Color(Enum): - RED = 1 - BLUE = 2 - YELLOW = 3 - - class Status(Enum): - SUCCESS = 1 - FAILURE = 2 - UNKNOWN = 3 - - ...and if we call `try_expanding_enum_to_union(Union[Color, Status], 'module.Color')`, - this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW, Status]. - """ - typ = get_proper_type(typ) - - if isinstance(typ, UnionType): - items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] - return make_simplified_union(items) - elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname: - new_items = [] - for name, symbol in typ.type.names.items(): - if not isinstance(symbol.node, Var): - continue - new_items.append(LiteralType(name, typ)) - # SymbolTables are really just dicts, and dicts are guaranteed to preserve - # insertion order only starting with Python 3.7. So, we sort these for older - # versions of Python to help make tests deterministic. - # - # We could probably skip the sort for Python 3.6 since people probably run mypy - # only using CPython, but we might as well for the sake of full correctness. - if sys.version_info < (3, 7): - new_items.sort(key=lambda lit: lit.value) - return make_simplified_union(new_items) - else: - return typ - - -def coerce_to_literal(typ: Type) -> ProperType: - """Recursively converts any Instances that have a last_known_value or are - instances of enum types with a single value into the corresponding LiteralType. - """ - typ = get_proper_type(typ) - if isinstance(typ, UnionType): - new_items = [coerce_to_literal(item) for item in typ.items] - return make_simplified_union(new_items) - elif isinstance(typ, Instance): - if typ.last_known_value: - return typ.last_known_value - elif typ.type.is_enum: - enum_values = get_enum_values(typ) - if len(enum_values) == 1: - return LiteralType(value=enum_values[0], fallback=typ) - return typ - - def has_bool_item(typ: ProperType) -> bool: """Return True if type is 'bool' or a union with a 'bool' item.""" if is_named_instance(typ, 'builtins.bool'): diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 62580dfb0f12..2f3f89d29990 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -123,15 +123,16 @@ def visit_type_var(self, t: TypeVarType) -> Type: return t -def remove_instance_last_known_values(t: Type) -> Type: - return t.accept(LastKnownValueEraser()) - +def remove_instance_transient_info(t: Type) -> Type: + """Recursively removes any info from Instances that exist + on a per-instance basis. Currently, this means erasing the + last-known literal type and any plugin metadata. + """ + return t.accept(TransientInstanceInfoEraser()) -class LastKnownValueEraser(TypeTranslator): - """Removes the Literal[...] type that may be associated with any - Instance types.""" +class TransientInstanceInfoEraser(TypeTranslator): def visit_instance(self, t: Instance) -> Type: - if t.last_known_value: - return t.copy_modified(last_known_value=None) + if t.last_known_value or t.metadata: + return t.copy_modified(last_known_value=None, metadata={}) return t diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index ca9d3baad3bb..2271b94ed362 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -22,6 +22,7 @@ class DefaultPlugin(Plugin): def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: from mypy.plugins import ctypes + from mypy.plugins import regex if fullname == 'contextlib.contextmanager': return contextmanager_callback @@ -29,6 +30,10 @@ def get_function_hook(self, fullname: str return open_callback elif fullname == 'ctypes.Array': return ctypes.array_constructor_callback + elif fullname == 're.compile': + return regex.re_compile_callback + elif fullname in regex.FUNCTIONS_PRODUCING_MATCH_OBJECT: + return regex.re_direct_match_callback return None def get_method_signature_hook(self, fullname: str @@ -52,6 +57,7 @@ def get_method_signature_hook(self, fullname: str def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], Type]]: from mypy.plugins import ctypes + from mypy.plugins import regex if fullname == 'typing.Mapping.get': return typed_dict_get_callback @@ -69,6 +75,12 @@ def get_method_hook(self, fullname: str return ctypes.array_iter_callback elif fullname == 'pathlib.Path.open': return path_open_callback + elif fullname in regex.METHODS_PRODUCING_MATCH_OBJECT: + return regex.re_get_match_callback + elif fullname == 'typing.Match.groups': + return regex.re_match_groups_callback + elif fullname in regex.METHODS_PRODUCING_GROUP: + return regex.re_match_group_callback return None def get_attribute_hook(self, fullname: str diff --git a/mypy/plugins/regex.py b/mypy/plugins/regex.py new file mode 100644 index 000000000000..c710e262f42c --- /dev/null +++ b/mypy/plugins/regex.py @@ -0,0 +1,337 @@ +""" +A plugin for analyzing regexes to determine how many groups a +regex can contain and whether those groups are always matched or not. +For example: + + pattern: Final = re.compile("(foo)(bar)?") + match: Final = pattern.match(input_text) + if match: + reveal_type(match.groups()) + +Without the plugin, the best we can really do is determine revealed +type is either Sequence[str] or Tuple[str, ...]. But with this plugin, +we can obtain a more precise type of Tuple[str, Optionl[str]]. We were +able to deduce th first group is mandatory and the second optional. + +Broadly, this plugin works by using the underlying builtin regex +parsing engine to obtain the regex AST. We can then crawl this AST +to obtain the mandatory groups, total number of groups, and any +named groups. + +We then inject this obtained data into the Pattern or Match objects +into a "metadata" field on a per-instance basis. + +Note that while we parse the regex, we at no point will ever actually +try matching anything against it. +""" + +from typing import Union, Iterator, Tuple, List, Any, Optional, Dict +from typing_extensions import Final + +from mypy.types import ( + Type, ProperType, Instance, NoneType, LiteralType, + TupleType, remove_optional, +) +from mypy.typeops import make_simplified_union, coerce_to_literal, get_proper_type +import mypy.plugin # To avoid circular imports. + +from sre_parse import parse, SubPattern +from sre_constants import ( + SUBPATTERN, MIN_REPEAT, MAX_REPEAT, GROUPREF_EXISTS, BRANCH, + error as SreError, _NamedIntConstant as NIC, +) + +STR_LIKE_TYPES = { + 'builtins.unicode', + 'builtins.str', + 'builtins.bytes', +} # type: Final + +FUNCTIONS_PRODUCING_MATCH_OBJECT = { + 're.search', + 're.match', + 're.fullmatch', +} + +METHODS_PRODUCING_MATCH_OBJECT = { + 'typing.Pattern.search', + 'typing.Pattern.match', + 'typing.Pattern.fullmatch', +} # type: Final + +METHODS_PRODUCING_GROUP = { + 'typing.Match.group', + 'typing.Match.__getitem__', +} + +OBJECTS_SUPPORTING_REGEX_METADATA = { + 'typing.Pattern', + 'typing.Match', +} + + +class RegexPluginException(Exception): + def __init__(self, msg: str) -> None: + super().__init__(msg) + self.msg = msg + + +def find_mandatory_groups(ast: Union[SubPattern, Tuple[NIC, Any]]) -> Iterator[int]: + """Yields the all group numbers that are guaranteed to match something + in the Match object corresponding to the given regex. + + For example, if the provided AST corresponds to the regex + "(a)(?:(b)|(c))(d)?(e)+(f)", this function would yield 1, 5, and 6. + + We do not yield 0 even though that group will always have a match. This + function only group numbers that can actually be found in the AST. + """ + if isinstance(ast, tuple): + data = [ast] # type: List[Tuple[NIC, Any]] + elif isinstance(ast, SubPattern): + data = ast.data + else: + raise RegexPluginException("Internal error: unexpected regex AST item '{}'".format(ast)) + + for op, av in data: + if op is SUBPATTERN: + group, _, _, children = av + # This can be 'None' for "extension notation groups" + if group is not None: + yield group + for child in children: + yield from find_mandatory_groups(child) + elif op in (MIN_REPEAT, MAX_REPEAT): + min_repeats, _, children = av + if min_repeats == 0: + continue + for child in children: + yield from find_mandatory_groups(child) + elif op in (BRANCH, GROUPREF_EXISTS): + # Note: We deliberately ignore branches (e.g. "(a)|(b)") or + # conditional matches (e.g. "(?(named-group)yes-branch|no-branch)". + # The whole point of a branch is that it'll be matched only + # some of the time, therefore no subgroups in either branch can + # ever be mandatory. + continue + elif isinstance(av, list): + for child in av: + yield from find_mandatory_groups(child) + + +def extract_regex_group_info(pattern: str) -> Tuple[List[int], int, Dict[str, int]]: + """Analyzes the given regex pattern and returns a tuple of: + + 1. A list of all mandatory group indexes in sorted order (including 0). + 2. The total number of groups, including optional groups and the zero-th group. + 3. A mapping of named groups to group indices. + + If the given str is not a valid regex, raises RegexPluginException. + """ + try: + ast = parse(pattern) + except SreError as ex: + raise RegexPluginException("Invalid regex: {}".format(ex.msg)) + + mandatory_groups = [0] + list(sorted(find_mandatory_groups(ast))) + total_groups = ast.pattern.groups + named_groups = ast.pattern.groupdict + + return mandatory_groups, total_groups, named_groups + + +def analyze_regex_pattern_call( + pattern_type: Type, + default_return_type: Type, +) -> Type: + """The re module contains several methods or functions + that accept some string containing a regex pattern and returns + either a typing.Pattern or typing.Match object. + + This function handles the core logic for extracting and + attaching this regex metadata to the return object in all + these cases. + """ + + pattern_type = coerce_to_literal(pattern_type) + if not isinstance(pattern_type, LiteralType): + return default_return_type + if pattern_type.fallback.type.fullname() not in STR_LIKE_TYPES: + return default_return_type + + return_type = get_proper_type(default_return_type) + if not isinstance(return_type, Instance): + return default_return_type + if return_type.type.fullname() not in OBJECTS_SUPPORTING_REGEX_METADATA: + return default_return_type + + pattern = pattern_type.value + assert isinstance(pattern, str) + mandatory_groups, total_groups, named_groups = extract_regex_group_info(pattern) + + metadata = { + "default_re_plugin": { + "mandatory_groups": mandatory_groups, + "total_groups": total_groups, + "named_groups": named_groups, + } + } + + return return_type.copy_modified( + metadata={**return_type.metadata, **metadata}, + ) + + +def extract_metadata(typ: ProperType) -> Optional[Tuple[Dict[str, Any], Instance]]: + """Returns the regex metadata from the given type, if it exists. + Otherwise returns None. + + This function is the dual of 'analyze_regex_pattern_call'. That function + tries finding and attaching the metadata to Pattern or Match objects; + this function tries extracting the attached metadata. + """ + if not isinstance(typ, Instance): + return None + + metadata = typ.metadata.get('default_re_plugin', None) + if metadata is None: + return None + + arg_type = get_proper_type(typ.args[0]) + if not isinstance(arg_type, Instance): + return None + + return metadata, arg_type + + +def re_direct_match_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """Analyzes functions such as 're.match(PATTERN, INPUT)'""" + try: + return analyze_regex_pattern_call( + ctx.arg_types[0][0], + remove_optional(ctx.default_return_type), + ) + except RegexPluginException as ex: + ctx.api.fail(ex.msg, ctx.context) + return ctx.default_return_type + + +def re_compile_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """Analyzes the 're.compile(PATTERN)' function.""" + try: + return analyze_regex_pattern_call( + ctx.arg_types[0][0], + ctx.default_return_type, + ) + except RegexPluginException as ex: + ctx.api.fail(ex.msg, ctx.context) + return ctx.default_return_type + + +def re_get_match_callback(ctx: mypy.plugin.MethodContext) -> Type: + """Analyzes the 'typing.Pattern.match(...)' method.""" + self_type = ctx.type + return_type = ctx.default_return_type + + if not isinstance(self_type, Instance) or 'default_re_plugin' not in self_type.metadata: + return return_type + + match_object = remove_optional(return_type) + assert isinstance(match_object, Instance) + + pattern_metadata = self_type.metadata['default_re_plugin'] + new_match_object = match_object.copy_modified(metadata={'default_re_plugin': pattern_metadata}) + return make_simplified_union([new_match_object, NoneType()]) + + +def re_match_groups_callback(ctx: mypy.plugin.MethodContext) -> Type: + """Analyzes the 'typing.Match.group(...)' method, which returns + a tuple of all matched groups.""" + info = extract_metadata(ctx.type) + if info is None: + return ctx.default_return_type + + metadata, mandatory_match_type = info + mandatory = set(metadata['mandatory_groups']) + total = metadata['total_groups'] + + if len(ctx.arg_types) > 0 and len(ctx.arg_types[0]) > 0: + default_type = ctx.arg_types[0][0] + else: + default_type = NoneType() + + optional_match_type = make_simplified_union([mandatory_match_type, default_type]) + + items = [] # type: List[Type] + for i in range(1, total): + if i in mandatory: + items.append(mandatory_match_type) + else: + items.append(optional_match_type) + + fallback = ctx.api.named_generic_type("builtins.tuple", [mandatory_match_type]) + return TupleType(items, fallback) + + +def re_match_group_callback(ctx: mypy.plugin.MethodContext) -> Type: + """Analyzes the 'typing.Match.group()' and '__getitem__(...)' methods.""" + info = extract_metadata(ctx.type) + if info is None: + return ctx.default_return_type + + metadata, mandatory_match_type = info + mandatory = set(metadata['mandatory_groups']) + total = metadata['total_groups'] + named_groups = metadata['named_groups'] + + if len(mandatory) != total: + optional_match_type = make_simplified_union([mandatory_match_type, NoneType()]) + else: + optional_match_type = mandatory_match_type + + possible_indices = [] + for arg_type in ctx.arg_types: + if len(arg_type) >= 1: + possible_indices.append(coerce_to_literal(arg_type[0])) + + outputs = [] # type: List[Type] + for possible_index in possible_indices: + if not isinstance(possible_index, LiteralType): + outputs.append(optional_match_type) + continue + + value = possible_index.value + fallback_name = possible_index.fallback.type.fullname() + + if isinstance(value, str) and fallback_name in STR_LIKE_TYPES: + if value not in named_groups: + ctx.api.fail("Regex does not contain group named '{}'".format(value), ctx.context) + outputs.append(optional_match_type) + continue + + index = named_groups[value] + elif isinstance(value, int): + if value < 0: + ctx.api.fail("Regex group number should not be negative", ctx.context) + outputs.append(optional_match_type) + continue + elif value >= total: + msg = "Regex has {} total groups, given group number {} is too big" + ctx.api.fail(msg.format(total, value), ctx.context) + outputs.append(optional_match_type) + continue + index = value + else: + outputs.append(optional_match_type) + continue + + if index in mandatory: + outputs.append(mandatory_match_type) + else: + outputs.append(optional_match_type) + + if len(outputs) == 1: + return outputs[0] + else: + fallback = ctx.api.named_generic_type("builtins.tuple", [mandatory_match_type]) + return TupleType(outputs, fallback) diff --git a/mypy/test/testplugin.py b/mypy/test/testplugin.py new file mode 100644 index 000000000000..137f5e2db894 --- /dev/null +++ b/mypy/test/testplugin.py @@ -0,0 +1,93 @@ +from typing import List, Dict + +from mypy.test.helpers import Suite, assert_equal +from mypy.plugins.regex import extract_regex_group_info, RegexPluginException + + +class RegexPluginSuite(Suite): + def test_regex_group_analysis(self) -> None: + def check(pattern: str, + expected_mandatory: List[int], + expected_total: int, + expected_named: Dict[str, int], + ) -> None: + actual_mandatory, actual_total, actual_named = extract_regex_group_info(pattern) + assert_equal(actual_mandatory, expected_mandatory) + assert_equal(actual_total, expected_total) + assert_equal(actual_named, expected_named) + + # Some conventions, to make reading these more clear: + # + # m1, m2, m3... -- text meant to be a part of mandatory groups + # o1, o2, o3... -- text meant to be a part of optional groups + # x, y, z -- other dummy filter text + # n1, n2, n3... -- names for named groups + + # Basic sanity checks + check(r"x", [0], 1, {}) + check(r"", [0], 1, {}) + check(r"(m1(m2(m3)(m4)))", [0, 1, 2, 3, 4], 5, {}) + + # Named groups + check(r"(?Pm1)(?P=n1)(?Po2)*", [0, 1], 3, {'n1': 1, 'n2': 2}) + check(r"(?Pfoo){0,4} (?Pbar)", [0, 2], 3, {'n1': 1, 'n2': 2}) + + # Repetition checks + check(r"(m1)(o2)?(m3)(o4)*(r5)+(o6)??", [0, 1, 3, 5], 7, {}) + check(r"(m1(o2)?)+", [0, 1], 3, {}) + check(r"(o1){0,3} (m2){2} (m3){1,2}", [0, 2, 3], 4, {}) + check(r"(o1){0,3}? (m2){2}? (m3){1,2}?", [0, 2, 3], 4, {}) + + # Branching + check(r"(o1)|(o2)(o3|x)", [0], 4, {}) + check(r"(m1(o2)|(o3))(m4|x)", [0, 1, 4], 5, {}) + check(r"(?:(o1)|(o2))(m3|x)", [0, 3], 4, {}) + + # Non-capturing groups + check(r"(?:x)(m1)", [0, 1], 2, {}) + check(r"(?:x)", [0], 1, {}) + + # Flag groups + # Note: Doing re.compile("(?a)foo") is equivalent to doing + # re.compile("foo", flags=re.A). You can also use inline + # flag groups "(?FLAGS:PATTERN)" to apply flags just for + # the specified pattern. + check(r"(?a)(?i)x", [0], 1, {}) + check(r"(?ai)x", [0], 1, {}) + check(r"(?a:(m1)(o2)*(?Pm3))", [0, 1, 3], 4, {'n3': 3}) + + # Lookahead assertions + check(r"(m1) (?=x) (m2)", [0, 1, 2], 3, {}) + check(r"(m1) (m2(?=x)) (m3)", [0, 1, 2, 3], 4, {}) + + # Negative lookahead assertions + check(r"(m1) (?!x) (m2)", [0, 1, 2], 3, {}) + check(r"(m1) (m2(?!x)) (m3)", [0, 1, 2, 3], 4, {}) + + # Positive lookbehind assertions + check(r"(m1)+ (?<=x)(m2)", [0, 1, 2], 3, {}) + check(r"(?<=x)x", [0], 1, {}) + + # Conditional matches + check(r"(?Pm1) (?(n1)x|y) (m2)", [0, 1, 2], 3, {"n1": 1}) + check(r"(?Po1)? (?(n1)x|y) (m2)", [0, 2], 3, {"n1": 1}) + check(r"(?Pm1) (?(n1)x) (m2)", [0, 1, 2], 3, {"n1": 1}) + check(r"(?Po1)? (?(n1)x) (m2)", [0, 2], 3, {"n1": 1}) + check(r"(m1) (?(1)x|y) (m2)", [0, 1, 2], 3, {}) + check(r"(o1)? (?(1)x|y) (m2)", [0, 2], 3, {}) + + # Comments + check(r"(m1)(?#comment)(r2)", [0, 1, 2], 3, {}) + + def test_regex_errors(self) -> None: + def check(pattern: str) -> None: + try: + extract_regex_group_info(pattern) + except RegexPluginException: + pass + else: + raise AssertionError("Did not throw expection for regex '{}'".format(pattern)) + + check(r"(unbalanced") + check(r"unbalanced)") + check(r"(?P=badgroupname)") diff --git a/mypy/typeops.py b/mypy/typeops.py index 39c81617c9ab..8a6d4a4df425 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -7,6 +7,8 @@ from typing import cast, Optional, List, Sequence, Set +import sys + from mypy.types import ( TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded, TypeVarType, TypeType, UninhabitedType, FormalArgument, UnionType, NoneType, @@ -15,7 +17,7 @@ ) from mypy.nodes import ( FuncBase, FuncItem, OverloadedFuncDef, TypeInfo, TypeVar, ARG_STAR, ARG_STAR2, Expression, - StrExpr + StrExpr, Var ) from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance, expand_type @@ -464,3 +466,94 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]] else: return None return strings + + +def get_enum_values(typ: Instance) -> List[str]: + """Return the list of values for an Enum.""" + return [name for name, sym in typ.type.names.items() if isinstance(sym.node, Var)] + + +def is_singleton_type(typ: Type) -> bool: + """Returns 'true' if this type is a "singleton type" -- if there exists + exactly only one runtime value associated with this type. + + That is, given two values 'a' and 'b' that have the same type 't', + 'is_singleton_type(t)' returns True if and only if the expression 'a is b' is + always true. + + Currently, this returns True when given NoneTypes, enum LiteralTypes and + enum types with a single value. + + Note that other kinds of LiteralTypes cannot count as singleton types. For + example, suppose we do 'a = 100000 + 1' and 'b = 100001'. It is not guaranteed + that 'a is b' will always be true -- some implementations of Python will end up + constructing two distinct instances of 100001. + """ + typ = get_proper_type(typ) + # TODO: Also make this return True if the type is a bool LiteralType. + # Also make this return True if the type corresponds to ... (ellipsis) or NotImplemented? + return ( + isinstance(typ, NoneType) or (isinstance(typ, LiteralType) and typ.is_enum_literal()) + or (isinstance(typ, Instance) and typ.type.is_enum and len(get_enum_values(typ)) == 1) + ) + + +def try_expanding_enum_to_union(typ: Type, target_fullname: str) -> ProperType: + """Attempts to recursively expand any enum Instances with the given target_fullname + into a Union of all of its component LiteralTypes. + + For example, if we have: + + class Color(Enum): + RED = 1 + BLUE = 2 + YELLOW = 3 + + class Status(Enum): + SUCCESS = 1 + FAILURE = 2 + UNKNOWN = 3 + + ...and if we call `try_expanding_enum_to_union(Union[Color, Status], 'module.Color')`, + this function will return Literal[Color.RED, Color.BLUE, Color.YELLOW, Status]. + """ + typ = get_proper_type(typ) + + if isinstance(typ, UnionType): + items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] + return make_simplified_union(items) + elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname: + new_items = [] + for name, symbol in typ.type.names.items(): + if not isinstance(symbol.node, Var): + continue + new_items.append(LiteralType(name, typ)) + # SymbolTables are really just dicts, and dicts are guaranteed to preserve + # insertion order only starting with Python 3.7. So, we sort these for older + # versions of Python to help make tests deterministic. + # + # We could probably skip the sort for Python 3.6 since people probably run mypy + # only using CPython, but we might as well for the sake of full correctness. + if sys.version_info < (3, 7): + new_items.sort(key=lambda lit: lit.value) + return make_simplified_union(new_items) + else: + return typ + + +def coerce_to_literal(typ: Type) -> ProperType: + """Recursively converts any Instances that have a last_known_value or are + instances of enum types with a single value into the corresponding LiteralType. + """ + typ = get_proper_type(typ) + if isinstance(typ, UnionType): + new_items = [coerce_to_literal(item) for item in typ.items] + return make_simplified_union(new_items) + elif isinstance(typ, Instance): + if typ.last_known_value: + return typ.last_known_value + elif typ.type.is_enum: + enum_values = get_enum_values(typ) + if len(enum_values) == 1: + return LiteralType(value=enum_values[0], fallback=typ) + return typ diff --git a/mypy/types.py b/mypy/types.py index 34899c40f824..e85f8dae266f 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -693,10 +693,11 @@ class Instance(ProperType): The list of type variables may be empty. """ - __slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'last_known_value') + __slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'metadata', 'last_known_value') def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type], line: int = -1, column: int = -1, erased: bool = False, + metadata: Optional[Dict[str, JsonDict]] = None, last_known_value: Optional['LiteralType'] = None) -> None: super().__init__(line, column) self.type = typ @@ -709,6 +710,13 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type], # True if recovered after incorrect number of type arguments error self.invalid = False + # This is a dictionary that will be serialized and un-serialized as is. + # It is useful for plugins to add their data to save in the cache. + if metadata: + self.metadata = metadata # type: Dict[str, JsonDict] + else: + self.metadata = {} + # This field keeps track of the underlying Literal[...] value associated with # this instance, if one is known. # @@ -776,6 +784,8 @@ def serialize(self) -> Union[JsonDict, str]: } # type: JsonDict data['type_ref'] = type_ref data['args'] = [arg.serialize() for arg in self.args] + if self.metadata: + data['metadata'] = self.metadata if self.last_known_value is not None: data['last_known_value'] = self.last_known_value.serialize() return data @@ -794,12 +804,15 @@ def deserialize(cls, data: Union[JsonDict, str]) -> 'Instance': args = [deserialize_type(arg) for arg in args_list] inst = Instance(NOT_READY, args) inst.type_ref = data['type_ref'] # Will be fixed up by fixup.py later. + if 'metadata' in data: + inst.metadata = data['metadata'] if 'last_known_value' in data: inst.last_known_value = LiteralType.deserialize(data['last_known_value']) return inst def copy_modified(self, *, args: Bogus[List[Type]] = _dummy, + metadata: Bogus[Dict[str, JsonDict]] = _dummy, last_known_value: Bogus[Optional['LiteralType']] = _dummy) -> 'Instance': return Instance( self.type, @@ -807,6 +820,7 @@ def copy_modified(self, *, self.line, self.column, self.erased, + metadata if metadata is not _dummy else self.metadata, last_known_value if last_known_value is not _dummy else self.last_known_value, ) diff --git a/test-data/unit/fixtures/floatdict.pyi b/test-data/unit/fixtures/floatdict.pyi index 7d2f55a6f6dd..7baa7ca9206f 100644 --- a/test-data/unit/fixtures/floatdict.pyi +++ b/test-data/unit/fixtures/floatdict.pyi @@ -36,7 +36,7 @@ class list(Iterable[T], Generic[T]): def append(self, x: T) -> None: pass def extend(self, x: Iterable[T]) -> None: pass -class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]): +class dict(Mapping[KT, VT], Generic[KT, VT]): @overload def __init__(self, **kwargs: VT) -> None: pass @overload diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index f3a88ca47dcc..48d2014879f0 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1479,3 +1479,133 @@ def f_suppresses() -> int: [out] _testUnreachableWithStdlibContextManagersNoStrictOptional.py:9: error: Statement is unreachable _testUnreachableWithStdlibContextManagersNoStrictOptional.py:15: error: Statement is unreachable + +[case testRegexPluginBasicCase] +# mypy: strict-optional +import re +from typing_extensions import Final + +pattern1: Final = re.compile("(foo)*(bar)") +match1: Final = pattern1.match("blah") +if match1: + reveal_type(match1.groups()) + reveal_type(match1.groups(default="test")) + reveal_type(match1.group(0)) + reveal_type(match1.group(1)) + reveal_type(match1.group(2)) + reveal_type(match1.group(0, 1, 2)) + +pattern2: Final = re.compile(b"(?Pfoo){0,4} (?Pbar)") +match2: Final = pattern2.search(b"blah") +if match2: + reveal_type(match2.groups()) + reveal_type(match2[0]) + reveal_type(match2[1]) + reveal_type(match2[2]) + reveal_type(match2["n1"]) + reveal_type(match2["n2"]) +[out] +_testRegexPluginBasicCase.py:8: note: Revealed type is 'Tuple[Union[builtins.str*, None], builtins.str*]' +_testRegexPluginBasicCase.py:9: note: Revealed type is 'Tuple[builtins.str*, builtins.str*]' +_testRegexPluginBasicCase.py:10: note: Revealed type is 'builtins.str*' +_testRegexPluginBasicCase.py:11: note: Revealed type is 'Union[builtins.str*, None]' +_testRegexPluginBasicCase.py:12: note: Revealed type is 'builtins.str*' +_testRegexPluginBasicCase.py:13: note: Revealed type is 'Tuple[builtins.str*, Union[builtins.str*, None], builtins.str*]' +_testRegexPluginBasicCase.py:18: note: Revealed type is 'Tuple[Union[builtins.bytes*, None], builtins.bytes*]' +_testRegexPluginBasicCase.py:19: note: Revealed type is 'builtins.bytes*' +_testRegexPluginBasicCase.py:20: note: Revealed type is 'Union[builtins.bytes*, None]' +_testRegexPluginBasicCase.py:21: note: Revealed type is 'builtins.bytes*' +_testRegexPluginBasicCase.py:22: note: Revealed type is 'Union[builtins.bytes*, None]' +_testRegexPluginBasicCase.py:23: note: Revealed type is 'builtins.bytes*' + +[case testRegexPluginNoFinal] +# mypy: strict-optional +import re + +pattern = re.compile("(foo)*(bar)") +match = pattern.match("blah") +if match: + # TODO: Consider typeshed so we default to using stricter types given ambiguity + reveal_type(match.groups()) + reveal_type(match[1]) +[out] +_testRegexPluginNoFinal.py:8: note: Revealed type is 'typing.Sequence[builtins.str*]' +_testRegexPluginNoFinal.py:9: note: Revealed type is 'builtins.str*' + +[case testRegexPluginErrors] +# mypy: strict-optional +import re +from typing_extensions import Final + +invalid1 = re.compile("(bad") +invalid2: Final = re.compile("(bad") + +pattern: Final = re.compile("(a)(b)*(?Pc)") +match: Final = pattern.fullmatch("blah") +if match: + match.group(5) + match.group("bad") +[out] +_testRegexPluginErrors.py:5: error: Invalid regex: missing ), unterminated subpattern +_testRegexPluginErrors.py:6: error: Invalid regex: missing ), unterminated subpattern +_testRegexPluginErrors.py:11: error: Regex has 4 total groups, given group number 5 is too big +_testRegexPluginErrors.py:12: error: Regex does not contain group named 'bad' + +[case testRegexPluginDirectMethods] +# mypy: strict-optional +import re +from typing_extensions import Final + +match: Final = re.search("(foo)*(bar)", "blah") +if match: + reveal_type(match.groups()) + reveal_type(match.groups(default="test")) + reveal_type(match[0]) + reveal_type(match[1]) + reveal_type(match[2]) + reveal_type(match.group(0, 1, 2)) +[out] +_testRegexPluginDirectMethods.py:7: note: Revealed type is 'Tuple[Union[builtins.str*, None], builtins.str*]' +_testRegexPluginDirectMethods.py:8: note: Revealed type is 'Tuple[builtins.str*, builtins.str*]' +_testRegexPluginDirectMethods.py:9: note: Revealed type is 'builtins.str*' +_testRegexPluginDirectMethods.py:10: note: Revealed type is 'Union[builtins.str*, None]' +_testRegexPluginDirectMethods.py:11: note: Revealed type is 'builtins.str*' +_testRegexPluginDirectMethods.py:12: note: Revealed type is 'Tuple[builtins.str*, Union[builtins.str*, None], builtins.str*]' + +[case testRegexPluginUnknownArg] +# mypy: strict-optional +import re +from typing_extensions import Final + +index: int +name: str + +pattern1: Final = re.compile("(foo)*(bar)(?Pbaz)?(?Pqux)") +match1: Final = pattern1.match("blah") +if match1: + reveal_type(match1.groups()) + reveal_type(match1[index]) + reveal_type(match1[name]) + reveal_type(match1.group(0, index, name)) + +pattern2: Final = re.compile("(foo)(?Pbar)") +match2: Final = pattern2.match("blah") +if match2: + # No optional groups, so we can always return str + reveal_type(match2.groups()) + reveal_type(match2[index]) + reveal_type(match2[name]) + reveal_type(match2.group(0, index, name)) + match2["bad"] + match2[5] +[out] +_testRegexPluginUnknownArg.py:11: note: Revealed type is 'Tuple[Union[builtins.str*, None], builtins.str*, Union[builtins.str*, None], builtins.str*]' +_testRegexPluginUnknownArg.py:12: note: Revealed type is 'Union[builtins.str*, None]' +_testRegexPluginUnknownArg.py:13: note: Revealed type is 'Union[builtins.str*, None]' +_testRegexPluginUnknownArg.py:14: note: Revealed type is 'Tuple[builtins.str*, Union[builtins.str*, None], Union[builtins.str*, None]]' +_testRegexPluginUnknownArg.py:20: note: Revealed type is 'Tuple[builtins.str*, builtins.str*]' +_testRegexPluginUnknownArg.py:21: note: Revealed type is 'builtins.str*' +_testRegexPluginUnknownArg.py:22: note: Revealed type is 'builtins.str*' +_testRegexPluginUnknownArg.py:23: note: Revealed type is 'Tuple[builtins.str*, builtins.str*, builtins.str*]' +_testRegexPluginUnknownArg.py:24: error: Regex does not contain group named 'bad' +_testRegexPluginUnknownArg.py:25: error: Regex has 3 total groups, given group number 5 is too big From 9a7097ff7046bce49bf88b5a9c3bc3fdd33f3e57 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Sun, 27 Oct 2019 15:56:13 -0700 Subject: [PATCH 2/7] Fix compatibility issues with other Python versions --- mypy/plugins/regex.py | 17 ++++++++++++++--- mypy/test/testplugin.py | 10 ++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/mypy/plugins/regex.py b/mypy/plugins/regex.py index c710e262f42c..de467cf4d429 100644 --- a/mypy/plugins/regex.py +++ b/mypy/plugins/regex.py @@ -28,6 +28,8 @@ from typing import Union, Iterator, Tuple, List, Any, Optional, Dict from typing_extensions import Final +import sys + from mypy.types import ( Type, ProperType, Instance, NoneType, LiteralType, TupleType, remove_optional, @@ -95,7 +97,11 @@ def find_mandatory_groups(ast: Union[SubPattern, Tuple[NIC, Any]]) -> Iterator[i for op, av in data: if op is SUBPATTERN: - group, _, _, children = av + # Use relative indexing for maximum compatibility: + # av contains just these two elements in Python 3.5 + # but four elements for newer Pythons. + group, children = av[0], av[-1] + # This can be 'None' for "extension notation groups" if group is not None: yield group @@ -134,8 +140,13 @@ def extract_regex_group_info(pattern: str) -> Tuple[List[int], int, Dict[str, in raise RegexPluginException("Invalid regex: {}".format(ex.msg)) mandatory_groups = [0] + list(sorted(find_mandatory_groups(ast))) - total_groups = ast.pattern.groups - named_groups = ast.pattern.groupdict + + if sys.version_info >= (3, 8, 0): + state = ast.state + else: + state = ast.pattern + total_groups = state.groups + named_groups = state.groupdict return mandatory_groups, total_groups, named_groups diff --git a/mypy/test/testplugin.py b/mypy/test/testplugin.py index 137f5e2db894..0f9cba204cf5 100644 --- a/mypy/test/testplugin.py +++ b/mypy/test/testplugin.py @@ -1,4 +1,5 @@ from typing import List, Dict +import sys from mypy.test.helpers import Suite, assert_equal from mypy.plugins.regex import extract_regex_group_info, RegexPluginException @@ -47,14 +48,15 @@ def check(pattern: str, check(r"(?:x)(m1)", [0, 1], 2, {}) check(r"(?:x)", [0], 1, {}) - # Flag groups + # Flag groups, added in Python 3.6. # Note: Doing re.compile("(?a)foo") is equivalent to doing # re.compile("foo", flags=re.A). You can also use inline # flag groups "(?FLAGS:PATTERN)" to apply flags just for # the specified pattern. - check(r"(?a)(?i)x", [0], 1, {}) - check(r"(?ai)x", [0], 1, {}) - check(r"(?a:(m1)(o2)*(?Pm3))", [0, 1, 3], 4, {'n3': 3}) + if sys.version_info >= (3, 6): + check(r"(?s)(?i)x", [0], 1, {}) + check(r"(?si)x", [0], 1, {}) + check(r"(?s:(m1)(o2)*(?Pm3))", [0, 1, 3], 4, {'n3': 3}) # Lookahead assertions check(r"(m1) (?=x) (m2)", [0, 1, 2], 3, {}) From 45168cd7fa4118b0c25efd30844f3be8d081bf84 Mon Sep 17 00:00:00 2001 From: Jingchen Ye <97littleleaf11@gmail.com> Date: Wed, 5 Jan 2022 19:46:21 +0800 Subject: [PATCH 3/7] Fix --- mypy/checker.py | 2 +- mypy/checkexpr.py | 4 ++-- mypy/typeops.py | 2 -- mypy/types.py | 2 +- 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 28b4cbecfe61..fea9b47e8586 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4988,7 +4988,7 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance: the name refers to a compatible generic type. """ info = self.lookup_typeinfo(name) - args = [remove_instance_last_known_values(arg) for arg in args] + args = [remove_instance_transient_info(arg) for arg in args] # TODO: assert len(args) == len(info.defn.type_vars) return Instance(info, args) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index dfac5be27d95..2963cacfaaf8 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -41,7 +41,7 @@ import mypy.checker from mypy import types from mypy.sametypes import is_same_type -from mypy.erasetype import replace_meta_vars, erase_type, remove_instance_last_known_values +from mypy.erasetype import replace_meta_vars, erase_type, remove_instance_transient_info from mypy.maptype import map_instance_to_supertype from mypy.messages import MessageBuilder from mypy import message_registry @@ -3334,7 +3334,7 @@ def check_lst_expr(self, items: List[Expression], fullname: str, [(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS) for i in items], context)[0] - return remove_instance_last_known_values(out) + return remove_instance_transient_info(out) def visit_tuple_expr(self, e: TupleExpr) -> Type: """Type check a tuple expression.""" diff --git a/mypy/typeops.py b/mypy/typeops.py index 124f54eb53d0..f1d9f2daed55 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -10,8 +10,6 @@ import itertools import sys -import sys - from mypy.types import ( TupleType, Instance, FunctionLike, Type, CallableType, TypeVarLikeType, Overloaded, TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, diff --git a/mypy/types.py b/mypy/types.py index 6089a6413ec3..1b15ab3b8c58 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -958,7 +958,7 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: Sequence[Type], # this instance, if one is known. # # This field is set whenever possible within expressions, but is erased upon - # variable assignment (see erasetype.remove_instance_last_known_values) unless + # variable assignment (see erasetype.remove_instance_transient_info) unless # the variable is declared to be final. # # For example, consider the following program: From 9d468df0fdafd28f7628891b5eb52701a39a3b5c Mon Sep 17 00:00:00 2001 From: Jingchen Ye <97littleleaf11@gmail.com> Date: Wed, 5 Jan 2022 20:18:45 +0800 Subject: [PATCH 4/7] Fix --- mypy/plugins/regex.py | 20 +++++++++----------- mypy/types.py | 12 +++++------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/mypy/plugins/regex.py b/mypy/plugins/regex.py index de467cf4d429..01e9fda0ad93 100644 --- a/mypy/plugins/regex.py +++ b/mypy/plugins/regex.py @@ -43,11 +43,11 @@ error as SreError, _NamedIntConstant as NIC, ) -STR_LIKE_TYPES = { +STR_LIKE_TYPES: Final = { 'builtins.unicode', 'builtins.str', 'builtins.bytes', -} # type: Final +} FUNCTIONS_PRODUCING_MATCH_OBJECT = { 're.search', @@ -55,11 +55,11 @@ 're.fullmatch', } -METHODS_PRODUCING_MATCH_OBJECT = { +METHODS_PRODUCING_MATCH_OBJECT: Final = { 'typing.Pattern.search', 'typing.Pattern.match', 'typing.Pattern.fullmatch', -} # type: Final +} METHODS_PRODUCING_GROUP = { 'typing.Match.group', @@ -89,7 +89,7 @@ def find_mandatory_groups(ast: Union[SubPattern, Tuple[NIC, Any]]) -> Iterator[i function only group numbers that can actually be found in the AST. """ if isinstance(ast, tuple): - data = [ast] # type: List[Tuple[NIC, Any]] + data: List[Tuple[NIC, Any]] = [ast] elif isinstance(ast, SubPattern): data = ast.data else: @@ -151,10 +151,8 @@ def extract_regex_group_info(pattern: str) -> Tuple[List[int], int, Dict[str, in return mandatory_groups, total_groups, named_groups -def analyze_regex_pattern_call( - pattern_type: Type, - default_return_type: Type, -) -> Type: +def analyze_regex_pattern_call(pattern_type: Type, + default_return_type: Type) -> Type: """The re module contains several methods or functions that accept some string containing a regex pattern and returns either a typing.Pattern or typing.Match object. @@ -273,7 +271,7 @@ def re_match_groups_callback(ctx: mypy.plugin.MethodContext) -> Type: optional_match_type = make_simplified_union([mandatory_match_type, default_type]) - items = [] # type: List[Type] + items: List[Type] = [] for i in range(1, total): if i in mandatory: items.append(mandatory_match_type) @@ -305,7 +303,7 @@ def re_match_group_callback(ctx: mypy.plugin.MethodContext) -> Type: if len(arg_type) >= 1: possible_indices.append(coerce_to_literal(arg_type[0])) - outputs = [] # type: List[Type] + outputs: List[Type] = [] for possible_index in possible_indices: if not isinstance(possible_index, LiteralType): outputs.append(optional_match_type) diff --git a/mypy/types.py b/mypy/types.py index 1b15ab3b8c58..66a8911c6a72 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -950,7 +950,7 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: Sequence[Type], # This is a dictionary that will be serialized and un-serialized as is. # It is useful for plugins to add their data to save in the cache. if metadata: - self.metadata = metadata # type: Dict[str, JsonDict] + self.metadata: Dict[str, JsonDict] = metadata else: self.metadata = {} @@ -1017,11 +1017,9 @@ def serialize(self) -> Union[JsonDict, str]: type_ref = self.type.fullname if not self.args and not self.last_known_value: return type_ref - data: JsonDict = { - ".class": "Instance", - } - data["type_ref"] = type_ref - data["args"] = [arg.serialize() for arg in self.args] + data: JsonDict = {".class": "Instance", + "type_ref": type_ref, + "args": [arg.serialize() for arg in self.args]} if self.metadata: data['metadata'] = self.metadata if self.last_known_value is not None: @@ -1058,8 +1056,8 @@ def copy_modified(self, *, args if args is not _dummy else self.args, self.line, self.column, - metadata if metadata is not _dummy else self.metadata, erased if erased is not _dummy else self.erased, + metadata if metadata is not _dummy else self.metadata, last_known_value if last_known_value is not _dummy else self.last_known_value, ) From ced8cdaec0d913c8db5dc28bb806fd21cac91df0 Mon Sep 17 00:00:00 2001 From: Jingchen Ye <97littleleaf11@gmail.com> Date: Wed, 5 Jan 2022 20:21:22 +0800 Subject: [PATCH 5/7] Fix calling on a property --- mypy/plugins/regex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/plugins/regex.py b/mypy/plugins/regex.py index 01e9fda0ad93..c8e4cd6da3da 100644 --- a/mypy/plugins/regex.py +++ b/mypy/plugins/regex.py @@ -165,13 +165,13 @@ def analyze_regex_pattern_call(pattern_type: Type, pattern_type = coerce_to_literal(pattern_type) if not isinstance(pattern_type, LiteralType): return default_return_type - if pattern_type.fallback.type.fullname() not in STR_LIKE_TYPES: + if pattern_type.fallback.type.fullname not in STR_LIKE_TYPES: return default_return_type return_type = get_proper_type(default_return_type) if not isinstance(return_type, Instance): return default_return_type - if return_type.type.fullname() not in OBJECTS_SUPPORTING_REGEX_METADATA: + if return_type.type.fullname not in OBJECTS_SUPPORTING_REGEX_METADATA: return default_return_type pattern = pattern_type.value From f54eff29cb842de1cc6596d5afca81bf855eca73 Mon Sep 17 00:00:00 2001 From: Jingchen Ye <97littleleaf11@gmail.com> Date: Wed, 5 Jan 2022 20:30:37 +0800 Subject: [PATCH 6/7] Fix type check --- mypy/plugins/regex.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mypy/plugins/regex.py b/mypy/plugins/regex.py index c8e4cd6da3da..8b1777a031ae 100644 --- a/mypy/plugins/regex.py +++ b/mypy/plugins/regex.py @@ -141,7 +141,7 @@ def extract_regex_group_info(pattern: str) -> Tuple[List[int], int, Dict[str, in mandatory_groups = [0] + list(sorted(find_mandatory_groups(ast))) - if sys.version_info >= (3, 8, 0): + if sys.version_info >= (3, 8): state = ast.state else: state = ast.pattern @@ -162,7 +162,7 @@ def analyze_regex_pattern_call(pattern_type: Type, these cases. """ - pattern_type = coerce_to_literal(pattern_type) + pattern_type = get_proper_type(coerce_to_literal(pattern_type)) if not isinstance(pattern_type, LiteralType): return default_return_type if pattern_type.fallback.type.fullname not in STR_LIKE_TYPES: @@ -245,7 +245,7 @@ def re_get_match_callback(ctx: mypy.plugin.MethodContext) -> Type: if not isinstance(self_type, Instance) or 'default_re_plugin' not in self_type.metadata: return return_type - match_object = remove_optional(return_type) + match_object = get_proper_type(remove_optional(return_type)) assert isinstance(match_object, Instance) pattern_metadata = self_type.metadata['default_re_plugin'] @@ -301,7 +301,7 @@ def re_match_group_callback(ctx: mypy.plugin.MethodContext) -> Type: possible_indices = [] for arg_type in ctx.arg_types: if len(arg_type) >= 1: - possible_indices.append(coerce_to_literal(arg_type[0])) + possible_indices.append(get_proper_type(coerce_to_literal(arg_type[0]))) outputs: List[Type] = [] for possible_index in possible_indices: @@ -310,7 +310,7 @@ def re_match_group_callback(ctx: mypy.plugin.MethodContext) -> Type: continue value = possible_index.value - fallback_name = possible_index.fallback.type.fullname() + fallback_name = possible_index.fallback.type.fullname if isinstance(value, str) and fallback_name in STR_LIKE_TYPES: if value not in named_groups: From 3b38dc9295433e6357b8b2e5669278edfba5b768 Mon Sep 17 00:00:00 2001 From: Jingchen Ye <97littleleaf11@gmail.com> Date: Wed, 5 Jan 2022 20:56:06 +0800 Subject: [PATCH 7/7] Fix test --- test-data/unit/pythoneval.test | 56 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 40bc1a441b8d..d8ae6c829ba5 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1598,18 +1598,18 @@ if match2: reveal_type(match2["n1"]) reveal_type(match2["n2"]) [out] -_testRegexPluginBasicCase.py:8: note: Revealed type is 'Tuple[Union[builtins.str*, None], builtins.str*]' -_testRegexPluginBasicCase.py:9: note: Revealed type is 'Tuple[builtins.str*, builtins.str*]' -_testRegexPluginBasicCase.py:10: note: Revealed type is 'builtins.str*' -_testRegexPluginBasicCase.py:11: note: Revealed type is 'Union[builtins.str*, None]' -_testRegexPluginBasicCase.py:12: note: Revealed type is 'builtins.str*' -_testRegexPluginBasicCase.py:13: note: Revealed type is 'Tuple[builtins.str*, Union[builtins.str*, None], builtins.str*]' -_testRegexPluginBasicCase.py:18: note: Revealed type is 'Tuple[Union[builtins.bytes*, None], builtins.bytes*]' -_testRegexPluginBasicCase.py:19: note: Revealed type is 'builtins.bytes*' -_testRegexPluginBasicCase.py:20: note: Revealed type is 'Union[builtins.bytes*, None]' -_testRegexPluginBasicCase.py:21: note: Revealed type is 'builtins.bytes*' -_testRegexPluginBasicCase.py:22: note: Revealed type is 'Union[builtins.bytes*, None]' -_testRegexPluginBasicCase.py:23: note: Revealed type is 'builtins.bytes*' +_testRegexPluginBasicCase.py:8: note: Revealed type is "Tuple[Union[builtins.str*, None], builtins.str*]" +_testRegexPluginBasicCase.py:9: note: Revealed type is "Tuple[builtins.str*, builtins.str*]" +_testRegexPluginBasicCase.py:10: note: Revealed type is "builtins.str*" +_testRegexPluginBasicCase.py:11: note: Revealed type is "Union[builtins.str*, None]" +_testRegexPluginBasicCase.py:12: note: Revealed type is "builtins.str*" +_testRegexPluginBasicCase.py:13: note: Revealed type is "Tuple[builtins.str*, Union[builtins.str*, None], builtins.str*]" +_testRegexPluginBasicCase.py:18: note: Revealed type is "Tuple[Union[builtins.bytes*, None], builtins.bytes*]" +_testRegexPluginBasicCase.py:19: note: Revealed type is "builtins.bytes*" +_testRegexPluginBasicCase.py:20: note: Revealed type is "Union[builtins.bytes*, None]" +_testRegexPluginBasicCase.py:21: note: Revealed type is "builtins.bytes*" +_testRegexPluginBasicCase.py:22: note: Revealed type is "Union[builtins.bytes*, None]" +_testRegexPluginBasicCase.py:23: note: Revealed type is "builtins.bytes*" [case testRegexPluginNoFinal] # mypy: strict-optional @@ -1622,8 +1622,8 @@ if match: reveal_type(match.groups()) reveal_type(match[1]) [out] -_testRegexPluginNoFinal.py:8: note: Revealed type is 'typing.Sequence[builtins.str*]' -_testRegexPluginNoFinal.py:9: note: Revealed type is 'builtins.str*' +_testRegexPluginNoFinal.py:8: note: Revealed type is "builtins.tuple[Union[builtins.str*, Any], ...]" +_testRegexPluginNoFinal.py:9: note: Revealed type is "Union[builtins.str*, Any]" [case testRegexPluginErrors] # mypy: strict-optional @@ -1658,12 +1658,12 @@ if match: reveal_type(match[2]) reveal_type(match.group(0, 1, 2)) [out] -_testRegexPluginDirectMethods.py:7: note: Revealed type is 'Tuple[Union[builtins.str*, None], builtins.str*]' -_testRegexPluginDirectMethods.py:8: note: Revealed type is 'Tuple[builtins.str*, builtins.str*]' -_testRegexPluginDirectMethods.py:9: note: Revealed type is 'builtins.str*' -_testRegexPluginDirectMethods.py:10: note: Revealed type is 'Union[builtins.str*, None]' -_testRegexPluginDirectMethods.py:11: note: Revealed type is 'builtins.str*' -_testRegexPluginDirectMethods.py:12: note: Revealed type is 'Tuple[builtins.str*, Union[builtins.str*, None], builtins.str*]' +_testRegexPluginDirectMethods.py:7: note: Revealed type is "Tuple[Union[builtins.str*, None], builtins.str*]" +_testRegexPluginDirectMethods.py:8: note: Revealed type is "Tuple[builtins.str*, builtins.str*]" +_testRegexPluginDirectMethods.py:9: note: Revealed type is "builtins.str*" +_testRegexPluginDirectMethods.py:10: note: Revealed type is "Union[builtins.str*, None]" +_testRegexPluginDirectMethods.py:11: note: Revealed type is "builtins.str*" +_testRegexPluginDirectMethods.py:12: note: Revealed type is "Tuple[builtins.str*, Union[builtins.str*, None], builtins.str*]" [case testRegexPluginUnknownArg] # mypy: strict-optional @@ -1692,13 +1692,13 @@ if match2: match2["bad"] match2[5] [out] -_testRegexPluginUnknownArg.py:11: note: Revealed type is 'Tuple[Union[builtins.str*, None], builtins.str*, Union[builtins.str*, None], builtins.str*]' -_testRegexPluginUnknownArg.py:12: note: Revealed type is 'Union[builtins.str*, None]' -_testRegexPluginUnknownArg.py:13: note: Revealed type is 'Union[builtins.str*, None]' -_testRegexPluginUnknownArg.py:14: note: Revealed type is 'Tuple[builtins.str*, Union[builtins.str*, None], Union[builtins.str*, None]]' -_testRegexPluginUnknownArg.py:20: note: Revealed type is 'Tuple[builtins.str*, builtins.str*]' -_testRegexPluginUnknownArg.py:21: note: Revealed type is 'builtins.str*' -_testRegexPluginUnknownArg.py:22: note: Revealed type is 'builtins.str*' -_testRegexPluginUnknownArg.py:23: note: Revealed type is 'Tuple[builtins.str*, builtins.str*, builtins.str*]' +_testRegexPluginUnknownArg.py:11: note: Revealed type is "Tuple[Union[builtins.str*, None], builtins.str*, Union[builtins.str*, None], builtins.str*]" +_testRegexPluginUnknownArg.py:12: note: Revealed type is "Union[builtins.str*, None]" +_testRegexPluginUnknownArg.py:13: note: Revealed type is "Union[builtins.str*, None]" +_testRegexPluginUnknownArg.py:14: note: Revealed type is "Tuple[builtins.str*, Union[builtins.str*, None], Union[builtins.str*, None]]" +_testRegexPluginUnknownArg.py:20: note: Revealed type is "Tuple[builtins.str*, builtins.str*]" +_testRegexPluginUnknownArg.py:21: note: Revealed type is "builtins.str*" +_testRegexPluginUnknownArg.py:22: note: Revealed type is "builtins.str*" +_testRegexPluginUnknownArg.py:23: note: Revealed type is "Tuple[builtins.str*, builtins.str*, builtins.str*]" _testRegexPluginUnknownArg.py:24: error: Regex does not contain group named 'bad' _testRegexPluginUnknownArg.py:25: error: Regex has 3 total groups, given group number 5 is too big