From bf2a0b36984543471871191a90e6e9a24ca027e7 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Sun, 13 Nov 2022 13:57:15 -0800 Subject: [PATCH] stubtest: if a default is present in the stub, check that it is correct Helps with python/typeshed#8988. --- mypy/stubtest.py | 191 ++++++++++++++++++++++++++++++++++++++ mypy/test/teststubtest.py | 55 ++++++++++- 2 files changed, 245 insertions(+), 1 deletion(-) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 87ccbd3176df..e0ad7d18773f 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -7,6 +7,7 @@ from __future__ import annotations import argparse +import ast import collections.abc import copy import enum @@ -29,6 +30,7 @@ import mypy.build import mypy.modulefinder +import mypy.nodes import mypy.state import mypy.types import mypy.version @@ -36,6 +38,7 @@ from mypy.config_parser import parse_config_file from mypy.options import Options from mypy.util import FancyFormatter, bytes_to_human_readable_repr, is_dunder, plural_s +from mypy.visitor import ExpressionVisitor class Missing: @@ -540,6 +543,179 @@ def names_approx_match(a: str, b: str) -> bool: ) +class _NodeEvaluator(ExpressionVisitor[object]): + def visit_int_expr(self, o: mypy.nodes.IntExpr) -> int: + return o.value + + def visit_str_expr(self, o: mypy.nodes.StrExpr) -> str: + return o.value + + def visit_bytes_expr(self, o: mypy.nodes.BytesExpr) -> bytes: + try: + return ast.literal_eval(f"b'{o.value}'") + except SyntaxError: + return ast.literal_eval(f'b"{o.value}"') + + def visit_float_expr(self, o: mypy.nodes.FloatExpr) -> float: + return o.value + + def visit_complex_expr(self, o: mypy.nodes.ComplexExpr) -> object: + return o.value + + def visit_ellipsis(self, o: mypy.nodes.EllipsisExpr) -> object: + return Ellipsis + + def visit_star_expr(self, o: mypy.nodes.StarExpr) -> object: + return MISSING + + def visit_name_expr(self, o: mypy.nodes.NameExpr) -> object: + if o.name == "True": + return True + elif o.name == "False": + return False + elif o.name == "None": + return None + return MISSING + + def visit_member_expr(self, o: mypy.nodes.MemberExpr) -> object: + return MISSING + + def visit_yield_from_expr(self, o: mypy.nodes.YieldFromExpr) -> object: + return MISSING + + def visit_yield_expr(self, o: mypy.nodes.YieldExpr) -> object: + return MISSING + + def visit_call_expr(self, o: mypy.nodes.CallExpr) -> object: + return MISSING + + def visit_op_expr(self, o: mypy.nodes.OpExpr) -> object: + return MISSING + + def visit_comparison_expr(self, o: mypy.nodes.ComparisonExpr) -> object: + return MISSING + + def visit_cast_expr(self, o: mypy.nodes.CastExpr) -> object: + return o.expr.accept(self) + + def visit_assert_type_expr(self, o: mypy.nodes.AssertTypeExpr) -> object: + return o.expr.accept(self) + + def visit_reveal_expr(self, o: mypy.nodes.RevealExpr) -> object: + return MISSING + + def visit_super_expr(self, o: mypy.nodes.SuperExpr) -> object: + return MISSING + + def visit_unary_expr(self, o: mypy.nodes.UnaryExpr) -> object: + operand = o.expr.accept(self) + if operand is MISSING: + return MISSING + if o.op == "-": + if isinstance(operand, (int, float, complex)): + return -operand + elif o.op == "+": + if isinstance(operand, (int, float, complex)): + return +operand + elif o.op == "~": + if isinstance(operand, int): + return ~operand + elif o.op == "not": + if isinstance(operand, (bool, int, float, str, bytes)): + return not operand + return MISSING + + def visit_assignment_expr(self, o: mypy.nodes.AssignmentExpr) -> object: + return o.value.accept(self) + + def visit_list_expr(self, o: mypy.nodes.ListExpr) -> object: + items = [item.accept(self) for item in o.items] + if all(item is not MISSING for item in items): + return items + return MISSING + + def visit_dict_expr(self, o: mypy.nodes.DictExpr) -> object: + items = [ + (MISSING if key is None else key.accept(self), value.accept(self)) + for key, value in o.items + ] + if all(key is not MISSING and value is not None for key, value in items): + return dict(items) + return MISSING + + def visit_tuple_expr(self, o: mypy.nodes.TupleExpr) -> object: + items = [item.accept(self) for item in o.items] + if all(item is not MISSING for item in items): + return tuple(items) + return MISSING + + def visit_set_expr(self, o: mypy.nodes.SetExpr) -> object: + items = [item.accept(self) for item in o.items] + if all(item is not MISSING for item in items): + return set(items) + return MISSING + + def visit_index_expr(self, o: mypy.nodes.IndexExpr) -> object: + return MISSING + + def visit_type_application(self, o: mypy.nodes.TypeApplication) -> object: + return MISSING + + def visit_lambda_expr(self, o: mypy.nodes.LambdaExpr) -> object: + return MISSING + + def visit_list_comprehension(self, o: mypy.nodes.ListComprehension) -> object: + return MISSING + + def visit_set_comprehension(self, o: mypy.nodes.SetComprehension) -> object: + return MISSING + + def visit_dictionary_comprehension(self, o: mypy.nodes.DictionaryComprehension) -> object: + return MISSING + + def visit_generator_expr(self, o: mypy.nodes.GeneratorExpr) -> object: + return MISSING + + def visit_slice_expr(self, o: mypy.nodes.SliceExpr) -> object: + return MISSING + + def visit_conditional_expr(self, o: mypy.nodes.ConditionalExpr) -> object: + return MISSING + + def visit_type_var_expr(self, o: mypy.nodes.TypeVarExpr) -> object: + return MISSING + + def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> object: + return MISSING + + def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> object: + return MISSING + + def visit_type_alias_expr(self, o: mypy.nodes.TypeAliasExpr) -> object: + return MISSING + + def visit_namedtuple_expr(self, o: mypy.nodes.NamedTupleExpr) -> object: + return MISSING + + def visit_enum_call_expr(self, o: mypy.nodes.EnumCallExpr) -> object: + return MISSING + + def visit_typeddict_expr(self, o: mypy.nodes.TypedDictExpr) -> object: + return MISSING + + def visit_newtype_expr(self, o: mypy.nodes.NewTypeExpr) -> object: + return MISSING + + def visit__promote_expr(self, o: mypy.nodes.PromoteExpr) -> object: + return MISSING + + def visit_await_expr(self, o: mypy.nodes.AwaitExpr) -> object: + return MISSING + + def visit_temp_node(self, o: mypy.nodes.TempNode) -> object: + return MISSING + + def _verify_arg_default_value( stub_arg: nodes.Argument, runtime_arg: inspect.Parameter ) -> Iterator[str]: @@ -573,6 +749,21 @@ def _verify_arg_default_value( f"has a default value of type {runtime_type}, " f"which is incompatible with stub argument type {stub_type}" ) + if runtime_arg.default is not ... and stub_arg.initializer is not None: + stub_default = stub_arg.initializer.accept(_NodeEvaluator()) + if ( + stub_default is not MISSING + and stub_default is not ... + and ( + stub_default != runtime_arg.default + or type(stub_default) is not type(runtime_arg.default) + ) + ): + yield ( + f'runtime argument "{runtime_arg.name}" ' + f"has a default value of {runtime_arg.default!r}, " + f"which is different from stub argument default {stub_default!r}" + ) else: if stub_arg.kind.is_optional(): yield ( diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 5a6904bfaaf4..e863f4f57568 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -302,7 +302,7 @@ def test_arg_kind(self) -> Iterator[Case]: ) @collect_cases - def test_default_value(self) -> Iterator[Case]: + def test_default_presence(self) -> Iterator[Case]: yield Case( stub="def f1(text: str = ...) -> None: ...", runtime="def f1(text = 'asdf'): pass", @@ -336,6 +336,59 @@ def f6(text: _T = ...) -> None: ... error="f6", ) + @collect_cases + def test_default_value(self) -> Iterator[Case]: + yield Case( + stub="def f1(text: str = 'x') -> None: ...", + runtime="def f1(text = 'y'): pass", + error="f1", + ) + yield Case( + stub='def f2(text: bytes = b"x\'") -> None: ...', + runtime='def f2(text = b"x\'"): pass', + error=None, + ) + yield Case( + stub='def f3(text: bytes = b"y\'") -> None: ...', + runtime='def f3(text = b"x\'"): pass', + error="f3", + ) + yield Case( + stub="def f4(text: object = 1) -> None: ...", + runtime="def f4(text = 1.0): pass", + error="f4", + ) + yield Case( + stub="def f5(text: object = True) -> None: ...", + runtime="def f5(text = 1): pass", + error="f5", + ) + yield Case( + stub="def f6(text: object = True) -> None: ...", + runtime="def f6(text = True): pass", + error=None, + ) + yield Case( + stub="def f7(text: object = not True) -> None: ...", + runtime="def f7(text = False): pass", + error=None, + ) + yield Case( + stub="def f8(text: object = not True) -> None: ...", + runtime="def f8(text = True): pass", + error="f8", + ) + yield Case( + stub="def f9(text: object = {1: 2}) -> None: ...", + runtime="def f9(text = {1: 3}): pass", + error="f9", + ) + yield Case( + stub="def f10(text: object = [1, 2]) -> None: ...", + runtime="def f10(text = [1, 2]): pass", + error=None, + ) + @collect_cases def test_static_class_method(self) -> Iterator[Case]: yield Case(