diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 58272769..e0790c1c 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -165,7 +165,7 @@ async def wrapper(): # async def sleep(delay: float) -> None: - await check_cancelled() + await checkpoint() await asyncio.sleep(delay) @@ -310,7 +310,7 @@ def shield(self) -> bool: return self._shield -async def check_cancelled(): +async def checkpoint(): try: cancel_scope = _task_states[current_task()].cancel_scope except KeyError: @@ -509,7 +509,7 @@ def thread_worker(): if not cancelled: loop.call_soon_threadsafe(queue.put_nowait, (result, None)) - await check_cancelled() + await checkpoint() loop = get_running_loop() task = current_task() queue: asyncio.Queue[_Retval_Queue_Type] = asyncio.Queue(1) @@ -630,7 +630,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int): - await check_cancelled() + await checkpoint() if shell: process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout, stderr=stderr) @@ -725,7 +725,7 @@ def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() if not self._protocol.read_queue and not self._transport.is_closing(): self._protocol.read_event.clear() self._transport.resume_reading() @@ -751,7 +751,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: self._transport.write(item) except RuntimeError as exc: @@ -826,7 +826,7 @@ def local_address(self) -> SockAddrType: async def accept(self) -> abc.SocketStream: with self._accept_guard: - await check_cancelled() + await checkpoint() try: client_sock, _addr = await self._loop.sock_accept(self._raw_socket) except asyncio.CancelledError: @@ -880,7 +880,7 @@ def setsockopt(self, level, optname, value, *args) -> None: async def receive(self) -> Tuple[bytes, IPSockAddrType]: with self._receive_guard: - await check_cancelled() + await checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): @@ -897,7 +897,7 @@ async def receive(self) -> Tuple[bytes, IPSockAddrType]: async def send(self, item: UDPPacketType) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError @@ -943,7 +943,7 @@ def setsockopt(self, level, optname, value, *args) -> None: async def receive(self) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): @@ -962,7 +962,7 @@ async def receive(self) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError @@ -1029,7 +1029,7 @@ async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Tuple[str, st async def wait_socket_readable(sock: socket.SocketType) -> None: - await check_cancelled() + await checkpoint() if _read_events.get(sock): raise BusyResourceError('reading from') from None @@ -1050,7 +1050,7 @@ async def wait_socket_readable(sock: socket.SocketType) -> None: async def wait_socket_writable(sock: socket.SocketType) -> None: - await check_cancelled() + await checkpoint() if _write_events.get(sock): raise BusyResourceError('writing to') from None @@ -1082,7 +1082,7 @@ def locked(self) -> bool: return self._lock.locked() async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._lock.acquire() async def release(self) -> None: @@ -1095,7 +1095,7 @@ def __init__(self, lock: Optional[Lock]): self._condition = asyncio.Condition(asyncio_lock) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._condition.acquire() async def release(self) -> None: @@ -1111,7 +1111,7 @@ async def notify_all(self): self._condition.notify_all() async def wait(self): - await check_cancelled() + await checkpoint() return await self._condition.wait() @@ -1126,7 +1126,7 @@ def is_set(self) -> bool: return self._event.is_set() async def wait(self): - await check_cancelled() + await checkpoint() await self._event.wait() @@ -1135,7 +1135,7 @@ def __init__(self, value: int): self._semaphore = asyncio.Semaphore(value) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._semaphore.acquire() async def release(self) -> None: diff --git a/src/anyio/_backends/_curio.py b/src/anyio/_backends/_curio.py index 68daf545..9537e740 100644 --- a/src/anyio/_backends/_curio.py +++ b/src/anyio/_backends/_curio.py @@ -72,7 +72,7 @@ async def wrapper(): # async def sleep(delay: float): - await check_cancelled() + await checkpoint() await curio.sleep(delay) @@ -215,7 +215,7 @@ def shield(self) -> bool: return self._shield -async def check_cancelled(): +async def checkpoint(): try: cancel_scope = _task_states[await curio.current_task()].cancel_scope except KeyError: @@ -421,7 +421,7 @@ def thread_worker(): if not helper_task.cancelled: queue.put(None) - await check_cancelled() + await checkpoint() task = await curio.current_task() queue = curio.UniversalQueue(maxsize=1) finish_event = curio.Event() @@ -554,7 +554,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int): - await check_cancelled() + await checkpoint() process = curio.subprocess.Popen(command, stdin=stdin, stdout=stdout, stderr=stderr, shell=shell) stdin_stream = FileStreamWrapper(process.stdin) if process.stdin else None @@ -634,7 +634,7 @@ def remote_address(self) -> Union[IPSockAddrType, str]: async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() try: data = await self._curio_socket.recv(max_bytes) except (OSError, AttributeError) as exc: @@ -647,7 +647,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: await self._curio_socket.sendall(item) except (OSError, AttributeError) as exc: @@ -664,7 +664,7 @@ def __init__(self, raw_socket: socket.SocketType): async def accept(self) -> SocketStream: with self._accept_guard: - await check_cancelled() + await checkpoint() try: curio_socket, _addr = await self._curio_socket.accept() except (OSError, AttributeError) as exc: @@ -684,7 +684,7 @@ def __init__(self, curio_socket: curio.io.Socket): async def receive(self) -> Tuple[bytes, IPSockAddrType]: with self._receive_guard: - await check_cancelled() + await checkpoint() try: return await self._curio_socket.recvfrom(65536) except (OSError, AttributeError) as exc: @@ -692,7 +692,7 @@ async def receive(self) -> Tuple[bytes, IPSockAddrType]: async def send(self, item: UDPPacketType) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: await self._curio_socket.sendto(*item) except (OSError, AttributeError) as exc: @@ -714,7 +714,7 @@ def remote_address(self) -> IPSockAddrType: async def receive(self) -> bytes: with self._receive_guard: - await check_cancelled() + await checkpoint() try: return await self._curio_socket.recv(65536) except (OSError, AttributeError) as exc: @@ -722,7 +722,7 @@ async def receive(self) -> bytes: async def send(self, item: bytes) -> None: with self._send_guard: - await check_cancelled() + await checkpoint() try: await self._curio_socket.send(item) except (OSError, AttributeError) as exc: @@ -772,7 +772,7 @@ def getaddrinfo(host: Union[bytearray, bytes, str], port: Union[str, int, None], async def wait_socket_readable(sock): - await check_cancelled() + await checkpoint() if _reader_tasks.get(sock): raise BusyResourceError('reading from') from None @@ -789,7 +789,7 @@ async def wait_socket_readable(sock): async def wait_socket_writable(sock): - await check_cancelled() + await checkpoint() if _writer_tasks.get(sock): raise BusyResourceError('writing to') from None @@ -817,7 +817,7 @@ def locked(self) -> bool: return self._lock.locked() async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._lock.acquire() async def release(self) -> None: @@ -830,7 +830,7 @@ def __init__(self, lock: Optional[Lock]): self._condition = curio.Condition(curio_lock) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._condition.acquire() async def release(self) -> None: @@ -846,7 +846,7 @@ async def notify_all(self): await self._condition.notify_all() async def wait(self): - await check_cancelled() + await checkpoint() return await self._condition.wait() @@ -861,7 +861,7 @@ def is_set(self) -> bool: return self._event.is_set() async def wait(self): - await check_cancelled() + await checkpoint() return await self._event.wait() @@ -870,7 +870,7 @@ def __init__(self, value: int): self._semaphore = curio.Semaphore(value) async def acquire(self) -> None: - await check_cancelled() + await checkpoint() await self._semaphore.acquire() async def release(self) -> None: diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 8f3d3a54..d9325794 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -54,6 +54,7 @@ # CancelledError = trio.Cancelled +checkpoint = trio.lowlevel.checkpoint class CancelScope(abc.CancelScope): diff --git a/src/anyio/_core/_lowlevel.py b/src/anyio/_core/_lowlevel.py new file mode 100644 index 00000000..1fa87378 --- /dev/null +++ b/src/anyio/_core/_lowlevel.py @@ -0,0 +1,6 @@ +from anyio._core._eventloop import get_asynclib + + +async def checkpoint(): + """Checks for cancellation and allows the scheduler to switch to another task.""" + await get_asynclib().checkpoint()