From f1410dfc63d06a8af675c001d6c7a6ea4758cbf1 Mon Sep 17 00:00:00 2001 From: Bill Kalantzis Date: Thu, 2 Jun 2016 14:08:39 -0700 Subject: [PATCH] added headers to ClientSession.ws_connnect #785 --- aiohttp/client.py | 19 +++++++--- tests/test_websocket_client_functional.py | 44 ++++++++++++++++++++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/aiohttp/client.py b/aiohttp/client.py index 67329103e51..bdbb3eb3db8 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -250,7 +250,8 @@ def ws_connect(self, url, *, autoclose=True, autoping=True, auth=None, - origin=None): + origin=None, + headers=None): """Initiate websocket connection.""" return _WSRequestContextManager( self._ws_connect(url, @@ -259,7 +260,8 @@ def ws_connect(self, url, *, autoclose=autoclose, autoping=autoping, auth=auth, - origin=origin)) + origin=origin, + headers=headers)) @asyncio.coroutine def _ws_connect(self, url, *, @@ -268,16 +270,25 @@ def _ws_connect(self, url, *, autoclose=True, autoping=True, auth=None, - origin=None): + origin=None, + headers=None): sec_key = base64.b64encode(os.urandom(16)) - headers = { + if headers is None: + headers = {} + + default_headers = { hdrs.UPGRADE: hdrs.WEBSOCKET, hdrs.CONNECTION: hdrs.UPGRADE, hdrs.SEC_WEBSOCKET_VERSION: '13', hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(), } + + for key, value in default_headers.items(): + if key not in headers: + headers[key] = value + if protocols: headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols) if origin is not None: diff --git a/tests/test_websocket_client_functional.py b/tests/test_websocket_client_functional.py index 335722c9141..206a46beb37 100644 --- a/tests/test_websocket_client_functional.py +++ b/tests/test_websocket_client_functional.py @@ -1,7 +1,7 @@ import aiohttp import asyncio import pytest -from aiohttp import web +from aiohttp import web, hdrs @pytest.mark.run_loop @@ -283,3 +283,45 @@ def handler(request): yield from asyncio.sleep(0.1, loop=loop) assert resp.closed assert resp.exception() is None + + +@pytest.mark.run_loop +def test_override_default_headers(create_app_and_client, loop): + + @asyncio.coroutine + def handler(request): + assert request.headers[hdrs.SEC_WEBSOCKET_VERSION] == '8' + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + ws.send_str('answer') + yield from ws.close() + return ws + + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) + resp = yield from client.ws_connect('/', headers={hdrs.SEC_WEBSOCKET_VERSION: '8'}) + msg = yield from resp.receive() + assert msg.data == 'answer' + yield from resp.close() + + +@pytest.mark.run_loop +def test_additional_headers(create_app_and_client, loop): + + @asyncio.coroutine + def handler(request): + assert request.headers['x-hdr'] == 'xtra' + ws = web.WebSocketResponse() + yield from ws.prepare(request) + + ws.send_str('answer') + yield from ws.close() + return ws + + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) + resp = yield from client.ws_connect('/', headers={'x-hdr': 'xtra'}) + msg = yield from resp.receive() + assert msg.data == 'answer' + yield from resp.close()