Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ExitStack, AbstractContextManager and AsyncAbstractContextManager generic in return type of __exit__ #11048

Merged
merged 18 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions stdlib/contextlib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,33 @@ if sys.version_info >= (3, 11):
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_T_io = TypeVar("_T_io", bound=IO[str] | None)
_ExitT_co = TypeVar("_ExitT_co", covariant=True, bound=bool | None, default=bool | None)
_F = TypeVar("_F", bound=Callable[..., Any])
_P = ParamSpec("_P")

_ExitFunc: TypeAlias = Callable[[type[BaseException] | None, BaseException | None, TracebackType | None], bool | None]
_CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any] | _ExitFunc)
_CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any, Any] | _ExitFunc)

@runtime_checkable
class AbstractContextManager(Protocol[_T_co]):
class AbstractContextManager(Protocol[_T_co, _ExitT_co]):
def __enter__(self) -> _T_co: ...
@abstractmethod
def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
) -> bool | None: ...
) -> _ExitT_co: ...

@runtime_checkable
class AbstractAsyncContextManager(Protocol[_T_co]):
class AbstractAsyncContextManager(Protocol[_T_co, _ExitT_co]):
async def __aenter__(self) -> _T_co: ...
@abstractmethod
async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
) -> bool | None: ...
) -> _ExitT_co: ...

class ContextDecorator:
def __call__(self, func: _F) -> _F: ...

class _GeneratorContextManager(AbstractContextManager[_T_co], ContextDecorator):
class _GeneratorContextManager(AbstractContextManager[_T_co, bool | None], ContextDecorator):
# __init__ and all instance attributes are actually inherited from _GeneratorContextManagerBase
# _GeneratorContextManagerBase is more trouble than it's worth to include in the stub; see #6676
def __init__(self, func: Callable[..., Iterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
Expand All @@ -81,7 +82,7 @@ if sys.version_info >= (3, 10):
class AsyncContextDecorator:
def __call__(self, func: _AF) -> _AF: ...

class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], AsyncContextDecorator):
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co, bool | None], AsyncContextDecorator):
# __init__ and these attributes are actually defined in the base class _GeneratorContextManagerBase,
# which is more trouble than it's worth to include in the stub (see #6676)
def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
Expand All @@ -94,7 +95,7 @@ if sys.version_info >= (3, 10):
) -> bool | None: ...

else:
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co]):
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co, bool | None]):
def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
gen: AsyncGenerator[_T_co, Any]
func: Callable[..., AsyncGenerator[_T_co, Any]]
Expand All @@ -111,7 +112,7 @@ class _SupportsClose(Protocol):

_SupportsCloseT = TypeVar("_SupportsCloseT", bound=_SupportsClose)

class closing(AbstractContextManager[_SupportsCloseT]):
class closing(AbstractContextManager[_SupportsCloseT, None]):
def __init__(self, thing: _SupportsCloseT) -> None: ...
def __exit__(self, *exc_info: Unused) -> None: ...

Expand All @@ -121,17 +122,17 @@ if sys.version_info >= (3, 10):

_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)

class aclosing(AbstractAsyncContextManager[_SupportsAcloseT]):
class aclosing(AbstractAsyncContextManager[_SupportsAcloseT, None]):
def __init__(self, thing: _SupportsAcloseT) -> None: ...
async def __aexit__(self, *exc_info: Unused) -> None: ...

class suppress(AbstractContextManager[None]):
class suppress(AbstractContextManager[None, bool]):
def __init__(self, *exceptions: type[BaseException]) -> None: ...
def __exit__(
self, exctype: type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None
) -> bool: ...

class _RedirectStream(AbstractContextManager[_T_io]):
class _RedirectStream(AbstractContextManager[_T_io, None]):
def __init__(self, new_target: _T_io) -> None: ...
def __exit__(
self, exctype: type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None
Expand All @@ -142,27 +143,27 @@ class redirect_stderr(_RedirectStream[_T_io]): ...

# In reality this is a subclass of `AbstractContextManager`;
# see #7961 for why we don't do that in the stub
class ExitStack(metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ...
class ExitStack(Generic[_ExitT_co], metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T, _ExitT_co]) -> _T: ...
def push(self, exit: _CM_EF) -> _CM_EF: ...
def callback(self, callback: Callable[_P, _T], /, *args: _P.args, **kwds: _P.kwargs) -> Callable[_P, _T]: ...
def pop_all(self) -> Self: ...
def close(self) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
) -> bool: ...
) -> _ExitT_co: ...

_ExitCoroFunc: TypeAlias = Callable[
[type[BaseException] | None, BaseException | None, TracebackType | None], Awaitable[bool | None]
]
_ACM_EF = TypeVar("_ACM_EF", bound=AbstractAsyncContextManager[Any] | _ExitCoroFunc)
_ACM_EF = TypeVar("_ACM_EF", bound=AbstractAsyncContextManager[Any, Any] | _ExitCoroFunc)

# In reality this is a subclass of `AbstractAsyncContextManager`;
# see #7961 for why we don't do that in the stub
class AsyncExitStack(metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ...
async def enter_async_context(self, cm: AbstractAsyncContextManager[_T]) -> _T: ...
class AsyncExitStack(Generic[_ExitT_co], metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T, _ExitT_co]) -> _T: ...
async def enter_async_context(self, cm: AbstractAsyncContextManager[_T, _ExitT_co]) -> _T: ...
def push(self, exit: _CM_EF) -> _CM_EF: ...
def push_async_exit(self, exit: _ACM_EF) -> _ACM_EF: ...
def callback(self, callback: Callable[_P, _T], /, *args: _P.args, **kwds: _P.kwargs) -> Callable[_P, _T]: ...
Expand All @@ -177,7 +178,7 @@ class AsyncExitStack(metaclass=abc.ABCMeta):
) -> bool: ...

if sys.version_info >= (3, 10):
class nullcontext(AbstractContextManager[_T], AbstractAsyncContextManager[_T]):
class nullcontext(AbstractContextManager[_T, None], AbstractAsyncContextManager[_T, None]):
enter_result: _T
@overload
def __init__(self: nullcontext[None], enter_result: None = None) -> None: ...
Expand All @@ -189,7 +190,7 @@ if sys.version_info >= (3, 10):
async def __aexit__(self, *exctype: Unused) -> None: ...

else:
class nullcontext(AbstractContextManager[_T]):
class nullcontext(AbstractContextManager[_T, None]):
enter_result: _T
@overload
def __init__(self: nullcontext[None], enter_result: None = None) -> None: ...
Expand All @@ -201,7 +202,7 @@ else:
if sys.version_info >= (3, 11):
_T_fd_or_any_path = TypeVar("_T_fd_or_any_path", bound=FileDescriptorOrPath)

class chdir(AbstractContextManager[None], Generic[_T_fd_or_any_path]):
class chdir(AbstractContextManager[None, None], Generic[_T_fd_or_any_path]):
path: _T_fd_or_any_path
def __init__(self, path: _T_fd_or_any_path) -> None: ...
def __enter__(self) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions stdlib/multiprocessing/synchronize.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Barrier(threading.Barrier):
self, parties: int, action: Callable[[], object] | None = None, timeout: float | None = None, *ctx: BaseContext
) -> None: ...

class Condition(AbstractContextManager[bool]):
class Condition(AbstractContextManager[bool, None]):
def __init__(self, lock: _LockLike | None = None, *, ctx: BaseContext) -> None: ...
def notify(self, n: int = 1) -> None: ...
def notify_all(self) -> None: ...
Expand All @@ -34,7 +34,7 @@ class Event:
def wait(self, timeout: float | None = None) -> bool: ...

# Not part of public API
class SemLock(AbstractContextManager[bool]):
class SemLock(AbstractContextManager[bool, None]):
def acquire(self, block: bool = ..., timeout: float | None = ...) -> bool: ...
def release(self) -> None: ...
def __exit__(
Expand Down
2 changes: 1 addition & 1 deletion stdlib/os/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def replace(
) -> None: ...
def rmdir(path: StrOrBytesPath, *, dir_fd: int | None = None) -> None: ...

class _ScandirIterator(Iterator[DirEntry[AnyStr]], AbstractContextManager[_ScandirIterator[AnyStr]]):
class _ScandirIterator(Iterator[DirEntry[AnyStr]], AbstractContextManager[_ScandirIterator[AnyStr], None]):
def __next__(self) -> DirEntry[AnyStr]: ...
def __exit__(self, *args: Unused) -> None: ...
def close(self) -> None: ...
Expand Down
15 changes: 12 additions & 3 deletions stdlib/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ if sys.version_info >= (3, 11):
if sys.version_info >= (3, 12):
__all__ += ["TypeAliasType", "override"]

ContextManager = AbstractContextManager
AsyncContextManager = AbstractAsyncContextManager

Any = object()

def final(f: _T) -> _T: ...
Expand Down Expand Up @@ -431,6 +428,18 @@ class Generator(Iterator[_YieldT_co], Generic[_YieldT_co, _SendT_contra, _Return
@property
def gi_yieldfrom(self) -> Generator[Any, Any, Any] | None: ...

# NOTE: Technically we would like this to be able to accept a second parameter as well, just
# like it's counterpart in contextlib, however `typing._SpecialGenericAlias` enforces the
# correct number of arguments at runtime, so we would be hiding runtime errors.
@runtime_checkable
class ContextManager(AbstractContextManager[_T_co, bool | None], Protocol[_T_co]): ...

# NOTE: Technically we would like this to be able to accept a second parameter as well, just
# like it's counterpart in contextlib, however `typing._SpecialGenericAlias` enforces the
# correct number of arguments at runtime, so we would be hiding runtime errors.
@runtime_checkable
class AsyncContextManager(AbstractAsyncContextManager[_T_co, bool | None], Protocol[_T_co]): ...

@runtime_checkable
class Awaitable(Protocol[_T_co]):
@abstractmethod
Expand Down
4 changes: 4 additions & 0 deletions tests/stubtest_allowlists/py3_common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,10 @@ typing(_extensions)?\.TextIO\.errors
typing(_extensions)?\.TextIO\.line_buffering
typing(_extensions)?\.TextIO\.newlines

# These are typing._SpecialGenericAlias at runtime, which is not a real type, but it
# behaves like one in most cases
typing(_extensions)?\.(Async)?ContextManager

types.MethodType.__closure__ # read-only but not actually a property; stubtest thinks it doesn't exist.
types.MethodType.__defaults__ # read-only but not actually a property; stubtest thinks it doesn't exist.
types.ModuleType.__dict__ # read-only but not actually a property; stubtest thinks it's a mutable attribute.
Expand Down