Skip to content

Commit

Permalink
Convert multipart to async/await syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov committed Nov 10, 2017
1 parent e59fd1f commit 5f4ba9f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 60 deletions.
95 changes: 40 additions & 55 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import base64
import binascii
import json
Expand Down Expand Up @@ -175,9 +174,8 @@ def __init__(self, resp, stream):
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
async def __anext__(self):
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
Expand Down Expand Up @@ -220,22 +218,19 @@ def __init__(self, boundary, headers, content):
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
async def __anext__(self):
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

@asyncio.coroutine
def next(self):
item = yield from self.read()
async def next(self):
item = await self.read()
if not item:
return None
return item

@asyncio.coroutine
def read(self, *, decode=False):
async def read(self, *, decode=False):
"""Reads body part data.
decode: Decodes data following by encoding
Expand All @@ -246,29 +241,28 @@ def read(self, *, decode=False):
return b''
data = bytearray()
while not self._at_eof:
data.extend((yield from self.read_chunk(self.chunk_size)))
data.extend((await self.read_chunk(self.chunk_size)))
if decode:
return self.decode(data)
return data

@asyncio.coroutine
def read_chunk(self, size=chunk_size):
async def read_chunk(self, size=chunk_size):
"""Reads body part content chunk of the specified size.
size: chunk size
"""
if self._at_eof:
return b''
if self._length:
chunk = yield from self._read_chunk_from_length(size)
chunk = await self._read_chunk_from_length(size)
else:
chunk = yield from self._read_chunk_from_stream(size)
chunk = await self._read_chunk_from_stream(size)

self._read_bytes += len(chunk)
if self._read_bytes == self._length:
self._at_eof = True
if self._at_eof:
clrf = yield from self._content.readline()
clrf = await self._content.readline()
assert b'\r\n' == clrf, \
'reader did not read all the data or it is malformed'
return chunk
Expand Down Expand Up @@ -312,16 +306,15 @@ async def _read_chunk_from_stream(self, size):
self._prev_chunk = chunk
return result

@asyncio.coroutine
def readline(self):
async def readline(self):
"""Reads body part by line by line."""
if self._at_eof:
return b''

if self._unread:
line = self._unread.popleft()
else:
line = yield from self._content.readline()
line = await self._content.readline()

if line.startswith(self._boundary):
# the very last boundary may not come with \r\n,
Expand All @@ -335,45 +328,41 @@ def readline(self):
self._unread.append(line)
return b''
else:
next_line = yield from self._content.readline()
next_line = await self._content.readline()
if next_line.startswith(self._boundary):
line = line[:-2] # strip CRLF but only once
self._unread.append(next_line)

return line

@asyncio.coroutine
def release(self):
async def release(self):
"""Like read(), but reads all the data to the void."""
if self._at_eof:
return
while not self._at_eof:
yield from self.read_chunk(self.chunk_size)
await self.read_chunk(self.chunk_size)

@asyncio.coroutine
def text(self, *, encoding=None):
async def text(self, *, encoding=None):
"""Like read(), but assumes that body part contains text data."""
data = yield from self.read(decode=True)
data = await self.read(decode=True)
# see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm # NOQA
# and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send # NOQA
encoding = encoding or self.get_charset(default='utf-8')
return data.decode(encoding)

@asyncio.coroutine
def json(self, *, encoding=None):
async def json(self, *, encoding=None):
"""Like read(), but assumes that body parts contains JSON data."""
data = yield from self.read(decode=True)
data = await self.read(decode=True)
if not data:
return None
encoding = encoding or self.get_charset(default='utf-8')
return json.loads(data.decode(encoding))

@asyncio.coroutine
def form(self, *, encoding=None):
async def form(self, *, encoding=None):
"""Like read(), but assumes that body parts contains form
urlencoded data.
"""
data = yield from self.read(decode=True)
data = await self.read(decode=True)
if not data:
return None
encoding = encoding or self.get_charset(default='utf-8')
Expand Down Expand Up @@ -491,9 +480,8 @@ def __init__(self, headers, content):
def __aiter__(self):
return self

@asyncio.coroutine
def __anext__(self):
part = yield from self.next()
async def __anext__(self):
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part
Expand Down Expand Up @@ -768,30 +756,29 @@ def size(self):
total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n'
return total

@asyncio.coroutine
def write(self, writer):
async def write(self, writer):
"""Write body."""
if not self._parts:
return

for part, headers, encoding, te_encoding in self._parts:
yield from writer.write(b'--' + self._boundary + b'\r\n')
yield from writer.write(headers)
await writer.write(b'--' + self._boundary + b'\r\n')
await writer.write(headers)

if encoding or te_encoding:
w = MultipartPayloadWriter(writer)
if encoding:
w.enable_compression(encoding)
if te_encoding:
w.enable_encoding(te_encoding)
yield from part.write(w)
yield from w.write_eof()
await part.write(w)
await w.write_eof()
else:
yield from part.write(writer)
await part.write(writer)

yield from writer.write(b'\r\n')
await writer.write(b'\r\n')

yield from writer.write(b'--' + self._boundary + b'--\r\n')
await writer.write(b'--' + self._boundary + b'--\r\n')


class MultipartPayloadWriter:
Expand All @@ -813,21 +800,19 @@ def enable_compression(self, encoding='deflate'):
if encoding == 'gzip' else -zlib.MAX_WBITS)
self._compress = zlib.compressobj(wbits=zlib_mode)

@asyncio.coroutine
def write_eof(self):
async def write_eof(self):
if self._compress is not None:
chunk = self._compress.flush()
if chunk:
self._compress = None
yield from self.write(chunk)
await self.write(chunk)

if self._encoding == 'base64':
if self._encoding_buffer:
yield from self._writer.write(base64.b64encode(
await self._writer.write(base64.b64encode(
self._encoding_buffer))

@asyncio.coroutine
def write(self, chunk):
async def write(self, chunk):
if self._compress is not None:
if chunk:
chunk = self._compress.compress(chunk)
Expand All @@ -844,8 +829,8 @@ def write(self, chunk):
buffer[:div * 3], buffer[div * 3:])
if enc_chunk:
enc_chunk = base64.b64encode(enc_chunk)
yield from self._writer.write(enc_chunk)
await self._writer.write(enc_chunk)
elif self._encoding == 'quoted-printable':
yield from self._writer.write(binascii.b2a_qp(chunk))
await self._writer.write(binascii.b2a_qp(chunk))
else:
yield from self._writer.write(chunk)
await self._writer.write(chunk)
10 changes: 5 additions & 5 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,15 +924,15 @@ def test_writer_content_transfer_encoding_unknown(buf, stream, writer):
writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'unknown'})


class MultipartWriterTestCase(unittest.TestCase):
class MultipartWriterTestCase(TestCase):

def setUp(self):
super().setUp()
self.buf = bytearray()
self.stream = mock.Mock()

def write(chunk):
async def write(chunk):
self.buf.extend(chunk)
return ()

self.stream.write.side_effect = write

Expand Down Expand Up @@ -1004,8 +1004,8 @@ def test_append_multipart(self):
part = self.writer._parts[0][0]
self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed')

def test_write(self):
self.assertEqual([], list(self.writer.write(self.stream)))
async def test_write(self):
await self.writer.write(self.stream)

def test_with(self):
with aiohttp.multipart.MultipartWriter(boundary=':') as writer:
Expand Down

0 comments on commit 5f4ba9f

Please sign in to comment.