diff --git a/mypy/checker.py b/mypy/checker.py index b90221a0a5a5..0abd744b8aa6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2,11 +2,12 @@ import itertools import fnmatch +from collections import defaultdict from contextlib import contextmanager from typing import ( Any, Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, - Iterable, Sequence, Mapping, Generic, AbstractSet, Callable + Iterable, Sequence, Mapping, Generic, AbstractSet, Callable, overload ) from typing_extensions import Final, TypeAlias as _TypeAlias @@ -25,8 +26,7 @@ Import, ImportFrom, ImportAll, ImportBase, TypeAlias, ARG_POS, ARG_STAR, LITERAL_TYPE, LDEF, MDEF, GDEF, CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr, - is_final_node, - ARG_NAMED) + is_final_node, ARG_NAMED, MatchStmt) from mypy import nodes from mypy import operators from mypy.literals import literal, literal_hash, Key @@ -50,6 +50,7 @@ type_object_type, analyze_decorator_or_funcbase_access, ) +from mypy.checkpattern import PatternChecker from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS from mypy.typeops import ( map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, @@ -171,6 +172,8 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): # Helper for type checking expressions expr_checker: mypy.checkexpr.ExpressionChecker + pattern_checker: PatternChecker + tscope: Scope scope: "CheckerScope" # Stack of function return types @@ -235,6 +238,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option self.msg = MessageBuilder(errors, modules) self.plugin = plugin self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg, self.plugin) + self.pattern_checker = PatternChecker(self, self.msg, self.plugin) self.tscope = Scope() self.scope = CheckerScope(tree) self.binder = ConditionalTypeBinder() @@ -1434,6 +1438,19 @@ def check_setattr_method(self, typ: Type, context: Context) -> None: if not is_subtype(typ, method_type): self.msg.invalid_signature_for_special_method(typ, context, '__setattr__') + def check_match_args(self, var: Var, typ: Type, context: Context) -> None: + """Check that __match_args__ is final and contains literal strings""" + + if not var.is_final: + self.note("__match_args__ must be final for checking of match statements to work", + context, code=codes.LITERAL_REQ) + + typ = get_proper_type(typ) + if not isinstance(typ, TupleType) or \ + not all([is_string_literal(item) for item in typ.items]): + self.msg.note("__match_args__ must be a tuple containing string literals for checking " + "of match statements to work", context, code=codes.LITERAL_REQ) + def expand_typevars(self, defn: FuncItem, typ: CallableType) -> List[Tuple[FuncItem, CallableType]]: # TODO use generator @@ -2166,6 +2183,10 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type else: self.check_getattr_method(signature, lvalue, name) + if name == '__match_args__' and inferred is not None: + typ = self.expr_checker.accept(rvalue) + self.check_match_args(inferred, typ, lvalue) + # Defer PartialType's super type checking. if (isinstance(lvalue, RefExpr) and not (isinstance(lvalue_type, PartialType) and lvalue_type.type is None)): @@ -3904,6 +3925,75 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None: self.binder.handle_continue() return None + def visit_match_stmt(self, s: MatchStmt) -> None: + with self.binder.frame_context(can_skip=False, fall_through=0): + subject_type = get_proper_type(self.expr_checker.accept(s.subject)) + + if isinstance(subject_type, DeletedType): + self.msg.deleted_as_rvalue(subject_type, s) + + pattern_types = [self.pattern_checker.accept(p, subject_type) for p in s.patterns] + + type_maps: List[TypeMap] = [t.captures for t in pattern_types] + self.infer_variable_types_from_type_maps(type_maps) + + for pattern_type, g, b in zip(pattern_types, s.guards, s.bodies): + with self.binder.frame_context(can_skip=True, fall_through=2): + if b.is_unreachable or isinstance(get_proper_type(pattern_type.type), + UninhabitedType): + self.push_type_map(None) + else: + self.binder.put(s.subject, pattern_type.type) + self.push_type_map(pattern_type.captures) + if g is not None: + gt = get_proper_type(self.expr_checker.accept(g)) + + if isinstance(gt, DeletedType): + self.msg.deleted_as_rvalue(gt, s) + + if_map, _ = self.find_isinstance_check(g) + + self.push_type_map(if_map) + self.accept(b) + + # This is needed due to a quirk in frame_context. Without it types will stay narrowed + # after the match. + with self.binder.frame_context(can_skip=False, fall_through=2): + pass + + def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> None: + all_captures: Dict[Var, List[Tuple[NameExpr, Type]]] = defaultdict(list) + for tm in type_maps: + if tm is not None: + for expr, typ in tm.items(): + if isinstance(expr, NameExpr): + node = expr.node + assert isinstance(node, Var) + all_captures[node].append((expr, typ)) + + for var, captures in all_captures.items(): + conflict = False + types: List[Type] = [] + for expr, typ in captures: + types.append(typ) + + previous_type, _, inferred = self.check_lvalue(expr) + if previous_type is not None: + conflict = True + self.check_subtype(typ, previous_type, expr, + msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, + subtype_label="pattern captures type", + supertype_label="variable has type") + for type_map in type_maps: + if type_map is not None and expr in type_map: + del type_map[expr] + + if not conflict: + new_type = UnionType.make_union(types) + # Infer the union type at the first occurrence + first_occurrence, _ = captures[0] + self.infer_variable_type(var, first_occurrence, new_type, first_occurrence) + def make_fake_typeinfo(self, curr_module_fullname: str, class_gen_name: str, @@ -4268,11 +4358,14 @@ def is_type_call(expr: CallExpr) -> bool: if_maps: List[TypeMap] = [] else_maps: List[TypeMap] = [] for expr in exprs_in_type_calls: - current_if_map, current_else_map = self.conditional_type_map_with_intersection( - expr, + current_if_type, current_else_type = self.conditional_types_with_intersection( type_map[expr], - type_being_compared + type_being_compared, + expr ) + current_if_map, current_else_map = conditional_types_to_typemaps(expr, + current_if_type, + current_else_type) if_maps.append(current_if_map) else_maps.append(current_else_map) @@ -4328,10 +4421,13 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: - return self.conditional_type_map_with_intersection( + return conditional_types_to_typemaps( expr, - type_map[expr], - get_isinstance_type(node.args[1], type_map), + *self.conditional_types_with_intersection( + type_map[expr], + get_isinstance_type(node.args[1], type_map), + expr + ) ) elif refers_to_fullname(node.callee, 'builtins.issubclass'): if len(node.args) != 2: # the error will be reported elsewhere @@ -4850,10 +4946,11 @@ def refine_identity_comparison_expression(self, if sum_type_name is not None: expr_type = try_expanding_sum_type_to_union(expr_type, sum_type_name) - # We intentionally use 'conditional_type_map' directly here instead of - # 'self.conditional_type_map_with_intersection': we only compute ad-hoc + # We intentionally use 'conditional_types' directly here instead of + # 'self.conditional_types_with_intersection': we only compute ad-hoc # intersections when working with pure instances. - partial_type_maps.append(conditional_type_map(expr, expr_type, target_type)) + types = conditional_types(expr_type, target_type) + partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) return reduce_conditional_maps(partial_type_maps) @@ -5280,54 +5377,72 @@ def infer_issubclass_maps(self, node: CallExpr, # Any other object whose type we don't know precisely # for example, Any or a custom metaclass. return {}, {} # unknown type - yes_map, no_map = self.conditional_type_map_with_intersection(expr, vartype, type) + yes_type, no_type = self.conditional_types_with_intersection(vartype, type, expr) + yes_map, no_map = conditional_types_to_typemaps(expr, yes_type, no_type) yes_map, no_map = map(convert_to_typetype, (yes_map, no_map)) return yes_map, no_map - def conditional_type_map_with_intersection(self, - expr: Expression, - expr_type: Type, - type_ranges: Optional[List[TypeRange]], - ) -> Tuple[TypeMap, TypeMap]: - # For some reason, doing "yes_map, no_map = conditional_type_maps(...)" + @overload + def conditional_types_with_intersection(self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + default: None = None + ) -> Tuple[Optional[Type], Optional[Type]]: ... + + @overload + def conditional_types_with_intersection(self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + default: Type + ) -> Tuple[Type, Type]: ... + + def conditional_types_with_intersection(self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + default: Optional[Type] = None + ) -> Tuple[Optional[Type], Optional[Type]]: + initial_types = conditional_types(expr_type, type_ranges, default) + # For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)" # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. - initial_maps = conditional_type_map(expr, expr_type, type_ranges) - yes_map: TypeMap = initial_maps[0] - no_map: TypeMap = initial_maps[1] + yes_type: Optional[Type] = initial_types[0] + no_type: Optional[Type] = initial_types[1] - if yes_map is not None or type_ranges is None: - return yes_map, no_map + if not isinstance(get_proper_type(yes_type), UninhabitedType) or type_ranges is None: + return yes_type, no_type - # If conditions_type_map was unable to successfully narrow the expr_type + # If conditional_types was unable to successfully narrow the expr_type # using the type_ranges and concluded if-branch is unreachable, we try # computing it again using a different algorithm that tries to generate # an ad-hoc intersection between the expr_type and the type_ranges. - expr_type = get_proper_type(expr_type) - if isinstance(expr_type, UnionType): - possible_expr_types = get_proper_types(expr_type.relevant_items()) + proper_type = get_proper_type(expr_type) + if isinstance(proper_type, UnionType): + possible_expr_types = get_proper_types(proper_type.relevant_items()) else: - possible_expr_types = [expr_type] + possible_expr_types = [proper_type] possible_target_types = [] for tr in type_ranges: item = get_proper_type(tr.item) if not isinstance(item, Instance) or tr.is_upper_bound: - return yes_map, no_map + return yes_type, no_type possible_target_types.append(item) out = [] for v in possible_expr_types: if not isinstance(v, Instance): - return yes_map, no_map + return yes_type, no_type for t in possible_target_types: - intersection = self.intersect_instances((v, t), expr) + intersection = self.intersect_instances((v, t), ctx) if intersection is None: continue out.append(intersection) if len(out) == 0: - return None, {} + return UninhabitedType(), expr_type new_yes_type = make_simplified_union(out) - return {expr: new_yes_type}, {} + return new_yes_type, expr_type def is_writable_attribute(self, node: Node) -> bool: """Check if an attribute is writable""" @@ -5340,48 +5455,75 @@ def is_writable_attribute(self, node: Node) -> bool: return False -def conditional_type_map(expr: Expression, - current_type: Optional[Type], - proposed_type_ranges: Optional[List[TypeRange]], - ) -> Tuple[TypeMap, TypeMap]: - """Takes in an expression, the current type of the expression, and a - proposed type of that expression. +@overload +def conditional_types(current_type: Type, + proposed_type_ranges: Optional[List[TypeRange]], + default: None = None + ) -> Tuple[Optional[Type], Optional[Type]]: ... + + +@overload +def conditional_types(current_type: Type, + proposed_type_ranges: Optional[List[TypeRange]], + default: Type + ) -> Tuple[Type, Type]: ... - Returns a 2-tuple: The first element is a map from the expression to - the proposed type, if the expression can be the proposed type. The - second element is a map from the expression to the type it would hold - if it was not the proposed type, if any. None means bot, {} means top""" + +def conditional_types(current_type: Type, + proposed_type_ranges: Optional[List[TypeRange]], + default: Optional[Type] = None + ) -> Tuple[Optional[Type], Optional[Type]]: + """Takes in the current type and a proposed type of an expression. + + Returns a 2-tuple: The first element is the proposed type, if the expression + can be the proposed type. The second element is the type it would hold + if it was not the proposed type, if any. UninhabitedType means unreachable. + None means no new information can be inferred. If default is set it is returned + instead.""" if proposed_type_ranges: proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) - if current_type: - if isinstance(proposed_type, AnyType): - # We don't really know much about the proposed type, so we shouldn't - # attempt to narrow anything. Instead, we broaden the expr to Any to - # avoid false positives - return {expr: proposed_type}, {} - elif (not any(type_range.is_upper_bound for type_range in proposed_type_ranges) - and is_proper_subtype(current_type, proposed_type)): - # Expression is always of one of the types in proposed_type_ranges - return {}, None - elif not is_overlapping_types(current_type, proposed_type, - prohibit_none_typevar_overlap=True): - # Expression is never of any type in proposed_type_ranges - return None, {} - else: - # we can only restrict when the type is precise, not bounded - proposed_precise_type = UnionType.make_union([ - type_range.item - for type_range in proposed_type_ranges - if not type_range.is_upper_bound - ]) - remaining_type = restrict_subtype_away(current_type, proposed_precise_type) - return {expr: proposed_type}, {expr: remaining_type} + if isinstance(proposed_type, AnyType): + # We don't really know much about the proposed type, so we shouldn't + # attempt to narrow anything. Instead, we broaden the expr to Any to + # avoid false positives + return proposed_type, default + elif (not any(type_range.is_upper_bound for type_range in proposed_type_ranges) + and is_proper_subtype(current_type, proposed_type)): + # Expression is always of one of the types in proposed_type_ranges + return default, UninhabitedType() + elif not is_overlapping_types(current_type, proposed_type, + prohibit_none_typevar_overlap=True): + # Expression is never of any type in proposed_type_ranges + return UninhabitedType(), default else: - return {expr: proposed_type}, {} + # we can only restrict when the type is precise, not bounded + proposed_precise_type = UnionType.make_union([type_range.item + for type_range in proposed_type_ranges + if not type_range.is_upper_bound]) + remaining_type = restrict_subtype_away(current_type, proposed_precise_type) + return proposed_type, remaining_type else: # An isinstance check, but we don't understand the type - return {}, {} + return current_type, default + + +def conditional_types_to_typemaps(expr: Expression, + yes_type: Optional[Type], + no_type: Optional[Type] + ) -> Tuple[TypeMap, TypeMap]: + maps: List[TypeMap] = [] + for typ in (yes_type, no_type): + proper_type = get_proper_type(typ) + if isinstance(proper_type, UninhabitedType): + maps.append(None) + elif proper_type is None: + maps.append({}) + else: + assert typ is not None + maps.append({expr: typ}) + + return cast(Tuple[TypeMap, TypeMap], tuple(maps)) def gen_unique_name(base: str, table: SymbolTable) -> str: @@ -6191,6 +6333,11 @@ def is_private(node_name: str) -> bool: return node_name.startswith('__') and not node_name.endswith('__') +def is_string_literal(typ: Type) -> bool: + strs = try_getting_str_literals_from_type(typ) + return strs is not None and len(strs) == 1 + + def has_bool_item(typ: ProperType) -> bool: """Return True if type is 'bool' or a union with a 'bool' item.""" if is_named_instance(typ, 'builtins.bool'): diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 1647339ef217..648d48a639ff 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3076,7 +3076,11 @@ def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) else: return union - def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type: + def visit_typeddict_index_expr(self, td_type: TypedDictType, + index: Expression, + local_errors: Optional[MessageBuilder] = None + ) -> Type: + local_errors = local_errors or self.msg if isinstance(index, (StrExpr, UnicodeExpr)): key_names = [index.value] else: @@ -3096,14 +3100,14 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) and key_type.fallback.type.fullname != 'builtins.bytes'): key_names.append(key_type.value) else: - self.msg.typeddict_key_must_be_string_literal(td_type, index) + local_errors.typeddict_key_must_be_string_literal(td_type, index) return AnyType(TypeOfAny.from_error) value_types = [] for key_name in key_names: value_type = td_type.items.get(key_name) if value_type is None: - self.msg.typeddict_key_not_found(td_type, key_name, index) + local_errors.typeddict_key_not_found(td_type, key_name, index) return AnyType(TypeOfAny.from_error) else: value_types.append(value_type) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py new file mode 100644 index 000000000000..2c40e856be88 --- /dev/null +++ b/mypy/checkpattern.py @@ -0,0 +1,683 @@ +"""Pattern checker. This file is conceptually part of TypeChecker.""" +from collections import defaultdict +from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union +from typing_extensions import Final + +import mypy.checker +from mypy.checkmember import analyze_member_access +from mypy.expandtype import expand_type_by_instance +from mypy.join import join_types +from mypy.literals import literal_hash +from mypy.maptype import map_instance_to_supertype +from mypy.meet import narrow_declared_type +from mypy import message_registry +from mypy.messages import MessageBuilder +from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr +from mypy.patterns import ( + Pattern, AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, + ClassPattern, SingletonPattern +) +from mypy.plugin import Plugin +from mypy.subtypes import is_subtype +from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union +from mypy.types import ( + ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type, + TypedDictType, TupleType, NoneType, UnionType +) +from mypy.typevars import fill_typevars +from mypy.visitor import PatternVisitor + +self_match_type_names: Final = [ + "builtins.bool", + "builtins.bytearray", + "builtins.bytes", + "builtins.dict", + "builtins.float", + "builtins.frozenset", + "builtins.int", + "builtins.list", + "builtins.set", + "builtins.str", + "builtins.tuple", +] + +non_sequence_match_type_names: Final = [ + "builtins.str", + "builtins.bytes", + "builtins.bytearray" +] + + +# For every Pattern a PatternType can be calculated. This requires recursively calculating +# the PatternTypes of the sub-patterns first. +# Using the data in the PatternType the match subject and captured names can be narrowed/inferred. +PatternType = NamedTuple( + 'PatternType', + [ + ('type', Type), # The type the match subject can be narrowed to + ('rest_type', Type), # For exhaustiveness checking. Not used yet + ('captures', Dict[Expression, Type]), # The variables captured by the pattern + ]) + + +class PatternChecker(PatternVisitor[PatternType]): + """Pattern checker. + + This class checks if a pattern can match a type, what the type can be narrowed to, and what + type capture patterns should be inferred as. + """ + + # Some services are provided by a TypeChecker instance. + chk: 'mypy.checker.TypeChecker' + # This is shared with TypeChecker, but stored also here for convenience. + msg: MessageBuilder + # Currently unused + plugin: Plugin + # The expression being matched against the pattern + subject: Expression + + subject_type: Type + # Type of the subject to check the (sub)pattern against + type_context: List[Type] + # Types that match against self instead of their __match_args__ if used as a class pattern + # Filled in from self_match_type_names + self_match_types: List[Type] + # Types that are sequences, but don't match sequence patterns. Filled in from + # non_sequence_match_type_names + non_sequence_match_types: List[Type] + + def __init__(self, + chk: 'mypy.checker.TypeChecker', + msg: MessageBuilder, plugin: Plugin + ) -> None: + self.chk = chk + self.msg = msg + self.plugin = plugin + + self.type_context = [] + self.self_match_types = self.generate_types_from_names(self_match_type_names) + self.non_sequence_match_types = self.generate_types_from_names( + non_sequence_match_type_names + ) + + def accept(self, o: Pattern, type_context: Type) -> PatternType: + self.type_context.append(type_context) + result = o.accept(self) + self.type_context.pop() + + return result + + def visit_as_pattern(self, o: AsPattern) -> PatternType: + current_type = self.type_context[-1] + if o.pattern is not None: + pattern_type = self.accept(o.pattern, current_type) + typ, rest_type, type_map = pattern_type + else: + typ, rest_type, type_map = current_type, UninhabitedType(), {} + + if not is_uninhabited(typ) and o.name is not None: + typ, _ = self.chk.conditional_types_with_intersection(current_type, + [get_type_range(typ)], + o, + default=current_type) + if not is_uninhabited(typ): + type_map[o.name] = typ + + return PatternType(typ, rest_type, type_map) + + def visit_or_pattern(self, o: OrPattern) -> PatternType: + current_type = self.type_context[-1] + + # + # Check all the subpatterns + # + pattern_types = [] + for pattern in o.patterns: + pattern_type = self.accept(pattern, current_type) + pattern_types.append(pattern_type) + current_type = pattern_type.rest_type + + # + # Collect the final type + # + types = [] + for pattern_type in pattern_types: + if not is_uninhabited(pattern_type.type): + types.append(pattern_type.type) + + # + # Check the capture types + # + capture_types: Dict[Var, List[Tuple[Expression, Type]]] = defaultdict(list) + # Collect captures from the first subpattern + for expr, typ in pattern_types[0].captures.items(): + node = get_var(expr) + capture_types[node].append((expr, typ)) + + # Check if other subpatterns capture the same names + for i, pattern_type in enumerate(pattern_types[1:]): + vars = {get_var(expr) for expr, _ in pattern_type.captures.items()} + if capture_types.keys() != vars: + self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i]) + for expr, typ in pattern_type.captures.items(): + node = get_var(expr) + capture_types[node].append((expr, typ)) + + captures: Dict[Expression, Type] = {} + for var, capture_list in capture_types.items(): + typ = UninhabitedType() + for _, other in capture_list: + typ = join_types(typ, other) + + captures[capture_list[0][0]] = typ + + union_type = make_simplified_union(types) + return PatternType(union_type, current_type, captures) + + def visit_value_pattern(self, o: ValuePattern) -> PatternType: + current_type = self.type_context[-1] + typ = self.chk.expr_checker.accept(o.expr) + narrowed_type, rest_type = self.chk.conditional_types_with_intersection( + current_type, + [get_type_range(typ)], + o, + default=current_type + ) + return PatternType(narrowed_type, rest_type, {}) + + def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: + current_type = self.type_context[-1] + value: Union[bool, None] = o.value + if isinstance(value, bool): + typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool") + elif value is None: + typ = NoneType() + else: + assert False + + narrowed_type, rest_type = self.chk.conditional_types_with_intersection( + current_type, + [get_type_range(typ)], + o, + default=current_type + ) + return PatternType(narrowed_type, rest_type, {}) + + def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: + # + # check for existence of a starred pattern + # + current_type = get_proper_type(self.type_context[-1]) + if not self.can_match_sequence(current_type): + return self.early_non_match() + star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)] + star_position: Optional[int] = None + if len(star_positions) == 1: + star_position = star_positions[0] + elif len(star_positions) >= 2: + assert False, "Parser should prevent multiple starred patterns" + required_patterns = len(o.patterns) + if star_position is not None: + required_patterns -= 1 + + # + # get inner types of original type + # + if isinstance(current_type, TupleType): + inner_types = current_type.items + size_diff = len(inner_types) - required_patterns + if size_diff < 0: + return self.early_non_match() + elif size_diff > 0 and star_position is None: + return self.early_non_match() + else: + inner_type = self.get_sequence_type(current_type) + if inner_type is None: + inner_type = self.chk.named_type("builtins.object") + inner_types = [inner_type] * len(o.patterns) + + # + # match inner patterns + # + contracted_new_inner_types: List[Type] = [] + contracted_rest_inner_types: List[Type] = [] + captures: Dict[Expression, Type] = {} + + contracted_inner_types = self.contract_starred_pattern_types(inner_types, + star_position, + required_patterns) + can_match = True + for p, t in zip(o.patterns, contracted_inner_types): + pattern_type = self.accept(p, t) + typ, rest, type_map = pattern_type + if is_uninhabited(typ): + can_match = False + else: + contracted_new_inner_types.append(typ) + contracted_rest_inner_types.append(rest) + self.update_type_map(captures, type_map) + new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, + star_position, + len(inner_types)) + + # + # Calculate new type + # + new_type: Type + rest_type: Type = current_type + if not can_match: + new_type = UninhabitedType() + elif isinstance(current_type, TupleType): + narrowed_inner_types = [] + inner_rest_types = [] + for inner_type, new_inner_type in zip(inner_types, new_inner_types): + narrowed_inner_type, inner_rest_type = \ + self.chk.conditional_types_with_intersection( + new_inner_type, + [get_type_range(inner_type)], + o, + default=new_inner_type + ) + narrowed_inner_types.append(narrowed_inner_type) + inner_rest_types.append(inner_rest_type) + if all(not is_uninhabited(typ) for typ in narrowed_inner_types): + new_type = TupleType(narrowed_inner_types, current_type.partial_fallback) + else: + new_type = UninhabitedType() + + if all(is_uninhabited(typ) for typ in inner_rest_types): + # All subpatterns always match, so we can apply negative narrowing + new_type, rest_type = self.chk.conditional_types_with_intersection( + current_type, [get_type_range(new_type)], o, default=current_type + ) + else: + new_inner_type = UninhabitedType() + for typ in new_inner_types: + new_inner_type = join_types(new_inner_type, typ) + new_type = self.construct_sequence_child(current_type, new_inner_type) + if not is_subtype(new_type, current_type): + new_type = current_type + return PatternType(new_type, rest_type, captures) + + def get_sequence_type(self, t: Type) -> Optional[Type]: + t = get_proper_type(t) + if isinstance(t, AnyType): + return AnyType(TypeOfAny.from_another_any, t) + if isinstance(t, UnionType): + items = [self.get_sequence_type(item) for item in t.items] + not_none_items = [item for item in items if item is not None] + if len(not_none_items) > 0: + return make_simplified_union(not_none_items) + else: + return None + + if self.chk.type_is_iterable(t) and isinstance(t, Instance): + return self.chk.iterable_item_type(t) + else: + return None + + def contract_starred_pattern_types(self, + types: List[Type], + star_pos: Optional[int], + num_patterns: int + ) -> List[Type]: + """ + Contracts a list of types in a sequence pattern depending on the position of a starred + capture pattern. + + For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str, + bytes] the contracted types are [bool, Union[int, str], bytes]. + + If star_pos in None the types are returned unchanged. + """ + if star_pos is None: + return types + new_types = types[:star_pos] + star_length = len(types) - num_patterns + new_types.append(make_simplified_union(types[star_pos:star_pos+star_length])) + new_types += types[star_pos+star_length:] + + return new_types + + def expand_starred_pattern_types(self, + types: List[Type], + star_pos: Optional[int], + num_types: int + ) -> List[Type]: + """ + Undoes the contraction done by contract_starred_pattern_types. + + For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended + to lenght 4 the result is [bool, int, int, str]. + """ + if star_pos is None: + return types + new_types = types[:star_pos] + star_length = num_types - len(types) + 1 + new_types += [types[star_pos]] * star_length + new_types += types[star_pos+1:] + + return new_types + + def visit_starred_pattern(self, o: StarredPattern) -> PatternType: + captures: Dict[Expression, Type] = {} + if o.capture is not None: + list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]]) + captures[o.capture] = list_type + return PatternType(self.type_context[-1], UninhabitedType(), captures) + + def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: + current_type = get_proper_type(self.type_context[-1]) + can_match = True + captures: Dict[Expression, Type] = {} + for key, value in zip(o.keys, o.values): + inner_type = self.get_mapping_item_type(o, current_type, key) + if inner_type is None: + can_match = False + inner_type = self.chk.named_type("builtins.object") + pattern_type = self.accept(value, inner_type) + if is_uninhabited(pattern_type.type): + can_match = False + else: + self.update_type_map(captures, pattern_type.captures) + + if o.rest is not None: + mapping = self.chk.named_type("typing.Mapping") + if is_subtype(current_type, mapping) and isinstance(current_type, Instance): + mapping_inst = map_instance_to_supertype(current_type, mapping.type) + dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict") + dict_type = fill_typevars(dict_typeinfo) + rest_type = expand_type_by_instance(dict_type, mapping_inst) + else: + object_type = self.chk.named_type("builtins.object") + rest_type = self.chk.named_generic_type("builtins.dict", + [object_type, object_type]) + + captures[o.rest] = rest_type + + if can_match: + # We can't narrow the type here, as Mapping key is invariant. + new_type = self.type_context[-1] + else: + new_type = UninhabitedType() + return PatternType(new_type, current_type, captures) + + def get_mapping_item_type(self, + pattern: MappingPattern, + mapping_type: Type, + key: Expression + ) -> Optional[Type]: + local_errors = self.msg.clean_copy() + local_errors.disable_count = 0 + mapping_type = get_proper_type(mapping_type) + if isinstance(mapping_type, TypedDictType): + result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr( + mapping_type, key, local_errors=local_errors) + # If we can't determine the type statically fall back to treating it as a normal + # mapping + if local_errors.is_errors(): + local_errors = self.msg.clean_copy() + local_errors.disable_count = 0 + result = self.get_simple_mapping_item_type(pattern, + mapping_type, + key, + local_errors) + + if local_errors.is_errors(): + result = None + else: + result = self.get_simple_mapping_item_type(pattern, + mapping_type, + key, + local_errors) + return result + + def get_simple_mapping_item_type(self, + pattern: MappingPattern, + mapping_type: Type, + key: Expression, + local_errors: MessageBuilder + ) -> Type: + result, _ = self.chk.expr_checker.check_method_call_by_name('__getitem__', + mapping_type, + [key], + [ARG_POS], + pattern, + local_errors=local_errors) + return result + + def visit_class_pattern(self, o: ClassPattern) -> PatternType: + current_type = get_proper_type(self.type_context[-1]) + + # + # Check class type + # + type_info = o.class_ref.node + assert type_info is not None + if isinstance(type_info, TypeAlias) and not type_info.no_args: + self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o) + return self.early_non_match() + if isinstance(type_info, TypeInfo): + any_type = AnyType(TypeOfAny.implementation_artifact) + typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars)) + elif isinstance(type_info, TypeAlias): + typ = type_info.target + else: + if isinstance(type_info, Var): + name = str(type_info.type) + else: + name = type_info.name + self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o.class_ref) + return self.early_non_match() + + new_type, rest_type = self.chk.conditional_types_with_intersection( + current_type, [get_type_range(typ)], o, default=current_type + ) + if is_uninhabited(new_type): + return self.early_non_match() + # TODO: Do I need this? + narrowed_type = narrow_declared_type(current_type, new_type) + + # + # Convert positional to keyword patterns + # + keyword_pairs: List[Tuple[Optional[str], Pattern]] = [] + match_arg_set: Set[str] = set() + + captures: Dict[Expression, Type] = {} + + if len(o.positionals) != 0: + if self.should_self_match(typ): + if len(o.positionals) > 1: + self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) + pattern_type = self.accept(o.positionals[0], narrowed_type) + if not is_uninhabited(pattern_type.type): + return PatternType(pattern_type.type, + join_types(rest_type, pattern_type.rest_type), + pattern_type.captures) + captures = pattern_type.captures + else: + local_errors = self.msg.clean_copy() + match_args_type = analyze_member_access("__match_args__", typ, o, + False, False, False, + local_errors, + original_type=typ, + chk=self.chk) + + if local_errors.is_errors(): + self.msg.fail(message_registry.MISSING_MATCH_ARGS.format(typ), o) + return self.early_non_match() + + proper_match_args_type = get_proper_type(match_args_type) + if isinstance(proper_match_args_type, TupleType): + match_arg_names = get_match_arg_names(proper_match_args_type) + + if len(o.positionals) > len(match_arg_names): + self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) + return self.early_non_match() + else: + match_arg_names = [None] * len(o.positionals) + + for arg_name, pos in zip(match_arg_names, o.positionals): + keyword_pairs.append((arg_name, pos)) + if arg_name is not None: + match_arg_set.add(arg_name) + + # + # Check for duplicate patterns + # + keyword_arg_set = set() + has_duplicates = False + for key, value in zip(o.keyword_keys, o.keyword_values): + keyword_pairs.append((key, value)) + if key in match_arg_set: + self.msg.fail( + message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), + value + ) + has_duplicates = True + elif key in keyword_arg_set: + self.msg.fail(message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), + value) + has_duplicates = True + keyword_arg_set.add(key) + + if has_duplicates: + return self.early_non_match() + + # + # Check keyword patterns + # + can_match = True + for keyword, pattern in keyword_pairs: + key_type: Optional[Type] = None + local_errors = self.msg.clean_copy() + if keyword is not None: + key_type = analyze_member_access(keyword, + narrowed_type, + pattern, + False, + False, + False, + local_errors, + original_type=new_type, + chk=self.chk) + else: + key_type = AnyType(TypeOfAny.from_error) + if local_errors.is_errors() or key_type is None: + key_type = AnyType(TypeOfAny.from_error) + self.msg.fail(message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(typ, keyword), + value) + + inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type) + if is_uninhabited(inner_type): + can_match = False + else: + self.update_type_map(captures, inner_captures) + if not is_uninhabited(inner_rest_type): + rest_type = current_type + + if not can_match: + new_type = UninhabitedType() + return PatternType(new_type, rest_type, captures) + + def should_self_match(self, typ: Type) -> bool: + typ = get_proper_type(typ) + if isinstance(typ, Instance) and typ.type.is_named_tuple: + return False + for other in self.self_match_types: + if is_subtype(typ, other): + return True + return False + + def can_match_sequence(self, typ: ProperType) -> bool: + if isinstance(typ, UnionType): + return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items) + for other in self.non_sequence_match_types: + # We have to ignore promotions, as memoryview should match, but bytes, + # which it can be promoted to, shouldn't + if is_subtype(typ, other, ignore_promotions=True): + return False + sequence = self.chk.named_type("typing.Sequence") + # If the static type is more general than sequence the actual type could still match + return is_subtype(typ, sequence) or is_subtype(sequence, typ) + + def generate_types_from_names(self, type_names: List[str]) -> List[Type]: + types: List[Type] = [] + for name in type_names: + try: + types.append(self.chk.named_type(name)) + except KeyError as e: + # Some built in types are not defined in all test cases + if not name.startswith('builtins.'): + raise e + pass + + return types + + def update_type_map(self, + original_type_map: Dict[Expression, Type], + extra_type_map: Dict[Expression, Type] + ) -> None: + # Calculating this would not be needed if TypeMap directly used literal hashes instead of + # expressions, as suggested in the TODO above it's definition + already_captured = set(literal_hash(expr) for expr in original_type_map) + for expr, typ in extra_type_map.items(): + if literal_hash(expr) in already_captured: + node = get_var(expr) + self.msg.fail(message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), + expr) + else: + original_type_map[expr] = typ + + def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type: + """ + If outer_type is a child class of typing.Sequence returns a new instance of + outer_type, that is a Sequence of inner_type. If outer_type is not a child class of + typing.Sequence just returns a Sequence of inner_type + + For example: + construct_sequence_child(List[int], str) = List[str] + """ + sequence = self.chk.named_generic_type("typing.Sequence", [inner_type]) + if is_subtype(outer_type, self.chk.named_type("typing.Sequence")): + proper_type = get_proper_type(outer_type) + assert isinstance(proper_type, Instance) + empty_type = fill_typevars(proper_type.type) + partial_type = expand_type_by_instance(empty_type, sequence) + return expand_type_by_instance(partial_type, proper_type) + else: + return sequence + + def early_non_match(self) -> PatternType: + return PatternType(UninhabitedType(), self.type_context[-1], {}) + + +def get_match_arg_names(typ: TupleType) -> List[Optional[str]]: + args: List[Optional[str]] = [] + for item in typ.items: + values = try_getting_str_literals_from_type(item) + if values is None or len(values) != 1: + args.append(None) + else: + args.append(values[0]) + return args + + +def get_var(expr: Expression) -> Var: + """ + Warning: this in only true for expressions captured by a match statement. + Don't call it from anywhere else + """ + assert isinstance(expr, NameExpr) + node = expr.node + assert isinstance(node, Var) + return node + + +def get_type_range(typ: Type) -> 'mypy.checker.TypeRange': + return mypy.checker.TypeRange(typ, is_upper_bound=False) + + +def is_uninhabited(typ: Type) -> bool: + return isinstance(get_proper_type(typ), UninhabitedType) diff --git a/mypy/errorcodes.py b/mypy/errorcodes.py index 2a07bbb8597b..b6c317107467 100644 --- a/mypy/errorcodes.py +++ b/mypy/errorcodes.py @@ -94,6 +94,9 @@ def __str__(self) -> str: EXIT_RETURN: Final = ErrorCode( "exit-return", "Warn about too general return type for '__exit__'", "General" ) +LITERAL_REQ: Final = ErrorCode( + "literal-required", "Check that value is a literal", 'General' +) # These error codes aren't enabled by default. NO_UNTYPED_DEF: Final[ErrorCode] = ErrorCode( diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 34fe2c0da32d..b2ad07b111af 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -8,6 +8,7 @@ from typing import ( Tuple, Union, TypeVar, Callable, Sequence, Optional, Any, Dict, cast, List ) + from typing_extensions import Final, Literal, overload from mypy.sharedparse import ( @@ -19,18 +20,22 @@ ClassDef, Decorator, Block, Var, OperatorAssignmentStmt, ExpressionStmt, AssignmentStmt, ReturnStmt, RaiseStmt, AssertStmt, DelStmt, BreakStmt, ContinueStmt, PassStmt, GlobalDecl, - WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, + WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, MatchStmt, TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, UnaryExpr, LambdaExpr, ComparisonExpr, AssignmentExpr, StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, - AwaitExpr, TempNode, Expression, Statement, + AwaitExpr, TempNode, RefExpr, Expression, Statement, ArgKind, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2, check_arg_names, FakeInfo, ) +from mypy.patterns import ( + AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, + ClassPattern, SingletonPattern +) from mypy.types import ( Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, TypeOfAny, Instance, RawExpressionType, ProperType, UnionType, @@ -106,6 +111,27 @@ def ast3_parse(source: Union[str, bytes], filename: str, mode: str, # These don't exist before 3.8 NamedExpr = Any Constant = Any + + if sys.version_info >= (3, 10): + Match = ast3.Match + MatchValue = ast3.MatchValue + MatchSingleton = ast3.MatchSingleton + MatchSequence = ast3.MatchSequence + MatchStar = ast3.MatchStar + MatchMapping = ast3.MatchMapping + MatchClass = ast3.MatchClass + MatchAs = ast3.MatchAs + MatchOr = ast3.MatchOr + else: + Match = Any + MatchValue = Any + MatchSingleton = Any + MatchSequence = Any + MatchStar = Any + MatchMapping = Any + MatchClass = Any + MatchAs = Any + MatchOr = Any except ImportError: try: from typed_ast import ast35 # type: ignore[attr-defined] # noqa: F401 @@ -1286,11 +1312,74 @@ def visit_Index(self, n: Index) -> Node: # cast for mypyc's benefit on Python 3.9 return self.visit(cast(Any, n).value) - def visit_Match(self, n: Any) -> Node: - self.fail("Match statement is not supported", - line=n.lineno, column=n.col_offset, blocker=True) - # Just return some valid node - return PassStmt() + # Match(expr subject, match_case* cases) # python 3.10 and later + def visit_Match(self, n: Match) -> MatchStmt: + node = MatchStmt(self.visit(n.subject), + [self.visit(c.pattern) for c in n.cases], + [self.visit(c.guard) for c in n.cases], + [self.as_required_block(c.body, n.lineno) for c in n.cases]) + return self.set_line(node, n) + + def visit_MatchValue(self, n: MatchValue) -> ValuePattern: + node = ValuePattern(self.visit(n.value)) + return self.set_line(node, n) + + def visit_MatchSingleton(self, n: MatchSingleton) -> SingletonPattern: + node = SingletonPattern(n.value) + return self.set_line(node, n) + + def visit_MatchSequence(self, n: MatchSequence) -> SequencePattern: + patterns = [self.visit(p) for p in n.patterns] + stars = [p for p in patterns if isinstance(p, StarredPattern)] + assert len(stars) < 2 + + node = SequencePattern(patterns) + return self.set_line(node, n) + + def visit_MatchStar(self, n: MatchStar) -> StarredPattern: + if n.name is None: + node = StarredPattern(None) + else: + node = StarredPattern(NameExpr(n.name)) + + return self.set_line(node, n) + + def visit_MatchMapping(self, n: MatchMapping) -> MappingPattern: + keys = [self.visit(k) for k in n.keys] + values = [self.visit(v) for v in n.patterns] + + if n.rest is None: + rest = None + else: + rest = NameExpr(n.rest) + + node = MappingPattern(keys, values, rest) + return self.set_line(node, n) + + def visit_MatchClass(self, n: MatchClass) -> ClassPattern: + class_ref = self.visit(n.cls) + assert isinstance(class_ref, RefExpr) + positionals = [self.visit(p) for p in n.patterns] + keyword_keys = n.kwd_attrs + keyword_values = [self.visit(p) for p in n.kwd_patterns] + + node = ClassPattern(class_ref, positionals, keyword_keys, keyword_values) + return self.set_line(node, n) + + # MatchAs(expr pattern, identifier name) + def visit_MatchAs(self, n: MatchAs) -> AsPattern: + if n.name is None: + name = None + else: + name = NameExpr(n.name) + name = self.set_line(name, n) + node = AsPattern(self.visit(n.pattern), name) + return self.set_line(node, n) + + # MatchOr(expr* pattern) + def visit_MatchOr(self, n: MatchOr) -> OrPattern: + node = OrPattern([self.visit(pattern) for pattern in n.patterns]) + return self.set_line(node, n) class TypeConverter: diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 77dff1154833..1477cc4da575 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -64,6 +64,7 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": INCOMPATIBLE_TYPES_IN_YIELD: Final = ErrorMessage('Incompatible types in "yield"') INCOMPATIBLE_TYPES_IN_YIELD_FROM: Final = ErrorMessage('Incompatible types in "yield from"') INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION: Final = "Incompatible types in string interpolation" +INCOMPATIBLE_TYPES_IN_CAPTURE: Final = ErrorMessage('Incompatible types in capture pattern') MUST_HAVE_NONE_RETURN_TYPE: Final = ErrorMessage('The return type of "{}" must be None') INVALID_TUPLE_INDEX_TYPE: Final = ErrorMessage("Invalid tuple index type") TUPLE_INDEX_OUT_OF_RANGE: Final = ErrorMessage("Tuple index out of range") @@ -229,3 +230,18 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": CONTIGUOUS_ITERABLE_EXPECTED: Final = ErrorMessage("Contiguous iterable with same type expected") ITERABLE_TYPE_EXPECTED: Final = ErrorMessage("Invalid type '{}' for *expr (iterable expected)") TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type guard requires positional argument") + +# Match Statement +MISSING_MATCH_ARGS: Final = 'Class "{}" doesn\'t define "__match_args__"' +OR_PATTERN_ALTERNATIVE_NAMES: Final = "Alternative patterns bind different names" +CLASS_PATTERN_GENERIC_TYPE_ALIAS: Final = ( + "Class pattern class must not be a type alias with type parameters" +) +CLASS_PATTERN_TYPE_REQUIRED: Final = 'Expected type in class pattern; found "{}"' +CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS: Final = "Too many positional patterns for class pattern" +CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL: Final = ( + 'Keyword "{}" already matches a positional pattern' +) +CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN: Final = 'Duplicate keyword pattern "{}"' +CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"' +MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern' diff --git a/mypy/messages.py b/mypy/messages.py index da284cc88ba4..a7fec2cf4178 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1250,7 +1250,7 @@ def typeddict_key_must_be_string_literal( context: Context) -> None: self.fail( 'TypedDict key must be a string literal; expected one of {}'.format( - format_item_name_list(typ.items.keys())), context) + format_item_name_list(typ.items.keys())), context, code=codes.LITERAL_REQ) def typeddict_key_not_found( self, diff --git a/mypy/nodes.py b/mypy/nodes.py index 156d756030ae..b8c24c983a1b 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -17,6 +17,9 @@ from mypy.bogus_type import Bogus +if TYPE_CHECKING: + from mypy.patterns import Pattern + class Context: """Base type for objects that are valid as error message locations.""" @@ -1363,6 +1366,25 @@ def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_with_stmt(self) +class MatchStmt(Statement): + subject: Expression + patterns: List['Pattern'] + guards: List[Optional[Expression]] + bodies: List[Block] + + def __init__(self, subject: Expression, patterns: List['Pattern'], + guards: List[Optional[Expression]], bodies: List[Block]) -> None: + super().__init__() + assert len(patterns) == len(guards) == len(bodies) + self.subject = subject + self.patterns = patterns + self.guards = guards + self.bodies = bodies + + def accept(self, visitor: StatementVisitor[T]) -> T: + return visitor.visit_match_stmt(self) + + class PrintStmt(Statement): """Python 2 print statement""" diff --git a/mypy/patterns.py b/mypy/patterns.py new file mode 100644 index 000000000000..8557fac6daf6 --- /dev/null +++ b/mypy/patterns.py @@ -0,0 +1,132 @@ +"""Classes for representing match statement patterns.""" +from typing import TypeVar, List, Optional, Union + +from mypy_extensions import trait + +from mypy.nodes import Node, RefExpr, NameExpr, Expression +from mypy.visitor import PatternVisitor + + +T = TypeVar('T') + + +@trait +class Pattern(Node): + """A pattern node.""" + + __slots__ = () + + def accept(self, visitor: PatternVisitor[T]) -> T: + raise RuntimeError('Not implemented') + + +class AsPattern(Pattern): + # The python ast, and therefore also our ast merges capture, wildcard and as patterns into one + # for easier handling. + # If pattern is None this is a capture pattern. If name and pattern are both none this is a + # wildcard pattern. + # Only name being None should not happen but also won't break anything. + pattern: Optional[Pattern] + name: Optional[NameExpr] + + def __init__(self, pattern: Optional[Pattern], name: Optional[NameExpr]) -> None: + super().__init__() + self.pattern = pattern + self.name = name + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_as_pattern(self) + + +class OrPattern(Pattern): + patterns: List[Pattern] + + def __init__(self, patterns: List[Pattern]) -> None: + super().__init__() + self.patterns = patterns + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_or_pattern(self) + + +class ValuePattern(Pattern): + expr: Expression + + def __init__(self, expr: Expression): + super().__init__() + self.expr = expr + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_value_pattern(self) + + +class SingletonPattern(Pattern): + # This can be exactly True, False or None + value: Union[bool, None] + + def __init__(self, value: Union[bool, None]): + super().__init__() + self.value = value + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_singleton_pattern(self) + + +class SequencePattern(Pattern): + patterns: List[Pattern] + + def __init__(self, patterns: List[Pattern]): + super().__init__() + self.patterns = patterns + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_sequence_pattern(self) + + +class StarredPattern(Pattern): + # None corresponds to *_ in a list pattern. It will match multiple items but won't bind them to + # a name. + capture: Optional[NameExpr] + + def __init__(self, capture: Optional[NameExpr]): + super().__init__() + self.capture = capture + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_starred_pattern(self) + + +class MappingPattern(Pattern): + keys: List[Expression] + values: List[Pattern] + rest: Optional[NameExpr] + + def __init__(self, keys: List[Expression], values: List[Pattern], + rest: Optional[NameExpr]): + super().__init__() + assert len(keys) == len(values) + self.keys = keys + self.values = values + self.rest = rest + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_mapping_pattern(self) + + +class ClassPattern(Pattern): + class_ref: RefExpr + positionals: List[Pattern] + keyword_keys: List[str] + keyword_values: List[Pattern] + + def __init__(self, class_ref: RefExpr, positionals: List[Pattern], keyword_keys: List[str], + keyword_values: List[Pattern]): + super().__init__() + assert len(keyword_keys) == len(keyword_values) + self.class_ref = class_ref + self.positionals = positionals + self.keyword_keys = keyword_keys + self.keyword_values = keyword_values + + def accept(self, visitor: PatternVisitor[T]) -> T: + return visitor.visit_class_pattern(self) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 1beb53849327..95f4618da4a1 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -156,6 +156,33 @@ def add_method_to_class( info.defn.defs.body.append(func) +def add_attribute_to_class( + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + name: str, + typ: Type, + final: bool = False, +) -> None: + """ + Adds a new attribute to a class definition. + This currently only generates the symbol table entry and no corresponding AssignmentStatement + """ + info = cls.info + + # NOTE: we would like the plugin generated node to dominate, but we still + # need to keep any existing definitions so they get semantically analyzed. + if name in info.names: + # Get a nice unique name instead. + r_name = get_unique_redefinition_name(name, info.names) + info.names[r_name] = info.names[name] + + node = Var(name, typ) + node.info = info + node.is_final = final + node._fullname = info.fullname + '.' + name + info.names[name] = SymbolTableNode(MDEF, node, plugin_generated=True) + + def deserialize_and_fixup_type( data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface ) -> Type: diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 6d78bc17e615..d4c19b8a770b 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -10,12 +10,12 @@ ) from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( - add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, + add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, add_attribute_to_class, ) from mypy.typeops import map_type_from_supertype from mypy.types import ( - Type, Instance, NoneType, TypeVarType, CallableType, get_proper_type, - AnyType, TypeOfAny, + Type, Instance, NoneType, TypeVarType, CallableType, TupleType, LiteralType, + get_proper_type, AnyType, TypeOfAny, ) from mypy.server.trigger import make_wildcard_trigger @@ -131,6 +131,7 @@ def transform(self) -> None: 'order': _get_decorator_bool_argument(self._ctx, 'order', False), 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), 'slots': _get_decorator_bool_argument(self._ctx, 'slots', False), + 'match_args': _get_decorator_bool_argument(self._ctx, 'match_args', True), } py_version = self._ctx.api.options.python_version @@ -200,6 +201,16 @@ def transform(self) -> None: self.reset_init_only_vars(info, attributes) + if (decorator_arguments['match_args'] and + ('__match_args__' not in info.names or + info.names['__match_args__'].plugin_generated) and + attributes): + str_type = ctx.api.named_type("builtins.str") + literals: List[Type] = [LiteralType(attr.name, str_type) + for attr in attributes if attr.is_in_init] + match_args_type = TupleType(literals, ctx.api.named_type("builtins.tuple")) + add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type, final=True) + self._add_dataclass_fields_magic_attribute() info.metadata['dataclass'] = { diff --git a/mypy/reachability.py b/mypy/reachability.py index 44a21b993cfc..eec472376317 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -4,11 +4,12 @@ from typing_extensions import Final from mypy.nodes import ( - Expression, IfStmt, Block, AssertStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, ComparisonExpr, - StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, Import, ImportFrom, - ImportAll, LITERAL_YES + Expression, IfStmt, Block, AssertStmt, MatchStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, + ComparisonExpr, StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, + Import, ImportFrom, ImportAll, LITERAL_YES ) from mypy.options import Options +from mypy.patterns import Pattern, AsPattern, OrPattern from mypy.traverser import TraverserVisitor from mypy.literals import literal @@ -63,6 +64,30 @@ def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: break +def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> None: + for i, guard in enumerate(s.guards): + pattern_value = infer_pattern_value(s.patterns[i]) + + if guard is not None: + guard_value = infer_condition_value(guard, options) + else: + guard_value = ALWAYS_TRUE + + if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) \ + or guard_value in (ALWAYS_FALSE, MYPY_FALSE): + # The case is considered always false, so we skip the case body. + mark_block_unreachable(s.bodies[i]) + elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) \ + and guard_value in (ALWAYS_TRUE, MYPY_TRUE): + for body in s.bodies[i + 1:]: + mark_block_unreachable(body) + + if guard_value == MYPY_TRUE: + # This condition is false at runtime; this will affect + # import priorities. + mark_block_mypy_only(s.bodies[i]) + + def assert_will_always_fail(s: AssertStmt, options: Options) -> bool: return infer_condition_value(s.expr, options) in (ALWAYS_FALSE, MYPY_FALSE) @@ -118,6 +143,16 @@ def infer_condition_value(expr: Expression, options: Options) -> int: return result +def infer_pattern_value(pattern: Pattern) -> int: + if isinstance(pattern, AsPattern) and pattern.pattern is None: + return ALWAYS_TRUE + elif isinstance(pattern, OrPattern) and \ + any(infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns): + return ALWAYS_TRUE + else: + return TRUTH_VALUE_UNKNOWN + + def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> int: """Consider whether expr is a comparison involving sys.version_info. diff --git a/mypy/renaming.py b/mypy/renaming.py index a43abb13c688..c200e94d58e7 100644 --- a/mypy/renaming.py +++ b/mypy/renaming.py @@ -4,9 +4,10 @@ from mypy.nodes import ( Block, AssignmentStmt, NameExpr, MypyFile, FuncDef, Lvalue, ListExpr, TupleExpr, - WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, StarExpr, ImportFrom, - MemberExpr, IndexExpr, Import, ClassDef + WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, MatchStmt, StarExpr, + ImportFrom, MemberExpr, IndexExpr, Import, ClassDef ) +from mypy.patterns import AsPattern from mypy.traverser import TraverserVisitor # Scope kinds @@ -159,6 +160,21 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lvalue in s.lvalues: self.analyze_lvalue(lvalue) + def visit_match_stmt(self, s: MatchStmt) -> None: + for i in range(len(s.patterns)): + with self.enter_block(): + s.patterns[i].accept(self) + guard = s.guards[i] + if guard is not None: + guard.accept(self) + # We already entered a block, so visit this block's statements directly + for stmt in s.bodies[i].body: + stmt.accept(self) + + def visit_capture_pattern(self, p: AsPattern) -> None: + if p.name is not None: + self.analyze_lvalue(p.name) + def analyze_lvalue(self, lvalue: Lvalue, is_nested: bool = False) -> None: """Process assignment; in particular, keep track of (re)defined names. diff --git a/mypy/semanal.py b/mypy/semanal.py index 9a5076beb6f4..c16b5f0ac291 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -78,6 +78,11 @@ typing_extensions_aliases, EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, ParamSpecExpr, EllipsisExpr, TypeVarLikeExpr, FuncBase, implicit_module_attrs, + MatchStmt +) +from mypy.patterns import ( + AsPattern, OrPattern, ValuePattern, SequencePattern, + StarredPattern, MappingPattern, ClassPattern ) from mypy.tvar_scope import TypeVarLikeScope from mypy.typevars import fill_typevars @@ -119,8 +124,8 @@ from mypy.semanal_enum import EnumCallAnalyzer, ENUM_BASES from mypy.semanal_newtype import NewTypeAnalyzer from mypy.reachability import ( - infer_reachability_of_if_statement, infer_condition_value, ALWAYS_FALSE, ALWAYS_TRUE, - MYPY_TRUE, MYPY_FALSE + infer_reachability_of_if_statement, infer_reachability_of_match_statement, + infer_condition_value, ALWAYS_FALSE, ALWAYS_TRUE, MYPY_TRUE, MYPY_FALSE ) from mypy.mro import calculate_mro, MroError @@ -3735,6 +3740,17 @@ def visit_exec_stmt(self, s: ExecStmt) -> None: if s.locals: s.locals.accept(self) + def visit_match_stmt(self, s: MatchStmt) -> None: + self.statement = s + infer_reachability_of_match_statement(s, self.options) + s.subject.accept(self) + for i in range(len(s.patterns)): + s.patterns[i].accept(self) + guard = s.guards[i] + if guard is not None: + guard.accept(self) + self.visit_block(s.bodies[i]) + # # Expressions # @@ -4201,6 +4217,46 @@ def visit_await_expr(self, expr: AwaitExpr) -> None: self.fail('"await" outside coroutine ("async def")', expr) expr.expr.accept(self) + # + # Patterns + # + + def visit_as_pattern(self, p: AsPattern) -> None: + if p.pattern is not None: + p.pattern.accept(self) + if p.name is not None: + self.analyze_lvalue(p.name) + + def visit_or_pattern(self, p: OrPattern) -> None: + for pattern in p.patterns: + pattern.accept(self) + + def visit_value_pattern(self, p: ValuePattern) -> None: + p.expr.accept(self) + + def visit_sequence_pattern(self, p: SequencePattern) -> None: + for pattern in p.patterns: + pattern.accept(self) + + def visit_starred_pattern(self, p: StarredPattern) -> None: + if p.capture is not None: + self.analyze_lvalue(p.capture) + + def visit_mapping_pattern(self, p: MappingPattern) -> None: + for key in p.keys: + key.accept(self) + for value in p.values: + value.accept(self) + if p.rest is not None: + self.analyze_lvalue(p.rest) + + def visit_class_pattern(self, p: ClassPattern) -> None: + p.class_ref.accept(self) + for pos in p.positionals: + pos.accept(self) + for v in p.keyword_values: + v.accept(self) + # # Lookup functions # diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index 8930c63d2bef..2357225caebb 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -9,7 +9,7 @@ from mypy.types import ( Type, TupleType, AnyType, TypeOfAny, CallableType, TypeType, TypeVarType, - UnboundType, + UnboundType, LiteralType, ) from mypy.semanal_shared import ( SemanticAnalyzerInterface, set_callable_name, calculate_tuple_fallback, PRIORITY_FALLBACKS @@ -398,6 +398,9 @@ def build_namedtuple_typeinfo(self, iterable_type = self.api.named_type_or_none('typing.Iterable', [implicit_any]) function_type = self.api.named_type('builtins.function') + literals: List[Type] = [LiteralType(item, strtype) for item in items] + match_args_type = TupleType(literals, basetuple_type) + info = self.api.basic_new_typeinfo(name, fallback, line) info.is_named_tuple = True tuple_base = TupleType(types, fallback) @@ -436,6 +439,7 @@ def add_field(var: Var, is_initialized_in_class: bool = False, add_field(Var('_source', strtype), is_initialized_in_class=True) add_field(Var('__annotations__', ordereddictype), is_initialized_in_class=True) add_field(Var('__doc__', strtype), is_initialized_in_class=True) + add_field(Var('__match_args__', match_args_type), is_initialized_in_class=True) tvd = TypeVarType(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, -1, [], info.tuple_type) diff --git a/mypy/semanal_pass1.py b/mypy/semanal_pass1.py index 0296788e3990..2b096f08082a 100644 --- a/mypy/semanal_pass1.py +++ b/mypy/semanal_pass1.py @@ -2,11 +2,14 @@ from mypy.nodes import ( MypyFile, AssertStmt, IfStmt, Block, AssignmentStmt, ExpressionStmt, ReturnStmt, ForStmt, - Import, ImportAll, ImportFrom, ClassDef, FuncDef + MatchStmt, Import, ImportAll, ImportFrom, ClassDef, FuncDef ) from mypy.traverser import TraverserVisitor from mypy.options import Options -from mypy.reachability import infer_reachability_of_if_statement, assert_will_always_fail +from mypy.reachability import ( + infer_reachability_of_if_statement, assert_will_always_fail, + infer_reachability_of_match_statement +) class SemanticAnalyzerPreAnalysis(TraverserVisitor): @@ -102,6 +105,14 @@ def visit_block(self, b: Block) -> None: return super().visit_block(b) + def visit_match_stmt(self, s: MatchStmt) -> None: + infer_reachability_of_match_statement(s, self.options) + for guard in s.guards: + if guard is not None: + guard.accept(self) + for body in s.bodies: + body.accept(self) + # The remaining methods are an optimization: don't visit nested expressions # of common statements, since they can have no effect. diff --git a/mypy/strconv.py b/mypy/strconv.py index c63063af0776..22534a44971d 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -4,11 +4,15 @@ import os from typing import Any, List, Tuple, Optional, Union, Sequence +from typing_extensions import TYPE_CHECKING from mypy.util import short_type, IdMapper import mypy.nodes from mypy.visitor import NodeVisitor +if TYPE_CHECKING: + import mypy.patterns + class StrConv(NodeVisitor[str]): """Visitor for converting a node to a human-readable string. @@ -311,6 +315,15 @@ def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> str: def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> str: return self.dump([o.expr, o.globals, o.locals], o) + def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> str: + a: List[Any] = [o.subject] + for i in range(len(o.patterns)): + a.append(('Pattern', [o.patterns[i]])) + if o.guards[i] is not None: + a.append(('Guard', [o.guards[i]])) + a.append(('Body', o.bodies[i].body)) + return self.dump(a, o) + # Expressions # Simple expressions @@ -537,6 +550,42 @@ def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> str: def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> str: return self.dump([o.type], o) + def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> str: + return self.dump([o.pattern, o.name], o) + + def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> str: + return self.dump(o.patterns, o) + + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> str: + return self.dump([o.expr], o) + + def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> str: + return self.dump([o.value], o) + + def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> str: + return self.dump(o.patterns, o) + + def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> str: + return self.dump([o.capture], o) + + def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> str: + a: List[Any] = [] + for i in range(len(o.keys)): + a.append(('Key', [o.keys[i]])) + a.append(('Value', [o.values[i]])) + if o.rest is not None: + a.append(('Rest', [o.rest])) + return self.dump(a, o) + + def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> str: + a: List[Any] = [o.class_ref] + if len(o.positionals) > 0: + a.append(('Positionals', o.positionals)) + for i in range(len(o.keyword_keys)): + a.append(('Keyword', [o.keyword_keys[i], o.keyword_values[i]])) + + return self.dump(a, o) + def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv') -> str: """Convert an array into a pretty-printed multiline string representation. diff --git a/mypy/subtypes.py b/mypy/subtypes.py index f90009445b09..e9fc29d4061b 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1446,7 +1446,7 @@ def visit_type_type(self, left: TypeType) -> bool: if right.type.fullname == 'builtins.type': # TODO: Strictly speaking, the type builtins.type is considered equivalent to # Type[Any]. However, this would break the is_proper_subtype check in - # conditional_type_map for cases like isinstance(x, type) when the type + # conditional_types for cases like isinstance(x, type) when the type # of x is Type[int]. It's unclear what's the right way to address this. return True if right.type.fullname == 'builtins.object': diff --git a/mypy/test/helpers.py b/mypy/test/helpers.py index fbd44bca868b..f9f117634c21 100644 --- a/mypy/test/helpers.py +++ b/mypy/test/helpers.py @@ -287,6 +287,8 @@ def num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int: def testfile_pyversion(path: str) -> Tuple[int, int]: if path.endswith('python2.test'): return defaults.PYTHON2_VERSION + elif path.endswith('python310.test'): + return 3, 10 else: return defaults.PYTHON3_VERSION diff --git a/mypy/test/testparse.py b/mypy/test/testparse.py index e9ff6839bc2c..1587147c0777 100644 --- a/mypy/test/testparse.py +++ b/mypy/test/testparse.py @@ -18,6 +18,9 @@ class ParserSuite(DataSuite): files = ['parse.test', 'parse-python2.test'] + if sys.version_info >= (3, 10): + files.append('parse-python310.test') + def run_case(self, testcase: DataDrivenTestCase) -> None: test_parser(testcase) @@ -31,6 +34,8 @@ def test_parser(testcase: DataDrivenTestCase) -> None: if testcase.file.endswith('python2.test'): options.python_version = defaults.PYTHON2_VERSION + elif testcase.file.endswith('python310.test'): + options.python_version = (3, 10) else: options.python_version = defaults.PYTHON3_VERSION diff --git a/mypy/test/testsemanal.py b/mypy/test/testsemanal.py index a71bac53619d..441f9ab32dbb 100644 --- a/mypy/test/testsemanal.py +++ b/mypy/test/testsemanal.py @@ -1,6 +1,7 @@ """Semantic analyzer test cases""" import os.path +import sys from typing import Dict, List @@ -38,6 +39,10 @@ ] +if sys.version_info >= (3, 10): + semanal_files.append('semanal-python310.test') + + def get_semanal_options(program_text: str, testcase: DataDrivenTestCase) -> Options: options = parse_options(program_text, testcase, 1) options.use_builtins_fixtures = True @@ -104,6 +109,8 @@ def test_semanal(testcase: DataDrivenTestCase) -> None: class SemAnalErrorSuite(DataSuite): files = ['semanal-errors.test'] + if sys.version_info >= (3, 10): + semanal_files.append('semanal-errors-python310.test') def run_case(self, testcase: DataDrivenTestCase) -> None: test_semanal_error(testcase) diff --git a/mypy/traverser.py b/mypy/traverser.py index a5f993bd2fa5..6ab97116a12e 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -3,13 +3,17 @@ from typing import List, Tuple from mypy_extensions import mypyc_attr +from mypy.patterns import ( + AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, + ClassPattern +) from mypy.visitor import NodeVisitor from mypy.nodes import ( Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt, ForStmt, ReturnStmt, AssertStmt, DelStmt, IfStmt, RaiseStmt, - TryStmt, WithStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, RevealExpr, - UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, AssignmentExpr, + TryStmt, WithStmt, MatchStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, + RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, AssignmentExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension, ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom, LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr, @@ -156,6 +160,15 @@ def visit_with_stmt(self, o: WithStmt) -> None: targ.accept(self) o.body.accept(self) + def visit_match_stmt(self, o: MatchStmt) -> None: + o.subject.accept(self) + for i in range(len(o.patterns)): + o.patterns[i].accept(self) + guard = o.guards[i] + if guard is not None: + guard.accept(self) + o.bodies[i].accept(self) + def visit_member_expr(self, o: MemberExpr) -> None: o.expr.accept(self) @@ -279,6 +292,42 @@ def visit_await_expr(self, o: AwaitExpr) -> None: def visit_super_expr(self, o: SuperExpr) -> None: o.call.accept(self) + def visit_as_pattern(self, o: AsPattern) -> None: + if o.pattern is not None: + o.pattern.accept(self) + if o.name is not None: + o.name.accept(self) + + def visit_or_pattern(self, o: OrPattern) -> None: + for p in o.patterns: + p.accept(self) + + def visit_value_pattern(self, o: ValuePattern) -> None: + o.expr.accept(self) + + def visit_sequence_pattern(self, o: SequencePattern) -> None: + for p in o.patterns: + p.accept(self) + + def visit_starred_patten(self, o: StarredPattern) -> None: + if o.capture is not None: + o.capture.accept(self) + + def visit_mapping_pattern(self, o: MappingPattern) -> None: + for key in o.keys: + key.accept(self) + for value in o.values: + value.accept(self) + if o.rest is not None: + o.rest.accept(self) + + def visit_class_pattern(self, o: ClassPattern) -> None: + o.class_ref.accept(self) + for p in o.positionals: + p.accept(self) + for v in o.keyword_values: + v.accept(self) + def visit_import(self, o: Import) -> None: for a in o.assignments: a.accept(self) diff --git a/mypy/typeops.py b/mypy/typeops.py index 9ba170b4b822..59035274f0d9 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -795,7 +795,7 @@ def coerce_to_literal(typ: Type) -> Type: typ = get_proper_type(typ) if isinstance(typ, UnionType): new_items = [coerce_to_literal(item) for item in typ.items] - return make_simplified_union(new_items) + return UnionType.make_union(new_items) elif isinstance(typ, Instance): if typ.last_known_value: return typ.last_known_value diff --git a/mypy/visitor.py b/mypy/visitor.py index b98ec773bbe3..9d3ebb6818b4 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: # break import cycle only needed for mypy import mypy.nodes + import mypy.patterns T = TypeVar('T') @@ -310,10 +311,50 @@ def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T: pass + @abstractmethod + def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: + pass + @trait @mypyc_attr(allow_interpreted_subclasses=True) -class NodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T]): +class PatternVisitor(Generic[T]): + @abstractmethod + def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> T: + pass + + @abstractmethod + def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: + pass + + @abstractmethod + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + pass + + @abstractmethod + def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> T: + pass + + @abstractmethod + def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> T: + pass + + @abstractmethod + def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> T: + pass + + @abstractmethod + def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> T: + pass + + @abstractmethod + def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> T: + pass + + +@trait +@mypyc_attr(allow_interpreted_subclasses=True) +class NodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T], PatternVisitor[T]): """Empty base class for parse tree node visitors. The T type argument specifies the return type of the visit @@ -429,6 +470,9 @@ def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T: pass + def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: + pass + # Expressions (default no-op implementation) def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T: @@ -562,3 +606,29 @@ def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass + + # Patterns + + def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> T: + pass + + def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: + pass + + def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + pass + + def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> T: + pass + + def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> T: + pass + + def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> T: + pass + + def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> T: + pass + + def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> T: + pass diff --git a/mypyc/irbuild/visitor.py b/mypyc/irbuild/visitor.py index 1a6a84809707..43cfd457667d 100644 --- a/mypyc/irbuild/visitor.py +++ b/mypyc/irbuild/visitor.py @@ -16,7 +16,8 @@ FloatExpr, GeneratorExpr, GlobalDecl, LambdaExpr, ListComprehension, SetComprehension, NamedTupleExpr, NewTypeExpr, NonlocalDecl, OverloadedFuncDef, PrintStmt, RaiseStmt, RevealExpr, SetExpr, SliceExpr, StarExpr, SuperExpr, TryStmt, TypeAliasExpr, TypeApplication, - TypeVarExpr, TypedDictExpr, UnicodeExpr, WithStmt, YieldFromExpr, YieldExpr, ParamSpecExpr + TypeVarExpr, TypedDictExpr, UnicodeExpr, WithStmt, YieldFromExpr, YieldExpr, ParamSpecExpr, + MatchStmt ) from mypyc.ir.ops import Value @@ -179,6 +180,9 @@ def visit_nonlocal_decl(self, stmt: NonlocalDecl) -> None: # Pure declaration -- no runtime effect pass + def visit_match_stmt(self, stmt: MatchStmt) -> None: + self.bail("Match statements are not yet supported", stmt.line) + # Expressions def visit_name_expr(self, expr: NameExpr) -> Value: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index d6b25ef456d9..16cdc69ec1b7 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1227,7 +1227,7 @@ def unreachable(x: Union[str, List[str]]) -> None: elif isinstance(x, list): reveal_type(x) # N: Revealed type is "builtins.list[builtins.str]" else: - reveal_type(x) # N: Revealed type is "" + reveal_type(x) # No output: this branch is unreachable def all_parts_covered(x: Union[str, List[str], List[int], int]) -> None: if isinstance(x, str): diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 3bcac61855b4..602faba7fbca 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1,7 +1,1177 @@ -[case testMatchStatementNotSupported] -# flags: --python-version 3.10 -match str(): # E: Match statement is not supported - case 'x': - 1 + '' +-- Capture Pattern -- +[case testCapturePatternType] +class A: ... +m: A + +match m: + case a: + reveal_type(a) # N: Revealed type is "__main__.A" + + +-- Literal Pattern -- +[case testLiteralPatternNarrows] +m: object + +match m: + case 1: + reveal_type(m) # N: Revealed type is "Literal[1]?" + +[case testLiteralPatternAlreadyNarrower] +m: bool + +match m: + case 1: + reveal_type(m) # N: Revealed type is "builtins.bool" + +[case testLiteralPatternUnreachable] +# primitives are needed because otherwise mypy doesn't see that int and str are incompatible +m: int + +match m: + case "str": + reveal_type(m) +[builtins fixtures/primitives.pyi] + + +-- Value Pattern -- +[case testValuePatternNarrows] +import b +m: object + +match m: + case b.b: + reveal_type(m) # N: Revealed type is "builtins.int" +[file b.py] +b: int + +[case testValuePatternAlreadyNarrower] +import b +m: bool + +match m: + case b.b: + reveal_type(m) # N: Revealed type is "builtins.bool" +[file b.py] +b: int + +[case testValuePatternIntersect] +import b + +class A: ... +m: A + +match m: + case b.b: + reveal_type(m) # N: Revealed type is "__main__." +[file b.py] +class B: ... +b: B + +[case testValuePatternUnreachable] +# primitives are needed because otherwise mypy doesn't see that int and str are incompatible +import b + +m: int + +match m: + case b.b: + reveal_type(m) +[file b.py] +b: str +[builtins fixtures/primitives.pyi] + + +-- Sequence Pattern -- +[case testSequencePatternCaptures] +from typing import List +m: List[int] + +match m: + case [a]: + reveal_type(a) # N: Revealed type is "builtins.int*" +[builtins fixtures/list.pyi] + +[case testSequencePatternCapturesStarred] +from typing import Sequence +m: Sequence[int] + +match m: + case [a, *b]: + reveal_type(a) # N: Revealed type is "builtins.int*" + reveal_type(b) # N: Revealed type is "builtins.list[builtins.int*]" +[builtins fixtures/list.pyi] + +[case testSequencePatternNarrowsInner] +from typing import Sequence +m: Sequence[object] + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" + +[case testSequencePatternNarrowsOuter] +from typing import Sequence +m: object + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" + +[case testSequencePatternAlreadyNarrowerInner] +from typing import Sequence +m: Sequence[bool] + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.bool]" + +[case testSequencePatternAlreadyNarrowerOuter] +from typing import Sequence +m: Sequence[object] + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.int]" + +[case testSequencePatternAlreadyNarrowerBoth] +from typing import Sequence +m: Sequence[bool] + +match m: + case [1, True]: + reveal_type(m) # N: Revealed type is "typing.Sequence[builtins.bool]" + +[case testNestedSequencePatternNarrowsInner] +from typing import Sequence +m: Sequence[Sequence[object]] + +match m: + case [[1], [True]]: + reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" + +[case testNestedSequencePatternNarrowsOuter] +from typing import Sequence +m: object + +match m: + case [[1], [True]]: + reveal_type(m) # N: Revealed type is "typing.Sequence[typing.Sequence[builtins.int]]" + +[case testSequencePatternDoesntNarrowInvariant] +from typing import List +m: List[object] + +match m: + case [1]: + reveal_type(m) # N: Revealed type is "builtins.list[builtins.object]" +[builtins fixtures/list.pyi] + +[case testSequencePatternMatches] +import array, collections +from typing import Sequence, Iterable + +m1: object +m2: Sequence[int] +m3: array.array[int] +m4: collections.deque[int] +m5: list[int] +m6: memoryview +m7: range +m8: tuple[int] + +m9: str +m10: bytes +m11: bytearray + +match m1: + case [a]: + reveal_type(a) # N: Revealed type is "builtins.object" + +match m2: + case [b]: + reveal_type(b) # N: Revealed type is "builtins.int*" + +match m3: + case [c]: + reveal_type(c) # N: Revealed type is "builtins.int*" + +match m4: + case [d]: + reveal_type(d) # N: Revealed type is "builtins.int*" + +match m5: + case [e]: + reveal_type(e) # N: Revealed type is "builtins.int*" + +match m6: + case [f]: + reveal_type(f) # N: Revealed type is "builtins.int*" + +match m7: + case [g]: + reveal_type(g) # N: Revealed type is "builtins.int*" + +match m8: + case [h]: + reveal_type(h) # N: Revealed type is "builtins.int" + +match m9: + case [i]: + reveal_type(i) + +match m10: + case [j]: + reveal_type(j) + +match m11: + case [k]: + reveal_type(k) +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + + +[case testSequencePatternCapturesTuple] +from typing import Tuple +m: Tuple[int, str, bool] + +match m: + case [a, b, c]: + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(c) # N: Revealed type is "builtins.bool" + reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.bool]" +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleTooLong] +from typing import Tuple +m: Tuple[int, str] + +match m: + case [a, b, c]: + reveal_type(a) + reveal_type(b) + reveal_type(c) +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleTooShort] +from typing import Tuple +m: Tuple[int, str, bool] + +match m: + case [a, b]: + reveal_type(a) + reveal_type(b) +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleNarrows] +from typing import Tuple +m: Tuple[object, object] + +match m: + case [1, "str"]: + reveal_type(m) # N: Revealed type is "Tuple[Literal[1]?, Literal['str']?]" +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleStarred] +from typing import Tuple +m: Tuple[int, str, bool] + +match m: + case [a, *b, c]: + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.list[builtins.str]" + reveal_type(c) # N: Revealed type is "builtins.bool" + reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.bool]" +[builtins fixtures/list.pyi] + +[case testSequencePatternTupleStarredUnion] +from typing import Tuple +m: Tuple[int, str, float, bool] + +match m: + case [a, *b, c]: + reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(b) # N: Revealed type is "builtins.list[Union[builtins.str, builtins.float]]" + reveal_type(c) # N: Revealed type is "builtins.bool" + reveal_type(m) # N: Revealed type is "Tuple[builtins.int, builtins.str, builtins.float, builtins.bool]" +[builtins fixtures/list.pyi] + + +[case testSequencePatternTupleStarredTooShort] +from typing import Tuple +m: Tuple[int] +reveal_type(m) # N: Revealed type is "Tuple[builtins.int]" + +match m: + case [a, *b, c]: + reveal_type(a) + reveal_type(b) + reveal_type(c) +[builtins fixtures/list.pyi] + +[case testNonMatchingSequencePattern] +from typing import List + +x: List[int] +match x: + case [str()]: + pass + +[case testSequenceUnion-skip] +from typing import List, Union +m: Union[List[List[str]], str] + +match m: + case [list(['str'])]: + reveal_type(m) # N: Revealed type is "builtins.list[builtins.list[builtins.str]]" +[builtins fixtures/list.pyi] + +-- Mapping Pattern -- +[case testMappingPatternCaptures] +from typing import Dict +import b +m: Dict[str, int] + +match m: + case {"key": v}: + reveal_type(v) # N: Revealed type is "builtins.int*" + case {b.b: v2}: + reveal_type(v2) # N: Revealed type is "builtins.int*" +[file b.py] +b: str +[builtins fixtures/dict.pyi] + +[case testMappingPatternCapturesWrongKeyType] +# This is not actually unreachable, as a subclass of dict could accept keys with different types +from typing import Dict +import b +m: Dict[str, int] + +match m: + case {1: v}: + reveal_type(v) # N: Revealed type is "builtins.int*" + case {b.b: v2}: + reveal_type(v2) # N: Revealed type is "builtins.int*" +[file b.py] +b: int +[builtins fixtures/dict.pyi] + +[case testMappingPatternCapturesTypedDict] +from typing import TypedDict + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {"a": v}: + reveal_type(v) # N: Revealed type is "builtins.str" + case {"b": v2}: + reveal_type(v2) # N: Revealed type is "builtins.int" + case {"a": v3, "b": v4}: + reveal_type(v3) # N: Revealed type is "builtins.str" + reveal_type(v4) # N: Revealed type is "builtins.int" + case {"o": v5}: + reveal_type(v5) # N: Revealed type is "builtins.object*" +[typing fixtures/typing-typeddict.pyi] + +[case testMappingPatternCapturesTypedDictWithLiteral] +from typing import TypedDict +import b + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {b.a: v}: + reveal_type(v) # N: Revealed type is "builtins.str" + case {b.b: v2}: + reveal_type(v2) # N: Revealed type is "builtins.int" + case {b.a: v3, b.b: v4}: + reveal_type(v3) # N: Revealed type is "builtins.str" + reveal_type(v4) # N: Revealed type is "builtins.int" + case {b.o: v5}: + reveal_type(v5) # N: Revealed type is "builtins.object*" +[file b.py] +from typing import Final, Literal +a: Final = "a" +b: Literal["b"] = "b" +o: Final[str] = "o" +[typing fixtures/typing-typeddict.pyi] + +[case testMappingPatternCapturesTypedDictWithNonLiteral] +from typing import TypedDict +import b + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {b.a: v}: + reveal_type(v) # N: Revealed type is "builtins.object*" +[file b.py] +from typing import Final, Literal +a: str +[typing fixtures/typing-typeddict.pyi] + +[case testMappingPatternCapturesTypedDictUnreachable] +# TypedDict keys are always str, so this is actually unreachable +from typing import TypedDict +import b + +class A(TypedDict): + a: str + b: int + +m: A + +match m: + case {1: v}: + reveal_type(v) + case {b.b: v2}: + reveal_type(v2) +[file b.py] +b: int +[typing fixtures/typing-typeddict.pyi] + +[case testMappingPatternCaptureRest] +m: object + +match m: + case {'k': 1, **r}: + reveal_type(r) # N: Revealed type is "builtins.dict[builtins.object, builtins.object]" +[builtins fixtures/dict.pyi] + +[case testMappingPatternCaptureRestFromMapping] +from typing import Mapping + +m: Mapping[str, int] + +match m: + case {'k': 1, **r}: + reveal_type(r) # N: Revealed type is "builtins.dict[builtins.str*, builtins.int*]" +[builtins fixtures/dict.pyi] + +-- Mapping patterns currently don't narrow -- + +-- Class Pattern -- +[case testClassPatternCapturePositional] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + +[case testClassPatternMemberClassCapturePositional] +import b + +m: b.A + +match m: + case b.A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" +[file b.py] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int +[builtins fixtures/tuple.pyi] + +[case testClassPatternCaptureKeyword] +class A: + a: str + b: int + +m: A + +match m: + case A(a=i, b=j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" + +[case testClassPatternCaptureSelf] +m: object + +match m: + case bool(a): + reveal_type(a) # N: Revealed type is "builtins.bool" + case bytearray(b): + reveal_type(b) # N: Revealed type is "builtins.bytearray" + case bytes(c): + reveal_type(c) # N: Revealed type is "builtins.bytes" + case dict(d): + reveal_type(d) # N: Revealed type is "builtins.dict[Any, Any]" + case float(e): + reveal_type(e) # N: Revealed type is "builtins.float" + case frozenset(f): + reveal_type(f) # N: Revealed type is "builtins.frozenset[Any]" + case int(g): + reveal_type(g) # N: Revealed type is "builtins.int" + case list(h): + reveal_type(h) # N: Revealed type is "builtins.list[Any]" + case set(i): + reveal_type(i) # N: Revealed type is "builtins.set[Any]" + case str(j): + reveal_type(j) # N: Revealed type is "builtins.str" + case tuple(k): + reveal_type(k) # N: Revealed type is "builtins.tuple[Any]" +[builtins fixtures/primitives.pyi] + +[case testClassPatternNarrowSelfCapture] +m: object + +match m: + case bool(): + reveal_type(m) # N: Revealed type is "builtins.bool" + case bytearray(): + reveal_type(m) # N: Revealed type is "builtins.bytearray" + case bytes(): + reveal_type(m) # N: Revealed type is "builtins.bytes" + case dict(): + reveal_type(m) # N: Revealed type is "builtins.dict[Any, Any]" + case float(): + reveal_type(m) # N: Revealed type is "builtins.float" + case frozenset(): + reveal_type(m) # N: Revealed type is "builtins.frozenset[Any]" + case int(): + reveal_type(m) # N: Revealed type is "builtins.int" + case list(): + reveal_type(m) # N: Revealed type is "builtins.list[Any]" + case set(): + reveal_type(m) # N: Revealed type is "builtins.set[Any]" + case str(): + reveal_type(m) # N: Revealed type is "builtins.str" + case tuple(): + reveal_type(m) # N: Revealed type is "builtins.tuple[Any]" +[builtins fixtures/primitives.pyi] + +[case testClassPatternCaptureDataclass] +from dataclasses import dataclass + +@dataclass +class A: + a: str + b: int + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" +[builtins fixtures/dataclasses.pyi] + +[case testClassPatternCaptureDataclassNoMatchArgs] +from dataclasses import dataclass + +@dataclass(match_args=False) +class A: + a: str + b: int + +m: A + +match m: + case A(i, j): # E: Class "__main__.A" doesn't define "__match_args__" + pass +[builtins fixtures/dataclasses.pyi] + +[case testClassPatternCaptureDataclassPartialMatchArgs] +from dataclasses import dataclass, field + +@dataclass +class A: + a: str + b: int = field(init=False) + +m: A + +match m: + case A(i, j): # E: Too many positional patterns for class pattern + pass + case A(k): + reveal_type(k) # N: Revealed type is "builtins.str" +[builtins fixtures/dataclasses.pyi] + +[case testClassPatternCaptureNamedTupleInline] +from collections import namedtuple + +A = namedtuple("A", ["a", "b"]) + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "Any" + reveal_type(j) # N: Revealed type is "Any" +[builtins fixtures/list.pyi] + +[case testClassPatternCaptureNamedTupleInlineTyped] +from typing import NamedTuple + +A = NamedTuple("A", [("a", str), ("b", int)]) + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" +[builtins fixtures/list.pyi] + +[case testClassPatternCaptureNamedTupleClass] +from typing import NamedTuple + +class A(NamedTuple): + a: str + b: int + +m: A + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + +[case testClassPatternCaptureGeneric] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +m: object + +match m: + case A(a=i): + reveal_type(m) # N: Revealed type is "__main__.A[Any]" + reveal_type(i) # N: Revealed type is "Any" + +[case testClassPatternCaptureGenericAlreadyKnown] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +m: A[int] + +match m: + case A(a=i): + reveal_type(m) # N: Revealed type is "__main__.A[builtins.int]" + reveal_type(i) # N: Revealed type is "builtins.int*" + +[case testClassPatternCaptureFilledGenericTypeAlias] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +B = A[int] + +m: object + +match m: + case B(a=i): # E: Class pattern class must not be a type alias with type parameters + reveal_type(i) + +[case testClassPatternCaptureGenericTypeAlias] +from typing import Generic, TypeVar + +T = TypeVar('T') + +class A(Generic[T]): + a: T + +B = A + +m: object + +match m: + case B(a=i): + pass + +[case testClassPatternNarrows] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int + +m: object + +match m: + case A(): + reveal_type(m) # N: Revealed type is "__main__.A" + case A(i, j): + reveal_type(m) # N: Revealed type is "__main__.A" +[builtins fixtures/tuple.pyi] + +[case testClassPatternNarrowsUnion] +from typing import Final, Union + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int + +class B: + __match_args__: Final = ("a", "b") + a: int + b: str + +m: Union[A, B] + +match m: + case A(): + reveal_type(m) # N: Revealed type is "__main__.A" + +match m: + case A(i, j): + reveal_type(m) # N: Revealed type is "__main__.A" + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "builtins.int" + +match m: + case B(): + reveal_type(m) # N: Revealed type is "__main__.B" + +match m: + case B(k, l): + reveal_type(m) # N: Revealed type is "__main__.B" + reveal_type(k) # N: Revealed type is "builtins.int" + reveal_type(l) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] + +[case testClassPatternAlreadyNarrower] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int +class B(A): ... + +m: B + +match m: + case A(): + reveal_type(m) # N: Revealed type is "__main__.B" + +match m: + case A(i, j): + reveal_type(m) # N: Revealed type is "__main__.B" +[builtins fixtures/tuple.pyi] + +[case testClassPatternIntersection] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int +class B: ... + +m: B + +match m: + case A(): + reveal_type(m) # N: Revealed type is "__main__." + case A(i, j): + reveal_type(m) # N: Revealed type is "__main__.1" +[builtins fixtures/tuple.pyi] + +[case testClassPatternNonexistentKeyword] +class A: ... + +m: object + +match m: + case A(a=j): # E: Class "__main__.A" has no attribute "a" + reveal_type(m) # N: Revealed type is "__main__.A" + reveal_type(j) # N: Revealed type is "Any" + +[case testClassPatternDuplicateKeyword] +class A: + a: str + +m: object + +match m: + case A(a=i, a=j): # E: Duplicate keyword pattern "a" + pass + +[case testClassPatternDuplicateImplicitKeyword] +from typing import Final + +class A: + __match_args__: Final = ("a",) + a: str + +m: object + +match m: + case A(i, a=j): # E: Keyword "a" already matches a positional pattern + pass +[builtins fixtures/tuple.pyi] + +[case testClassPatternTooManyPositionals] +from typing import Final + +class A: + __match_args__: Final = ("a", "b") + a: str + b: int + +m: object + +match m: + case A(i, j, k): # E: Too many positional patterns for class pattern + pass +[builtins fixtures/tuple.pyi] + +[case testClassPatternIsNotType] +a = 1 +m: object + +match m: + case a(i, j): # E: Expected type in class pattern; found "builtins.int" + reveal_type(i) + reveal_type(j) + +[case testClassPatternNestedGenerics] +# From cpython test_patma.py +x = [[{0: 0}]] +match x: + case list([({-0-0j: int(real=0+0j, imag=0-0j) | (1) as z},)]): + y = 0 + +reveal_type(x) # N: Revealed type is "builtins.list[builtins.list*[builtins.dict*[builtins.int*, builtins.int*]]]" +reveal_type(y) # N: Revealed type is "builtins.int" +reveal_type(z) # N: Revealed type is "builtins.int*" +[builtins fixtures/dict.pyi] + +[case testNonFinalMatchArgs] +class A: + __match_args__ = ("a", "b") # N: __match_args__ must be final for checking of match statements to work + a: str + b: int + +m: object + +match m: + case A(i, j): + reveal_type(i) # N: Revealed type is "Any" + reveal_type(j) # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] + +[case testAnyTupleMatchArgs] +from typing import Tuple, Any + +class A: + __match_args__: Tuple[Any, ...] + a: str + b: int + +m: object + +match m: + case A(i, j, k): + reveal_type(i) # N: Revealed type is "Any" + reveal_type(j) # N: Revealed type is "Any" + reveal_type(k) # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] + +[case testNonLiteralMatchArgs] +from typing import Final + +b: str = "b" +class A: + __match_args__: Final = ("a", b) # N: __match_args__ must be a tuple containing string literals for checking of match statements to work + a: str + b: int + +m: object + +match m: + case A(i, j, k): # E: Too many positional patterns for class pattern + pass + case A(i, j): + reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(j) # N: Revealed type is "Any" +[builtins fixtures/tuple.pyi] + +[case testExternalMatchArgs] +from typing import Final, Literal + +args: Final = ("a", "b") +class A: + __match_args__: Final = args + a: str + b: int + +arg: Final = "a" +arg2: Literal["b"] = "b" +class B: + __match_args__: Final = (arg, arg2) + a: str + b: int + +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + + +-- As Pattern -- +[case testAsPattern] +m: int + +match m: + case x as l: + reveal_type(x) # N: Revealed type is "builtins.int" + reveal_type(l) # N: Revealed type is "builtins.int" + +[case testAsPatternNarrows] +m: object + +match m: + case int() as l: + reveal_type(l) # N: Revealed type is "builtins.int" + +[case testAsPatternCapturesOr] +m: object + +match m: + case 1 | 2 as n: + reveal_type(n) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]" + +[case testAsPatternAlreadyNarrower] +m: bool + +match m: + case int() as l: + reveal_type(l) # N: Revealed type is "builtins.bool" + + +-- Or Pattern -- +[case testOrPatternNarrows] +m: object + +match m: + case 1 | 2: + reveal_type(m) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]" + +[case testOrPatternNarrowsStr] +m: object + +match m: + case "foo" | "bar": + reveal_type(m) # N: Revealed type is "Union[Literal['foo']?, Literal['bar']?]" + +[case testOrPatternNarrowsUnion] +m: object + +match m: + case 1 | "foo": + reveal_type(m) # N: Revealed type is "Union[Literal[1]?, Literal['foo']?]" + +[case testOrPatterCapturesMissing] +from typing import List +m: List[int] + +match m: + case [x, y] | list(x): # E: Alternative patterns bind different names + reveal_type(x) # N: Revealed type is "builtins.object" + reveal_type(y) # N: Revealed type is "builtins.int*" +[builtins fixtures/list.pyi] + +[case testOrPatternCapturesJoin] +m: object + +match m: + case list(x) | dict(x): + reveal_type(x) # N: Revealed type is "typing.Iterable[Any]" +[builtins fixtures/dict.pyi] + + +-- Interactions -- +[case testCapturePatternMultipleCases] +m: object + +match m: + case int(x): + reveal_type(x) # N: Revealed type is "builtins.int" + case str(x): + reveal_type(x) # N: Revealed type is "builtins.str" + +reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]" + +[case testCapturePatternMultipleCaptures] +from typing import Iterable + +m: Iterable[int] + +match m: + case [x, x]: # E: Multiple assignments to name "x" in pattern + reveal_type(x) # N: Revealed type is "builtins.int" +[builtins fixtures/list.pyi] + +[case testCapturePatternPreexistingSame] +a: int +m: int + +match m: + case a: + reveal_type(a) # N: Revealed type is "builtins.int" + +[case testCapturePatternPreexistingIncompatible] +a: str +m: int + +match m: + case a: # E: Incompatible types in capture pattern (pattern captures type "int", variable has type "str") + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testCapturePatternPreexistingIncompatibleLater] +a: str +m: object + +match m: + case str(a): + reveal_type(a) # N: Revealed type is "builtins.str" + case int(a): # E: Incompatible types in capture pattern (pattern captures type "int", variable has type "str") + reveal_type(a) # N: Revealed type is "builtins.str" + + +-- Guards -- +[case testSimplePatternGuard] +m: str + +def guard() -> bool: ... + +match m: + case a if guard(): + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testAlwaysTruePatternGuard] +m: str + +match m: + case a if True: + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testAlwaysFalsePatternGuard] +m: str + +match m: + case a if False: + reveal_type(a) + +[case testRedefiningPatternGuard] +# flags: --strict-optional +m: str + +match m: + case a if a := 1: # E: Incompatible types in assignment (expression has type "int", variable has type "str") + reveal_type(a) # N: Revealed type is "" + +[case testAssigningPatternGuard] +m: str + +match m: + case a if a := "test": + reveal_type(a) # N: Revealed type is "builtins.str" + +[case testNarrowingPatternGuard] +m: object + +match m: + case a if isinstance(a, str): + reveal_type(a) # N: Revealed type is "builtins.str" +[builtins fixtures/isinstancelist.pyi] + +[case testIncompatiblePatternGuard] +class A: ... +class B: ... + +m: A + +match m: + case a if isinstance(a, B): + reveal_type(a) # N: Revealed type is "__main__." +[builtins fixtures/isinstancelist.pyi] + +[case testUnreachablePatternGuard] +m: str + +match m: + case a if isinstance(a, int): + reveal_type(a) +[builtins fixtures/isinstancelist.pyi] + +-- Exhaustiveness -- +[case testUnionNegativeNarrowing-skip] +from typing import Union + +m: Union[str, int] + +match m: + case str(a): + reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(m) # N: Revealed type is "builtins.str" + case b: + reveal_type(b) # N: Revealed type is "builtins.int" + reveal_type(m) # N: Revealed type is "builtins.int" + +[case testOrPatternNegativeNarrowing-skip] +from typing import Union + +m: Union[str, bytes, int] + +match m: + case str(a) | bytes(a): + reveal_type(a) # N: Revealed type is "builtins.object" + reveal_type(m) # N: Revealed type is "Union[builtins.str, builtins.bytes]" + case b: + reveal_type(b) # N: Revealed type is "builtins.int" + +[case testExhaustiveReturn-skip] +def foo(value) -> int: + match value: + case "bar": + return 1 case _: - 1 + b'' + return 2 + +[case testNoneExhaustiveReturn-skip] +def foo(value) -> int: # E: Missing return statement + match value: + case "bar": + return 1 + case 2: + return 2 diff --git a/test-data/unit/deps.test b/test-data/unit/deps.test index d833a05e201a..fd593a975ca0 100644 --- a/test-data/unit/deps.test +++ b/test-data/unit/deps.test @@ -1436,6 +1436,7 @@ class B(A): -> , m -> -> , m.B.__init__ + -> -> -> -> diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index fd509de8a6c2..f8a5e3481d13 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -32,7 +32,11 @@ class dict(Mapping[KT, VT]): def __len__(self) -> int: ... class int: # for convenience - def __add__(self, x: int) -> int: pass + def __add__(self, x: Union[int, complex]) -> int: pass + def __sub__(self, x: Union[int, complex]) -> int: pass + def __neg__(self): pass + real: int + imag: int class str: pass # for keyword argument key type class unicode: pass # needed for py2 docstrings diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index 71f59a9c1d8c..c72838535443 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -1,5 +1,6 @@ # builtins stub with non-generic primitive types -from typing import Generic, TypeVar, Sequence, Iterator, Mapping +from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable, overload + T = TypeVar('T') V = TypeVar('V') @@ -48,5 +49,20 @@ class list(Sequence[T]): def __getitem__(self, item: int) -> T: pass class dict(Mapping[T, V]): def __iter__(self) -> Iterator[T]: pass +class set(Iterable[T]): + def __iter__(self) -> Iterator[T]: pass +class frozenset(Iterable[T]): + def __iter__(self) -> Iterator[T]: pass class function: pass class ellipsis: pass + +class range(Sequence[int]): + @overload + def __init__(self, stop: int) -> None: pass + @overload + def __init__(self, start: int, stop: int, step: int = ...) -> None: pass + def count(self, value: int) -> int: pass + def index(self, value: int) -> int: pass + def __getitem__(self, i: int) -> int: pass + def __iter__(self) -> Iterator[int]: pass + def __contains__(self, other: object) -> bool: pass diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 3de9f934b255..739bf703f3e7 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -129,6 +129,10 @@ class Sequence(Iterable[T_co], Container[T_co]): @abstractmethod def __getitem__(self, n: Any) -> T_co: pass +class MutableSequence(Sequence[T]): + @abstractmethod + def __setitem__(self, n: Any, o: T) -> None: pass + class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta): def __getitem__(self, key: T) -> T_co: pass @overload diff --git a/test-data/unit/lib-stub/collections.pyi b/test-data/unit/lib-stub/collections.pyi index 71f797e565e8..7ea264f764ee 100644 --- a/test-data/unit/lib-stub/collections.pyi +++ b/test-data/unit/lib-stub/collections.pyi @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Union, Optional, Dict, TypeVar, overload, Optional, Callable, Sized +from typing import Any, Iterable, Union, Dict, TypeVar, Optional, Callable, Generic, Sequence, MutableMapping def namedtuple( typename: str, @@ -20,6 +20,6 @@ class defaultdict(Dict[KT, VT]): class Counter(Dict[KT, int], Generic[KT]): ... -class deque(Sized, Iterable[KT], Reversible[KT], Generic[KT]): ... +class deque(Sequence[KT], Generic[KT]): ... class ChainMap(MutableMapping[KT, VT], Generic[KT, VT]): ... diff --git a/test-data/unit/lib-stub/dataclasses.pyi b/test-data/unit/lib-stub/dataclasses.pyi index e0491e14876b..bd33b459266c 100644 --- a/test-data/unit/lib-stub/dataclasses.pyi +++ b/test-data/unit/lib-stub/dataclasses.pyi @@ -15,7 +15,6 @@ def dataclass(*, init: bool = ..., repr: bool = ..., eq: bool = ..., order: bool unsafe_hash: bool = ..., frozen: bool = ..., match_args: bool = ..., kw_only: bool = ..., slots: bool = ...) -> Callable[[Type[_T]], Type[_T]]: ... - @overload def field(*, default: _T, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., diff --git a/test-data/unit/lib-stub/types.pyi b/test-data/unit/lib-stub/types.pyi index 6fc596ecbf13..3ac4945ef5a7 100644 --- a/test-data/unit/lib-stub/types.pyi +++ b/test-data/unit/lib-stub/types.pyi @@ -5,8 +5,6 @@ _T = TypeVar('_T') def coroutine(func: _T) -> _T: pass -class bool: ... - class ModuleType: __file__ = ... # type: str diff --git a/test-data/unit/merge.test b/test-data/unit/merge.test index 836ad87857f8..b5d68899f019 100644 --- a/test-data/unit/merge.test +++ b/test-data/unit/merge.test @@ -671,15 +671,16 @@ TypeInfo<2>( _NT<6> __annotations__<7> (builtins.object<1>) __doc__<8> (builtins.str<9>) - __new__<10> - _asdict<11> - _field_defaults<12> (builtins.object<1>) - _field_types<13> (builtins.object<1>) - _fields<14> (Tuple[builtins.str<9>]) - _make<15> - _replace<16> - _source<17> (builtins.str<9>) - x<18> (target.A<0>))) + __match_args__<10> (Tuple[Literal['x']]) + __new__<11> + _asdict<12> + _field_defaults<13> (builtins.object<1>) + _field_types<14> (builtins.object<1>) + _fields<15> (Tuple[builtins.str<9>]) + _make<16> + _replace<17> + _source<18> (builtins.str<9>) + x<19> (target.A<0>))) ==> TypeInfo<0>( Name(target.A) @@ -694,16 +695,17 @@ TypeInfo<2>( _NT<6> __annotations__<7> (builtins.object<1>) __doc__<8> (builtins.str<9>) - __new__<10> - _asdict<11> - _field_defaults<12> (builtins.object<1>) - _field_types<13> (builtins.object<1>) - _fields<14> (Tuple[builtins.str<9>, builtins.str<9>]) - _make<15> - _replace<16> - _source<17> (builtins.str<9>) - x<18> (target.A<0>) - y<19> (target.A<0>))) + __match_args__<10> (Tuple[Literal['x'], Literal['y']]) + __new__<11> + _asdict<12> + _field_defaults<13> (builtins.object<1>) + _field_types<14> (builtins.object<1>) + _fields<15> (Tuple[builtins.str<9>, builtins.str<9>]) + _make<16> + _replace<17> + _source<18> (builtins.str<9>) + x<19> (target.A<0>) + y<20> (target.A<0>))) [case testUnionType_types] import target diff --git a/test-data/unit/parse-python310.test b/test-data/unit/parse-python310.test new file mode 100644 index 000000000000..87e0e9d5d283 --- /dev/null +++ b/test-data/unit/parse-python310.test @@ -0,0 +1,603 @@ +-- Test cases for parser -- Python 3.10 syntax (match statement) +-- +-- See parse.test for a description of this file format. + +[case testSimpleMatch] +match a: + case 1: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + ValuePattern:2( + IntExpr(1))) + Body( + PassStmt:3()))) + + +[case testTupleMatch] +match a, b: + case 1: + pass +[out] +MypyFile:1( + MatchStmt:1( + TupleExpr:1( + NameExpr(a) + NameExpr(b)) + Pattern( + ValuePattern:2( + IntExpr(1))) + Body( + PassStmt:3()))) + +[case testMatchWithGuard] +match a: + case 1 if f(): + pass + case d if d > 5: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + ValuePattern:2( + IntExpr(1))) + Guard( + CallExpr:2( + NameExpr(f) + Args())) + Body( + PassStmt:3()) + Pattern( + AsPattern:4( + NameExpr(d))) + Guard( + ComparisonExpr:4( + > + NameExpr(d) + IntExpr(5))) + Body( + PassStmt:5()))) + +[case testAsPattern] +match a: + case 1 as b: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + AsPattern:2( + ValuePattern:2( + IntExpr(1)) + NameExpr(b))) + Body( + PassStmt:3()))) + + +[case testLiteralPattern] +match a: + case 1: + pass + case -1: + pass + case 1+2j: + pass + case -1+2j: + pass + case 1-2j: + pass + case -1-2j: + pass + case "str": + pass + case b"bytes": + pass + case r"raw_string": + pass + case None: + pass + case True: + pass + case False: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + ValuePattern:2( + IntExpr(1))) + Body( + PassStmt:3()) + Pattern( + ValuePattern:4( + UnaryExpr:4( + - + IntExpr(1)))) + Body( + PassStmt:5()) + Pattern( + ValuePattern:6( + OpExpr:6( + + + IntExpr(1) + ComplexExpr(2j)))) + Body( + PassStmt:7()) + Pattern( + ValuePattern:8( + OpExpr:8( + + + UnaryExpr:8( + - + IntExpr(1)) + ComplexExpr(2j)))) + Body( + PassStmt:9()) + Pattern( + ValuePattern:10( + OpExpr:10( + - + IntExpr(1) + ComplexExpr(2j)))) + Body( + PassStmt:11()) + Pattern( + ValuePattern:12( + OpExpr:12( + - + UnaryExpr:12( + - + IntExpr(1)) + ComplexExpr(2j)))) + Body( + PassStmt:13()) + Pattern( + ValuePattern:14( + StrExpr(str))) + Body( + PassStmt:15()) + Pattern( + ValuePattern:16( + BytesExpr(bytes))) + Body( + PassStmt:17()) + Pattern( + ValuePattern:18( + StrExpr(raw_string))) + Body( + PassStmt:19()) + Pattern( + SingletonPattern:20()) + Body( + PassStmt:21()) + Pattern( + SingletonPattern:22( + True)) + Body( + PassStmt:23()) + Pattern( + SingletonPattern:24( + False)) + Body( + PassStmt:25()))) + +[case testCapturePattern] +match a: + case x: + pass + case longName: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + AsPattern:2( + NameExpr(x))) + Body( + PassStmt:3()) + Pattern( + AsPattern:4( + NameExpr(longName))) + Body( + PassStmt:5()))) + +[case testWildcardPattern] +match a: + case _: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + AsPattern:2()) + Body( + PassStmt:3()))) + +[case testValuePattern] +match a: + case b.c: + pass + case b.c.d.e.f: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + ValuePattern:2( + MemberExpr:2( + NameExpr(b) + c))) + Body( + PassStmt:3()) + Pattern( + ValuePattern:4( + MemberExpr:4( + MemberExpr:4( + MemberExpr:4( + MemberExpr:4( + NameExpr(b) + c) + d) + e) + f))) + Body( + PassStmt:5()))) + +[case testGroupPattern] +# This is optimized out by the compiler. It doesn't appear in the ast +match a: + case (1): + pass +[out] +MypyFile:1( + MatchStmt:2( + NameExpr(a) + Pattern( + ValuePattern:3( + IntExpr(1))) + Body( + PassStmt:4()))) + +[case testSequencePattern] +match a: + case []: + pass + case (): + pass + case [1]: + pass + case (1,): + pass + case 1,: + pass + case [1, 2, 3]: + pass + case (1, 2, 3): + pass + case 1, 2, 3: + pass + case [1, *a, 2]: + pass + case (1, *a, 2): + pass + case 1, *a, 2: + pass + case [1, *_, 2]: + pass + case (1, *_, 2): + pass + case 1, *_, 2: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + SequencePattern:2()) + Body( + PassStmt:3()) + Pattern( + SequencePattern:4()) + Body( + PassStmt:5()) + Pattern( + SequencePattern:6( + ValuePattern:6( + IntExpr(1)))) + Body( + PassStmt:7()) + Pattern( + SequencePattern:8( + ValuePattern:8( + IntExpr(1)))) + Body( + PassStmt:9()) + Pattern( + SequencePattern:10( + ValuePattern:10( + IntExpr(1)))) + Body( + PassStmt:11()) + Pattern( + SequencePattern:12( + ValuePattern:12( + IntExpr(1)) + ValuePattern:12( + IntExpr(2)) + ValuePattern:12( + IntExpr(3)))) + Body( + PassStmt:13()) + Pattern( + SequencePattern:14( + ValuePattern:14( + IntExpr(1)) + ValuePattern:14( + IntExpr(2)) + ValuePattern:14( + IntExpr(3)))) + Body( + PassStmt:15()) + Pattern( + SequencePattern:16( + ValuePattern:16( + IntExpr(1)) + ValuePattern:16( + IntExpr(2)) + ValuePattern:16( + IntExpr(3)))) + Body( + PassStmt:17()) + Pattern( + SequencePattern:18( + ValuePattern:18( + IntExpr(1)) + StarredPattern:18( + NameExpr(a)) + ValuePattern:18( + IntExpr(2)))) + Body( + PassStmt:19()) + Pattern( + SequencePattern:20( + ValuePattern:20( + IntExpr(1)) + StarredPattern:20( + NameExpr(a)) + ValuePattern:20( + IntExpr(2)))) + Body( + PassStmt:21()) + Pattern( + SequencePattern:22( + ValuePattern:22( + IntExpr(1)) + StarredPattern:22( + NameExpr(a)) + ValuePattern:22( + IntExpr(2)))) + Body( + PassStmt:23()) + Pattern( + SequencePattern:24( + ValuePattern:24( + IntExpr(1)) + StarredPattern:24() + ValuePattern:24( + IntExpr(2)))) + Body( + PassStmt:25()) + Pattern( + SequencePattern:26( + ValuePattern:26( + IntExpr(1)) + StarredPattern:26() + ValuePattern:26( + IntExpr(2)))) + Body( + PassStmt:27()) + Pattern( + SequencePattern:28( + ValuePattern:28( + IntExpr(1)) + StarredPattern:28() + ValuePattern:28( + IntExpr(2)))) + Body( + PassStmt:29()))) + +[case testMappingPattern] +match a: + case {'k': v}: + pass + case {a.b: v}: + pass + case {1: v}: + pass + case {a.c: v}: + pass + case {'k': v1, a.b: v2, 1: v3, a.c: v4}: + pass + case {'k1': 1, 'k2': "str", 'k3': b'bytes', 'k4': None}: + pass + case {'k': v, **r}: + pass + case {**r}: + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + MappingPattern:2( + Key( + StrExpr(k)) + Value( + AsPattern:2( + NameExpr(v))))) + Body( + PassStmt:3()) + Pattern( + MappingPattern:4( + Key( + MemberExpr:4( + NameExpr(a) + b)) + Value( + AsPattern:4( + NameExpr(v))))) + Body( + PassStmt:5()) + Pattern( + MappingPattern:6( + Key( + IntExpr(1)) + Value( + AsPattern:6( + NameExpr(v))))) + Body( + PassStmt:7()) + Pattern( + MappingPattern:8( + Key( + MemberExpr:8( + NameExpr(a) + c)) + Value( + AsPattern:8( + NameExpr(v))))) + Body( + PassStmt:9()) + Pattern( + MappingPattern:10( + Key( + StrExpr(k)) + Value( + AsPattern:10( + NameExpr(v1))) + Key( + MemberExpr:10( + NameExpr(a) + b)) + Value( + AsPattern:10( + NameExpr(v2))) + Key( + IntExpr(1)) + Value( + AsPattern:10( + NameExpr(v3))) + Key( + MemberExpr:10( + NameExpr(a) + c)) + Value( + AsPattern:10( + NameExpr(v4))))) + Body( + PassStmt:11()) + Pattern( + MappingPattern:12( + Key( + StrExpr(k1)) + Value( + ValuePattern:12( + IntExpr(1))) + Key( + StrExpr(k2)) + Value( + ValuePattern:12( + StrExpr(str))) + Key( + StrExpr(k3)) + Value( + ValuePattern:12( + BytesExpr(bytes))) + Key( + StrExpr(k4)) + Value( + SingletonPattern:12()))) + Body( + PassStmt:13()) + Pattern( + MappingPattern:14( + Key( + StrExpr(k)) + Value( + AsPattern:14( + NameExpr(v))) + Rest( + NameExpr(r)))) + Body( + PassStmt:15()) + Pattern( + MappingPattern:16( + Rest( + NameExpr(r)))) + Body( + PassStmt:17()))) + +[case testClassPattern] +match a: + case A(): + pass + case B(1, 2): + pass + case B(1, b=2): + pass + case B(a=1, b=2): + pass +[out] +MypyFile:1( + MatchStmt:1( + NameExpr(a) + Pattern( + ClassPattern:2( + NameExpr(A))) + Body( + PassStmt:3()) + Pattern( + ClassPattern:4( + NameExpr(B) + Positionals( + ValuePattern:4( + IntExpr(1)) + ValuePattern:4( + IntExpr(2))))) + Body( + PassStmt:5()) + Pattern( + ClassPattern:6( + NameExpr(B) + Positionals( + ValuePattern:6( + IntExpr(1))) + Keyword( + b + ValuePattern:6( + IntExpr(2))))) + Body( + PassStmt:7()) + Pattern( + ClassPattern:8( + NameExpr(B) + Keyword( + a + ValuePattern:8( + IntExpr(1))) + Keyword( + b + ValuePattern:8( + IntExpr(2))))) + Body( + PassStmt:9()))) diff --git a/test-data/unit/semanal-errors-python310.test b/test-data/unit/semanal-errors-python310.test new file mode 100644 index 000000000000..68c158cddae6 --- /dev/null +++ b/test-data/unit/semanal-errors-python310.test @@ -0,0 +1,43 @@ +[case testMatchUndefinedSubject] +import typing +match x: + case _: + pass +[out] +main:2: error: Name "x" is not defined + +[case testMatchUndefinedValuePattern] +import typing +x = 1 +match x: + case a.b: + pass +[out] +main:4: error: Name "a" is not defined + +[case testMatchUndefinedClassPattern] +import typing +x = 1 +match x: + case A(): + pass +[out] +main:4: error: Name "A" is not defined + +[case testNoneBindingWildcardPattern] +import typing +x = 1 +match x: + case _: + _ +[out] +main:5: error: Name "_" is not defined + +[case testNoneBindingStarredWildcardPattern] +import typing +x = 1 +match x: + case [*_]: + _ +[out] +main:5: error: Name "_" is not defined diff --git a/test-data/unit/semanal-python310.test b/test-data/unit/semanal-python310.test new file mode 100644 index 000000000000..a009636575dc --- /dev/null +++ b/test-data/unit/semanal-python310.test @@ -0,0 +1,204 @@ +-- Python 3.10 semantic analysis test cases. + +[case testCapturePattern] +x = 1 +match x: + case a: + a +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + AsPattern:3( + NameExpr(a* [__main__.a]))) + Body( + ExpressionStmt:4( + NameExpr(a [__main__.a]))))) + +[case testCapturePatternOutliving] +x = 1 +match x: + case a: + pass +a +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + AsPattern:3( + NameExpr(a* [__main__.a]))) + Body( + PassStmt:4())) + ExpressionStmt:5( + NameExpr(a [__main__.a]))) + +[case testNestedCapturePatterns] +x = 1 +match x: + case ([a], {'k': b}): + a + b +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + SequencePattern:3( + SequencePattern:3( + AsPattern:3( + NameExpr(a* [__main__.a]))) + MappingPattern:3( + Key( + StrExpr(k)) + Value( + AsPattern:3( + NameExpr(b* [__main__.b])))))) + Body( + ExpressionStmt:4( + NameExpr(a [__main__.a])) + ExpressionStmt:5( + NameExpr(b [__main__.b]))))) + +[case testMappingPatternRest] +x = 1 +match x: + case {**r}: + r +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + MappingPattern:3( + Rest( + NameExpr(r* [__main__.r])))) + Body( + ExpressionStmt:4( + NameExpr(r [__main__.r]))))) + + +[case testAsPattern] +x = 1 +match x: + case 1 as a: + a +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + AsPattern:3( + ValuePattern:3( + IntExpr(1)) + NameExpr(a* [__main__.a]))) + Body( + ExpressionStmt:4( + NameExpr(a [__main__.a]))))) + +[case testGuard] +x = 1 +a = 1 +match x: + case 1 if a: + pass +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + AssignmentStmt:2( + NameExpr(a* [__main__.a]) + IntExpr(1)) + MatchStmt:3( + NameExpr(x [__main__.x]) + Pattern( + ValuePattern:4( + IntExpr(1))) + Guard( + NameExpr(a [__main__.a])) + Body( + PassStmt:5()))) + +[case testCapturePatternInGuard] +x = 1 +match x: + case a if a: + pass +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + AsPattern:3( + NameExpr(a* [__main__.a]))) + Guard( + NameExpr(a [__main__.a])) + Body( + PassStmt:4()))) + +[case testAsPatternInGuard] +x = 1 +match x: + case 1 as a if a: + pass +[out] +MypyFile:1( + AssignmentStmt:1( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:2( + NameExpr(x [__main__.x]) + Pattern( + AsPattern:3( + ValuePattern:3( + IntExpr(1)) + NameExpr(a* [__main__.a]))) + Guard( + NameExpr(a [__main__.a])) + Body( + PassStmt:4()))) + +[case testValuePattern] +import _a + +x = 1 +match x: + case _a.b: + pass +[file _a.py] +b = 1 +[out] +MypyFile:1( + Import:1(_a) + AssignmentStmt:3( + NameExpr(x* [__main__.x]) + IntExpr(1)) + MatchStmt:4( + NameExpr(x [__main__.x]) + Pattern( + ValuePattern:5( + MemberExpr:5( + NameExpr(_a) + b [_a.b]))) + Body( + PassStmt:6()))) diff --git a/test-requirements.txt b/test-requirements.txt index c5db79ada816..ac9e27c1bb2c 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -6,7 +6,8 @@ flake8-bugbear flake8-pyi>=20.5 lxml>=4.4.0 psutil>=4.0 -pytest>=6.2.0,<7.0.0 +# pytest 6.2.3 does not support Python 3.10 +pytest>=6.2.4,<7.0.0 pytest-xdist>=1.34.0,<2.0.0 pytest-forked>=1.3.0,<2.0.0 pytest-cov>=2.10.0,<3.0.0