Skip to content

Commit

Permalink
Allow configuration of header size limits (#912)
Browse files Browse the repository at this point in the history
* Do not fail logging if message is not defined

* Allow user specification of max_headers, etc.

Users can specify `max_headers`, `max_line_size`, and `max_field_size`
via `app.make_handler()`.  Fixes #909.
  • Loading branch information
djmitche authored and asvetlov committed Jul 12, 2016
1 parent 913df8a commit 03bbd52
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
2 changes: 2 additions & 0 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ def _format_e(key, args):

@staticmethod
def _format_i(key, args):
if not args[0]:
return '(no headers)'
return args[0].headers.get(multidict.upstr(key), '-')

@staticmethod
Expand Down
17 changes: 14 additions & 3 deletions aiohttp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class ServerHttpProtocol(aiohttp.StreamProtocol):
:param str access_log_format: access log format string
:param loop: Optional event loop
:param int max_line_size: Optional maximum header line size
:param int max_field_size: Optional maximum header field size
:param int max_headers: Optional maximum header size
"""
_request_count = 0
_request_handler = None
Expand All @@ -81,9 +87,6 @@ class ServerHttpProtocol(aiohttp.StreamProtocol):
_keep_alive_handle = None # keep alive timer handle
_timeout_handle = None # slow request timer handle

_request_prefix = aiohttp.HttpPrefixParser() # HTTP method parser
_request_parser = aiohttp.HttpRequestParser() # default request parser

def __init__(self, *, loop=None,
keep_alive=75, # NGINX default value is 75 secs
keep_alive_on=True,
Expand All @@ -103,6 +106,14 @@ def __init__(self, *, loop=None,
self._timeout = timeout # slow request timeout
self._loop = loop if loop is not None else asyncio.get_event_loop()

parser_kwargs = {}
for kwarg in ['max_line_size', 'max_field_size', 'max_headers']:
if kwarg in kwargs:
parser_kwargs[kwarg] = kwargs.pop(kwarg)

self._request_prefix = aiohttp.HttpPrefixParser()
self._request_parser = aiohttp.HttpRequestParser(**parser_kwargs)

self.logger = log or logger
self.debug = debug
self.access_log = access_log
Expand Down
5 changes: 3 additions & 2 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ def test_logger_no_message_and_environ():
mock_logger = mock.Mock()
mock_transport = mock.Mock()
mock_transport.get_extra_info.return_value = ("127.0.0.3", 0)
access_logger = helpers.AccessLogger(mock_logger, "%r %{FOOBAR}e")
access_logger = helpers.AccessLogger(mock_logger,
"%r %{FOOBAR}e %{content-type}i")
access_logger.log(None, None, None, mock_transport, 0.0)
mock_logger.info.assert_called_with("- -")
mock_logger.info.assert_called_with("- - (no headers)")


def test_reify():
Expand Down
46 changes: 43 additions & 3 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ def find_unused_port(self):

@asyncio.coroutine
def create_server(self, method, path, handler=None, ssl_ctx=None,
logger=log.server_logger):
app = web.Application(loop=self.loop)
logger=log.server_logger, handler_kwargs=None):
app = web.Application(
loop=self.loop)
if handler:
app.router.add_route(method, path, handler)

port = self.find_unused_port()
self.handler = app.make_handler(
keep_alive_on=False,
access_log=log.access_logger,
logger=logger)
logger=logger,
**(handler_kwargs or {}))
srv = yield from self.loop.create_server(
self.handler, '127.0.0.1', port, ssl=ssl_ctx)
protocol = "https" if ssl_ctx else "http"
Expand Down Expand Up @@ -776,6 +778,44 @@ def go():

self.loop.run_until_complete(go())

def test_large_header(self):

@asyncio.coroutine
def handler(request):
return web.Response()

@asyncio.coroutine
def go():
_, srv, url = yield from self.create_server('GET', '/', handler)
headers = {'Long-Header': 'ab' * 8129}
resp = yield from request('GET', url,
headers=headers,
loop=self.loop)
self.assertEqual(400, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())

def test_large_header_allowed(self):

@asyncio.coroutine
def handler(request):
return web.Response()

@asyncio.coroutine
def go():
handler_kwargs = {'max_field_size': 81920}
_, srv, url = yield from self.create_server(
'GET', '/', handler, handler_kwargs=handler_kwargs)
headers = {'Long-Header': 'ab' * 8129}
resp = yield from request('GET', url,
headers=headers,
loop=self.loop)
self.assertEqual(200, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())

def test_get_with_empty_arg_with_equal(self):

@asyncio.coroutine
Expand Down

0 comments on commit 03bbd52

Please sign in to comment.