diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index dc2b7bd68ae..3ef0ddea5e8 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -332,6 +332,8 @@ def _format_e(key, args): @staticmethod def _format_i(key, args): + if not args[0]: + return '(no headers)' return args[0].headers.get(multidict.upstr(key), '-') @staticmethod diff --git a/aiohttp/server.py b/aiohttp/server.py index 482df45edff..6ece4efac4d 100644 --- a/aiohttp/server.py +++ b/aiohttp/server.py @@ -73,6 +73,12 @@ class ServerHttpProtocol(aiohttp.StreamProtocol): :param str access_log_format: access log format string :param loop: Optional event loop + + :param int max_line_size: Optional maximum header line size + + :param int max_field_size: Optional maximum header field size + + :param int max_headers: Optional maximum header size """ _request_count = 0 _request_handler = None @@ -81,9 +87,6 @@ class ServerHttpProtocol(aiohttp.StreamProtocol): _keep_alive_handle = None # keep alive timer handle _timeout_handle = None # slow request timer handle - _request_prefix = aiohttp.HttpPrefixParser() # HTTP method parser - _request_parser = aiohttp.HttpRequestParser() # default request parser - def __init__(self, *, loop=None, keep_alive=75, # NGINX default value is 75 secs keep_alive_on=True, @@ -103,6 +106,14 @@ def __init__(self, *, loop=None, self._timeout = timeout # slow request timeout self._loop = loop if loop is not None else asyncio.get_event_loop() + parser_kwargs = {} + for kwarg in ['max_line_size', 'max_field_size', 'max_headers']: + if kwarg in kwargs: + parser_kwargs[kwarg] = kwargs.pop(kwarg) + + self._request_prefix = aiohttp.HttpPrefixParser() + self._request_parser = aiohttp.HttpRequestParser(**parser_kwargs) + self.logger = log or logger self.debug = debug self.access_log = access_log diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 1ec96adae0f..53f7702ab30 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -177,9 +177,10 @@ def test_logger_no_message_and_environ(): mock_logger = mock.Mock() mock_transport = mock.Mock() mock_transport.get_extra_info.return_value = ("127.0.0.3", 0) - access_logger = helpers.AccessLogger(mock_logger, "%r %{FOOBAR}e") + access_logger = helpers.AccessLogger(mock_logger, + "%r %{FOOBAR}e %{content-type}i") access_logger.log(None, None, None, mock_transport, 0.0) - mock_logger.info.assert_called_with("- -") + mock_logger.info.assert_called_with("- - (no headers)") def test_reify(): diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index f9463b6e60d..113abdcb24f 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -43,8 +43,9 @@ def find_unused_port(self): @asyncio.coroutine def create_server(self, method, path, handler=None, ssl_ctx=None, - logger=log.server_logger): - app = web.Application(loop=self.loop) + logger=log.server_logger, handler_kwargs=None): + app = web.Application( + loop=self.loop) if handler: app.router.add_route(method, path, handler) @@ -52,7 +53,8 @@ def create_server(self, method, path, handler=None, ssl_ctx=None, self.handler = app.make_handler( keep_alive_on=False, access_log=log.access_logger, - logger=logger) + logger=logger, + **(handler_kwargs or {})) srv = yield from self.loop.create_server( self.handler, '127.0.0.1', port, ssl=ssl_ctx) protocol = "https" if ssl_ctx else "http" @@ -776,6 +778,44 @@ def go(): self.loop.run_until_complete(go()) + def test_large_header(self): + + @asyncio.coroutine + def handler(request): + return web.Response() + + @asyncio.coroutine + def go(): + _, srv, url = yield from self.create_server('GET', '/', handler) + headers = {'Long-Header': 'ab' * 8129} + resp = yield from request('GET', url, + headers=headers, + loop=self.loop) + self.assertEqual(400, resp.status) + yield from resp.release() + + self.loop.run_until_complete(go()) + + def test_large_header_allowed(self): + + @asyncio.coroutine + def handler(request): + return web.Response() + + @asyncio.coroutine + def go(): + handler_kwargs = {'max_field_size': 81920} + _, srv, url = yield from self.create_server( + 'GET', '/', handler, handler_kwargs=handler_kwargs) + headers = {'Long-Header': 'ab' * 8129} + resp = yield from request('GET', url, + headers=headers, + loop=self.loop) + self.assertEqual(200, resp.status) + yield from resp.release() + + self.loop.run_until_complete(go()) + def test_get_with_empty_arg_with_equal(self): @asyncio.coroutine