From 5f4ba9f54bf97cd4ef4fd0e1882f7a934cffc1c7 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 10 Nov 2017 11:24:05 +0200 Subject: [PATCH] Convert multipart to async/await syntax --- aiohttp/multipart.py | 95 +++++++++++++++++------------------------ tests/test_multipart.py | 10 ++--- 2 files changed, 45 insertions(+), 60 deletions(-) diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index e28932313aa..bc90895ba50 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -1,4 +1,3 @@ -import asyncio import base64 import binascii import json @@ -175,9 +174,8 @@ def __init__(self, resp, stream): def __aiter__(self): return self - @asyncio.coroutine - def __anext__(self): - part = yield from self.next() + async def __anext__(self): + part = await self.next() if part is None: raise StopAsyncIteration # NOQA return part @@ -220,22 +218,19 @@ def __init__(self, boundary, headers, content): def __aiter__(self): return self - @asyncio.coroutine - def __anext__(self): - part = yield from self.next() + async def __anext__(self): + part = await self.next() if part is None: raise StopAsyncIteration # NOQA return part - @asyncio.coroutine - def next(self): - item = yield from self.read() + async def next(self): + item = await self.read() if not item: return None return item - @asyncio.coroutine - def read(self, *, decode=False): + async def read(self, *, decode=False): """Reads body part data. decode: Decodes data following by encoding @@ -246,13 +241,12 @@ def read(self, *, decode=False): return b'' data = bytearray() while not self._at_eof: - data.extend((yield from self.read_chunk(self.chunk_size))) + data.extend((await self.read_chunk(self.chunk_size))) if decode: return self.decode(data) return data - @asyncio.coroutine - def read_chunk(self, size=chunk_size): + async def read_chunk(self, size=chunk_size): """Reads body part content chunk of the specified size. size: chunk size @@ -260,15 +254,15 @@ def read_chunk(self, size=chunk_size): if self._at_eof: return b'' if self._length: - chunk = yield from self._read_chunk_from_length(size) + chunk = await self._read_chunk_from_length(size) else: - chunk = yield from self._read_chunk_from_stream(size) + chunk = await self._read_chunk_from_stream(size) self._read_bytes += len(chunk) if self._read_bytes == self._length: self._at_eof = True if self._at_eof: - clrf = yield from self._content.readline() + clrf = await self._content.readline() assert b'\r\n' == clrf, \ 'reader did not read all the data or it is malformed' return chunk @@ -312,8 +306,7 @@ async def _read_chunk_from_stream(self, size): self._prev_chunk = chunk return result - @asyncio.coroutine - def readline(self): + async def readline(self): """Reads body part by line by line.""" if self._at_eof: return b'' @@ -321,7 +314,7 @@ def readline(self): if self._unread: line = self._unread.popleft() else: - line = yield from self._content.readline() + line = await self._content.readline() if line.startswith(self._boundary): # the very last boundary may not come with \r\n, @@ -335,45 +328,41 @@ def readline(self): self._unread.append(line) return b'' else: - next_line = yield from self._content.readline() + next_line = await self._content.readline() if next_line.startswith(self._boundary): line = line[:-2] # strip CRLF but only once self._unread.append(next_line) return line - @asyncio.coroutine - def release(self): + async def release(self): """Like read(), but reads all the data to the void.""" if self._at_eof: return while not self._at_eof: - yield from self.read_chunk(self.chunk_size) + await self.read_chunk(self.chunk_size) - @asyncio.coroutine - def text(self, *, encoding=None): + async def text(self, *, encoding=None): """Like read(), but assumes that body part contains text data.""" - data = yield from self.read(decode=True) + data = await self.read(decode=True) # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # NOQA # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send # NOQA encoding = encoding or self.get_charset(default='utf-8') return data.decode(encoding) - @asyncio.coroutine - def json(self, *, encoding=None): + async def json(self, *, encoding=None): """Like read(), but assumes that body parts contains JSON data.""" - data = yield from self.read(decode=True) + data = await self.read(decode=True) if not data: return None encoding = encoding or self.get_charset(default='utf-8') return json.loads(data.decode(encoding)) - @asyncio.coroutine - def form(self, *, encoding=None): + async def form(self, *, encoding=None): """Like read(), but assumes that body parts contains form urlencoded data. """ - data = yield from self.read(decode=True) + data = await self.read(decode=True) if not data: return None encoding = encoding or self.get_charset(default='utf-8') @@ -491,9 +480,8 @@ def __init__(self, headers, content): def __aiter__(self): return self - @asyncio.coroutine - def __anext__(self): - part = yield from self.next() + async def __anext__(self): + part = await self.next() if part is None: raise StopAsyncIteration # NOQA return part @@ -768,15 +756,14 @@ def size(self): total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' return total - @asyncio.coroutine - def write(self, writer): + async def write(self, writer): """Write body.""" if not self._parts: return for part, headers, encoding, te_encoding in self._parts: - yield from writer.write(b'--' + self._boundary + b'\r\n') - yield from writer.write(headers) + await writer.write(b'--' + self._boundary + b'\r\n') + await writer.write(headers) if encoding or te_encoding: w = MultipartPayloadWriter(writer) @@ -784,14 +771,14 @@ def write(self, writer): w.enable_compression(encoding) if te_encoding: w.enable_encoding(te_encoding) - yield from part.write(w) - yield from w.write_eof() + await part.write(w) + await w.write_eof() else: - yield from part.write(writer) + await part.write(writer) - yield from writer.write(b'\r\n') + await writer.write(b'\r\n') - yield from writer.write(b'--' + self._boundary + b'--\r\n') + await writer.write(b'--' + self._boundary + b'--\r\n') class MultipartPayloadWriter: @@ -813,21 +800,19 @@ def enable_compression(self, encoding='deflate'): if encoding == 'gzip' else -zlib.MAX_WBITS) self._compress = zlib.compressobj(wbits=zlib_mode) - @asyncio.coroutine - def write_eof(self): + async def write_eof(self): if self._compress is not None: chunk = self._compress.flush() if chunk: self._compress = None - yield from self.write(chunk) + await self.write(chunk) if self._encoding == 'base64': if self._encoding_buffer: - yield from self._writer.write(base64.b64encode( + await self._writer.write(base64.b64encode( self._encoding_buffer)) - @asyncio.coroutine - def write(self, chunk): + async def write(self, chunk): if self._compress is not None: if chunk: chunk = self._compress.compress(chunk) @@ -844,8 +829,8 @@ def write(self, chunk): buffer[:div * 3], buffer[div * 3:]) if enc_chunk: enc_chunk = base64.b64encode(enc_chunk) - yield from self._writer.write(enc_chunk) + await self._writer.write(enc_chunk) elif self._encoding == 'quoted-printable': - yield from self._writer.write(binascii.b2a_qp(chunk)) + await self._writer.write(binascii.b2a_qp(chunk)) else: - yield from self._writer.write(chunk) + await self._writer.write(chunk) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 4100d79fdad..5e4d3d106f2 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -924,15 +924,15 @@ def test_writer_content_transfer_encoding_unknown(buf, stream, writer): writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'unknown'}) -class MultipartWriterTestCase(unittest.TestCase): +class MultipartWriterTestCase(TestCase): def setUp(self): + super().setUp() self.buf = bytearray() self.stream = mock.Mock() - def write(chunk): + async def write(chunk): self.buf.extend(chunk) - return () self.stream.write.side_effect = write @@ -1004,8 +1004,8 @@ def test_append_multipart(self): part = self.writer._parts[0][0] self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') - def test_write(self): - self.assertEqual([], list(self.writer.write(self.stream))) + async def test_write(self): + await self.writer.write(self.stream) def test_with(self): with aiohttp.multipart.MultipartWriter(boundary=':') as writer: