Skip to content

Commit

Permalink
Fix #1828: make enable_compression work on HTTP/1.0 (#1910)
Browse files Browse the repository at this point in the history
  • Loading branch information
hubo1016 authored and asvetlov committed Jun 22, 2017
1 parent 075c34c commit 58a7a58
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ Changes

- Cancel websocket heartbeat on close #1793

- Make enable_compression work on HTTP/1.0 #1828


2.0.6 (2017-04-04)
------------------
Expand Down
67 changes: 51 additions & 16 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import time
import warnings
import zlib
from email.utils import parsedate

from multidict import CIMultiDict, CIMultiDictProxy
Expand Down Expand Up @@ -299,7 +300,10 @@ def _do_start_compression(self, coding):
if coding != ContentCoding.identity:
self.headers[hdrs.CONTENT_ENCODING] = coding.value
self._payload_writer.enable_compression(coding.value)
self._chunked = True
# Compressed payload may have different content length,
# remove the header
if hdrs.CONTENT_LENGTH in self._headers:
del self._headers[hdrs.CONTENT_LENGTH]

def _start_compression(self, request):
if self._compression_force:
Expand Down Expand Up @@ -362,11 +366,14 @@ def _start(self, request,
del headers[CONTENT_LENGTH]
elif self._length_check:
writer.length = self.content_length
if writer.length is None and version >= HttpVersion11:
writer.enable_chunking()
headers[TRANSFER_ENCODING] = 'chunked'
if CONTENT_LENGTH in headers:
del headers[CONTENT_LENGTH]
if writer.length is None:
if version >= HttpVersion11:
writer.enable_chunking()
headers[TRANSFER_ENCODING] = 'chunked'
if CONTENT_LENGTH in headers:
del headers[CONTENT_LENGTH]
else:
keep_alive = False

headers.setdefault(CONTENT_TYPE, 'application/octet-stream')
headers.setdefault(DATE, request.time_service.strtime())
Expand Down Expand Up @@ -489,6 +496,8 @@ def __init__(self, *, body=None, status=200,
else:
self.body = body

self._compressed_body = None

@property
def body(self):
return self._body
Expand All @@ -513,12 +522,10 @@ def body(self, body,

headers = self._headers

# enable chunked encoding if needed
# set content-length header if needed
if not self._chunked and CONTENT_LENGTH not in headers:
size = body.size
if size is None:
self._chunked = True
elif CONTENT_LENGTH not in headers:
if size is not None:
headers[CONTENT_LENGTH] = str(size)

# set content-type
Expand All @@ -531,6 +538,8 @@ def body(self, body,
if key not in headers:
headers[key] = value

self._compressed_body = None

@property
def text(self):
if self._body is None:
Expand All @@ -549,6 +558,7 @@ def text(self, text):

self._body = text.encode(self.charset)
self._body_payload = False
self._compressed_body = None

@property
def content_length(self):
Expand All @@ -558,7 +568,13 @@ def content_length(self):
if hdrs.CONTENT_LENGTH in self.headers:
return super().content_length

if self._body is not None:
if self._compressed_body is not None:
# Return length of the compressed body
return len(self._compressed_body)
elif self._body_payload:
# A payload without content length, or a compressed payload
return None
elif self._body is not None:
return len(self._body)
else:
return 0
Expand All @@ -571,7 +587,10 @@ def content_length(self, value):
def write_eof(self):
if self._eof_sent:
return
body = self._body
if self._compressed_body is not None:
body = self._compressed_body
else:
body = self._body
if body is not None:
if (self._req._method == hdrs.METH_HEAD or
self._status in [204, 304]):
Expand All @@ -586,13 +605,29 @@ def write_eof(self):

def _start(self, request):
if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
if self._body is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
else:
self._headers[hdrs.CONTENT_LENGTH] = '0'
if not self._body_payload:
if self._body is not None:
self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
else:
self._headers[hdrs.CONTENT_LENGTH] = '0'

return super()._start(request)

def _do_start_compression(self, coding):
if self._body_payload or self._chunked:
return super()._do_start_compression(coding)
if coding != ContentCoding.identity:
# Instead of using _payload_writer.enable_compression,
# compress the whole body
zlib_mode = (16 + zlib.MAX_WBITS
if coding.value == 'gzip' else -zlib.MAX_WBITS)
compressobj = zlib.compressobj(wbits=zlib_mode)
self._compressed_body = compressobj.compress(self._body) +\
compressobj.flush()
self._headers[hdrs.CONTENT_ENCODING] = coding.value
self._headers[hdrs.CONTENT_LENGTH] = \
str(len(self._compressed_body))


def json_response(data=sentinel, *, text=None, body=None, status=200,
reason=None, headers=None, content_type='application/json',
Expand Down
38 changes: 38 additions & 0 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,25 @@ def handler(request):
resp.close()


@asyncio.coroutine
def test_encoding_deflate_nochunk(loop, test_client):
@asyncio.coroutine
def handler(request):
resp = web.Response(text='text')
resp.enable_compression(web.ContentCoding.deflate)
return resp

app = web.Application()
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
txt = yield from resp.text()
assert txt == 'text'
resp.close()


@asyncio.coroutine
def test_encoding_gzip(loop, test_client):
@asyncio.coroutine
Expand All @@ -1775,6 +1794,25 @@ def handler(request):
resp.close()


@asyncio.coroutine
def test_encoding_gzip_nochunk(loop, test_client):
@asyncio.coroutine
def handler(request):
resp = web.Response(text='text')
resp.enable_compression(web.ContentCoding.gzip)
return resp

app = web.Application()
app.router.add_get('/', handler)
client = yield from test_client(app)

resp = yield from client.get('/')
assert 200 == resp.status
txt = yield from resp.text()
assert txt == 'text'
resp.close()


@asyncio.coroutine
def test_bad_payload_compression(loop, test_client):
@asyncio.coroutine
Expand Down
140 changes: 139 additions & 1 deletion tests/test_web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from multidict import CIMultiDict

from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs, signals
from aiohttp.payload import BytesPayload
from aiohttp.test_utils import make_mocked_request
from aiohttp.web import ContentCoding, Response, StreamResponse, json_response

Expand Down Expand Up @@ -385,15 +386,152 @@ def test_force_compression_no_accept_gzip():


@asyncio.coroutine
def test_delete_content_length_if_compression_enabled():
def test_change_content_length_if_compression_enabled():
req = make_request('GET', '/')
resp = Response(body=b'answer')
resp.enable_compression(ContentCoding.gzip)

yield from resp.prepare(req)
assert resp.content_length is not None and \
resp.content_length != len(b'answer')


@asyncio.coroutine
def test_set_content_length_if_compression_enabled():
writer = mock.Mock()

def write_headers(status_line, headers):
assert hdrs.CONTENT_LENGTH in headers
assert headers[hdrs.CONTENT_LENGTH] == '26'
assert hdrs.TRANSFER_ENCODING not in headers

writer.write_headers.side_effect = write_headers
req = make_request('GET', '/', payload_writer=writer)
resp = Response(body=b'answer')
resp.enable_compression(ContentCoding.gzip)

yield from resp.prepare(req)
assert resp.content_length == 26
del resp.headers[hdrs.CONTENT_LENGTH]
assert resp.content_length == 26


@asyncio.coroutine
def test_remove_content_length_if_compression_enabled_http11():
writer = mock.Mock()

def write_headers(status_line, headers):
assert hdrs.CONTENT_LENGTH not in headers
assert headers.get(hdrs.TRANSFER_ENCODING, '') == 'chunked'

writer.write_headers.side_effect = write_headers
req = make_request('GET', '/', payload_writer=writer)
resp = StreamResponse()
resp.content_length = 123
resp.enable_compression(ContentCoding.gzip)
yield from resp.prepare(req)
assert resp.content_length is None


@asyncio.coroutine
def test_remove_content_length_if_compression_enabled_http10():
writer = mock.Mock()

def write_headers(status_line, headers):
assert hdrs.CONTENT_LENGTH not in headers
assert hdrs.TRANSFER_ENCODING not in headers

writer.write_headers.side_effect = write_headers
req = make_request('GET', '/', version=HttpVersion10,
payload_writer=writer)
resp = StreamResponse()
resp.content_length = 123
resp.enable_compression(ContentCoding.gzip)
yield from resp.prepare(req)
assert resp.content_length is None


@asyncio.coroutine
def test_force_compression_identity():
writer = mock.Mock()

def write_headers(status_line, headers):
assert hdrs.CONTENT_LENGTH in headers
assert hdrs.TRANSFER_ENCODING not in headers

writer.write_headers.side_effect = write_headers
req = make_request('GET', '/',
payload_writer=writer)
resp = StreamResponse()
resp.content_length = 123
resp.enable_compression(ContentCoding.identity)
yield from resp.prepare(req)
assert resp.content_length == 123


@asyncio.coroutine
def test_force_compression_identity_response():
writer = mock.Mock()

def write_headers(status_line, headers):
assert headers[hdrs.CONTENT_LENGTH] == "6"
assert hdrs.TRANSFER_ENCODING not in headers

writer.write_headers.side_effect = write_headers
req = make_request('GET', '/',
payload_writer=writer)
resp = Response(body=b'answer')
resp.enable_compression(ContentCoding.identity)
yield from resp.prepare(req)
assert resp.content_length == 6


@asyncio.coroutine
def test_remove_content_length_if_compression_enabled_on_payload_http11():
writer = mock.Mock()

def write_headers(status_line, headers):
assert hdrs.CONTENT_LENGTH not in headers
assert headers.get(hdrs.TRANSFER_ENCODING, '') == 'chunked'

writer.write_headers.side_effect = write_headers
req = make_request('GET', '/', payload_writer=writer)
payload = BytesPayload(b'answer', headers={"X-Test-Header": "test"})
resp = Response(body=payload)
assert resp.content_length == 6
resp.body = payload
resp.enable_compression(ContentCoding.gzip)
yield from resp.prepare(req)
assert resp.content_length is None


@asyncio.coroutine
def test_remove_content_length_if_compression_enabled_on_payload_http10():
writer = mock.Mock()

def write_headers(status_line, headers):
assert hdrs.CONTENT_LENGTH not in headers
assert hdrs.TRANSFER_ENCODING not in headers

writer.write_headers.side_effect = write_headers
req = make_request('GET', '/', version=HttpVersion10,
payload_writer=writer)
resp = Response(body=BytesPayload(b'answer'))
resp.enable_compression(ContentCoding.gzip)
yield from resp.prepare(req)
assert resp.content_length is None


@asyncio.coroutine
def test_content_length_on_chunked():
req = make_request('GET', '/')
resp = Response(body=b'answer')
assert resp.content_length == 6
resp.enable_chunked_encoding()
assert resp.content_length is None
yield from resp.prepare(req)


@asyncio.coroutine
def test_write_non_byteish():
resp = StreamResponse()
Expand Down

0 comments on commit 58a7a58

Please sign in to comment.