diff --git a/CHANGES.rst b/CHANGES.rst index e9ac893d0aa..d90c3042c80 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -141,6 +141,8 @@ Changes - Cancel websocket heartbeat on close #1793 +- Make enable_compression work on HTTP/1.0 #1828 + 2.0.6 (2017-04-04) ------------------ diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index d17e1227efb..1feec8fa089 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -5,6 +5,7 @@ import math import time import warnings +import zlib from email.utils import parsedate from multidict import CIMultiDict, CIMultiDictProxy @@ -299,7 +300,10 @@ def _do_start_compression(self, coding): if coding != ContentCoding.identity: self.headers[hdrs.CONTENT_ENCODING] = coding.value self._payload_writer.enable_compression(coding.value) - self._chunked = True + # Compressed payload may have different content length, + # remove the header + if hdrs.CONTENT_LENGTH in self._headers: + del self._headers[hdrs.CONTENT_LENGTH] def _start_compression(self, request): if self._compression_force: @@ -362,11 +366,14 @@ def _start(self, request, del headers[CONTENT_LENGTH] elif self._length_check: writer.length = self.content_length - if writer.length is None and version >= HttpVersion11: - writer.enable_chunking() - headers[TRANSFER_ENCODING] = 'chunked' - if CONTENT_LENGTH in headers: - del headers[CONTENT_LENGTH] + if writer.length is None: + if version >= HttpVersion11: + writer.enable_chunking() + headers[TRANSFER_ENCODING] = 'chunked' + if CONTENT_LENGTH in headers: + del headers[CONTENT_LENGTH] + else: + keep_alive = False headers.setdefault(CONTENT_TYPE, 'application/octet-stream') headers.setdefault(DATE, request.time_service.strtime()) @@ -489,6 +496,8 @@ def __init__(self, *, body=None, status=200, else: self.body = body + self._compressed_body = None + @property def body(self): return self._body @@ -513,12 +522,10 @@ def body(self, body, headers = self._headers - # enable chunked encoding if needed + # set content-length header if needed if not self._chunked and CONTENT_LENGTH not in headers: size = body.size - if size is None: - self._chunked = True - elif CONTENT_LENGTH not in headers: + if size is not None: headers[CONTENT_LENGTH] = str(size) # set content-type @@ -531,6 +538,8 @@ def body(self, body, if key not in headers: headers[key] = value + self._compressed_body = None + @property def text(self): if self._body is None: @@ -549,6 +558,7 @@ def text(self, text): self._body = text.encode(self.charset) self._body_payload = False + self._compressed_body = None @property def content_length(self): @@ -558,7 +568,13 @@ def content_length(self): if hdrs.CONTENT_LENGTH in self.headers: return super().content_length - if self._body is not None: + if self._compressed_body is not None: + # Return length of the compressed body + return len(self._compressed_body) + elif self._body_payload: + # A payload without content length, or a compressed payload + return None + elif self._body is not None: return len(self._body) else: return 0 @@ -571,7 +587,10 @@ def content_length(self, value): def write_eof(self): if self._eof_sent: return - body = self._body + if self._compressed_body is not None: + body = self._compressed_body + else: + body = self._body if body is not None: if (self._req._method == hdrs.METH_HEAD or self._status in [204, 304]): @@ -586,13 +605,29 @@ def write_eof(self): def _start(self, request): if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: - if self._body is not None: - self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body)) - else: - self._headers[hdrs.CONTENT_LENGTH] = '0' + if not self._body_payload: + if self._body is not None: + self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body)) + else: + self._headers[hdrs.CONTENT_LENGTH] = '0' return super()._start(request) + def _do_start_compression(self, coding): + if self._body_payload or self._chunked: + return super()._do_start_compression(coding) + if coding != ContentCoding.identity: + # Instead of using _payload_writer.enable_compression, + # compress the whole body + zlib_mode = (16 + zlib.MAX_WBITS + if coding.value == 'gzip' else -zlib.MAX_WBITS) + compressobj = zlib.compressobj(wbits=zlib_mode) + self._compressed_body = compressobj.compress(self._body) +\ + compressobj.flush() + self._headers[hdrs.CONTENT_ENCODING] = coding.value + self._headers[hdrs.CONTENT_LENGTH] = \ + str(len(self._compressed_body)) + def json_response(data=sentinel, *, text=None, body=None, status=200, reason=None, headers=None, content_type='application/json', diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 4304c845bb2..d192c339ddb 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1755,6 +1755,25 @@ def handler(request): resp.close() +@asyncio.coroutine +def test_encoding_deflate_nochunk(loop, test_client): + @asyncio.coroutine + def handler(request): + resp = web.Response(text='text') + resp.enable_compression(web.ContentCoding.deflate) + return resp + + app = web.Application() + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + txt = yield from resp.text() + assert txt == 'text' + resp.close() + + @asyncio.coroutine def test_encoding_gzip(loop, test_client): @asyncio.coroutine @@ -1775,6 +1794,25 @@ def handler(request): resp.close() +@asyncio.coroutine +def test_encoding_gzip_nochunk(loop, test_client): + @asyncio.coroutine + def handler(request): + resp = web.Response(text='text') + resp.enable_compression(web.ContentCoding.gzip) + return resp + + app = web.Application() + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + txt = yield from resp.text() + assert txt == 'text' + resp.close() + + @asyncio.coroutine def test_bad_payload_compression(loop, test_client): @asyncio.coroutine diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 8a9b0832c1f..9df6670d555 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -8,6 +8,7 @@ from multidict import CIMultiDict from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs, signals +from aiohttp.payload import BytesPayload from aiohttp.test_utils import make_mocked_request from aiohttp.web import ContentCoding, Response, StreamResponse, json_response @@ -385,15 +386,152 @@ def test_force_compression_no_accept_gzip(): @asyncio.coroutine -def test_delete_content_length_if_compression_enabled(): +def test_change_content_length_if_compression_enabled(): req = make_request('GET', '/') resp = Response(body=b'answer') resp.enable_compression(ContentCoding.gzip) + yield from resp.prepare(req) + assert resp.content_length is not None and \ + resp.content_length != len(b'answer') + + +@asyncio.coroutine +def test_set_content_length_if_compression_enabled(): + writer = mock.Mock() + + def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH in headers + assert headers[hdrs.CONTENT_LENGTH] == '26' + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request('GET', '/', payload_writer=writer) + resp = Response(body=b'answer') + resp.enable_compression(ContentCoding.gzip) + + yield from resp.prepare(req) + assert resp.content_length == 26 + del resp.headers[hdrs.CONTENT_LENGTH] + assert resp.content_length == 26 + + +@asyncio.coroutine +def test_remove_content_length_if_compression_enabled_http11(): + writer = mock.Mock() + + def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert headers.get(hdrs.TRANSFER_ENCODING, '') == 'chunked' + + writer.write_headers.side_effect = write_headers + req = make_request('GET', '/', payload_writer=writer) + resp = StreamResponse() + resp.content_length = 123 + resp.enable_compression(ContentCoding.gzip) + yield from resp.prepare(req) + assert resp.content_length is None + + +@asyncio.coroutine +def test_remove_content_length_if_compression_enabled_http10(): + writer = mock.Mock() + + def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request('GET', '/', version=HttpVersion10, + payload_writer=writer) + resp = StreamResponse() + resp.content_length = 123 + resp.enable_compression(ContentCoding.gzip) yield from resp.prepare(req) assert resp.content_length is None +@asyncio.coroutine +def test_force_compression_identity(): + writer = mock.Mock() + + def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH in headers + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request('GET', '/', + payload_writer=writer) + resp = StreamResponse() + resp.content_length = 123 + resp.enable_compression(ContentCoding.identity) + yield from resp.prepare(req) + assert resp.content_length == 123 + + +@asyncio.coroutine +def test_force_compression_identity_response(): + writer = mock.Mock() + + def write_headers(status_line, headers): + assert headers[hdrs.CONTENT_LENGTH] == "6" + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request('GET', '/', + payload_writer=writer) + resp = Response(body=b'answer') + resp.enable_compression(ContentCoding.identity) + yield from resp.prepare(req) + assert resp.content_length == 6 + + +@asyncio.coroutine +def test_remove_content_length_if_compression_enabled_on_payload_http11(): + writer = mock.Mock() + + def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert headers.get(hdrs.TRANSFER_ENCODING, '') == 'chunked' + + writer.write_headers.side_effect = write_headers + req = make_request('GET', '/', payload_writer=writer) + payload = BytesPayload(b'answer', headers={"X-Test-Header": "test"}) + resp = Response(body=payload) + assert resp.content_length == 6 + resp.body = payload + resp.enable_compression(ContentCoding.gzip) + yield from resp.prepare(req) + assert resp.content_length is None + + +@asyncio.coroutine +def test_remove_content_length_if_compression_enabled_on_payload_http10(): + writer = mock.Mock() + + def write_headers(status_line, headers): + assert hdrs.CONTENT_LENGTH not in headers + assert hdrs.TRANSFER_ENCODING not in headers + + writer.write_headers.side_effect = write_headers + req = make_request('GET', '/', version=HttpVersion10, + payload_writer=writer) + resp = Response(body=BytesPayload(b'answer')) + resp.enable_compression(ContentCoding.gzip) + yield from resp.prepare(req) + assert resp.content_length is None + + +@asyncio.coroutine +def test_content_length_on_chunked(): + req = make_request('GET', '/') + resp = Response(body=b'answer') + assert resp.content_length == 6 + resp.enable_chunked_encoding() + assert resp.content_length is None + yield from resp.prepare(req) + + @asyncio.coroutine def test_write_non_byteish(): resp = StreamResponse()