diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index d6e5ff7b7d..723ecd78e9 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -403,6 +403,10 @@ Socket objects left in an unknown state – possibly open, and possibly closed. The only reasonable thing to do is to close it. + .. method:: is_readable + + Check whether the socket is readable or not. + .. method:: sendfile `Not implemented yet! `__ diff --git a/newsfragments/760.feature.rst b/newsfragments/760.feature.rst new file mode 100644 index 0000000000..dd7ab9d34b --- /dev/null +++ b/newsfragments/760.feature.rst @@ -0,0 +1,2 @@ +Trio sockets have a new method `~trio.socket.SocketType.is_readable` that allows +you to check whether a socket is readable. This is useful for HTTP/1.1 clients. diff --git a/trio/_socket.py b/trio/_socket.py index e866213ffa..0e257849b0 100644 --- a/trio/_socket.py +++ b/trio/_socket.py @@ -1,5 +1,6 @@ import os as _os import sys as _sys +import select import socket as _stdlib_socket from functools import wraps as _wraps @@ -289,7 +290,7 @@ def socket( def _sniff_sockopts_for_fileno(family, type, proto, fileno): """Correct SOCKOPTS for given fileno, falling back to provided values. - + """ # Wrap the raw fileno into a Python socket object # This object might have the wrong metadata, but it lets us easily call getsockopt @@ -478,6 +479,15 @@ def shutdown(self, flag): if flag in [_stdlib_socket.SHUT_WR, _stdlib_socket.SHUT_RDWR]: self._did_shutdown_SHUT_WR = True + def is_readable(self): + # use select.select on Windows, and select.poll everywhere else + if _sys.platform == "win32": + rready, _, _ = select.select([self._sock], [], [], 0) + return bool(rready) + p = select.poll() + p.register(self._sock, select.POLLIN) + return bool(p.poll(0)) + async def wait_writable(self): await _core.wait_writable(self._sock) diff --git a/trio/tests/test_socket.py b/trio/tests/test_socket.py index beab12a999..c03c8bb8ff 100644 --- a/trio/tests/test_socket.py +++ b/trio/tests/test_socket.py @@ -404,6 +404,17 @@ async def test_SocketType_simple_server(address, socket_type): assert await client.recv(1) == b"x" +async def test_SocketType_is_readable(): + a, b = tsocket.socketpair() + with a, b: + assert not a.is_readable() + await b.send(b"x") + await _core.wait_readable(a) + assert a.is_readable() + assert await a.recv(1) == b"x" + assert not a.is_readable() + + # On some macOS systems, getaddrinfo likes to return V4-mapped addresses even # when we *don't* pass AI_V4MAPPED. # https://github.com/python-trio/trio/issues/580