diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 0b9b31aa9ea..b4b4533b6bf 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -154,6 +154,9 @@ def send_bytes(self, data): type(data)) self._writer.send(data, binary=True) + def send_json(self, data, *, dumps=json.dumps): + self.send_str(dumps(data)) + @asyncio.coroutine def write_eof(self): if self._eof_sent: @@ -280,12 +283,8 @@ def receive_bytes(self): @asyncio.coroutine def receive_json(self, *, loads=json.loads): - msg = yield from self.receive() - if msg.tp != MsgType.text: - raise TypeError( - "Received message {}:{!r} is not str".format(msg.tp, msg.data) - ) - return msg.json(loads=loads) + data = yield from self.receive_str() + return loads(data) def write(self, data): raise RuntimeError("Cannot call .write() for websocket") diff --git a/aiohttp/websocket_client.py b/aiohttp/websocket_client.py index e3e8c12a94a..c197c9b2a03 100644 --- a/aiohttp/websocket_client.py +++ b/aiohttp/websocket_client.py @@ -3,6 +3,7 @@ import asyncio import sys +import json from enum import IntEnum from .websocket import Message @@ -88,6 +89,9 @@ def send_bytes(self, data): type(data)) self._writer.send(data, binary=True) + def send_json(self, data, *, dumps=json.dumps): + self.send_str(dumps(data)) + @asyncio.coroutine def close(self, *, code=1000, message=b''): if not self._closed: @@ -171,6 +175,28 @@ def receive(self): finally: self._waiting = False + @asyncio.coroutine + def receive_str(self): + msg = yield from self.receive() + if msg.tp != MsgType.text: + raise TypeError( + "Received message {}:{!r} is not str".format(msg.tp, msg.data)) + return msg.data + + @asyncio.coroutine + def receive_bytes(self): + msg = yield from self.receive() + if msg.tp != MsgType.binary: + raise TypeError( + "Received message {}:{!r} is not bytes".format(msg.tp, + msg.data)) + return msg.data + + @asyncio.coroutine + def receive_json(self, *, loads=json.loads): + data = yield from self.receive_str() + return loads(data) + if PY_35: @asyncio.coroutine def __aiter__(self): diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 83dd38960a0..ee9d9fde397 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1277,6 +1277,22 @@ manually. :raise TypeError: if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview`. + .. method:: send_json(data, *, dumps=json.loads) + + Send *data* to peer as JSON string. + + :param data: data to send. + + :param callable dumps: any :term:`callable` that accepts an object and + returns a JSON string + (:func:`json.dumps` by default). + + :raise RuntimeError: if connection is not started or closing + + :raise ValueError: if data is not serializable object + + :raise TypeError: if value returned by :term:`dumps` is not :class:`str` + .. comethod:: close(*, code=1000, message=b'') A :ref:`coroutine` that initiates closing handshake by sending @@ -1306,6 +1322,40 @@ manually. :return: :class:`~aiohttp.websocket.Message`, `tp` is types of `~aiohttp.MsgType` + .. coroutinemethod:: receive_str() + + A :ref:`coroutine` that calls :meth:`receive` but + also asserts the message type is + :const:`~aiohttp.websocket.MSG_TEXT`. + + :return str: peer's message content. + + :raise TypeError: if message is :const:`~aiohttp.websocket.MSG_BINARY`. + + .. coroutinemethod:: receive_bytes() + + A :ref:`coroutine` that calls :meth:`receive` but + also asserts the message type is + :const:`~aiohttp.websocket.MSG_BINARY`. + + :return bytes: peer's message content. + + :raise TypeError: if message is :const:`~aiohttp.websocket.MSG_TEXT`. + + .. coroutinemethod:: receive_json(*, loads=json.loads) + + A :ref:`coroutine` that calls :meth:`receive_str` and loads + the JSON string to a Python dict. + + :param callable loads: any :term:`callable` that accepts + :class:`str` and returns :class:`dict` + with parsed JSON (:func:`json.loads` by + default). + + :return dict: loaded JSON content + + :raise TypeError: if message is :const:`~aiohttp.websocket.MSG_BINARY`. + :raise ValueError: if message is not valid JSON. Utilities --------- diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 5d781fab6e1..93226f7761a 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -830,6 +830,22 @@ WebSocketResponse :raise TypeError: if data is not :class:`bytes`, :class:`bytearray` or :class:`memoryview`. + .. method:: send_json(data, *, dumps=json.loads) + + Send *data* to peer as JSON string. + + :param data: data to send. + + :param callable dumps: any :term:`callable` that accepts an object and + returns a JSON string + (:func:`json.dumps` by default). + + :raise RuntimeError: if connection is not started or closing + + :raise ValueError: if data is not serializable object + + :raise TypeError: if value returned by :term:`dumps` is not :class:`str` + .. coroutinemethod:: close(*, code=1000, message=b'') A :ref:`coroutine` that initiates closing @@ -888,9 +904,8 @@ WebSocketResponse .. coroutinemethod:: receive_json(*, loads=json.loads) - A :ref:`coroutine` that calls :meth:`receive`, asserts the - message type is :const:`~aiohttp.websocket.MSG_TEXT`, and loads the JSON - string to a Python dict. + A :ref:`coroutine` that calls :meth:`receive_str` and loads the + JSON string to a Python dict. :param callable loads: any :term:`callable` that accepts :class:`str` and returns :class:`dict` diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index cd53eada871..99c0f53a9b7 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -65,6 +65,11 @@ def test_nonstarted_send_bytes(self): with self.assertRaises(RuntimeError): ws.send_bytes(b'bytes') + def test_nonstarted_send_json(self): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + ws.send_json({'type': 'json'}) + def test_nonstarted_close(self): ws = WebSocketResponse() with self.assertRaises(RuntimeError): @@ -90,6 +95,16 @@ def go(): self.loop.run_until_complete(go()) + def test_nonstarted_receive_json(self): + + @asyncio.coroutine + def go(): + ws = WebSocketResponse() + with self.assertRaises(RuntimeError): + yield from ws.receive_json() + + self.loop.run_until_complete(go()) + def test_receive_str_nonstring(self): @asyncio.coroutine @@ -142,6 +157,13 @@ def test_send_bytes_nonbytes(self): with self.assertRaises(TypeError): ws.send_bytes('string') + def test_send_json_nonjson(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + self.loop.run_until_complete(ws.prepare(req)) + with self.assertRaises(TypeError): + ws.send_json(set()) + def test_write(self): ws = WebSocketResponse() with self.assertRaises(RuntimeError): @@ -196,6 +218,14 @@ def test_send_bytes_closed(self): with self.assertRaises(RuntimeError): ws.send_bytes(b'bytes') + def test_send_json_closed(self): + req = self.make_request('GET', '/') + ws = WebSocketResponse() + self.loop.run_until_complete(ws.prepare(req)) + self.loop.run_until_complete(ws.close()) + with self.assertRaises(RuntimeError): + ws.send_json({'type': 'json'}) + def test_ping_closed(self): req = self.make_request('GET', '/') ws = WebSocketResponse() diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index d73e54165a9..44bf5c8d4c8 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -40,14 +40,12 @@ def test_websocket_json_invalid_message(create_app_and_client): def handler(request): ws = web.WebSocketResponse() yield from ws.prepare(request) - msg = yield from ws.receive() - try: - msg.json() + yield from ws.receive_json() except ValueError: - ws.send_str("ValueError raised: '%s'" % msg.data) + ws.send_str('ValueError was raised') else: - raise Exception("No ValueError was raised") + raise Exception('No Exception') finally: yield from ws.close() return ws @@ -59,8 +57,32 @@ def handler(request): payload = 'NOT A VALID JSON STRING' ws.send_str(payload) - resp = yield from ws.receive() - assert payload in resp.data + data = yield from ws.receive_str() + assert 'ValueError was raised' in data + + +@pytest.mark.run_loop +def test_websocket_send_json(create_app_and_client): + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + data = yield from ws.receive_json() + ws.send_json(data) + + yield from ws.close() + return ws + + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) + + ws = yield from client.ws_connect('/') + expected_value = 'value' + ws.send_json({'test': expected_value}) + + data = yield from ws.receive_json() + assert data['test'] == expected_value @pytest.mark.run_loop diff --git a/tests/test_web_websocket_functional_oldstyle.py b/tests/test_web_websocket_functional_oldstyle.py index 41f35cfe352..be6ac2b196b 100644 --- a/tests/test_web_websocket_functional_oldstyle.py +++ b/tests/test_web_websocket_functional_oldstyle.py @@ -117,6 +117,41 @@ def go(): self.loop.run_until_complete(go()) + def test_send_recv_json(self): + closed = helpers.create_future(self.loop) + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + data = yield from ws.receive_json() + ws.send_json({'response': data['request']}) + yield from ws.close() + closed.set_result(1) + return ws + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp, reader, writer = yield from self.connect_ws(url) + writer.send('{"request": "test"}') + msg = yield from reader.read() + data = msg.json() + self.assertEqual(msg.tp, websocket.MSG_TEXT) + self.assertEqual(data['response'], 'test') + + msg = yield from reader.read() + self.assertEqual(msg.tp, websocket.MSG_CLOSE) + self.assertEqual(msg.data, 1000) + self.assertEqual(msg.extra, '') + + writer.close() + + yield from closed + resp.close() + + self.loop.run_until_complete(go()) + def test_auto_pong_with_closing_by_peer(self): closed = helpers.create_future(self.loop) diff --git a/tests/test_websocket_client.py b/tests/test_websocket_client.py index 3aeb0223cf5..933676619e4 100644 --- a/tests/test_websocket_client.py +++ b/tests/test_websocket_client.py @@ -334,6 +334,7 @@ def test_send_data_after_close(self, m_req, m_os, WebSocketWriter): self.assertRaises(RuntimeError, resp.pong) self.assertRaises(RuntimeError, resp.send_str, 's') self.assertRaises(RuntimeError, resp.send_bytes, b'b') + self.assertRaises(RuntimeError, resp.send_json, {}) @mock.patch('aiohttp.client.WebSocketWriter') @mock.patch('aiohttp.client.os') @@ -357,6 +358,7 @@ def test_send_data_type_errors(self, m_req, m_os, WebSocketWriter): self.assertRaises(TypeError, resp.send_str, b's') self.assertRaises(TypeError, resp.send_bytes, 'b') + self.assertRaises(TypeError, resp.send_json, set()) @mock.patch('aiohttp.client.WebSocketWriter') @mock.patch('aiohttp.client.os') diff --git a/tests/test_websocket_client_functional.py b/tests/test_websocket_client_functional.py index c6992356543..980c28d4308 100644 --- a/tests/test_websocket_client_functional.py +++ b/tests/test_websocket_client_functional.py @@ -22,8 +22,8 @@ def handler(request): resp = yield from client.ws_connect('/') resp.send_str('ask') - msg = yield from resp.receive() - assert msg.data == 'ask/answer' + data = yield from resp.receive_str() + assert data == 'ask/answer' yield from resp.close() @@ -46,9 +46,33 @@ def handler(request): resp.send_bytes(b'ask') - msg = yield from resp.receive() - assert msg.data == b'ask/answer' + data = yield from resp.receive_bytes() + assert data == b'ask/answer' + + yield from resp.close() + + +@pytest.mark.run_loop +def test_send_recv_json(create_app_and_client): + + @asyncio.coroutine + def handler(request): + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + data = yield from ws.receive_json() + ws.send_json({'response': data['request']}) + yield from ws.close() + return ws + + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) + resp = yield from client.ws_connect('/') + payload = {'request': 'test'} + resp.send_json(payload) + data = yield from resp.receive_json() + assert data['response'] == payload['request'] yield from resp.close()