Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lenient handling of trivial Callable suffixes #15913

Merged
merged 8 commits into from
Sep 14, 2023
4 changes: 3 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,7 +1208,9 @@ def check_func_def(
):
if defn.is_class or defn.name == "__new__":
ref_type = mypy.types.TypeType.make_normalized(ref_type)
erased = get_proper_type(erase_to_bound(arg_type))
# This level of erasure matches the one in checkmember.check_self_arg(),
# better keep these two checks consistent.
erased = get_proper_type(erase_typevars(erase_to_bound(arg_type)))
ilevkivskyi marked this conversation as resolved.
Show resolved Hide resolved
if not is_subtype(ref_type, erased, ignore_type_params=True):
if (
isinstance(erased, Instance)
Expand Down
2 changes: 2 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ def f(self: S) -> T: ...
return functype
else:
selfarg = get_proper_type(item.arg_types[0])
# This level of erasure matches the one in checker.check_func_def(),
# better keep these two checks consistent.
if subtypes.is_subtype(dispatched_arg_type, erase_typevars(erase_to_bound(selfarg))):
new_items.append(item)
elif isinstance(selfarg, ParamSpecType):
Expand Down
3 changes: 3 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,9 @@ def report_protocol_problems(
not is_subtype(subtype, erase_type(supertype), options=self.options)
or not subtype.type.defn.type_vars
or not supertype.type.defn.type_vars
# Always show detailed message for ParamSpec
or subtype.type.has_param_spec_type
or supertype.type.has_param_spec_type
):
type_name = format_type(subtype, self.options, module_names=True)
self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code)
Expand Down
19 changes: 16 additions & 3 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,18 @@ def are_trivial_parameters(param: Parameters | NormalizedCallableType) -> bool:
)


def is_trivial_suffix(param: Parameters | NormalizedCallableType) -> bool:
param_star = param.var_arg()
param_star2 = param.kw_arg()
return (
param.arg_kinds[-2:] == [ARG_STAR, ARG_STAR2]
and param_star is not None
and isinstance(get_proper_type(param_star.typ), AnyType)
and param_star2 is not None
and isinstance(get_proper_type(param_star2.typ), AnyType)
)


def are_parameters_compatible(
left: Parameters | NormalizedCallableType,
right: Parameters | NormalizedCallableType,
Expand All @@ -1540,6 +1552,7 @@ def are_parameters_compatible(
# Treat "def _(*a: Any, **kw: Any) -> X" similarly to "Callable[..., X]"
if are_trivial_parameters(right):
return True
trivial_suffix = is_trivial_suffix(right)

# Match up corresponding arguments and check them for compatibility. In
# every pair (argL, argR) of corresponding arguments from L and R, argL must
Expand Down Expand Up @@ -1570,7 +1583,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
if right_arg is None:
return False
if left_arg is None:
return not allow_partial_overlap
return not allow_partial_overlap and not trivial_suffix
return not is_compat(right_arg.typ, left_arg.typ)

if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
Expand All @@ -1594,7 +1607,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None:
if right_star is not None and not trivial_suffix:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand All @@ -1621,7 +1634,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# Phase 1d: Check kw args. Right has an infinite series of optional named
# arguments. Get all further named args of left, and make sure
# they're more general than the corresponding member in right.
if right_star2 is not None:
if right_star2 is not None and not trivial_suffix:
right_names = {name for name in right.arg_names if name is not None}
left_only_names = set()
for name, kind in zip(left.arg_names, left.arg_kinds):
Expand Down
4 changes: 4 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def supported_self_type(typ: ProperType) -> bool:
"""
if isinstance(typ, TypeType):
return supported_self_type(typ.item)
if isinstance(typ, CallableType):
# Special case: allow class callable instead of Type[...] as cls annotation,
# as well as callable self for callback protocols.
return True
return isinstance(typ, TypeVarType) or (
isinstance(typ, Instance) and typ != fill_typevars(typ.type)
)
Expand Down
31 changes: 31 additions & 0 deletions test-data/unit/check-callable.test
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,34 @@ a: A
a() # E: Missing positional argument "other" in call to "__call__" of "A"
a(a)
a(lambda: None)

[case testCallableSubtypingTrivialSuffix]
from typing import Any, Protocol

class Call(Protocol):
def __call__(self, x: int, *args: Any, **kwargs: Any) -> None: ...

def f1() -> None: ...
a1: Call = f1 # E: Incompatible types in assignment (expression has type "Callable[[], None]", variable has type "Call") \
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
def f2(x: str) -> None: ...
a2: Call = f2 # E: Incompatible types in assignment (expression has type "Callable[[str], None]", variable has type "Call") \
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
def f3(y: int) -> None: ...
a3: Call = f3 # E: Incompatible types in assignment (expression has type "Callable[[int], None]", variable has type "Call") \
# N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]"
def f4(x: int) -> None: ...
a4: Call = f4

def f5(x: int, y: int) -> None: ...
a5: Call = f5

def f6(x: int, y: int = 0) -> None: ...
a6: Call = f6

def f7(x: int, *, y: int) -> None: ...
a7: Call = f7

def f8(x: int, *args: int, **kwargs: str) -> None: ...
a8: Call = f8
[builtins fixtures/tuple.pyi]
12 changes: 6 additions & 6 deletions test-data/unit/check-modules.test
Original file line number Diff line number Diff line change
Expand Up @@ -3193,29 +3193,29 @@ from test1 import aaaa # E: Module "test1" has no attribute "aaaa"
import b
[file a.py]
class Foo:
def frobnicate(self, x, *args, **kwargs): pass
def frobnicate(self, x: str, *args, **kwargs): pass
[file b.py]
from a import Foo
class Bar(Foo):
def frobnicate(self) -> None: pass
[file b.py.2]
from a import Foo
class Bar(Foo):
def frobnicate(self, *args) -> None: pass
def frobnicate(self, *args: int) -> None: pass
[file b.py.3]
from a import Foo
class Bar(Foo):
def frobnicate(self, *args) -> None: pass # type: ignore[override] # I know
def frobnicate(self, *args: int) -> None: pass # type: ignore[override] # I know
[builtins fixtures/dict.pyi]
[out1]
tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo"
tmp/b.py:3: note: Superclass:
tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: Subclass:
tmp/b.py:3: note: def frobnicate(self) -> None
[out2]
tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo"
tmp/b.py:3: note: Superclass:
tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any
tmp/b.py:3: note: Subclass:
tmp/b.py:3: note: def frobnicate(self, *args: Any) -> None
tmp/b.py:3: note: def frobnicate(self, *args: int) -> None
139 changes: 138 additions & 1 deletion test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,12 @@ class A(Protocol[P]):
...

def bar(b: A[P]) -> A[Concatenate[int, P]]:
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") \
# N: Following member(s) of "A[P]" have conflicts: \
# N: Expected: \
# N: def foo(self, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \
# N: Got: \
# N: def foo(self, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingValidNonStrict]
Expand Down Expand Up @@ -1825,6 +1830,138 @@ c: C[int, [int, str], str] # E: Nested parameter specifications are not allowed
reveal_type(c) # N: Revealed type is "__main__.C[Any]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateSelfType]
from typing import Callable
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
class A:
def __init__(self, a_param_1: str) -> None: ...

@classmethod
def add_params(cls: Callable[P, A]) -> Callable[Concatenate[float, P], A]:
def new_constructor(i: float, *args: P.args, **kwargs: P.kwargs) -> A:
return cls(*args, **kwargs)
return new_constructor

@classmethod
def remove_params(cls: Callable[Concatenate[str, P], A]) -> Callable[P, A]:
def new_constructor(*args: P.args, **kwargs: P.kwargs) -> A:
return cls("my_special_str", *args, **kwargs)
return new_constructor

reveal_type(A.add_params()) # N: Revealed type is "def (builtins.float, a_param_1: builtins.str) -> __main__.A"
reveal_type(A.remove_params()) # N: Revealed type is "def () -> __main__.A"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackProtocol]
from typing import Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R", covariant=True)

class Path: ...

class Function(Protocol[P, R]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...

def file_cache(fn: Function[Concatenate[Path, P], R]) -> Function[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "__main__.Function[[*, some_arg: builtins.int], builtins.int]"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateKeywordOnly]
from typing import Callable, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R")

class Path: ...

def file_cache(fn: Callable[Concatenate[Path, P], R]) -> Callable[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "def (*, some_arg: builtins.int) -> builtins.int"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackApply]
from typing import Callable, Protocol
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class FuncType(Protocol[P]):
def __call__(self, x: int, s: str, *args: P.args, **kw_args: P.kwargs) -> str: ...

def forwarder1(fp: FuncType[P], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def forwarder2(fp: Callable[Concatenate[int, str, P], str], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def my_f(x: int, s: str, d: bool) -> str: ...
forwarder1(my_f, True) # OK
forwarder2(my_f, True) # OK
forwarder1(my_f, 1.0) # E: Argument 2 to "forwarder1" has incompatible type "float"; expected "bool"
forwarder2(my_f, 1.0) # E: Argument 2 to "forwarder2" has incompatible type "float"; expected "bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecCallbackProtocolSelf]
from typing import Callable, Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

Params = ParamSpec("Params")
Result = TypeVar("Result", covariant=True)

class FancyMethod(Protocol):
def __call__(self, arg1: int, arg2: str) -> bool: ...
def return_me(self: Callable[Params, Result]) -> Callable[Params, Result]: ...
def return_part(self: Callable[Concatenate[int, Params], Result]) -> Callable[Params, Result]: ...

m: FancyMethod
reveal_type(m.return_me()) # N: Revealed type is "def (arg1: builtins.int, arg2: builtins.str) -> builtins.bool"
reveal_type(m.return_part()) # N: Revealed type is "def (arg2: builtins.str) -> builtins.bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecInferenceCallableAgainstAny]
from typing import Callable, TypeVar, Any
from typing_extensions import ParamSpec, Concatenate

_P = ParamSpec("_P")
_R = TypeVar("_R")

class A: ...
a = A()

def a_func(
func: Callable[Concatenate[A, _P], _R],
) -> Callable[Concatenate[Any, _P], _R]:
def wrapper(__a: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R:
return func(a, *args, **kwargs)
return wrapper

def test(a, *args): ...
x: Any
y: object

a_func(test)
x = a_func(test)
y = a_func(test)
[builtins fixtures/paramspec.pyi]

[case testParamSpecInferenceWithCallbackProtocol]
from typing import Protocol, Callable, ParamSpec

Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/paramspec.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class object:

class function: ...
class ellipsis: ...
class classmethod: ...

class type:
def __init__(self, *a: object) -> None: ...
Expand Down