diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 35f0f8f1164..47caad26d49 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -214,6 +214,7 @@ def __init__(self, boundary, headers, content): self._read_bytes = 0 self._unread = deque() self._prev_chunk = None + self._content_eof = 0 @asyncio.coroutine def __aiter__(self): @@ -310,13 +311,14 @@ def _read_chunk_from_stream(self, size): self._prev_chunk = yield from self._content.read(size) chunk = yield from self._content.read(size) - + self._content_eof += int(self._content.at_eof()) + assert self._content_eof < 3, "Reading after EOF" window = self._prev_chunk + chunk sub = b'\r\n' + self._boundary if first_chunk: idx = window.find(sub) else: - idx = window.find(sub, len(self._prev_chunk) - len(sub)) + idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub))) if idx >= 0: # pushing boundary back to content self._content.unread_data(window[idx:]) @@ -325,6 +327,10 @@ def _read_chunk_from_stream(self, size): chunk = window[len(self._prev_chunk):idx] if not chunk: self._at_eof = True + if 0 < len(chunk) < len(sub) and not self._content_eof: + self._prev_chunk += chunk + self._at_eof = False + return b'' result = self._prev_chunk self._prev_chunk = chunk return result diff --git a/tests/test_multipart.py b/tests/test_multipart.py index bd736a261e2..cb39656ad4d 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -69,6 +69,9 @@ def __init__(self, content): def read(self, size=None): return self.content.read(size) + def at_eof(self): + return self.content.tell() == len(self.content.getbuffer()) + @asyncio.coroutine def readline(self): return self.content.readline() @@ -191,6 +194,47 @@ def prepare(data): c3 = yield from obj.read_chunk(8) self.assertEqual(c3, b'!') + def test_read_all_at_once(self): + stream = Stream(b'Hello, World!\r\n--:--\r\n') + obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream) + result = yield from obj.read_chunk() + self.assertEqual(b'Hello, World!', result) + result = yield from obj.read_chunk() + self.assertEqual(b'', result) + self.assertTrue(obj.at_eof()) + + def test_read_incomplete_body_chunked(self): + stream = Stream(b'Hello, World!\r\n-') + obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream) + result = b'' + with self.assertRaises(AssertionError): + for _ in range(4): + result += yield from obj.read_chunk(7) + self.assertEqual(b'Hello, World!\r\n-', result) + + def test_read_boundary_with_incomplete_chunk(self): + stream = Stream(b'') + + def prepare(data): + f = asyncio.Future(loop=self.loop) + f.set_result(data) + return f + + with mock.patch.object(stream, 'read', side_effect=[ + prepare(b'Hello, World'), + prepare(b'!\r\n'), + prepare(b'--:'), + prepare(b'') + ]): + obj = aiohttp.multipart.BodyPartReader( + self.boundary, {}, stream) + c1 = yield from obj.read_chunk(12) + self.assertEqual(c1, b'Hello, World') + c2 = yield from obj.read_chunk(8) + self.assertEqual(c2, b'!') + c3 = yield from obj.read_chunk(8) + self.assertEqual(c3, b'') + def test_multi_read_chunk(self): stream = Stream(b'Hello,\r\n--:\r\n\r\nworld!\r\n--:--') obj = aiohttp.multipart.BodyPartReader(self.boundary, {}, stream)