Skip to content

Commit

Permalink
Enable flake8-pyi (#2821)
Browse files Browse the repository at this point in the history
* Enable flake8-pyi

* Fix "PYI041: Use `float` instead of `int | float`"

* Rename `getaddrinfoResponse -> GetAddrInfoResponse` to fix `PYI042` error

* Revert `int | float` -> `float` change and ignore the errors
Planning to do another PR soon to make it `int | None` eventually

---------

Co-authored-by: John Litborn <[email protected]>
  • Loading branch information
CoolCat467 and jakkdl authored Oct 28, 2023
1 parent 52b3aea commit 6f9ab95
Show file tree
Hide file tree
Showing 10 changed files with 31 additions and 25 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ select = [
"B", # flake8-bugbear
"YTT", # flake8-2020
"ASYNC", # flake8-async
"PYI", # flake8-pyi
]
extend-ignore = [
'F403', # undefined-local-with-import-star
Expand Down
6 changes: 3 additions & 3 deletions trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def _open_memory_channel(
max_buffer_size: int | float,
max_buffer_size: int | float, # noqa: PYI041
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
"""Open a channel for passing objects between tasks within a process.
Expand Down Expand Up @@ -95,11 +95,11 @@ def _open_memory_channel(
# Need to use Tuple instead of tuple due to CI check running on 3.8
class open_memory_channel(Tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]):
def __new__( # type: ignore[misc] # "must return a subtype"
cls, max_buffer_size: int | float
cls, max_buffer_size: int | float # noqa: PYI041
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int | float):
def __init__(self, max_buffer_size: int | float): # noqa: PYI041
...

else:
Expand Down
2 changes: 1 addition & 1 deletion trio/_core/_multierror.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def __new__( # type: ignore[misc] # mypy says __new__ must return a class inst
# In an earlier version of the code, we didn't define __init__ and
# simply set the `exceptions` attribute directly on the new object.
# However, linters expect attributes to be initialized in __init__.
from_class: type[Self] | type[NonBaseMultiError] = cls
from_class: type[Self | NonBaseMultiError] = cls
if all(isinstance(exc, Exception) for exc in exceptions):
from_class = NonBaseMultiError

Expand Down
8 changes: 5 additions & 3 deletions trio/_core/_parking_lot.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:

await _core.wait_task_rescheduled(abort_fn)

def _pop_several(self, count: int | float) -> Iterator[Task]:
def _pop_several(self, count: int | float) -> Iterator[Task]: # noqa: PYI041
if isinstance(count, float):
if math.isinf(count):
count = len(self._parked)
Expand All @@ -159,7 +159,7 @@ def _pop_several(self, count: int | float) -> Iterator[Task]:
yield task

@_core.enable_ki_protection
def unpark(self, *, count: int | float = 1) -> list[Task]:
def unpark(self, *, count: int | float = 1) -> list[Task]: # noqa: PYI041
"""Unpark one or more tasks.
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
Expand All @@ -180,7 +180,9 @@ def unpark_all(self) -> list[Task]:
return self.unpark(count=len(self))

@_core.enable_ki_protection
def repark(self, new_lot: ParkingLot, *, count: int | float = 1) -> None:
def repark(
self, new_lot: ParkingLot, *, count: int | float = 1 # noqa: PYI041
) -> None:
"""Move parked tasks from one :class:`ParkingLot` object to another.
This dequeues ``count`` tasks from one lot, and requeues them on
Expand Down
2 changes: 1 addition & 1 deletion trio/_dtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def challenge_for(


class _Queue(Generic[_T]):
def __init__(self, incoming_packets_buffer: int | float):
def __init__(self, incoming_packets_buffer: int | float): # noqa: PYI041
self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer)


Expand Down
9 changes: 6 additions & 3 deletions trio/_highlevel_open_tcp_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
# so this is unnecessary -- we can just pass in "infinity" and get the maximum
# that way. (Verified on Windows, Linux, macOS using
# notes-to-self/measure-listen-backlog.py)
def _compute_backlog(backlog: int | float | None) -> int:
def _compute_backlog(backlog: int | float | None) -> int: # noqa: PYI041
# Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are
# missing overflow protection, so we apply our own overflow protection.
# https://github.com/golang/go/issues/5030
Expand All @@ -57,7 +57,10 @@ def _compute_backlog(backlog: int | float | None) -> int:


async def open_tcp_listeners(
port: int, *, host: str | bytes | None = None, backlog: int | float | None = None
port: int,
*,
host: str | bytes | None = None,
backlog: int | float | None = None, # noqa: PYI041
) -> list[trio.SocketListener]:
"""Create :class:`SocketListener` objects to listen for TCP connections.
Expand Down Expand Up @@ -169,7 +172,7 @@ async def serve_tcp(
port: int,
*,
host: str | bytes | None = None,
backlog: int | float | None = None,
backlog: int | float | None = None, # noqa: PYI041
handler_nursery: trio.Nursery | None = None,
task_status: TaskStatus[list[trio.SocketListener]] = trio.TASK_STATUS_IGNORED,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions trio/_highlevel_ssl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def open_ssl_over_tcp_listeners(
*,
host: str | bytes | None = None,
https_compatible: bool = False,
backlog: int | float | None = None,
backlog: int | float | None = None, # noqa: PYI041
) -> list[trio.SSLListener]:
"""Start listening for SSL/TLS-encrypted TCP connections to the given port.
Expand Down Expand Up @@ -101,7 +101,7 @@ async def serve_ssl_over_tcp(
*,
host: str | bytes | None = None,
https_compatible: bool = False,
backlog: int | float | None = None,
backlog: int | float | None = None, # noqa: PYI041
handler_nursery: trio.Nursery | None = None,
task_status: trio.TaskStatus[list[trio.SSLListener]] = trio.TASK_STATUS_IGNORED,
) -> NoReturn:
Expand Down
4 changes: 2 additions & 2 deletions trio/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ._util import NoPublicConstructor, final

if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias


# Only subscriptable in 3.9+
Expand Down Expand Up @@ -215,7 +215,7 @@ def returncode(self) -> int | None:
issue=1104,
instead="run_process or nursery.start(run_process, ...)",
)
async def __aenter__(self) -> Process:
async def __aenter__(self) -> Self:
return self

@deprecated(
Expand Down
4 changes: 2 additions & 2 deletions trio/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class CapacityLimiter(AsyncContextManagerMixin):
"""

# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float):
def __init__(self, total_tokens: int | float): # noqa: PYI041
self._lot = ParkingLot()
self._borrowers: set[Task | object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
Expand Down Expand Up @@ -245,7 +245,7 @@ def total_tokens(self) -> int | float:
return self._total_tokens

@total_tokens.setter
def total_tokens(self, new_total_tokens: int | float) -> None:
def total_tokens(self, new_total_tokens: int | float) -> None: # noqa: PYI041
if not isinstance(new_total_tokens, int) and new_total_tokens != math.inf:
raise TypeError("total_tokens must be an int or math.inf")
if new_total_tokens < 1:
Expand Down
16 changes: 8 additions & 8 deletions trio/_tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@
str,
Union[Tuple[str, int], Tuple[str, int, int, int]],
]
getaddrinfoResponse: TypeAlias = List[GaiTuple]
GetAddrInfoResponse: TypeAlias = List[GaiTuple]
else:
GaiTuple: object
getaddrinfoResponse = object
GetAddrInfoResponse = object

################################################################
# utils
################################################################


class MonkeypatchedGAI:
def __init__(self, orig_getaddrinfo: Callable[..., getaddrinfoResponse]):
def __init__(self, orig_getaddrinfo: Callable[..., GetAddrInfoResponse]):
self._orig_getaddrinfo = orig_getaddrinfo
self._responses: dict[tuple[Any, ...], getaddrinfoResponse | str] = {}
self._responses: dict[tuple[Any, ...], GetAddrInfoResponse | str] = {}
self.record: list[tuple[Any, ...]] = []

# get a normalized getaddrinfo argument tuple
Expand All @@ -54,11 +54,11 @@ def _frozenbind(self, *args: Any, **kwargs: Any) -> tuple[Any, ...]:
return frozenbound

def set(
self, response: getaddrinfoResponse | str, *args: Any, **kwargs: Any
self, response: GetAddrInfoResponse | str, *args: Any, **kwargs: Any
) -> None:
self._responses[self._frozenbind(*args, **kwargs)] = response

def getaddrinfo(self, *args: Any, **kwargs: Any) -> getaddrinfoResponse | str:
def getaddrinfo(self, *args: Any, **kwargs: Any) -> GetAddrInfoResponse | str:
bound = self._frozenbind(*args, **kwargs)
self.record.append(bound)
if bound in self._responses:
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_socket_has_some_reexports() -> None:


async def test_getaddrinfo(monkeygai: MonkeypatchedGAI) -> None:
def check(got: getaddrinfoResponse, expected: getaddrinfoResponse) -> None:
def check(got: GetAddrInfoResponse, expected: GetAddrInfoResponse) -> None:
# win32 returns 0 for the proto field
# musl and glibc have inconsistent handling of the canonical name
# field (https://github.com/python-trio/trio/issues/1499)
Expand All @@ -137,7 +137,7 @@ def interesting_fields(
return (family, type, sockaddr)

def filtered(
gai_list: getaddrinfoResponse,
gai_list: GetAddrInfoResponse,
) -> list[
tuple[
AddressFamily,
Expand Down

0 comments on commit 6f9ab95

Please sign in to comment.