Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-46805: Add low level UDP socket functions to asyncio #31455

Merged
merged 14 commits into from
Mar 13, 2022
Merged
35 changes: 35 additions & 0 deletions Doc/library/asyncio-eventloop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,29 @@ convenient.

.. versionadded:: 3.7

.. coroutinemethod:: loop.sock_recvfrom(sock, bufsize)

Receive a datagram of up to *bufsize* from *sock*. Asynchronous version of
:meth:`socket.recvfrom() <socket.socket.recvfrom>`.

Return a tuple of (received data, remote address).

*sock* must be a non-blocking socket.

.. versionadded:: 3.11

.. coroutinemethod:: loop.sock_recvfrom_into(sock, buf, nbytes=0)

Receive a datagram of up to *nbytes* from *sock* into *buf*.
Asynchronous version of
:meth:`socket.recvfrom_into() <socket.socket.recvfrom_into>`.

Return a tuple of (received data, remote address).
asvetlov marked this conversation as resolved.
Show resolved Hide resolved

*sock* must be a non-blocking socket.

.. versionadded:: 3.11

.. coroutinemethod:: loop.sock_sendall(sock, data)

Send *data* to the *sock* socket. Asynchronous version of
Expand All @@ -940,6 +963,18 @@ convenient.
method, before Python 3.7 it returned a :class:`Future`.
Since Python 3.7, this is an ``async def`` method.

.. coroutinemethod:: loop.sock_sendto(sock, data, address)

Send a datagram from *sock* to *address*.
Asynchronous version of
:meth:`socket.sendto() <socket.socket.sendto>`.

Return the number of bytes sent.

*sock* must be a non-blocking socket.

.. versionadded:: 3.11

.. coroutinemethod:: loop.sock_connect(sock, address)

Connect *sock* to a remote socket at *address*.
Expand Down
9 changes: 9 additions & 0 deletions Doc/library/asyncio-llapi-index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,18 @@ See also the main documentation section about the
* - ``await`` :meth:`loop.sock_recv_into`
- Receive data from the :class:`~socket.socket` into a buffer.

* - ``await`` :meth:`loop.sock_recvfrom`
- Receive a datagram from the :class:`~socket.socket`.

* - ``await`` :meth:`loop.sock_recvfrom_into`
- Receive a datagram from the :class:`~socket.socket` into a buffer.

* - ``await`` :meth:`loop.sock_sendall`
- Send data to the :class:`~socket.socket`.

* - ``await`` :meth:`loop.sock_sendto`
- Send a datagram via the :class:`~socket.socket` to the given address.

* - ``await`` :meth:`loop.sock_connect`
- Connect the :class:`~socket.socket`.

Expand Down
9 changes: 9 additions & 0 deletions Lib/asyncio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,18 @@ async def sock_recv(self, sock, nbytes):
async def sock_recv_into(self, sock, buf):
raise NotImplementedError

async def sock_recvfrom(self, sock, bufsize):
raise NotImplementedError

async def sock_recvfrom_into(self, sock, buf, nbytes=0):
raise NotImplementedError

async def sock_sendall(self, sock, data):
raise NotImplementedError

async def sock_sendto(self, sock, data, address):
raise NotImplementedError

async def sock_connect(self, sock, address):
raise NotImplementedError

Expand Down
12 changes: 12 additions & 0 deletions Lib/asyncio/proactor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,9 +700,21 @@ async def sock_recv(self, sock, n):
async def sock_recv_into(self, sock, buf):
return await self._proactor.recv_into(sock, buf)

async def sock_recvfrom(self, sock, bufsize):
return await self._proactor.recvfrom(sock, bufsize)

async def sock_recvfrom_into(self, sock, buf, nbytes=0):
if not nbytes:
nbytes = len(buf)

return await self._proactor.recvfrom_into(sock, buf, nbytes)

async def sock_sendall(self, sock, data):
return await self._proactor.send(sock, data)

async def sock_sendto(self, sock, data, address):
return await self._proactor.sendto(sock, data, 0, address)

async def sock_connect(self, sock, address):
return await self._proactor.connect(sock, address)

Expand Down
124 changes: 124 additions & 0 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,88 @@ def _sock_recv_into(self, fut, sock, buf):
else:
fut.set_result(nbytes)

async def sock_recvfrom(self, sock, bufsize):
"""Receive a datagram from a datagram socket.

The return value is a tuple of (bytes, address) representing the
datagram received and the address it came from.
The maximum amount of data to be received at once is specified by
nbytes.
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.recvfrom(bufsize)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
handle = self._add_reader(fd, self._sock_recvfrom, fut, sock, bufsize)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd, handle=handle))
return await fut

def _sock_recvfrom(self, fut, sock, bufsize):
# _sock_recvfrom() can add itself as an I/O callback if the operation
# can't be done immediately. Don't use it directly, call
# sock_recvfrom().
if fut.done():
return
try:
result = sock.recvfrom(bufsize)
except (BlockingIOError, InterruptedError):
return # try again next time
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(result)

async def sock_recvfrom_into(self, sock, buf, nbytes=0):
"""Receive data from the socket.

The received data is written into *buf* (a writable buffer).
The return value is a tuple of (number of bytes written, address).
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
if not nbytes:
nbytes = len(buf)

try:
return sock.recvfrom_into(buf, nbytes)
except (BlockingIOError, InterruptedError):
pass
fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
handle = self._add_reader(fd, self._sock_recvfrom_into, fut, sock, buf,
nbytes)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd, handle=handle))
return await fut

def _sock_recvfrom_into(self, fut, sock, buf, bufsize):
# _sock_recv_into() can add itself as an I/O callback if the operation
# can't be done immediately. Don't use it directly, call
# sock_recv_into().
if fut.done():
return
try:
result = sock.recvfrom_into(buf, bufsize)
except (BlockingIOError, InterruptedError):
return # try again next time
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(result)

async def sock_sendall(self, sock, data):
"""Send data to the socket.

Expand Down Expand Up @@ -487,6 +569,48 @@ def _sock_sendall(self, fut, sock, view, pos):
else:
pos[0] = start

async def sock_sendto(self, sock, data, address):
"""Send data to the socket.

The socket must be connected to a remote socket. This method continues
to send data from data until either all data has been sent or an
error occurs. None is returned on success. On error, an exception is
raised, and there is no way to determine how much data, if any, was
successfully processed by the receiving end of the connection.
"""
base_events._check_ssl_socket(sock)
if self._debug and sock.gettimeout() != 0:
raise ValueError("the socket must be non-blocking")
try:
return sock.sendto(data, address)
except (BlockingIOError, InterruptedError):
pass

fut = self.create_future()
fd = sock.fileno()
self._ensure_fd_no_transport(fd)
# use a trick with a list in closure to store a mutable state
handle = self._add_writer(fd, self._sock_sendto, fut, sock, data,
address)
fut.add_done_callback(
functools.partial(self._sock_write_done, fd, handle=handle))
return await fut

def _sock_sendto(self, fut, sock, data, address):
if fut.done():
# Future cancellation can be scheduled on previous loop iteration
return
try:
n = sock.sendto(data, 0, address)
except (BlockingIOError, InterruptedError):
return
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(n)

async def sock_connect(self, sock, address):
"""Connect to a remote socket at address.

Expand Down
20 changes: 20 additions & 0 deletions Lib/asyncio/windows_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,26 @@ def finish_recv(trans, key, ov):

return self._register(ov, conn, finish_recv)

def recvfrom_into(self, conn, buf, flags=0):
self._register_with_iocp(conn)
ov = _overlapped.Overlapped(NULL)
try:
ov.WSARecvFromInto(conn.fileno(), buf, flags)
except BrokenPipeError:
return self._result((b'', None))
asvetlov marked this conversation as resolved.
Show resolved Hide resolved

def finish_recv(trans, key, ov):
try:
return ov.getresult()
except OSError as exc:
if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
_overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args)
else:
raise

return self._register(ov, conn, finish_recv)

def sendto(self, conn, buf, flags=0, addr=None):
self._register_with_iocp(conn)
ov = _overlapped.Overlapped(NULL)
Expand Down
75 changes: 74 additions & 1 deletion Lib/test/test_asyncio/test_sock_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from asyncio import proactor_events
from itertools import cycle, islice
from unittest.mock import patch, Mock
from test.test_asyncio import utils as test_utils
from test import support
from test.support import socket_helper


def tearDownModule():
asyncio.set_event_loop_policy(None)

Expand Down Expand Up @@ -380,6 +380,79 @@ def test_huge_content_recvinto(self):
self.loop.run_until_complete(
self._basetest_huge_content_recvinto(httpd.address))

async def _basetest_datagram_recvfrom(self, server_address):
# Happy path, sock.sendto() returns immediately
data = b'\x01' * 4096
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
await self.loop.sock_sendto(sock, data, server_address)
received_data, from_addr = await self.loop.sock_recvfrom(
sock, 4096)
self.assertEqual(received_data, data)
self.assertEqual(from_addr, server_address)

def test_recvfrom(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom(server_address))

async def _basetest_datagram_recvfrom_into(self, server_address):
# Happy path, sock.sendto() returns immediately
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)

buf = bytearray(4096)
data = b'\x01' * 4096
await self.loop.sock_sendto(sock, data, server_address)
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
sock, buf)
self.assertEqual(num_bytes, 4096)
self.assertEqual(buf, data)
self.assertEqual(from_addr, server_address)

buf = bytearray(8192)
await self.loop.sock_sendto(sock, data, server_address)
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
sock, buf, 4096)
self.assertEqual(num_bytes, 4096)
self.assertEqual(buf[:4096], data[:4096])
self.assertEqual(from_addr, server_address)

def test_recvfrom_into(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into(server_address))

async def _basetest_datagram_sendto_blocking(self, server_address):
# Sad path, sock.sendto() raises BlockingIOError
# This involves patching sock.sendto() to raise BlockingIOError but
# sendto() is not used by the proactor event loop
data = b'\x01' * 4096
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
mock_sock = Mock(sock)
mock_sock.gettimeout = sock.gettimeout
mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
mock_sock.fileno = sock.fileno
self.loop.call_soon(
lambda: setattr(mock_sock, 'sendto', sock.sendto)
)
await self.loop.sock_sendto(mock_sock, data, server_address)

received_data, from_addr = await self.loop.sock_recvfrom(
sock, 4096)
self.assertEqual(received_data, data)
self.assertEqual(from_addr, server_address)

def test_sendto_blocking(self):
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
raise unittest.SkipTest('Not relevant to ProactorEventLoop')

with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_sendto_blocking(server_address))

@socket_helper.skip_unless_bind_unix_socket
def test_unix_sock_client_ops(self):
with test_utils.run_test_unix_server() as httpd:
Expand Down
25 changes: 25 additions & 0 deletions Lib/test/test_asyncio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,31 @@ def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
server_ssl_cls=SSLWSGIServer)


def echo_datagrams(sock):
while True:
data, addr = sock.recvfrom(4096)
if data == b'STOP':
sock.close()
break
else:
sock.sendto(data, addr)


@contextlib.contextmanager
def run_udp_echo_server(*, host='127.0.0.1', port=0):
addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
family, type, proto, _, sockaddr = addr_info[0]
sock = socket.socket(family, type, proto)
sock.bind((host, port))
thread = threading.Thread(target=lambda: echo_datagrams(sock))
thread.start()
try:
yield sock.getsockname()
finally:
sock.sendto(b'STOP', sock.getsockname())
thread.join()


def make_test_protocol(base):
dct = {}
for name in dir(base):
Expand Down
Loading