diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst index 678d09c21a2e..488d719e8f50 100644 --- a/docs/source/command_line.rst +++ b/docs/source/command_line.rst @@ -386,6 +386,23 @@ of the above sections. # 'items' now has type List[List[str]] ... +``--strict-equality`` + By default, mypy allows always-false comparisons like ``42 == 'no'``. + Use this flag to prohibit such comparisons of non-overlapping types, and + similar identity and container checks: + + .. code-block:: python + + from typing import Text + + text: Text + if b'some bytes' in text: # Error: non-overlapping check! + ... + if text != b'other bytes': # Error: non-overlapping check! + ... + + assert text is not None # OK, this special case is allowed. + ``--strict`` This flag mode enables all optional error checking flags. You can see the list of flags enabled by strict mode in the full ``mypy --help`` output. diff --git a/docs/source/config_file.rst b/docs/source/config_file.rst index a2e64d69a83b..8575a3b38a8a 100644 --- a/docs/source/config_file.rst +++ b/docs/source/config_file.rst @@ -294,6 +294,10 @@ Miscellaneous strictness flags Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition. +``strict_equality`` (bool, default False) + Prohibit equality checks, identity checks, and container checks between + non-overlapping types. + Global-only options ******************* diff --git a/mypy/checker.py b/mypy/checker.py index 7559186964f5..758121f82dfa 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2997,6 +2997,25 @@ def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]: nextmethod = 'next' return iterator, echk.check_method_call_by_name(nextmethod, iterator, [], [], expr)[0] + def analyze_container_item_type(self, typ: Type) -> Optional[Type]: + """Check if a type is a nominal container of a union of such. + + Return the corresponding container item type. + """ + if isinstance(typ, UnionType): + types = [] # type: List[Type] + for item in typ.items: + c_type = self.analyze_container_item_type(item) + if c_type: + types.append(c_type) + return UnionType.make_union(types) + if isinstance(typ, Instance) and typ.type.has_base('typing.Container'): + supertype = self.named_type('typing.Container').type + super_instance = map_instance_to_supertype(typ, supertype) + assert len(super_instance.args) == 1 + return super_instance.args[0] + return None + def analyze_index_variables(self, index: Expression, item_type: Type, infer_lvalue_type: bool, context: Context) -> None: """Type check or infer for loop or list comprehension index vars.""" diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index eed21748c49e..60aa0bd75e69 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -44,7 +44,7 @@ from mypy import message_registry from mypy.infer import infer_type_arguments, infer_function_type_arguments from mypy import join -from mypy.meet import narrow_declared_type +from mypy.meet import narrow_declared_type, is_overlapping_types from mypy.subtypes import ( is_subtype, is_proper_subtype, is_equivalent, find_member, non_method_protocol_members, ) @@ -1914,6 +1914,11 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: _, method_type = self.check_method_call_by_name( '__contains__', right_type, [left], [ARG_POS], e, local_errors) sub_result = self.bool_type() + # Container item type for strict type overlap checks. Note: we need to only + # check for nominal type, because a usual "Unsupported operands for in" + # will be reported for types incompatible with __contains__(). + # See testCustomContainsCheckStrictEquality for an example. + cont_type = self.chk.analyze_container_item_type(right_type) if isinstance(right_type, PartialType): # We don't really know if this is an error or not, so just shut up. pass @@ -1929,16 +1934,29 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: self.named_type('builtins.function')) if not is_subtype(left_type, itertype): self.msg.unsupported_operand_types('in', left_type, right_type, e) + # Only show dangerous overlap if there are no other errors. + elif (not local_errors.is_errors() and cont_type and + self.dangerous_comparison(left_type, cont_type)): + self.msg.dangerous_comparison(left_type, cont_type, 'container', e) else: self.msg.add_errors(local_errors) elif operator in nodes.op_methods: method = self.get_operator_method(operator) + err_count = self.msg.errors.total_errors() sub_result, method_type = self.check_op(method, left_type, right, e, - allow_reverse=True) + allow_reverse=True) + # Only show dangerous overlap if there are no other errors. See + # testCustomEqCheckStrictEquality for an example. + if self.msg.errors.total_errors() == err_count and operator in ('==', '!='): + right_type = self.accept(right) + if self.dangerous_comparison(left_type, right_type): + self.msg.dangerous_comparison(left_type, right_type, 'equality', e) elif operator == 'is' or operator == 'is not': - self.accept(right) # validate the right operand + right_type = self.accept(right) # validate the right operand sub_result = self.bool_type() + if self.dangerous_comparison(left_type, right_type): + self.msg.dangerous_comparison(left_type, right_type, 'identity', e) method_type = None else: raise RuntimeError('Unknown comparison operator {}'.format(operator)) @@ -1954,6 +1972,30 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: assert result is not None return result + def dangerous_comparison(self, left: Type, right: Type) -> bool: + """Check for dangerous non-overlapping comparisons like 42 == 'no'. + + Rules: + * X and None are overlapping even in strict-optional mode. This is to allow + 'assert x is not None' for x defined as 'x = None # type: str' in class body + (otherwise mypy itself would have couple dozen errors because of this). + * Optional[X] and Optional[Y] are non-overlapping if X and Y are + non-overlapping, although technically None is overlap, it is most + likely an error. + * Any overlaps with everything, i.e. always safe. + * Promotions are ignored, so both 'abc' == b'abc' and 1 == 1.0 + are errors. This is mostly needed for bytes vs unicode, and + int vs float are added just for consistency. + """ + if not self.chk.options.strict_equality: + return False + if isinstance(left, NoneTyp) or isinstance(right, NoneTyp): + return False + if isinstance(left, UnionType) and isinstance(right, UnionType): + left = remove_optional(left) + right = remove_optional(right) + return not is_overlapping_types(left, right, ignore_promotions=True) + def get_operator_method(self, op: str) -> str: if op == '/' and self.chk.options.python_version[0] == 2: # TODO also check for "from __future__ import division" diff --git a/mypy/errors.py b/mypy/errors.py index a177b5d6805a..0053e3ec08c4 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -169,6 +169,9 @@ def copy(self) -> 'Errors': new.scope = self.scope return new + def total_errors(self) -> int: + return sum(len(errs) for errs in self.error_info_map.values()) + def set_ignore_prefix(self, prefix: str) -> None: """Set path prefix that will be removed from all paths.""" prefix = os.path.normpath(prefix) diff --git a/mypy/main.py b/mypy/main.py index 5bccd0af2a3f..445f18e2eb08 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -527,6 +527,11 @@ def add_invertible_flag(flag: str, help="Allow unconditional variable redefinition with a new type", group=strictness_group) + add_invertible_flag('--strict-equality', default=False, strict_flag=False, + help="Prohibit equality, identity, and container checks for" + " non-overlapping types", + group=strictness_group) + incremental_group = parser.add_argument_group( title='Incremental mode', description="Adjust how mypy incrementally type checks and caches modules. " diff --git a/mypy/meet.py b/mypy/meet.py index 62f1c9f85356..10d5b051293a 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -221,9 +221,28 @@ def is_none_typevar_overlap(t1: Type, t2: Type) -> bool: # As before, we degrade into 'Instance' whenever possible. if isinstance(left, TypeType) and isinstance(right, TypeType): - # TODO: Can Callable[[...], T] and Type[T] be partially overlapping? return _is_overlapping_types(left.item, right.item) + def _type_object_overlap(left: Type, right: Type) -> bool: + """Special cases for type object types overlaps.""" + # TODO: these checks are a bit in gray area, adjust if they cause problems. + # 1. Type[C] vs Callable[..., C], where the latter is class object. + if isinstance(left, TypeType) and isinstance(right, CallableType) and right.is_type_obj(): + return _is_overlapping_types(left.item, right.ret_type) + # 2. Type[C] vs Meta, where Meta is a metaclass for C. + if (isinstance(left, TypeType) and isinstance(left.item, Instance) and + isinstance(right, Instance)): + left_meta = left.item.type.metaclass_type + if left_meta is not None: + return _is_overlapping_types(left_meta, right) + # builtins.type (default metaclass) overlaps with all metaclasses + return right.type.has_base('builtins.type') + # 3. Callable[..., C] vs Meta is considered below, when we switch to fallbacks. + return False + + if isinstance(left, TypeType) or isinstance(right, TypeType): + return _type_object_overlap(left, right) or _type_object_overlap(right, left) + if isinstance(left, CallableType) and isinstance(right, CallableType): return is_callable_compatible(left, right, is_compat=_is_overlapping_types, diff --git a/mypy/messages.py b/mypy/messages.py index 3512bd878cb6..ce7a90ea32d1 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -975,6 +975,13 @@ def incompatible_typevar_value(self, .format(typevar_name, callable_name(callee) or 'function', self.format(typ)), context) + def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None: + left_str = 'element' if kind == 'container' else 'left operand' + right_str = 'container item' if kind == 'container' else 'right operand' + message = 'Non-overlapping {} check ({} type: {}, {} type: {})' + left_typ, right_typ = self.format_distinctly(left, right) + self.fail(message.format(kind, left_str, left_typ, right_str, right_typ), ctx) + def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None: self.fail( 'Overload does not consistently use the "@{}" '.format(decorator) diff --git a/mypy/options.py b/mypy/options.py index 506c127fb3e8..66d7854a2ec8 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -22,6 +22,7 @@ class BuildType: # Please keep this list sorted "allow_untyped_globals", "allow_redefinition", + "strict_equality", "always_false", "always_true", "check_untyped_defs", @@ -157,6 +158,10 @@ def __init__(self) -> None: # and the same nesting level as the initialization self.allow_redefinition = False + # Prohibit equality, identity, and container checks for non-overlapping types. + # This makes 1 == '1', 1 in ['1'], and 1 is '1' errors. + self.strict_equality = False + # Variable names considered True self.always_true = [] # type: List[str] diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 008b697845d0..e893f5fe3db9 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -508,7 +508,7 @@ class C: # type: ('int') -> bool pass -[builtins_py2 fixtures/bool.pyi] +[builtins_py2 fixtures/bool_py2.pyi] [case cmpIgnoredPy3] @@ -604,7 +604,7 @@ class X: class Y: def __lt__(self, o: 'Y') -> A: pass def __gt__(self, o: 'Y') -> A: pass - def __eq__(self, o: 'Y') -> B: pass + def __eq__(self, o: 'Y') -> B: pass # type: ignore [builtins fixtures/bool.pyi] @@ -1947,3 +1947,175 @@ a.__pow__() # E: Too few arguments for "__pow__" of "int" x, y = [], [] # E: Need type annotation for 'x' \ # E: Need type annotation for 'y' [builtins fixtures/list.pyi] + +[case testStrictEqualityEq] +# flags: --strict-equality +class A: ... +class B: ... +class C(B): ... + +A() == B() # E: Non-overlapping equality check (left operand type: "A", right operand type: "B") +B() == C() +C() == B() +A() != B() # E: Non-overlapping equality check (left operand type: "A", right operand type: "B") +B() != C() +C() != B() +[builtins fixtures/bool.pyi] + +[case testStrictEqualityIs] +# flags: --strict-equality +class A: ... +class B: ... +class C(B): ... + +A() is B() # E: Non-overlapping identity check (left operand type: "A", right operand type: "B") +B() is C() +C() is B() +A() is not B() # E: Non-overlapping identity check (left operand type: "A", right operand type: "B") +B() is not C() +C() is not B() +[builtins fixtures/bool.pyi] + +[case testStrictEqualityContains] +# flags: --strict-equality +class A: ... +class B: ... +class C(B): ... + +A() in [B()] # E: Non-overlapping container check (element type: "A", container item type: "B") +B() in [C()] +C() in [B()] +A() not in [B()] # E: Non-overlapping container check (element type: "A", container item type: "B") +B() not in [C()] +C() not in [B()] +[builtins fixtures/list.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityUnions] +# flags: --strict-equality +from typing import Container, Union + +class A: ... +class B: ... + +a: Union[int, str] +b: Union[A, B] + +a == 42 +b == 42 # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "int") + +a is 42 +b is 42 # E: Non-overlapping identity check (left operand type: "Union[A, B]", right operand type: "int") + +ca: Union[Container[int], Container[str]] +cb: Union[Container[A], Container[B]] + +42 in ca +42 in cb # E: Non-overlapping container check (element type: "int", container item type: "Union[A, B]") +[builtins fixtures/bool.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityNoPromote] +# flags: --strict-equality +'a' == b'a' # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes") +b'a' in 'abc' # E: Non-overlapping container check (element type: "bytes", container item type: "str") + +x: str +y: bytes +x != y # E: Non-overlapping equality check (left operand type: "str", right operand type: "bytes") +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityAny] +# flags: --strict-equality +from typing import Any, Container + +x: Any +c: Container[str] +x in c +x == 42 +x is 42 +[builtins fixtures/bool.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityStrictOptional] +# flags: --strict-equality --strict-optional + +x: str +if x is not None: # OK even with strict-optional + pass +[builtins fixtures/bool.pyi] + +[case testStrictEqualityNoStrictOptional] +# flags: --strict-equality --no-strict-optional + +x: str +if x is not None: # OK without strict-optional + pass +[builtins fixtures/bool.pyi] + +[case testStrictEqualityEqNoOptionalOverlap] +# flags: --strict-equality --strict-optional +from typing import Optional + +x: Optional[str] +y: Optional[int] +if x == y: # E: Non-overlapping equality check (left operand type: "Optional[str]", right operand type: "Optional[int]") + ... +[builtins fixtures/bool.pyi] + +[case testCustomEqCheckStrictEquality] +# flags: --strict-equality +class A: + def __eq__(self, other: A) -> bool: # type: ignore + ... +class B: + def __eq__(self, other: B) -> bool: # type: ignore + ... + +# Don't report non-overlapping check if there is already and error. +A() == B() # E: Unsupported operand types for == ("A" and "B") +[builtins fixtures/bool.pyi] + +[case testCustomContainsCheckStrictEquality] +# flags: --strict-equality +class A: + def __contains__(self, other: A) -> bool: + ... + +# Don't report non-overlapping check if there is already and error. +42 in A() # E: Unsupported operand types for in ("int" and "A") +[builtins fixtures/bool.pyi] + +[case testStrictEqualityTypeVsCallable] +# flags: --strict-equality +from typing import Type, List +class C: ... +class D(C): ... +class Bad: ... + +subclasses: List[Type[C]] +object in subclasses +D in subclasses +Bad in subclasses # E: Non-overlapping container check (element type: "Type[Bad]", container item type: "Type[C]") +[builtins fixtures/list.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityMetaclass] +# flags: --strict-equality +from typing import List, Type + +class Meta(type): ... + +class A(metaclass=Meta): ... +class B(metaclass=Meta): ... +class C: ... + +o: Type[object] +exp: List[Meta] + +A in exp +C in exp # E: Non-overlapping container check (element type: "Type[C]", container item type: "Meta") +o in exp +[builtins fixtures/list.pyi] +[typing fixtures/typing-full.pyi] diff --git a/test-data/unit/check-flags.test b/test-data/unit/check-flags.test index cb3156d8a39a..09818180efc2 100644 --- a/test-data/unit/check-flags.test +++ b/test-data/unit/check-flags.test @@ -1107,3 +1107,16 @@ class A(Generic[T]): def f(c: A) -> None: # E: Missing type parameters for generic type pass [out] + +[case testStrictEqualityPerFile] +# flags: --config-file tmp/mypy.ini +import b +42 == 'no' # E: Non-overlapping equality check (left operand type: "int", right operand type: "str") +[file b.py] +42 == 'no' +[file mypy.ini] +[[mypy] +strict_equality = True +[[mypy-b] +strict_equality = False +[builtins fixtures/bool.pyi] diff --git a/test-data/unit/fixtures/async_await.pyi b/test-data/unit/fixtures/async_await.pyi index b6161f45dacf..ed64289c0d4d 100644 --- a/test-data/unit/fixtures/async_await.pyi +++ b/test-data/unit/fixtures/async_await.pyi @@ -5,6 +5,7 @@ U = typing.TypeVar('U') class list(typing.Sequence[T]): def __iter__(self) -> typing.Iterator[T]: ... def __getitem__(self, i: int) -> T: ... + def __contains__(self, item: object) -> bool: ... class object: def __init__(self) -> None: pass @@ -12,7 +13,7 @@ class type: pass class function: pass class int: pass class str: pass -class bool: pass +class bool(int): pass class dict(typing.Generic[T, U]): pass class set(typing.Generic[T]): pass class tuple(typing.Generic[T]): pass diff --git a/test-data/unit/fixtures/bool.pyi b/test-data/unit/fixtures/bool.pyi index bf506d97312f..07bc461819a0 100644 --- a/test-data/unit/fixtures/bool.pyi +++ b/test-data/unit/fixtures/bool.pyi @@ -4,6 +4,8 @@ T = TypeVar('T') class object: def __init__(self) -> None: pass + def __eq__(self, other: object) -> bool: pass + def __ne__(self, other: object) -> bool: pass class type: pass class tuple(Generic[T]): pass diff --git a/test-data/unit/fixtures/bool_py2.pyi b/test-data/unit/fixtures/bool_py2.pyi new file mode 100644 index 000000000000..b2c935132d57 --- /dev/null +++ b/test-data/unit/fixtures/bool_py2.pyi @@ -0,0 +1,16 @@ +# builtins stub used in boolean-related test cases. +from typing import Generic, TypeVar +import sys +T = TypeVar('T') + +class object: + def __init__(self) -> None: pass + +class type: pass +class tuple(Generic[T]): pass +class function: pass +class bool: pass +class int: pass +class str: pass +class unicode: pass +class ellipsis: pass diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index 93648b274d98..d7e8d11b7d0b 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -39,11 +39,12 @@ class list(Sequence[T]): # needed by some test cases def __getitem__(self, x: int) -> T: pass def __iter__(self) -> Iterator[T]: pass def __mul__(self, x: int) -> list[T]: pass + def __contains__(self, item: object) -> bool: pass class tuple(Generic[T]): pass class function: pass class float: pass -class bool: pass +class bool(int): pass class ellipsis: pass def isinstance(x: object, t: Union[type, Tuple[type, ...]]) -> bool: pass diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi index 6b93f16d2247..25ff5888a2cf 100644 --- a/test-data/unit/fixtures/isinstancelist.pyi +++ b/test-data/unit/fixtures/isinstancelist.pyi @@ -35,6 +35,7 @@ class list(Sequence[T]): def __setitem__(self, x: int, v: T) -> None: pass def __getitem__(self, x: int) -> T: pass def __add__(self, x: List[T]) -> T: pass + def __contains__(self, item: object) -> bool: pass class dict(Mapping[KT, VT]): @overload diff --git a/test-data/unit/fixtures/primitives.pyi b/test-data/unit/fixtures/primitives.pyi index a2c1f390f65c..796196fa08c6 100644 --- a/test-data/unit/fixtures/primitives.pyi +++ b/test-data/unit/fixtures/primitives.pyi @@ -1,10 +1,12 @@ # builtins stub with non-generic primitive types -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Sequence, Iterator T = TypeVar('T') class object: def __init__(self) -> None: pass def __str__(self) -> str: pass + def __eq__(self, other: object) -> bool: pass + def __ne__(self, other: object) -> bool: pass class type: def __init__(self, x) -> None: pass @@ -15,10 +17,14 @@ class float: def __float__(self) -> float: pass class complex: pass class bool(int): pass -class str: +class str(Sequence[str]): def __add__(self, s: str) -> str: pass + def __iter__(self) -> Iterator[str]: pass + def __contains__(self, other: object) -> bool: pass + def __getitem__(self, item: int) -> str: pass def format(self, *args) -> str: pass class bytes: pass class bytearray: pass class tuple(Generic[T]): pass class function: pass +class ellipsis: pass diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 82e119d043e1..0bfc46c0d992 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -39,10 +39,10 @@ S = TypeVar('S') # to silence the protocol variance checks. Maybe it is better to use type: ignore? @runtime -class Container(Protocol[T_contra]): +class Container(Protocol[T_co]): @abstractmethod # Use int because bool isn't in the default test builtins - def __contains__(self, arg: T_contra) -> int: pass + def __contains__(self, arg: object) -> int: pass @runtime class Sized(Protocol): @@ -117,7 +117,7 @@ class AsyncIterator(AsyncIterable[T], Protocol): @abstractmethod def __anext__(self) -> Awaitable[T]: pass -class Sequence(Iterable[T_co]): +class Sequence(Iterable[T_co], Container[T_co]): @abstractmethod def __getitem__(self, n: Any) -> T_co: pass