Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support additinal attributes in callback protocols #14084

Merged
merged 2 commits into from
Nov 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5875,18 +5875,19 @@ def check_subtype(
if (
isinstance(supertype, Instance)
and supertype.type.is_protocol
and isinstance(subtype, (Instance, TupleType, TypedDictType))
and isinstance(subtype, (CallableType, Instance, TupleType, TypedDictType))
):
self.msg.report_protocol_problems(subtype, supertype, context, code=msg.code)
if isinstance(supertype, CallableType) and isinstance(subtype, Instance):
call = find_member("__call__", subtype, subtype, is_operator=True)
if call:
self.msg.note_call(subtype, call, context, code=msg.code)
if isinstance(subtype, (CallableType, Overloaded)) and isinstance(supertype, Instance):
if supertype.type.is_protocol and supertype.type.protocol_members == ["__call__"]:
if supertype.type.is_protocol and "__call__" in supertype.type.protocol_members:
call = find_member("__call__", supertype, subtype, is_operator=True)
assert call is not None
self.msg.note_call(supertype, call, context, code=msg.code)
if not is_subtype(subtype, call, options=self.options):
self.msg.note_call(supertype, call, context, code=msg.code)
self.check_possible_missing_await(subtype, supertype, context)
return False

Expand Down
9 changes: 5 additions & 4 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
original_actual = actual = self.actual
res: list[Constraint] = []
if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol:
if template.type.protocol_members == ["__call__"]:
if "__call__" in template.type.protocol_members:
# Special case: a generic callback protocol
if not any(template == t for t in template.type.inferring):
template.type.inferring.append(template)
Expand All @@ -565,7 +565,6 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
subres = infer_constraints(call, actual, self.direction)
res.extend(subres)
template.type.inferring.pop()
return res
if isinstance(actual, CallableType) and actual.fallback is not None:
if actual.is_type_obj() and template.type.is_protocol:
ret_type = get_proper_type(actual.ret_type)
Expand Down Expand Up @@ -815,7 +814,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
# because some type may be considered a subtype of a protocol
# due to _promote, but still not implement the protocol.
not any(template == t for t in reversed(template.type.inferring))
and mypy.subtypes.is_protocol_implementation(instance, erased)
and mypy.subtypes.is_protocol_implementation(instance, erased, skip=["__call__"])
):
template.type.inferring.append(template)
res.extend(
Expand All @@ -831,7 +830,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
and
# We avoid infinite recursion for structural subtypes also here.
not any(instance == i for i in reversed(instance.type.inferring))
and mypy.subtypes.is_protocol_implementation(erased, instance)
and mypy.subtypes.is_protocol_implementation(erased, instance, skip=["__call__"])
):
instance.type.inferring.append(instance)
res.extend(
Expand Down Expand Up @@ -887,6 +886,8 @@ def infer_constraints_from_protocol_members(
inst = mypy.subtypes.find_member(member, instance, subtype, class_obj=class_obj)
temp = mypy.subtypes.find_member(member, template, subtype)
if inst is None or temp is None:
if member == "__call__":
continue
return [] # See #11020
# The above is safe since at this point we know that 'instance' is a subtype
# of (erased) 'template', therefore it defines all protocol members
Expand Down
27 changes: 16 additions & 11 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,7 @@ def report_protocol_problems(

class_obj = False
is_module = False
skip = []
if isinstance(subtype, TupleType):
if not isinstance(subtype.partial_fallback, Instance):
return
Expand All @@ -1880,20 +1881,22 @@ def report_protocol_problems(
class_obj = True
subtype = subtype.item
elif isinstance(subtype, CallableType):
if not subtype.is_type_obj():
return
ret_type = get_proper_type(subtype.ret_type)
if isinstance(ret_type, TupleType):
ret_type = ret_type.partial_fallback
if not isinstance(ret_type, Instance):
return
class_obj = True
subtype = ret_type
if subtype.is_type_obj():
ret_type = get_proper_type(subtype.ret_type)
if isinstance(ret_type, TupleType):
ret_type = ret_type.partial_fallback
if not isinstance(ret_type, Instance):
return
class_obj = True
subtype = ret_type
else:
subtype = subtype.fallback
skip = ["__call__"]
if subtype.extra_attrs and subtype.extra_attrs.mod_name:
is_module = True

# Report missing members
missing = get_missing_protocol_members(subtype, supertype)
missing = get_missing_protocol_members(subtype, supertype, skip=skip)
if (
missing
and len(missing) < len(supertype.type.protocol_members)
Expand Down Expand Up @@ -2605,13 +2608,15 @@ def variance_string(variance: int) -> str:
return "invariant"


def get_missing_protocol_members(left: Instance, right: Instance) -> list[str]:
def get_missing_protocol_members(left: Instance, right: Instance, skip: list[str]) -> list[str]:
"""Find all protocol members of 'right' that are not implemented
(i.e. completely missing) in 'left'.
"""
assert right.type.is_protocol
missing: list[str] = []
for member in right.type.protocol_members:
if member in skip:
continue
if not find_member(member, left, left):
missing.append(member)
return missing
Expand Down
25 changes: 19 additions & 6 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,13 +678,16 @@ def visit_callable_type(self, left: CallableType) -> bool:
elif isinstance(right, Overloaded):
return all(self._is_subtype(left, item) for item in right.items)
elif isinstance(right, Instance):
if right.type.is_protocol and right.type.protocol_members == ["__call__"]:
# OK, a callable can implement a protocol with a single `__call__` member.
if right.type.is_protocol and "__call__" in right.type.protocol_members:
# OK, a callable can implement a protocol with a `__call__` member.
# TODO: we should probably explicitly exclude self-types in this case.
call = find_member("__call__", right, left, is_operator=True)
assert call is not None
if self._is_subtype(left, call):
return True
if len(right.type.protocol_members) == 1:
return True
if is_protocol_implementation(left.fallback, right, skip=["__call__"]):
return True
if right.type.is_protocol and left.is_type_obj():
ret_type = get_proper_type(left.ret_type)
if isinstance(ret_type, TupleType):
Expand Down Expand Up @@ -792,12 +795,15 @@ def visit_literal_type(self, left: LiteralType) -> bool:
def visit_overloaded(self, left: Overloaded) -> bool:
right = self.right
if isinstance(right, Instance):
if right.type.is_protocol and right.type.protocol_members == ["__call__"]:
if right.type.is_protocol and "__call__" in right.type.protocol_members:
# same as for CallableType
call = find_member("__call__", right, left, is_operator=True)
assert call is not None
if self._is_subtype(left, call):
return True
if len(right.type.protocol_members) == 1:
return True
if is_protocol_implementation(left.fallback, right, skip=["__call__"]):
return True
return self._is_subtype(left.fallback, right)
elif isinstance(right, CallableType):
for item in left.items:
Expand Down Expand Up @@ -938,7 +944,11 @@ def pop_on_exit(stack: list[tuple[T, T]], left: T, right: T) -> Iterator[None]:


def is_protocol_implementation(
left: Instance, right: Instance, proper_subtype: bool = False, class_obj: bool = False
left: Instance,
right: Instance,
proper_subtype: bool = False,
class_obj: bool = False,
skip: list[str] | None = None,
) -> bool:
"""Check whether 'left' implements the protocol 'right'.

Expand All @@ -958,10 +968,13 @@ def f(self) -> A: ...
as well.
"""
assert right.type.is_protocol
if skip is None:
skip = []
# We need to record this check to generate protocol fine-grained dependencies.
TypeState.record_protocol_subtype_check(left.type, right.type)
# nominal subtyping currently ignores '__init__' and '__new__' signatures
members_not_to_check = {"__init__", "__new__"}
members_not_to_check.update(skip)
# Trivial check that circumvents the bug described in issue 9771:
if left.type.is_protocol:
members_right = set(right.type.protocol_members) - members_not_to_check
Expand Down
4 changes: 3 additions & 1 deletion mypy/test/testtypegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mypy import build
from mypy.errors import CompileError
from mypy.modulefinder import BuildSource
from mypy.nodes import NameExpr
from mypy.nodes import NameExpr, TempNode
from mypy.options import Options
from mypy.test.config import test_temp_dir
from mypy.test.data import DataDrivenTestCase, DataSuite
Expand Down Expand Up @@ -54,6 +54,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
# Filter nodes that should be included in the output.
keys = []
for node in nodes:
if isinstance(node, TempNode):
continue
if node.line != -1 and map[node]:
if ignore_node(node) or node in ignored:
continue
Expand Down
47 changes: 47 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -2642,6 +2642,53 @@ reveal_type([b, a]) # N: Revealed type is "builtins.list[def (x: def (__main__.
[builtins fixtures/list.pyi]
[out]

[case testCallbackProtocolFunctionAttributesSubtyping]
from typing import Protocol

class A(Protocol):
__name__: str
def __call__(self) -> str: ...

class B1(Protocol):
__name__: int
def __call__(self) -> str: ...

class B2(Protocol):
__name__: str
def __call__(self) -> int: ...

class B3(Protocol):
__name__: str
extra_stuff: int
def __call__(self) -> str: ...

def f() -> str: ...

reveal_type(f.__name__) # N: Revealed type is "builtins.str"
a: A = f # OK
b1: B1 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B1") \
# N: Following member(s) of "function" have conflicts: \
# N: __name__: expected "int", got "str"
b2: B2 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B2") \
# N: "B2.__call__" has type "Callable[[], int]"
b3: B3 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B3") \
# N: "function" is missing following "B3" protocol member: \
# N: extra_stuff

[case testCallbackProtocolFunctionAttributesInference]
from typing import Protocol, TypeVar, Generic, Tuple

T = TypeVar("T")
S = TypeVar("S", covariant=True)
class A(Protocol[T, S]):
__name__: T
def __call__(self) -> S: ...

def f() -> int: ...
def test(func: A[T, S]) -> Tuple[T, S]: ...
reveal_type(test(f)) # N: Revealed type is "Tuple[builtins.str, builtins.int]"
[builtins fixtures/tuple.pyi]

[case testProtocolsAlwaysABCs]
from typing import Protocol

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/fine-grained-inspect.test
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Meta(type):
==
{"C": ["meth", "x"]}
{"C": ["meth", "x"], "Meta": ["y"], "type": ["__init__"]}
{}
{"object": ["__init__"]}
{"function": ["__name__"]}
{"function": ["__name__"], "object": ["__init__"]}

[case testInspectDefBasic]
# inspect2: --show=definition foo.py:5:5
Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class tuple(Sequence[Tco], Generic[Tco]):
def __rmul__(self, n: int) -> Tuple[Tco, ...]: pass
def __add__(self, x: Tuple[Tco, ...]) -> Tuple[Tco, ...]: pass
def count(self, obj: object) -> int: pass
class function: pass
class function:
__name__: str
class ellipsis: pass
class classmethod: pass

Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/lib-stub/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class float: pass
class str: pass
class bytes: pass

class function: pass
class function:
__name__: str
class ellipsis: pass

from typing import Generic, Sequence, TypeVar
Expand Down