Skip to content

Commit

Permalink
Renamed check_cancelled() to checkpoint() and exposed it internally
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Aug 16, 2020
1 parent 23803be commit 1d548ca
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 36 deletions.
36 changes: 18 additions & 18 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def wrapper():
#

async def sleep(delay: float) -> None:
await check_cancelled()
await checkpoint()
await asyncio.sleep(delay)


Expand Down Expand Up @@ -310,7 +310,7 @@ def shield(self) -> bool:
return self._shield


async def check_cancelled():
async def checkpoint():
try:
cancel_scope = _task_states[current_task()].cancel_scope
except KeyError:
Expand Down Expand Up @@ -509,7 +509,7 @@ def thread_worker():
if not cancelled:
loop.call_soon_threadsafe(queue.put_nowait, (result, None))

await check_cancelled()
await checkpoint()
loop = get_running_loop()
task = current_task()
queue: asyncio.Queue[_Retval_Queue_Type] = asyncio.Queue(1)
Expand Down Expand Up @@ -630,7 +630,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]:


async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int):
await check_cancelled()
await checkpoint()
if shell:
process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout,
stderr=stderr)
Expand Down Expand Up @@ -725,7 +725,7 @@ def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()
if not self._protocol.read_queue and not self._transport.is_closing():
self._protocol.read_event.clear()
self._transport.resume_reading()
Expand All @@ -751,7 +751,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes:

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
self._transport.write(item)
except RuntimeError as exc:
Expand Down Expand Up @@ -826,7 +826,7 @@ def local_address(self) -> SockAddrType:

async def accept(self) -> abc.SocketStream:
with self._accept_guard:
await check_cancelled()
await checkpoint()
try:
client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
except asyncio.CancelledError:
Expand Down Expand Up @@ -880,7 +880,7 @@ def setsockopt(self, level, optname, value, *args) -> None:

async def receive(self) -> Tuple[bytes, IPSockAddrType]:
with self._receive_guard:
await check_cancelled()
await checkpoint()

# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
Expand All @@ -897,7 +897,7 @@ async def receive(self) -> Tuple[bytes, IPSockAddrType]:

async def send(self, item: UDPPacketType) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
Expand Down Expand Up @@ -943,7 +943,7 @@ def setsockopt(self, level, optname, value, *args) -> None:

async def receive(self) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()

# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
Expand All @@ -962,7 +962,7 @@ async def receive(self) -> bytes:

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
Expand Down Expand Up @@ -1029,7 +1029,7 @@ async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Tuple[str, st


async def wait_socket_readable(sock: socket.SocketType) -> None:
await check_cancelled()
await checkpoint()
if _read_events.get(sock):
raise BusyResourceError('reading from') from None

Expand All @@ -1050,7 +1050,7 @@ async def wait_socket_readable(sock: socket.SocketType) -> None:


async def wait_socket_writable(sock: socket.SocketType) -> None:
await check_cancelled()
await checkpoint()
if _write_events.get(sock):
raise BusyResourceError('writing to') from None

Expand Down Expand Up @@ -1082,7 +1082,7 @@ def locked(self) -> bool:
return self._lock.locked()

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._lock.acquire()

async def release(self) -> None:
Expand All @@ -1095,7 +1095,7 @@ def __init__(self, lock: Optional[Lock]):
self._condition = asyncio.Condition(asyncio_lock)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._condition.acquire()

async def release(self) -> None:
Expand All @@ -1111,7 +1111,7 @@ async def notify_all(self):
self._condition.notify_all()

async def wait(self):
await check_cancelled()
await checkpoint()
return await self._condition.wait()


Expand All @@ -1126,7 +1126,7 @@ def is_set(self) -> bool:
return self._event.is_set()

async def wait(self):
await check_cancelled()
await checkpoint()
await self._event.wait()


Expand All @@ -1135,7 +1135,7 @@ def __init__(self, value: int):
self._semaphore = asyncio.Semaphore(value)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._semaphore.acquire()

async def release(self) -> None:
Expand Down
36 changes: 18 additions & 18 deletions src/anyio/_backends/_curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def wrapper():
#

async def sleep(delay: float):
await check_cancelled()
await checkpoint()
await curio.sleep(delay)


Expand Down Expand Up @@ -215,7 +215,7 @@ def shield(self) -> bool:
return self._shield


async def check_cancelled():
async def checkpoint():
try:
cancel_scope = _task_states[await curio.current_task()].cancel_scope
except KeyError:
Expand Down Expand Up @@ -421,7 +421,7 @@ def thread_worker():
if not helper_task.cancelled:
queue.put(None)

await check_cancelled()
await checkpoint()
task = await curio.current_task()
queue = curio.UniversalQueue(maxsize=1)
finish_event = curio.Event()
Expand Down Expand Up @@ -554,7 +554,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]:


async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int):
await check_cancelled()
await checkpoint()
process = curio.subprocess.Popen(command, stdin=stdin, stdout=stdout, stderr=stderr,
shell=shell)
stdin_stream = FileStreamWrapper(process.stdin) if process.stdin else None
Expand Down Expand Up @@ -634,7 +634,7 @@ def remote_address(self) -> Union[IPSockAddrType, str]:

async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()
try:
data = await self._curio_socket.recv(max_bytes)
except (OSError, AttributeError) as exc:
Expand All @@ -647,7 +647,7 @@ async def receive(self, max_bytes: int = 65536) -> bytes:

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
await self._curio_socket.sendall(item)
except (OSError, AttributeError) as exc:
Expand All @@ -664,7 +664,7 @@ def __init__(self, raw_socket: socket.SocketType):

async def accept(self) -> SocketStream:
with self._accept_guard:
await check_cancelled()
await checkpoint()
try:
curio_socket, _addr = await self._curio_socket.accept()
except (OSError, AttributeError) as exc:
Expand All @@ -684,15 +684,15 @@ def __init__(self, curio_socket: curio.io.Socket):

async def receive(self) -> Tuple[bytes, IPSockAddrType]:
with self._receive_guard:
await check_cancelled()
await checkpoint()
try:
return await self._curio_socket.recvfrom(65536)
except (OSError, AttributeError) as exc:
self._convert_socket_error(exc)

async def send(self, item: UDPPacketType) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
await self._curio_socket.sendto(*item)
except (OSError, AttributeError) as exc:
Expand All @@ -714,15 +714,15 @@ def remote_address(self) -> IPSockAddrType:

async def receive(self) -> bytes:
with self._receive_guard:
await check_cancelled()
await checkpoint()
try:
return await self._curio_socket.recv(65536)
except (OSError, AttributeError) as exc:
self._convert_socket_error(exc)

async def send(self, item: bytes) -> None:
with self._send_guard:
await check_cancelled()
await checkpoint()
try:
await self._curio_socket.send(item)
except (OSError, AttributeError) as exc:
Expand Down Expand Up @@ -772,7 +772,7 @@ def getaddrinfo(host: Union[bytearray, bytes, str], port: Union[str, int, None],


async def wait_socket_readable(sock):
await check_cancelled()
await checkpoint()
if _reader_tasks.get(sock):
raise BusyResourceError('reading from') from None

Expand All @@ -789,7 +789,7 @@ async def wait_socket_readable(sock):


async def wait_socket_writable(sock):
await check_cancelled()
await checkpoint()
if _writer_tasks.get(sock):
raise BusyResourceError('writing to') from None

Expand Down Expand Up @@ -817,7 +817,7 @@ def locked(self) -> bool:
return self._lock.locked()

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._lock.acquire()

async def release(self) -> None:
Expand All @@ -830,7 +830,7 @@ def __init__(self, lock: Optional[Lock]):
self._condition = curio.Condition(curio_lock)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._condition.acquire()

async def release(self) -> None:
Expand All @@ -846,7 +846,7 @@ async def notify_all(self):
await self._condition.notify_all()

async def wait(self):
await check_cancelled()
await checkpoint()
return await self._condition.wait()


Expand All @@ -861,7 +861,7 @@ def is_set(self) -> bool:
return self._event.is_set()

async def wait(self):
await check_cancelled()
await checkpoint()
return await self._event.wait()


Expand All @@ -870,7 +870,7 @@ def __init__(self, value: int):
self._semaphore = curio.Semaphore(value)

async def acquire(self) -> None:
await check_cancelled()
await checkpoint()
await self._semaphore.acquire()

async def release(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#

CancelledError = trio.Cancelled
checkpoint = trio.lowlevel.checkpoint


class CancelScope(abc.CancelScope):
Expand Down
6 changes: 6 additions & 0 deletions src/anyio/_core/_lowlevel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from anyio._core._eventloop import get_asynclib


async def checkpoint():
"""Checks for cancellation and allows the scheduler to switch to another task."""
await get_asynclib().checkpoint()

0 comments on commit 1d548ca

Please sign in to comment.