Skip to content

Commit

Permalink
Merge pull request #562 from KeepSafe/signals-v2
Browse files Browse the repository at this point in the history
Signals v2
  • Loading branch information
asvetlov committed Oct 16, 2015
2 parents e2eceae + 0c32f47 commit 6dd0eb7
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ vtest: flake .develop
cov cover coverage:
tox

cov-dev: flake develop
cov-dev: develop
@coverage erase
@coverage run -m pytest -s tests
@mv .coverage .coverage.accel
Expand Down
71 changes: 71 additions & 0 deletions aiohttp/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
from itertools import count


class BaseSignal(list):

@asyncio.coroutine
def _send(self, *args, **kwargs):
for receiver in self:
res = receiver(*args, **kwargs)
if asyncio.iscoroutine(res) or isinstance(res, asyncio.Future):
yield from res

def copy(self):
raise NotImplementedError("copy() is forbidden")

def sort(self):
raise NotImplementedError("sort() is forbidden")


class Signal(BaseSignal):
"""Coroutine-based signal implementation.
To connect a callback to a signal, use any list method.
Signals are fired using the :meth:`send` coroutine, which takes named
arguments.
"""

def __init__(self, app):
super().__init__()
self._app = app
klass = self.__class__
self._name = klass.__module__ + ':' + klass.__qualname__
self._pre = app.on_pre_signal
self._post = app.on_post_signal

@asyncio.coroutine
def send(self, *args, **kwargs):
"""
Sends data to all registered receivers.
"""
ordinal = None
debug = self._app._debug
if debug:
ordinal = self._pre.ordinal()
yield from self._pre.send(ordinal, self._name, *args, **kwargs)
yield from self._send(*args, **kwargs)
if debug:
yield from self._post.send(ordinal, self._name, *args, **kwargs)


class DebugSignal(BaseSignal):

@asyncio.coroutine
def send(self, ordinal, name, *args, **kwargs):
yield from self._send(ordinal, name, *args, **kwargs)


class PreSignal(DebugSignal):

def __init__(self):
super().__init__()
self._counter = count(1)

def ordinal(self):
return next(self._counter)


class PostSignal(DebugSignal):
pass
24 changes: 23 additions & 1 deletion aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .web_urldispatcher import * # noqa
from .web_ws import * # noqa
from .protocol import HttpVersion # noqa
from .signals import Signal, PreSignal, PostSignal


import asyncio
Expand Down Expand Up @@ -179,13 +180,14 @@ class Application(dict):

def __init__(self, *, logger=web_logger, loop=None,
router=None, handler_factory=RequestHandlerFactory,
middlewares=()):
middlewares=(), debug=False):
if loop is None:
loop = asyncio.get_event_loop()
if router is None:
router = UrlDispatcher()
assert isinstance(router, AbstractRouter), router

self._debug = debug
self._router = router
self._handler_factory = handler_factory
self._finish_callbacks = []
Expand All @@ -196,6 +198,26 @@ def __init__(self, *, logger=web_logger, loop=None,
assert asyncio.iscoroutinefunction(factory), factory
self._middlewares = list(middlewares)

self._on_pre_signal = PreSignal()
self._on_post_signal = PostSignal()
self._on_response_prepare = Signal(self)

@property
def debug(self):
return self._debug

@property
def on_response_prepare(self):
return self._on_response_prepare

@property
def on_pre_signal(self):
return self._on_pre_signal

@property
def on_post_signal(self):
return self._on_post_signal

@property
def router(self):
return self._router
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/web_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,8 @@ def prepare(self, request):
resp_impl = self._start_pre_check(request)
if resp_impl is not None:
return resp_impl
yield from request.app.on_response_prepare.send(request=request,
response=self)

return self._start(request)

Expand Down
8 changes: 8 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ aiohttp.protocol module
:undoc-members:
:show-inheritance:

aiohttp.signals module
----------------------

.. automodule:: aiohttp.signals
:members:
:undoc-members:
:show-inheritance:

aiohttp.streams module
----------------------

Expand Down
14 changes: 14 additions & 0 deletions docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ StreamResponse

Use :meth:`prepare` instead.

.. warning:: The method doesn't call
:attr:`web.Application.on_response_prepare` signal, use
:meth:`prepare` instead.

.. coroutinemethod:: prepare(request)

:param aiohttp.web.Request request: HTTP request object, that the
Expand All @@ -568,6 +572,9 @@ StreamResponse
Send *HTTP header*. You should not change any header data after
calling this method.

The coroutine calls :attr:`web.Application.on_response_prepare`
signal handlers.

.. versionadded:: 0.18

.. method:: write(data)
Expand Down Expand Up @@ -920,6 +927,13 @@ arbitrary properties for later access from

:ref:`event loop<asyncio-event-loop>` used for processing HTTP requests.

.. attribute:: on_response_prepare

A :class:`~aiohttp.signals.Signal` that is fired at the beginning
of :meth:`StreamResponse.prepare` with parameters *request* and
*response*. It can be used, for example, to add custom headers to each
response before sending.

.. method:: make_handler(**kwargs)

Creates HTTP protocol factory for handling requests.
Expand Down
145 changes: 145 additions & 0 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
from unittest import mock
from aiohttp.multidict import CIMultiDict
from aiohttp.signals import Signal
from aiohttp.web import Application
from aiohttp.web import Request, Response
from aiohttp.protocol import HttpVersion11
from aiohttp.protocol import RawRequestMessage

import pytest


@pytest.fixture
def app(loop):
return Application(loop=loop)


@pytest.fixture
def debug_app(loop):
return Application(loop=loop, debug=True)


def make_request(app, method, path, headers=CIMultiDict()):
message = RawRequestMessage(method, path, HttpVersion11, headers,
False, False)
return request_from_message(message, app)


def request_from_message(message, app):
payload = mock.Mock()
transport = mock.Mock()
reader = mock.Mock()
writer = mock.Mock()
req = Request(app, message, payload,
transport, reader, writer)
return req


def test_add_response_prepare_signal_handler(loop, app):
callback = asyncio.coroutine(lambda request, response: None)
app.on_response_prepare.append(callback)


def test_add_signal_handler_not_a_callable(loop, app):
callback = True
app.on_response_prepare.append(callback)
with pytest.raises(TypeError):
app.on_response_prepare(None, None)


def test_function_signal_dispatch(loop, app):
signal = Signal(app)
kwargs = {'foo': 1, 'bar': 2}

callback_mock = mock.Mock()

@asyncio.coroutine
def callback(**kwargs):
callback_mock(**kwargs)

signal.append(callback)

loop.run_until_complete(signal.send(**kwargs))
callback_mock.assert_called_once_with(**kwargs)


def test_function_signal_dispatch2(loop, app):
signal = Signal(app)
args = {'a', 'b'}
kwargs = {'foo': 1, 'bar': 2}

callback_mock = mock.Mock()

@asyncio.coroutine
def callback(*args, **kwargs):
callback_mock(*args, **kwargs)

signal.append(callback)

loop.run_until_complete(signal.send(*args, **kwargs))
callback_mock.assert_called_once_with(*args, **kwargs)


def test_response_prepare(loop, app):
callback = mock.Mock()

@asyncio.coroutine
def cb(*args, **kwargs):
callback(*args, **kwargs)

app.on_response_prepare.append(cb)

request = make_request(app, 'GET', '/')
response = Response(body=b'')
loop.run_until_complete(response.prepare(request))

callback.assert_called_once_with(request=request,
response=response)


def test_non_coroutine(loop, app):
signal = Signal(app)
kwargs = {'foo': 1, 'bar': 2}

callback = mock.Mock()

signal.append(callback)

loop.run_until_complete(signal.send(**kwargs))
callback.assert_called_once_with(**kwargs)


def test_copy_forbidden(app):
signal = Signal(app)
with pytest.raises(NotImplementedError):
signal.copy()


def test_sort_forbidden(app):
l1 = lambda: None
l2 = lambda: None
l3 = lambda: None
signal = Signal(app)
signal.extend([l1, l2, l3])
with pytest.raises(NotImplementedError):
signal.sort()
assert signal == [l1, l2, l3]


def test_debug_signal(loop, debug_app):
assert debug_app.debug, "Should be True"
signal = Signal(debug_app)

callback = mock.Mock()
pre = mock.Mock()
post = mock.Mock()

signal.append(callback)
debug_app.on_pre_signal.append(pre)
debug_app.on_post_signal.append(post)

loop.run_until_complete(signal.send(1, a=2))
callback.assert_called_once_with(1, a=2)
pre.assert_called_once_with(1, 'aiohttp.signals:Signal', 1, a=2)
post.assert_called_once_with(1, 'aiohttp.signals:Signal', 1, a=2)
4 changes: 3 additions & 1 deletion tests/test_web_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aiohttp.web import Request
from aiohttp.protocol import RawRequestMessage, HttpVersion11

from aiohttp import web
from aiohttp import signals, web


class TestHTTPExceptions(unittest.TestCase):
Expand All @@ -32,6 +32,8 @@ def append(self, data):

def make_request(self, method='GET', path='/', headers=CIMultiDict()):
self.app = mock.Mock()
self.app._debug = False
self.app.on_response_prepare = signals.Signal(self.app)
message = RawRequestMessage(method, path, HttpVersion11, headers,
False, False)
req = Request(self.app, message, self.payload,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_web_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import unittest
from unittest import mock
from aiohttp.signals import Signal
from aiohttp.web import Request
from aiohttp.multidict import MultiDict, CIMultiDict
from aiohttp.protocol import HttpVersion
Expand All @@ -23,6 +24,8 @@ def make_request(self, method, path, headers=CIMultiDict(), *,
if version < HttpVersion(1, 1):
closing = True
self.app = mock.Mock()
self.app._debug = False
self.app.on_response_prepare = Signal(self.app)
message = RawRequestMessage(method, path, version, headers, closing,
False)
self.payload = mock.Mock()
Expand Down
Loading

0 comments on commit 6dd0eb7

Please sign in to comment.