Skip to content

Commit

Permalink
Simpler implementation of awaitmethod
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Oct 18, 2023
1 parent 2119ac4 commit bd6e58f
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 17 deletions.
38 changes: 31 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async code blocked. If the code tries to access the event loop, e.g. by creatin
The `syncfunction()` decorator can be used to automatically wrap an async function
so that it is executed using `await_sync()`:

```python
```pycon
>>> @asynkit.syncfunction
... async def sync_function():
... async def async_function():
Expand Down Expand Up @@ -207,6 +207,7 @@ async def agen():
for v in range(3):
yield v


assert list(aiter_sync(agen())) == [1, 2, 3]
```

Expand Down Expand Up @@ -279,18 +280,14 @@ async def main():
This is similar to `contextvars.Context.run()` but works for async functions. This function is
implemented using [`CoroStart`](#corostart)

## `coro_iter()`

This helper function turns a coroutine function into an iterator. It is primarily
intended to be used by the [`awaitmethod()`](#awaitmethod) function decorator.

## `awaitmethod()`

This decorator turns the decorated method into a `Generator` as required for
`__await__` methods, which must only return `Iterator` objects.
It does so by invoking the `coro_iter()` helper.

This makes it simple to make a class _awaitable_ by decorating an `async`
This makes it simple to make a class instance _awaitable_ by decorating an `async`
`__await__()` method.

```python
Expand Down Expand Up @@ -319,6 +316,31 @@ asyncio.run(main())
```
Unlike a regular _coroutine_ (the result of calling a _coroutine function_), an object with an `__await__` method can potentially be awaited multiple times.

The method can also be a `classmethod` or `staticmethod:`
```python
class Constructor:
@staticmethod
@asynkit.awaitmethod
async def __await__():
await asyncio.sleep(0)
return Constructor()


async def construct():
return await Constructor
```

## `awaitmethod_iter()`

An alternative way of creating an __await__ method, it uses
the `coro_iter()` method to to create a coroutine iterator. It
is provided for completeness.

## `coro_iter()`

This helper function turns a coroutine function into an iterator. It is primarily
intended to be used by the [`awaitmethod_iter()`](#awaitmethod_iter) function decorator.

## `Monitor`

A `Monitor` object can be used to await a coroutine, while listening for _out of band_ messages
Expand Down Expand Up @@ -360,6 +382,7 @@ async def coro(m):
await m.oob("foo")
return "bar"


m = Monitor()
b = m(coro(m))
try:
Expand Down Expand Up @@ -434,6 +457,7 @@ async def stateful_parser(monitor, input_data):
# continue parsing, maye requesting more data
return await parsed_data(monitor, input_data)


m: Monitor[Tuple[Any, bytes]] = Monitor()
initial_data = b""
p = m(stateful_parser(m, b""))
Expand Down Expand Up @@ -766,6 +790,7 @@ be) and therefore it cannot cause collisions with other interrupts.
async def test():
async def task():
await asyncio.sleep(1)

create_pytask(task)
await asyncio.sleep(0)
assert task_is_blocked(task)
Expand All @@ -777,7 +802,6 @@ async def test():
pass
else:
assert False, "never happens"

```

### `create_pytask()`
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ black = "black ."
test = "pytest tests examples"
cov = "pytest --cov=asynkit --cov-report term-missing --cov-branch"
typing = "mypy -p asynkit -p tests -p examples"
blacken-docs = "blacken-docs README.md"
blackall = ["black", "blacken-docs"]
check = ["style", "lint", "typing", "cov"]

[tool.poe.tasks.style]
help = "Validate black code style"
cmd = "black . --check --diff"

[tool.poetry.dependencies]
python = "^3.8"
Expand Down
26 changes: 21 additions & 5 deletions src/asynkit/coroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
__all__ = [
"CoroStart",
"awaitmethod",
"awaitmethod_iter",
"coro_await",
"coro_eager",
"func_eager",
Expand Down Expand Up @@ -483,18 +484,33 @@ def coro_iter(coro: Coroutine[Any, Any, T]) -> Generator[Any, Any, T]:
return cast(T, exc.value)


def awaitmethod(
func: Callable[[S], Coroutine[Any, Any, T]]
) -> Callable[[S], Iterator[T]]:
def awaitmethod(func: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, Iterator[T]]:
"""
Decorator to make a function return an awaitable.
The function must be a coroutine function.
Specifically intended to be used for __await__ methods.
Can also be used for class or static methods.
"""

@functools.wraps(func)
def wrapper(self: S) -> Iterator[T]:
return coro_iter(func(self))
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[T]:
return func(*args, **kwargs).__await__()

return wrapper


def awaitmethod_iter(
func: Callable[P, Coroutine[Any, Any, T]]
) -> Callable[P, Iterator[T]]:
"""
Same as above, but implemented using the coro_iter helper.
Only included for completeness, it is better to use the
builtin coroutine.__await__() method.
"""

@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Iterator[T]:
return coro_iter(func(*args, **kwargs))

return wrapper

Expand Down
48 changes: 43 additions & 5 deletions tests/test_coro.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,29 +854,59 @@ def __init__(self, coro, args=None):
self.args = args or ["bar1", "bar2"]

def __await__(self):
# manually create a coroutine object
return asynkit.coro_iter(self.coro(self.args.pop(0)))

class Awaiter2(Awaiter):
"""
Test the awaitable decorator
Test the awaitmethod decorator
"""

@asynkit.awaitmethod_iter
async def __await__(self):
return await self.coro(self.args.pop(0))

class Awaiter3(Awaiter):
"""
Test the awaitmethod decorator
"""

@asynkit.awaitmethod
async def __await__(self):
return await self.coro(self.args.pop(0))

@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2])
class Awaiter4:
"""
Test the awaitmethod classmethod
"""

@classmethod
@asynkit.awaitmethod
async def __await__(cls) -> str:
return "Awaiter4"

class Awaiter5:
"""
Test the awaitmethod staticmethod
"""

@staticmethod
@asynkit.awaitmethod
async def __await__() -> str:
return "Awaiter5"

@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2, Awaiter3])
async def test_await(self, awaiter):
a = awaiter(self.coroutine1, ["bar1"])
assert await a == "foobar1"

@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2])
@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2, Awaiter3])
async def test_await_again(self, awaiter):
a = awaiter(self.coroutine1, ["bar2", "bar3"])
assert await a == "foobar2"
assert await a == "foobar3" # it can be awaited again

@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2])
@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2, Awaiter3])
async def test_await_exception(self, awaiter):
a = awaiter(self.coroutine2)
with pytest.raises(RuntimeError) as err:
Expand All @@ -886,7 +916,7 @@ async def test_await_exception(self, awaiter):
await a
assert err.value.args[0] == "foobar2"

@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2])
@pytest.mark.parametrize("awaiter", [Awaiter, Awaiter2, Awaiter3])
async def test_await_immediate(self, awaiter):
async def coroutine(arg):
return "coro" + arg
Expand Down Expand Up @@ -940,6 +970,14 @@ def helper():
assert err.value.value == "foo"
c.close()

@pytest.mark.parametrize("awaiter", [Awaiter4, Awaiter5])
async def test_await_static(self, awaiter):
a = awaiter()
if awaiter == self.Awaiter4:
assert await a == "Awaiter4"
else:
assert await a == "Awaiter5"


async def test_async_function():
def sync_method():
Expand Down

0 comments on commit bd6e58f

Please sign in to comment.