From b93ed9e2aa83401ed6742be2bb11afbb8f06d6c2 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 00:26:40 -0800 Subject: [PATCH 1/9] Permit referring to some generic types in generic ways Specifically, declare `SendChannel`, `ReceiveChannel`, `Listener`, and `RunVar` to be generic in one type parameter, and also support the `open_memory_channel[T](bufsize)` syntax at runtime. Until trio is able to support typing directly, this change allows users of external stubs to use correctly-typed channels without too many hacks. --- trio/_abc.py | 27 ++++++++++++++++---------- trio/_channel.py | 2 ++ trio/_core/_local.py | 6 ++++-- trio/_core/tests/test_local.py | 2 +- trio/_util.py | 35 ++++++++++++++++++++++++++++++++-- trio/tests/test_util.py | 18 ++++++++++++++++- 6 files changed, 74 insertions(+), 16 deletions(-) diff --git a/trio/_abc.py b/trio/_abc.py index c9957f1c5c..300a511fa8 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from typing import Generic, TypeVar from ._util import aiter_compat from . import _core @@ -483,7 +484,13 @@ async def send_eof(self): """ -class Listener(AsyncResource): +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) +T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True) + + +class Listener(AsyncResource, Generic[T_resource]): """A standard interface for listening for incoming connections. :class:`Listener` objects also implement the :class:`AsyncResource` @@ -494,7 +501,7 @@ class Listener(AsyncResource): __slots__ = () @abstractmethod - async def accept(self): + async def accept(self) -> T_resource: """Wait until an incoming connection arrives, and then return it. Returns: @@ -521,7 +528,7 @@ async def accept(self): """ -class SendChannel(AsyncResource): +class SendChannel(AsyncResource, Generic[T_contra]): """A standard interface for sending Python objects to some receiver. :class:`SendChannel` objects also implement the :class:`AsyncResource` @@ -535,7 +542,7 @@ class SendChannel(AsyncResource): __slots__ = () @abstractmethod - def send_nowait(self, value): + def send_nowait(self, value: T_contra) -> None: """Attempt to send an object through the channel, without blocking. Args: @@ -553,7 +560,7 @@ def send_nowait(self, value): """ @abstractmethod - async def send(self, value): + async def send(self, value: T_contra) -> None: """Attempt to send an object through the channel, blocking if necessary. Args: @@ -570,7 +577,7 @@ async def send(self, value): """ @abstractmethod - def clone(self): + def clone(self: T) -> T: """Clone this send channel object. This returns a new :class:`SendChannel` object, which acts as a @@ -595,7 +602,7 @@ def clone(self): """ -class ReceiveChannel(AsyncResource): +class ReceiveChannel(AsyncResource, Generic[T_co]): """A standard interface for receiving Python objects from some sender. You can iterate over a :class:`ReceiveChannel` using an ``async for`` @@ -618,7 +625,7 @@ class ReceiveChannel(AsyncResource): __slots__ = () @abstractmethod - def receive_nowait(self): + def receive_nowait(self) -> T_co: """Attempt to receive an incoming object, without blocking. Returns: @@ -637,7 +644,7 @@ def receive_nowait(self): """ @abstractmethod - async def receive(self): + async def receive(self) -> T_co: """Attempt to receive an incoming object, blocking if necessary. It's legal for multiple tasks to call :meth:`receive` at the same @@ -658,7 +665,7 @@ async def receive(self): """ @abstractmethod - def clone(self): + def clone(self: T) -> T: """Clone this receive channel object. This returns a new :class:`ReceiveChannel` object, which acts as a diff --git a/trio/_channel.py b/trio/_channel.py index 87c2714a78..8b0a7d7426 100644 --- a/trio/_channel.py +++ b/trio/_channel.py @@ -6,8 +6,10 @@ from . import _core from .abc import SendChannel, ReceiveChannel +from ._util import generic_function +@generic_function def open_memory_channel(max_buffer_size): """Open a channel for passing objects between tasks within a process. diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 777273600f..07f9a07cb3 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,10 +1,12 @@ # Runvar implementations +from typing import Generic, TypeVar from . import _run __all__ = ["RunVar"] +T = TypeVar("T") -class _RunVarToken(object): +class _RunVarToken: _no_value = object() __slots__ = ("_var", "previous_value", "redeemed") @@ -19,7 +21,7 @@ def __init__(self, var, value): self.redeemed = False -class RunVar(object): +class RunVar(Generic[T]): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, diff --git a/trio/_core/tests/test_local.py b/trio/_core/tests/test_local.py index 7f403168ea..8cfd53bfaa 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/tests/test_local.py @@ -6,7 +6,7 @@ # scary runvar tests def test_runvar_smoketest(): t1 = _core.RunVar("test1") - t2 = _core.RunVar("test2", default="catfish") + t2 = _core.RunVar[str]("test2", default="catfish") assert "RunVar" in repr(t1) diff --git a/trio/_util.py b/trio/_util.py index bfe7138191..966a50fbf3 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -4,7 +4,7 @@ import signal import sys import pathlib -from functools import wraps +from functools import wraps, update_wrapper import typing as t import async_generator @@ -22,6 +22,7 @@ "ConflictDetector", "fixup_module_metadata", "fspath", + "indexable_function", ] # Equivalent to the C function raise(), which Python doesn't wrap @@ -177,7 +178,9 @@ def fix_one(obj): obj.__module__ = module_name if isinstance(obj, type): for attr_value in obj.__dict__.values(): - fix_one(attr_value) + # avoid infinite recursion when using typing.Generic + if attr_value is not obj: + fix_one(attr_value) for objname, obj in namespace.items(): if not objname.startswith("_"): # ignore private attributes @@ -242,3 +245,31 @@ def fspath(path) -> t.Union[str, bytes]: if hasattr(os, "fspath"): fspath = os.fspath # noqa + + +class generic_function: + """Decorator that makes a function indexable, to communicate + non-inferrable generic type parameters to a static type checker. + + If you write:: + + @generic_function + def open_memory_channel(max_buffer_size: int) -> Tuple[ + SendChannel[T], ReceiveChannel[T] + ]: ... + + it is valid at runtime to say ``open_memory_channel[bytes](5)``. + This behaves identically to ``open_memory_channel(5)`` at runtime, + and currently won't type-check without a mypy plugin or clever stubs, + but at least it becomes possible to write those. + """ + + def __init__(self, fn): + update_wrapper(self, fn) + self._fn = fn + + def __call__(self, *args, **kwargs): + return self._fn(*args, **kwargs) + + def __getitem__(self, _): + return self diff --git a/trio/tests/test_util.py b/trio/tests/test_util.py index cbfe6fd546..ea7b74deeb 100644 --- a/trio/tests/test_util.py +++ b/trio/tests/test_util.py @@ -7,7 +7,9 @@ from .. import _core from .._threads import run_sync_in_worker_thread -from .._util import signal_raise, ConflictDetector, fspath, is_main_thread +from .._util import ( + signal_raise, ConflictDetector, fspath, is_main_thread, generic_function +) from ..testing import wait_all_tasks_blocked, assert_checkpoints @@ -168,3 +170,17 @@ def not_main_thread(): assert not is_main_thread() await run_sync_in_worker_thread(not_main_thread) + + +def test_generic_function(): + @generic_function + def test_func(arg): + """Look, a docstring!""" + return arg + + assert test_func is test_func[int] is test_func[int, str] + assert test_func(42) == test_func[int](42) == 42 + assert test_func.__doc__ == "Look, a docstring!" + assert test_func.__qualname__ == "test_generic_function..test_func" + assert test_func.__name__ == "test_func" + assert test_func.__module__ == __name__ From ac9ce99011528efec2996cc6a788d907bd8025a8 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 01:35:46 -0800 Subject: [PATCH 2/9] Remove annotations from ABC methods because they look ugly in Sphinx --- docs/source/reference-core.rst | 2 +- trio/_abc.py | 25 +++++++++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index a1d3a83389..9790cd61f9 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1221,7 +1221,7 @@ many cases, you just want to pass objects between different tasks inside a single process, and for that you can use :func:`trio.open_memory_channel`: -.. autofunction:: open_memory_channel +.. autofunction:: open_memory_channel(max_buffer_size) .. note:: If you've used the :mod:`threading` or :mod:`asyncio` modules, you may be familiar with :class:`queue.Queue` or diff --git a/trio/_abc.py b/trio/_abc.py index 300a511fa8..3ca395b7d5 100644 --- a/trio/_abc.py +++ b/trio/_abc.py @@ -484,9 +484,18 @@ async def send_eof(self): """ -T = TypeVar("T") +# The type of object produced by a ReceiveChannel (covariant because +# ReceiveChannel[Derived] can be passed to someone expecting +# ReceiveChannel[Base]) T_co = TypeVar("T_co", covariant=True) + +# The type of object accepted by a SendChannel (contravariant because +# SendChannel[Base] can be passed to someone expecting +# SendChannel[Derived]) T_contra = TypeVar("T_contra", contravariant=True) + +# The type of object produced by a Listener (covariant plus must be +# an AsyncResource) T_resource = TypeVar("T_resource", bound=AsyncResource, covariant=True) @@ -501,7 +510,7 @@ class Listener(AsyncResource, Generic[T_resource]): __slots__ = () @abstractmethod - async def accept(self) -> T_resource: + async def accept(self): """Wait until an incoming connection arrives, and then return it. Returns: @@ -542,7 +551,7 @@ class SendChannel(AsyncResource, Generic[T_contra]): __slots__ = () @abstractmethod - def send_nowait(self, value: T_contra) -> None: + def send_nowait(self, value): """Attempt to send an object through the channel, without blocking. Args: @@ -560,7 +569,7 @@ def send_nowait(self, value: T_contra) -> None: """ @abstractmethod - async def send(self, value: T_contra) -> None: + async def send(self, value): """Attempt to send an object through the channel, blocking if necessary. Args: @@ -577,7 +586,7 @@ async def send(self, value: T_contra) -> None: """ @abstractmethod - def clone(self: T) -> T: + def clone(self): """Clone this send channel object. This returns a new :class:`SendChannel` object, which acts as a @@ -625,7 +634,7 @@ class ReceiveChannel(AsyncResource, Generic[T_co]): __slots__ = () @abstractmethod - def receive_nowait(self) -> T_co: + def receive_nowait(self): """Attempt to receive an incoming object, without blocking. Returns: @@ -644,7 +653,7 @@ def receive_nowait(self) -> T_co: """ @abstractmethod - async def receive(self) -> T_co: + async def receive(self): """Attempt to receive an incoming object, blocking if necessary. It's legal for multiple tasks to call :meth:`receive` at the same @@ -665,7 +674,7 @@ async def receive(self) -> T_co: """ @abstractmethod - def clone(self: T) -> T: + def clone(self): """Clone this receive channel object. This returns a new :class:`ReceiveChannel` object, which acts as a From eff920e20ad84966923241b21d5cefd2d28fad15 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 01:58:14 -0800 Subject: [PATCH 3/9] scattershot attempt to fix 3.5.0 CI --- trio/_core/_local.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 07f9a07cb3..88ea10579f 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -31,10 +31,10 @@ class RunVar(Generic[T]): """ _NO_DEFAULT = object() - __slots__ = ("_name", "_default") + __slots__ = ("_varname", "_default") def __init__(self, name, default=_NO_DEFAULT): - self._name = name + self._varname = name self._default = default def get(self, default=_NO_DEFAULT): @@ -97,4 +97,4 @@ def reset(self, token): token.redeemed = True def __repr__(self): - return ("".format(self._name)) + return ("".format(self._varname)) From 25e82dd42c567b1a20dc60eda2bbf15a2f77aae2 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 10:13:34 -0800 Subject: [PATCH 4/9] Add newsfragment, revert RunVar changes --- newsfragments/908.feature.rst | 7 +++++++ trio/_core/_local.py | 12 +++++------- trio/_core/tests/test_local.py | 2 +- 3 files changed, 13 insertions(+), 8 deletions(-) create mode 100644 newsfragments/908.feature.rst diff --git a/newsfragments/908.feature.rst b/newsfragments/908.feature.rst new file mode 100644 index 0000000000..45379e753c --- /dev/null +++ b/newsfragments/908.feature.rst @@ -0,0 +1,7 @@ +:class:`~trio.abc.SendChannel`, :class:`~trio.abc.ReceiveChannel`, :class:`~trio.abc.Listener`, +and :func:`~trio.open_memory_channel` can now be referenced using a generic type parameter +(the type of object sent over the channel or produced by the listener) using PEP 484 syntax: +``trio.abc.SendChannel[bytes]``, ``trio.abc.Listener[trio.SocketStream]``, +``trio.open_memory_channel[MyMessage](5)``, etc. The added type information does not change +the runtime semantics, but permits better integration with external static type checkers. + diff --git a/trio/_core/_local.py b/trio/_core/_local.py index 88ea10579f..777273600f 100644 --- a/trio/_core/_local.py +++ b/trio/_core/_local.py @@ -1,12 +1,10 @@ # Runvar implementations -from typing import Generic, TypeVar from . import _run __all__ = ["RunVar"] -T = TypeVar("T") -class _RunVarToken: +class _RunVarToken(object): _no_value = object() __slots__ = ("_var", "previous_value", "redeemed") @@ -21,7 +19,7 @@ def __init__(self, var, value): self.redeemed = False -class RunVar(Generic[T]): +class RunVar(object): """The run-local variant of a context variable. :class:`RunVar` objects are similar to context variable objects, @@ -31,10 +29,10 @@ class RunVar(Generic[T]): """ _NO_DEFAULT = object() - __slots__ = ("_varname", "_default") + __slots__ = ("_name", "_default") def __init__(self, name, default=_NO_DEFAULT): - self._varname = name + self._name = name self._default = default def get(self, default=_NO_DEFAULT): @@ -97,4 +95,4 @@ def reset(self, token): token.redeemed = True def __repr__(self): - return ("".format(self._varname)) + return ("".format(self._name)) diff --git a/trio/_core/tests/test_local.py b/trio/_core/tests/test_local.py index 8cfd53bfaa..7f403168ea 100644 --- a/trio/_core/tests/test_local.py +++ b/trio/_core/tests/test_local.py @@ -6,7 +6,7 @@ # scary runvar tests def test_runvar_smoketest(): t1 = _core.RunVar("test1") - t2 = _core.RunVar[str]("test2", default="catfish") + t2 = _core.RunVar("test2", default="catfish") assert "RunVar" in repr(t1) From 77edcbbcc15a113ef9ff9a6e16a1381c6a1ef130 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 10:18:25 -0800 Subject: [PATCH 5/9] add test of instantiating a SendChannel[T] subclass --- trio/tests/test_abc.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index c8af0927d6..15945b1169 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -19,3 +19,15 @@ async def aclose(self): assert myar.record == [] assert myar.record == ["ac"] + + +def test_abc_generics(): + class SomeChannel(tabc.SendChannel[tabc.Stream]): + def send_nowait(self, value): raise RuntimeError + async def send(self, value): raise RuntimeError + def clone(self): raise RuntimeError + async def aclose(self): pass + + channel = SomeChannel() + with pytest.raises(RuntimeError): + channel.send_nowait(None) From 2b8d2b8b6845ef30354c8c98e8b2c827a4ab4937 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 10:19:58 -0800 Subject: [PATCH 6/9] make it slotted --- trio/tests/test_abc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index 15945b1169..a22de9fb0a 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -22,12 +22,13 @@ async def aclose(self): def test_abc_generics(): - class SomeChannel(tabc.SendChannel[tabc.Stream]): + class SlottedChannel(tabc.SendChannel[tabc.Stream]): + __slots__ = () def send_nowait(self, value): raise RuntimeError async def send(self, value): raise RuntimeError def clone(self): raise RuntimeError async def aclose(self): pass - channel = SomeChannel() + channel = SlottedChannel() with pytest.raises(RuntimeError): channel.send_nowait(None) From 22747c293aeae4f525ebd44cf3dc5e4d30cc7b54 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 10:36:03 -0800 Subject: [PATCH 7/9] yapf --- trio/tests/test_abc.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index a22de9fb0a..e1f1db7e3d 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -24,10 +24,18 @@ async def aclose(self): def test_abc_generics(): class SlottedChannel(tabc.SendChannel[tabc.Stream]): __slots__ = () - def send_nowait(self, value): raise RuntimeError - async def send(self, value): raise RuntimeError - def clone(self): raise RuntimeError - async def aclose(self): pass + + def send_nowait(self, value): + raise RuntimeError + + async def send(self, value): + raise RuntimeError + + def clone(self): + raise RuntimeError + + async def aclose(self): + pass channel = SlottedChannel() with pytest.raises(RuntimeError): From 6fda34354f5c29d07a6ce851ab4873fe46be12d8 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Feb 2019 11:37:39 -0800 Subject: [PATCH 8/9] coverage --- trio/tests/test_abc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index e1f1db7e3d..984c725824 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -23,19 +23,19 @@ async def aclose(self): def test_abc_generics(): class SlottedChannel(tabc.SendChannel[tabc.Stream]): - __slots__ = () + __slots__ = ("x",) def send_nowait(self, value): raise RuntimeError async def send(self, value): - raise RuntimeError + raise RuntimeError # pragma: no cover def clone(self): - raise RuntimeError + raise RuntimeError # pragma: no cover async def aclose(self): - pass + pass # pragma: no cover channel = SlottedChannel() with pytest.raises(RuntimeError): From 7c2d2d675041b270ee2ea83e7851e26fa65a74b0 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Wed, 6 Feb 2019 11:55:11 -0800 Subject: [PATCH 9/9] CR comments --- trio/_highlevel_socket.py | 2 +- trio/_ssl.py | 2 +- trio/_util.py | 14 ++++++++++---- trio/tests/test_abc.py | 7 +++++++ 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/trio/_highlevel_socket.py b/trio/_highlevel_socket.py index b6219c0bf9..aadd766880 100644 --- a/trio/_highlevel_socket.py +++ b/trio/_highlevel_socket.py @@ -318,7 +318,7 @@ def getsockopt(self, level, option, buffersize=0): pass -class SocketListener(Listener): +class SocketListener(Listener[SocketStream]): """A :class:`~trio.abc.Listener` that uses a listening socket to accept incoming connections as :class:`SocketStream` objects. diff --git a/trio/_ssl.py b/trio/_ssl.py index 2a3204bd3a..6f62121ccc 100644 --- a/trio/_ssl.py +++ b/trio/_ssl.py @@ -827,7 +827,7 @@ async def wait_send_all_might_not_block(self): await self.transport_stream.wait_send_all_might_not_block() -class SSLListener(Listener): +class SSLListener(Listener[SSLStream]): """A :class:`~trio.abc.Listener` for SSL/TLS-encrypted servers. :class:`SSLListener` wraps around another Listener, and converts diff --git a/trio/_util.py b/trio/_util.py index 966a50fbf3..d8dffa5628 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -22,7 +22,7 @@ "ConflictDetector", "fixup_module_metadata", "fspath", - "indexable_function", + "generic_function", ] # Equivalent to the C function raise(), which Python doesn't wrap @@ -172,15 +172,21 @@ def decorator(func): def fixup_module_metadata(module_name, namespace): + seen_ids = set() + def fix_one(obj): + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + mod = getattr(obj, "__module__", None) if mod is not None and mod.startswith("trio."): obj.__module__ = module_name if isinstance(obj, type): for attr_value in obj.__dict__.values(): - # avoid infinite recursion when using typing.Generic - if attr_value is not obj: - fix_one(attr_value) + fix_one(attr_value) for objname, obj in namespace.items(): if not objname.startswith("_"): # ignore private attributes diff --git a/trio/tests/test_abc.py b/trio/tests/test_abc.py index 984c725824..b55267b24f 100644 --- a/trio/tests/test_abc.py +++ b/trio/tests/test_abc.py @@ -22,6 +22,13 @@ async def aclose(self): def test_abc_generics(): + # Pythons below 3.5.2 had a typing.Generic that would throw + # errors when instantiating or subclassing a parameterized + # version of a class with any __slots__. This is why RunVar + # (which has slots) is not generic. This tests that + # the generic ABCs are fine, because while they are slotted + # they don't actually define any slots. + class SlottedChannel(tabc.SendChannel[tabc.Stream]): __slots__ = ("x",)