From cf6778e93f72a528f474ae3b65fcd7ecda944c13 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 13 Jan 2016 19:59:28 +0200 Subject: [PATCH 1/2] Work on proper Connection header --- aiohttp/protocol.py | 10 +++++++--- tests/test_client_functional.py | 1 - tests/test_protocol.py | 31 +++++++++++++++++++++++++------ tests/test_web_exceptions.py | 3 --- tests/test_web_functional.py | 22 ++++++++++++++++++++-- tests/test_web_response.py | 3 --- 6 files changed, 52 insertions(+), 18 deletions(-) diff --git a/aiohttp/protocol.py b/aiohttp/protocol.py index b4d6af09250..ffb4ec9cb7e 100644 --- a/aiohttp/protocol.py +++ b/aiohttp/protocol.py @@ -675,14 +675,18 @@ def send_headers(self, _sep=': ', _end='\r\n'): def _add_default_headers(self): # set the connection header + connection = None if self.upgrade: connection = 'upgrade' elif not self.closing if self.keepalive is None else self.keepalive: - connection = 'keep-alive' + if self.version == HttpVersion10: + connection = 'keep-alive' else: - connection = 'close' + if self.version == HttpVersion11: + connection = 'close' - self.headers[hdrs.CONNECTION] = connection + if connection is not None: + self.headers[hdrs.CONNECTION] = connection def write(self, chunk, *, drain=False, EOF_MARKER=EOF_MARKER, EOL_MARKER=EOL_MARKER): diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 5595eff15da..be82e305b7c 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -413,7 +413,6 @@ def handler(request): resp = yield from client.get('/') assert resp.status == 200 assert resp.raw_headers == ((b'CONTENT-LENGTH', b'0'), - (b'CONNECTION', b'keep-alive'), (b'DATE', mock.ANY), (b'SERVER', mock.ANY)) resp.close() diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 42937efaaf0..c66b668d79a 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -173,13 +173,21 @@ def test_add_headers_hop_headers(transport): assert [] == list(msg.headers) -def test_default_headers(transport): +def test_default_headers_http_10(transport): + msg = protocol.Response(transport, 200, + http_version=protocol.HttpVersion10) + msg._add_default_headers() + + assert 'DATE' in msg.headers + assert 'keep-alive' == msg.headers['CONNECTION'] + + +def test_default_headers_http_11(transport): msg = protocol.Response(transport, 200) msg._add_default_headers() - headers = [r for r, _ in msg.headers.items()] - assert 'DATE' in headers - assert 'CONNECTION' in headers + assert 'DATE' in msg.headers + assert 'CONNECTION' not in msg.headers def test_default_headers_server(transport): @@ -222,8 +230,9 @@ def test_default_headers_connection_close(transport): assert [('CONNECTION', 'close')] == headers -def test_default_headers_connection_keep_alive(transport): - msg = protocol.Response(transport, 200) +def test_default_headers_connection_keep_alive_http_10(transport): + msg = protocol.Response(transport, 200, + http_version=protocol.HttpVersion10) msg.keepalive = True msg._add_default_headers() @@ -231,6 +240,16 @@ def test_default_headers_connection_keep_alive(transport): assert [('CONNECTION', 'keep-alive')] == headers +def test_default_headers_connection_keep_alive_11(transport): + msg = protocol.Response(transport, 200, + http_version=protocol.HttpVersion11) + msg.keepalive = True + msg._add_default_headers() + + headers = [r for r in msg.headers.items() if r[0] == 'CONNECTION'] + assert 'CONNECTION' not in headers + + def test_send_headers(transport): write = transport.write = mock.Mock() diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index d3f761abf1a..b89c659e161 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -57,7 +57,6 @@ def test_HTTPOk(buf, request): assert re.match(('HTTP/1.1 200 OK\r\n' 'CONTENT-TYPE: text/plain; charset=utf-8\r\n' 'CONTENT-LENGTH: 7\r\n' - 'CONNECTION: keep-alive\r\n' 'DATE: .+\r\n' 'SERVER: .+\r\n\r\n' '200: OK'), txt) @@ -95,7 +94,6 @@ def test_HTTPFound(buf, request): 'CONTENT-TYPE: text/plain; charset=utf-8\r\n' 'CONTENT-LENGTH: 10\r\n' 'LOCATION: /redirect\r\n' - 'CONNECTION: keep-alive\r\n' 'DATE: .+\r\n' 'SERVER: .+\r\n\r\n' '302: Found', txt) @@ -122,7 +120,6 @@ def test_HTTPMethodNotAllowed(buf, request): 'CONTENT-TYPE: text/plain; charset=utf-8\r\n' 'CONTENT-LENGTH: 23\r\n' 'ALLOW: POST,PUT\r\n' - 'CONNECTION: keep-alive\r\n' 'DATE: .+\r\n' 'SERVER: .+\r\n\r\n' '405: Method Not Allowed', txt) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 8fab9d9b4a0..49f818d08d3 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -520,6 +520,23 @@ def go(): self.loop.run_until_complete(go()) + def test_http11_keep_alive_default(self): + + @asyncio.coroutine + def handler(request): + yield from request.read() + return web.Response(body=b'OK') + + @asyncio.coroutine + def go(): + _, _, url = yield from self.create_server('GET', '/', handler) + resp = yield from request('GET', url, loop=self.loop, + version=HttpVersion10) + self.assertNotIn('CONNECTION', resp.headers) + resp.close() + + self.loop.run_until_complete(go()) + def test_http10_keep_alive_default(self): @asyncio.coroutine @@ -532,7 +549,8 @@ def go(): _, _, url = yield from self.create_server('GET', '/', handler) resp = yield from request('GET', url, loop=self.loop, version=HttpVersion10) - self.assertEqual('close', resp.headers['CONNECTION']) + self.assertEqual(resp.version, HttpVersion10) + self.assertEqual('keep-alive', resp.headers['CONNECTION']) resp.close() self.loop.run_until_complete(go()) @@ -551,7 +569,7 @@ def go(): resp = yield from request('GET', url, loop=self.loop, headers=headers, version=HttpVersion(0, 9)) - self.assertEqual('close', resp.headers['CONNECTION']) + self.assertNotIn('CONNECTION', resp.headers) resp.close() self.loop.run_until_complete(go()) diff --git a/tests/test_web_response.py b/tests/test_web_response.py index a5e13d5fba2..6e49572aa52 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -813,7 +813,6 @@ def append(data): yield from resp.write_eof() txt = buf.decode('utf8') assert re.match('HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 0\r\n' - 'CONNECTION: keep-alive\r\n' 'DATE: .+\r\nSERVER: .+\r\n\r\n', txt) @@ -836,7 +835,6 @@ def append(data): yield from resp.write_eof() txt = buf.decode('utf8') assert re.match('HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 4\r\n' - 'CONNECTION: keep-alive\r\n' 'DATE: .+\r\nSERVER: .+\r\n\r\ndata', txt) @@ -861,7 +859,6 @@ def append(data): txt = buf.decode('utf8') assert re.match('HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 0\r\n' 'SET-COOKIE: name=value\r\n' - 'CONNECTION: keep-alive\r\n' 'DATE: .+\r\nSERVER: .+\r\n\r\n', txt) From e23f85beda485efec2be5a2b2921bbd1cb8b5824 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 13 Jan 2016 21:19:39 +0200 Subject: [PATCH 2/2] Enable keepalive by default for HTTP 1.0 --- aiohttp/protocol.py | 2 +- tests/test_web_functional.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/aiohttp/protocol.py b/aiohttp/protocol.py index ffb4ec9cb7e..ece8372da7b 100644 --- a/aiohttp/protocol.py +++ b/aiohttp/protocol.py @@ -886,7 +886,7 @@ def __init__(self, transport, method, path, http_version=HttpVersion11, close=False): # set the default for HTTP 1.0 to be different # will only be overwritten with keep-alive header - if http_version < HttpVersion11: + if http_version < HttpVersion10: close = True super().__init__(transport, http_version, close) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 49f818d08d3..6be948d3af4 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -531,7 +531,7 @@ def handler(request): def go(): _, _, url = yield from self.create_server('GET', '/', handler) resp = yield from request('GET', url, loop=self.loop, - version=HttpVersion10) + version=HttpVersion11) self.assertNotIn('CONNECTION', resp.headers) resp.close() @@ -547,11 +547,12 @@ def handler(request): @asyncio.coroutine def go(): _, _, url = yield from self.create_server('GET', '/', handler) - resp = yield from request('GET', url, loop=self.loop, - version=HttpVersion10) - self.assertEqual(resp.version, HttpVersion10) - self.assertEqual('keep-alive', resp.headers['CONNECTION']) - resp.close() + with ClientSession(loop=self.loop) as session: + resp = yield from session.get(url, + version=HttpVersion10) + self.assertEqual(resp.version, HttpVersion10) + self.assertEqual('keep-alive', resp.headers['CONNECTION']) + resp.close() self.loop.run_until_complete(go()) @@ -587,7 +588,7 @@ def go(): headers = {'Connection': 'close'} resp = yield from request('GET', url, loop=self.loop, headers=headers, version=HttpVersion10) - self.assertEqual('close', resp.headers['CONNECTION']) + self.assertNotIn('CONNECTION', resp.headers) resp.close() self.loop.run_until_complete(go())