Skip to content

Commit

Permalink
Support for chunk_size (#1277)
Browse files Browse the repository at this point in the history
* Support iter_raw(chunk_size=...) and aiter_raw(chunk_size=...)

* Unit tests for ByteChunker

* Support iter_bytes(chunk_size=...)

* Add TextChunker

* Support iter_text(chunk_size=...)

* Fix merge with master

Co-authored-by: Florimond Manca <[email protected]>
  • Loading branch information
tomchristie and florimondmanca authored Nov 25, 2020
1 parent c4d2e6f commit 27df5e4
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 32 deletions.
79 changes: 79 additions & 0 deletions httpx/_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding
"""
import codecs
import io
import typing
import zlib

Expand Down Expand Up @@ -155,6 +156,84 @@ def flush(self) -> bytes:
return data


class ByteChunker:
"""
Handles returning byte content in fixed-size chunks.
"""

def __init__(self, chunk_size: int = None) -> None:
self._buffer = io.BytesIO()
self._chunk_size = chunk_size

def decode(self, content: bytes) -> typing.List[bytes]:
if self._chunk_size is None:
return [content]

self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []

def flush(self) -> typing.List[bytes]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []


class TextChunker:
"""
Handles returning text content in fixed-size chunks.
"""

def __init__(self, chunk_size: int = None) -> None:
self._buffer = io.StringIO()
self._chunk_size = chunk_size

def decode(self, content: str) -> typing.List[str]:
if self._chunk_size is None:
return [content]

self._buffer.write(content)
if self._buffer.tell() >= self._chunk_size:
value = self._buffer.getvalue()
chunks = [
value[i : i + self._chunk_size]
for i in range(0, len(value), self._chunk_size)
]
if len(chunks[-1]) == self._chunk_size:
self._buffer.seek(0)
self._buffer.truncate()
return chunks
else:
self._buffer.seek(0)
self._buffer.write(chunks[-1])
self._buffer.truncate()
return chunks[:-1]
else:
return []

def flush(self) -> typing.List[str]:
value = self._buffer.getvalue()
self._buffer.seek(0)
self._buffer.truncate()
return [value] if value else []


class TextDecoder:
"""
Handles incrementally decoding bytes into text
Expand Down
104 changes: 76 additions & 28 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from ._content import PlainByteStream, encode_request, encode_response
from ._decoders import (
SUPPORTED_DECODERS,
ByteChunker,
ContentDecoder,
IdentityDecoder,
LineDecoder,
MultiDecoder,
TextChunker,
TextDecoder,
)
from ._exceptions import (
Expand Down Expand Up @@ -1162,31 +1164,47 @@ def read(self) -> bytes:
self._content = b"".join(self.iter_bytes())
return self._content

def iter_bytes(self) -> typing.Iterator[bytes]:
def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), chunk_size):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
for chunk in self.iter_raw():
yield decoder.decode(chunk)
yield decoder.flush()

def iter_text(self) -> typing.Iterator[str]:
for raw_bytes in self.iter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk
for chunk in chunker.flush():
yield chunk

def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
for byte_content in self.iter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
for chunk in chunker.flush():
yield chunk

def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
Expand All @@ -1197,7 +1215,7 @@ def iter_lines(self) -> typing.Iterator[str]:
for line in decoder.flush():
yield line

def iter_raw(self) -> typing.Iterator[bytes]:
def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
Expand All @@ -1210,10 +1228,17 @@ def iter_raw(self) -> typing.Iterator[bytes]:

self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
for part in self.stream:
self._num_bytes_downloaded += len(part)
yield part
for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk

for chunk in chunker.flush():
yield chunk

self.close()

def close(self) -> None:
Expand All @@ -1234,31 +1259,47 @@ async def aread(self) -> bytes:
self._content = b"".join([part async for part in self.aiter_bytes()])
return self._content

async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, and brotli encoded responses.
"""
if hasattr(self, "_content"):
yield self._content
chunk_size = len(self._content) if chunk_size is None else chunk_size
for i in range(0, len(self._content), chunk_size):
yield self._content[i : i + chunk_size]
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
async for chunk in self.aiter_raw():
yield decoder.decode(chunk)
yield decoder.flush()

async def aiter_text(self) -> typing.AsyncIterator[str]:
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk
for chunk in chunker.flush():
yield chunk

async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
async for chunk in self.aiter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk
for chunk in chunker.flush():
yield chunk

async def aiter_lines(self) -> typing.AsyncIterator[str]:
decoder = LineDecoder()
Expand All @@ -1269,7 +1310,7 @@ async def aiter_lines(self) -> typing.AsyncIterator[str]:
for line in decoder.flush():
yield line

async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator over the raw response content.
"""
Expand All @@ -1282,10 +1323,17 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]:

self.is_stream_consumed = True
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
async for part in self.stream:
self._num_bytes_downloaded += len(part)
yield part
async for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk

for chunk in chunker.flush():
yield chunk

await self.aclose()

async def aclose(self) -> None:
Expand Down
Loading

0 comments on commit 27df5e4

Please sign in to comment.