Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ExitStack, AbstractContextManager and AsyncAbstractContextManager generic in return type of __exit__ #11048

Merged
merged 18 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions stdlib/contextlib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,33 @@ if sys.version_info >= (3, 11):
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_T_io = TypeVar("_T_io", bound=IO[str] | None)
_ExitT_co = TypeVar("_ExitT_co", covariant=True, bound=bool | None, default=bool | None)
_F = TypeVar("_F", bound=Callable[..., Any])
_P = ParamSpec("_P")

_ExitFunc: TypeAlias = Callable[[type[BaseException] | None, BaseException | None, TracebackType | None], bool | None]
_CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any] | _ExitFunc)
_CM_EF = TypeVar("_CM_EF", bound=AbstractContextManager[Any, Any] | _ExitFunc)

@runtime_checkable
class AbstractContextManager(Protocol[_T_co]):
class AbstractContextManager(Protocol[_T_co, _ExitT_co]):
def __enter__(self) -> _T_co: ...
@abstractmethod
def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
) -> bool | None: ...
) -> _ExitT_co: ...

@runtime_checkable
class AbstractAsyncContextManager(Protocol[_T_co]):
class AbstractAsyncContextManager(Protocol[_T_co, _ExitT_co]):
async def __aenter__(self) -> _T_co: ...
@abstractmethod
async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
) -> bool | None: ...
) -> _ExitT_co: ...

class ContextDecorator:
def __call__(self, func: _F) -> _F: ...

class _GeneratorContextManager(AbstractContextManager[_T_co], ContextDecorator):
class _GeneratorContextManager(AbstractContextManager[_T_co, bool | None], ContextDecorator):
# __init__ and all instance attributes are actually inherited from _GeneratorContextManagerBase
# _GeneratorContextManagerBase is more trouble than it's worth to include in the stub; see #6676
def __init__(self, func: Callable[..., Iterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
Expand All @@ -81,7 +82,7 @@ if sys.version_info >= (3, 10):
class AsyncContextDecorator:
def __call__(self, func: _AF) -> _AF: ...

class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co], AsyncContextDecorator):
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co, bool | None], AsyncContextDecorator):
# __init__ and these attributes are actually defined in the base class _GeneratorContextManagerBase,
# which is more trouble than it's worth to include in the stub (see #6676)
def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
Expand All @@ -94,7 +95,7 @@ if sys.version_info >= (3, 10):
) -> bool | None: ...

else:
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co]):
class _AsyncGeneratorContextManager(AbstractAsyncContextManager[_T_co, bool | None]):
def __init__(self, func: Callable[..., AsyncIterator[_T_co]], args: tuple[Any, ...], kwds: dict[str, Any]) -> None: ...
gen: AsyncGenerator[_T_co, Any]
func: Callable[..., AsyncGenerator[_T_co, Any]]
Expand All @@ -111,7 +112,7 @@ class _SupportsClose(Protocol):

_SupportsCloseT = TypeVar("_SupportsCloseT", bound=_SupportsClose)

class closing(AbstractContextManager[_SupportsCloseT]):
class closing(AbstractContextManager[_SupportsCloseT, None]):
def __init__(self, thing: _SupportsCloseT) -> None: ...
def __exit__(self, *exc_info: Unused) -> None: ...

Expand All @@ -121,17 +122,17 @@ if sys.version_info >= (3, 10):

_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)

class aclosing(AbstractAsyncContextManager[_SupportsAcloseT]):
class aclosing(AbstractAsyncContextManager[_SupportsAcloseT, None]):
def __init__(self, thing: _SupportsAcloseT) -> None: ...
async def __aexit__(self, *exc_info: Unused) -> None: ...

class suppress(AbstractContextManager[None]):
class suppress(AbstractContextManager[None, bool]):
def __init__(self, *exceptions: type[BaseException]) -> None: ...
def __exit__(
self, exctype: type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None
) -> bool: ...

class _RedirectStream(AbstractContextManager[_T_io]):
class _RedirectStream(AbstractContextManager[_T_io, None]):
def __init__(self, new_target: _T_io) -> None: ...
def __exit__(
self, exctype: type[BaseException] | None, excinst: BaseException | None, exctb: TracebackType | None
Expand All @@ -142,27 +143,27 @@ class redirect_stderr(_RedirectStream[_T_io]): ...

# In reality this is a subclass of `AbstractContextManager`;
# see #7961 for why we don't do that in the stub
class ExitStack(metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ...
class ExitStack(Generic[_ExitT_co], metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T, _ExitT_co]) -> _T: ...
def push(self, exit: _CM_EF) -> _CM_EF: ...
def callback(self, callback: Callable[_P, _T], /, *args: _P.args, **kwds: _P.kwargs) -> Callable[_P, _T]: ...
def pop_all(self) -> Self: ...
def close(self) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, /
) -> bool: ...
) -> _ExitT_co: ...

_ExitCoroFunc: TypeAlias = Callable[
[type[BaseException] | None, BaseException | None, TracebackType | None], Awaitable[bool | None]
]
_ACM_EF = TypeVar("_ACM_EF", bound=AbstractAsyncContextManager[Any] | _ExitCoroFunc)
_ACM_EF = TypeVar("_ACM_EF", bound=AbstractAsyncContextManager[Any, Any] | _ExitCoroFunc)

# In reality this is a subclass of `AbstractAsyncContextManager`;
# see #7961 for why we don't do that in the stub
class AsyncExitStack(metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T]) -> _T: ...
async def enter_async_context(self, cm: AbstractAsyncContextManager[_T]) -> _T: ...
class AsyncExitStack(Generic[_ExitT_co], metaclass=abc.ABCMeta):
def enter_context(self, cm: AbstractContextManager[_T, _ExitT_co]) -> _T: ...
async def enter_async_context(self, cm: AbstractAsyncContextManager[_T, _ExitT_co]) -> _T: ...
def push(self, exit: _CM_EF) -> _CM_EF: ...
def push_async_exit(self, exit: _ACM_EF) -> _ACM_EF: ...
def callback(self, callback: Callable[_P, _T], /, *args: _P.args, **kwds: _P.kwargs) -> Callable[_P, _T]: ...
Expand All @@ -177,7 +178,7 @@ class AsyncExitStack(metaclass=abc.ABCMeta):
) -> bool: ...

if sys.version_info >= (3, 10):
class nullcontext(AbstractContextManager[_T], AbstractAsyncContextManager[_T]):
class nullcontext(AbstractContextManager[_T, None], AbstractAsyncContextManager[_T, None]):
enter_result: _T
@overload
def __init__(self: nullcontext[None], enter_result: None = None) -> None: ...
Expand All @@ -189,7 +190,7 @@ if sys.version_info >= (3, 10):
async def __aexit__(self, *exctype: Unused) -> None: ...

else:
class nullcontext(AbstractContextManager[_T]):
class nullcontext(AbstractContextManager[_T, None]):
enter_result: _T
@overload
def __init__(self: nullcontext[None], enter_result: None = None) -> None: ...
Expand All @@ -201,7 +202,7 @@ else:
if sys.version_info >= (3, 11):
_T_fd_or_any_path = TypeVar("_T_fd_or_any_path", bound=FileDescriptorOrPath)

class chdir(AbstractContextManager[None], Generic[_T_fd_or_any_path]):
class chdir(AbstractContextManager[None, None], Generic[_T_fd_or_any_path]):
path: _T_fd_or_any_path
def __init__(self, path: _T_fd_or_any_path) -> None: ...
def __enter__(self) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions stdlib/multiprocessing/synchronize.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Barrier(threading.Barrier):
self, parties: int, action: Callable[[], object] | None = None, timeout: float | None = None, *ctx: BaseContext
) -> None: ...

class Condition(AbstractContextManager[bool]):
class Condition(AbstractContextManager[bool, None]):
def __init__(self, lock: _LockLike | None = None, *, ctx: BaseContext) -> None: ...
def notify(self, n: int = 1) -> None: ...
def notify_all(self) -> None: ...
Expand All @@ -34,7 +34,7 @@ class Event:
def wait(self, timeout: float | None = None) -> bool: ...

# Not part of public API
class SemLock(AbstractContextManager[bool]):
class SemLock(AbstractContextManager[bool, None]):
def acquire(self, block: bool = ..., timeout: float | None = ...) -> bool: ...
def release(self) -> None: ...
def __exit__(
Expand Down
2 changes: 1 addition & 1 deletion stdlib/os/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def replace(
) -> None: ...
def rmdir(path: StrOrBytesPath, *, dir_fd: int | None = None) -> None: ...

class _ScandirIterator(Iterator[DirEntry[AnyStr]], AbstractContextManager[_ScandirIterator[AnyStr]]):
class _ScandirIterator(Iterator[DirEntry[AnyStr]], AbstractContextManager[_ScandirIterator[AnyStr], None]):
def __next__(self) -> DirEntry[AnyStr]: ...
def __exit__(self, *args: Unused) -> None: ...
def close(self) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions stdlib/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ if sys.version_info >= (3, 11):
if sys.version_info >= (3, 12):
__all__ += ["TypeAliasType", "override"]

ContextManager = AbstractContextManager
AsyncContextManager = AbstractAsyncContextManager
ContextManager = AbstractContextManager[_T_co, bool | None] # noqa: Y026
AsyncContextManager = AbstractAsyncContextManager[_T_co, bool | None] # noqa: Y026
Daverball marked this conversation as resolved.
Show resolved Hide resolved

# This itself is only available during type checking
def type_check_only(func_or_cls: _F) -> _F: ...
Expand Down
Loading