Skip to content

Commit

Permalink
Support additinal attributes in callback protocols (#14084)
Browse files Browse the repository at this point in the history
Fixes #10976
Fixes #10403

This is quite straightforward. Note that we will not allow _arbitrary_
attributes on functions, only those that are defined in
`types.FunctionType` (or more precisely `builtins.function` that is
identical). We have a separate issue for arbitrary attributes
#2087
  • Loading branch information
ilevkivskyi authored Nov 13, 2022
1 parent 47a435f commit 57ce73d
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 29 deletions.
7 changes: 4 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5883,18 +5883,19 @@ def check_subtype(
if (
isinstance(supertype, Instance)
and supertype.type.is_protocol
and isinstance(subtype, (Instance, TupleType, TypedDictType))
and isinstance(subtype, (CallableType, Instance, TupleType, TypedDictType))
):
self.msg.report_protocol_problems(subtype, supertype, context, code=msg.code)
if isinstance(supertype, CallableType) and isinstance(subtype, Instance):
call = find_member("__call__", subtype, subtype, is_operator=True)
if call:
self.msg.note_call(subtype, call, context, code=msg.code)
if isinstance(subtype, (CallableType, Overloaded)) and isinstance(supertype, Instance):
if supertype.type.is_protocol and supertype.type.protocol_members == ["__call__"]:
if supertype.type.is_protocol and "__call__" in supertype.type.protocol_members:
call = find_member("__call__", supertype, subtype, is_operator=True)
assert call is not None
self.msg.note_call(supertype, call, context, code=msg.code)
if not is_subtype(subtype, call, options=self.options):
self.msg.note_call(supertype, call, context, code=msg.code)
self.check_possible_missing_await(subtype, supertype, context)
return False

Expand Down
9 changes: 5 additions & 4 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
original_actual = actual = self.actual
res: list[Constraint] = []
if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol:
if template.type.protocol_members == ["__call__"]:
if "__call__" in template.type.protocol_members:
# Special case: a generic callback protocol
if not any(template == t for t in template.type.inferring):
template.type.inferring.append(template)
Expand All @@ -565,7 +565,6 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
subres = infer_constraints(call, actual, self.direction)
res.extend(subres)
template.type.inferring.pop()
return res
if isinstance(actual, CallableType) and actual.fallback is not None:
if actual.is_type_obj() and template.type.is_protocol:
ret_type = get_proper_type(actual.ret_type)
Expand Down Expand Up @@ -815,7 +814,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
# because some type may be considered a subtype of a protocol
# due to _promote, but still not implement the protocol.
not any(template == t for t in reversed(template.type.inferring))
and mypy.subtypes.is_protocol_implementation(instance, erased)
and mypy.subtypes.is_protocol_implementation(instance, erased, skip=["__call__"])
):
template.type.inferring.append(template)
res.extend(
Expand All @@ -831,7 +830,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
and
# We avoid infinite recursion for structural subtypes also here.
not any(instance == i for i in reversed(instance.type.inferring))
and mypy.subtypes.is_protocol_implementation(erased, instance)
and mypy.subtypes.is_protocol_implementation(erased, instance, skip=["__call__"])
):
instance.type.inferring.append(instance)
res.extend(
Expand Down Expand Up @@ -887,6 +886,8 @@ def infer_constraints_from_protocol_members(
inst = mypy.subtypes.find_member(member, instance, subtype, class_obj=class_obj)
temp = mypy.subtypes.find_member(member, template, subtype)
if inst is None or temp is None:
if member == "__call__":
continue
return [] # See #11020
# The above is safe since at this point we know that 'instance' is a subtype
# of (erased) 'template', therefore it defines all protocol members
Expand Down
27 changes: 16 additions & 11 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,7 @@ def report_protocol_problems(

class_obj = False
is_module = False
skip = []
if isinstance(subtype, TupleType):
if not isinstance(subtype.partial_fallback, Instance):
return
Expand All @@ -1880,20 +1881,22 @@ def report_protocol_problems(
class_obj = True
subtype = subtype.item
elif isinstance(subtype, CallableType):
if not subtype.is_type_obj():
return
ret_type = get_proper_type(subtype.ret_type)
if isinstance(ret_type, TupleType):
ret_type = ret_type.partial_fallback
if not isinstance(ret_type, Instance):
return
class_obj = True
subtype = ret_type
if subtype.is_type_obj():
ret_type = get_proper_type(subtype.ret_type)
if isinstance(ret_type, TupleType):
ret_type = ret_type.partial_fallback
if not isinstance(ret_type, Instance):
return
class_obj = True
subtype = ret_type
else:
subtype = subtype.fallback
skip = ["__call__"]
if subtype.extra_attrs and subtype.extra_attrs.mod_name:
is_module = True

# Report missing members
missing = get_missing_protocol_members(subtype, supertype)
missing = get_missing_protocol_members(subtype, supertype, skip=skip)
if (
missing
and len(missing) < len(supertype.type.protocol_members)
Expand Down Expand Up @@ -2605,13 +2608,15 @@ def variance_string(variance: int) -> str:
return "invariant"


def get_missing_protocol_members(left: Instance, right: Instance) -> list[str]:
def get_missing_protocol_members(left: Instance, right: Instance, skip: list[str]) -> list[str]:
"""Find all protocol members of 'right' that are not implemented
(i.e. completely missing) in 'left'.
"""
assert right.type.is_protocol
missing: list[str] = []
for member in right.type.protocol_members:
if member in skip:
continue
if not find_member(member, left, left):
missing.append(member)
return missing
Expand Down
25 changes: 19 additions & 6 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,13 +678,16 @@ def visit_callable_type(self, left: CallableType) -> bool:
elif isinstance(right, Overloaded):
return all(self._is_subtype(left, item) for item in right.items)
elif isinstance(right, Instance):
if right.type.is_protocol and right.type.protocol_members == ["__call__"]:
# OK, a callable can implement a protocol with a single `__call__` member.
if right.type.is_protocol and "__call__" in right.type.protocol_members:
# OK, a callable can implement a protocol with a `__call__` member.
# TODO: we should probably explicitly exclude self-types in this case.
call = find_member("__call__", right, left, is_operator=True)
assert call is not None
if self._is_subtype(left, call):
return True
if len(right.type.protocol_members) == 1:
return True
if is_protocol_implementation(left.fallback, right, skip=["__call__"]):
return True
if right.type.is_protocol and left.is_type_obj():
ret_type = get_proper_type(left.ret_type)
if isinstance(ret_type, TupleType):
Expand Down Expand Up @@ -792,12 +795,15 @@ def visit_literal_type(self, left: LiteralType) -> bool:
def visit_overloaded(self, left: Overloaded) -> bool:
right = self.right
if isinstance(right, Instance):
if right.type.is_protocol and right.type.protocol_members == ["__call__"]:
if right.type.is_protocol and "__call__" in right.type.protocol_members:
# same as for CallableType
call = find_member("__call__", right, left, is_operator=True)
assert call is not None
if self._is_subtype(left, call):
return True
if len(right.type.protocol_members) == 1:
return True
if is_protocol_implementation(left.fallback, right, skip=["__call__"]):
return True
return self._is_subtype(left.fallback, right)
elif isinstance(right, CallableType):
for item in left.items:
Expand Down Expand Up @@ -938,7 +944,11 @@ def pop_on_exit(stack: list[tuple[T, T]], left: T, right: T) -> Iterator[None]:


def is_protocol_implementation(
left: Instance, right: Instance, proper_subtype: bool = False, class_obj: bool = False
left: Instance,
right: Instance,
proper_subtype: bool = False,
class_obj: bool = False,
skip: list[str] | None = None,
) -> bool:
"""Check whether 'left' implements the protocol 'right'.
Expand All @@ -958,10 +968,13 @@ def f(self) -> A: ...
as well.
"""
assert right.type.is_protocol
if skip is None:
skip = []
# We need to record this check to generate protocol fine-grained dependencies.
TypeState.record_protocol_subtype_check(left.type, right.type)
# nominal subtyping currently ignores '__init__' and '__new__' signatures
members_not_to_check = {"__init__", "__new__"}
members_not_to_check.update(skip)
# Trivial check that circumvents the bug described in issue 9771:
if left.type.is_protocol:
members_right = set(right.type.protocol_members) - members_not_to_check
Expand Down
4 changes: 3 additions & 1 deletion mypy/test/testtypegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mypy import build
from mypy.errors import CompileError
from mypy.modulefinder import BuildSource
from mypy.nodes import NameExpr
from mypy.nodes import NameExpr, TempNode
from mypy.options import Options
from mypy.test.config import test_temp_dir
from mypy.test.data import DataDrivenTestCase, DataSuite
Expand Down Expand Up @@ -54,6 +54,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
# Filter nodes that should be included in the output.
keys = []
for node in nodes:
if isinstance(node, TempNode):
continue
if node.line != -1 and map[node]:
if ignore_node(node) or node in ignored:
continue
Expand Down
47 changes: 47 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -2642,6 +2642,53 @@ reveal_type([b, a]) # N: Revealed type is "builtins.list[def (x: def (__main__.
[builtins fixtures/list.pyi]
[out]

[case testCallbackProtocolFunctionAttributesSubtyping]
from typing import Protocol

class A(Protocol):
__name__: str
def __call__(self) -> str: ...

class B1(Protocol):
__name__: int
def __call__(self) -> str: ...

class B2(Protocol):
__name__: str
def __call__(self) -> int: ...

class B3(Protocol):
__name__: str
extra_stuff: int
def __call__(self) -> str: ...

def f() -> str: ...

reveal_type(f.__name__) # N: Revealed type is "builtins.str"
a: A = f # OK
b1: B1 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B1") \
# N: Following member(s) of "function" have conflicts: \
# N: __name__: expected "int", got "str"
b2: B2 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B2") \
# N: "B2.__call__" has type "Callable[[], int]"
b3: B3 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B3") \
# N: "function" is missing following "B3" protocol member: \
# N: extra_stuff

[case testCallbackProtocolFunctionAttributesInference]
from typing import Protocol, TypeVar, Generic, Tuple

T = TypeVar("T")
S = TypeVar("S", covariant=True)
class A(Protocol[T, S]):
__name__: T
def __call__(self) -> S: ...

def f() -> int: ...
def test(func: A[T, S]) -> Tuple[T, S]: ...
reveal_type(test(f)) # N: Revealed type is "Tuple[builtins.str, builtins.int]"
[builtins fixtures/tuple.pyi]

[case testProtocolsAlwaysABCs]
from typing import Protocol

Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/fine-grained-inspect.test
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Meta(type):
==
{"C": ["meth", "x"]}
{"C": ["meth", "x"], "Meta": ["y"], "type": ["__init__"]}
{}
{"object": ["__init__"]}
{"function": ["__name__"]}
{"function": ["__name__"], "object": ["__init__"]}

[case testInspectDefBasic]
# inspect2: --show=definition foo.py:5:5
Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/fixtures/tuple.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class tuple(Sequence[Tco], Generic[Tco]):
def __rmul__(self, n: int) -> Tuple[Tco, ...]: pass
def __add__(self, x: Tuple[Tco, ...]) -> Tuple[Tco, ...]: pass
def count(self, obj: object) -> int: pass
class function: pass
class function:
__name__: str
class ellipsis: pass
class classmethod: pass

Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/lib-stub/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class float: pass
class str: pass
class bytes: pass

class function: pass
class function:
__name__: str
class ellipsis: pass

from typing import Generic, Sequence, TypeVar
Expand Down

0 comments on commit 57ce73d

Please sign in to comment.