Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the response's request parameter optional #1238

Merged
merged 9 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions httpx/_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,13 @@

import chardet

from ._exceptions import DecodingError

try:
import brotli
except ImportError: # pragma: nocover
brotli = None

if typing.TYPE_CHECKING: # pragma: no cover
from ._models import Request


class Decoder:
def __init__(self, request: "Request") -> None:
self.request = request

def decode(self, data: bytes) -> bytes:
raise NotImplementedError() # pragma: nocover

Expand All @@ -50,8 +42,7 @@ class DeflateDecoder(Decoder):
See: https://stackoverflow.com/questions/1838699
"""

def __init__(self, request: "Request") -> None:
self.request = request
def __init__(self) -> None:
self.first_attempt = True
self.decompressor = zlib.decompressobj()

Expand All @@ -64,13 +55,13 @@ def decode(self, data: bytes) -> bytes:
if was_first_attempt:
self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
return self.decode(data)
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))


class GZipDecoder(Decoder):
Expand All @@ -80,21 +71,20 @@ class GZipDecoder(Decoder):
See: https://stackoverflow.com/questions/1838699
"""

def __init__(self, request: "Request") -> None:
self.request = request
def __init__(self) -> None:
self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16)

def decode(self, data: bytes) -> bytes:
try:
return self.decompressor.decompress(data)
except zlib.error as exc:
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))


class BrotliDecoder(Decoder):
Expand All @@ -107,15 +97,14 @@ class BrotliDecoder(Decoder):
name. The top branches are for 'brotlipy' and bottom branches for 'Brotli'
"""

def __init__(self, request: "Request") -> None:
def __init__(self) -> None:
if brotli is None: # pragma: nocover
raise ImportError(
"Using 'BrotliDecoder', but the 'brotlipy' or 'brotli' library "
"is not installed."
"Make sure to install httpx using `pip install httpx[brotli]`."
) from None

self.request = request
self.decompressor = brotli.Decompressor()
self.seen_data = False
if hasattr(self.decompressor, "decompress"):
Expand All @@ -130,7 +119,7 @@ def decode(self, data: bytes) -> bytes:
try:
return self._decompress(data)
except brotli.error as exc:
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> bytes:
if not self.seen_data:
Expand All @@ -140,7 +129,7 @@ def flush(self) -> bytes:
self.decompressor.finish()
return b""
except brotli.error as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))


class MultiDecoder(Decoder):
Expand Down Expand Up @@ -173,8 +162,7 @@ class TextDecoder:
Handles incrementally decoding bytes into text
"""

def __init__(self, request: "Request", encoding: typing.Optional[str] = None):
self.request = request
def __init__(self, encoding: typing.Optional[str] = None):
self.decoder: typing.Optional[codecs.IncrementalDecoder] = (
None if encoding is None else codecs.getincrementaldecoder(encoding)()
)
Expand Down Expand Up @@ -209,7 +197,7 @@ def decode(self, data: bytes) -> str:

return text
except UnicodeDecodeError as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> str:
try:
Expand All @@ -222,14 +210,13 @@ def flush(self) -> str:

return self.decoder.decode(b"", True)
except UnicodeDecodeError as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def _detector_result(self) -> str:
self.detector.close()
result = self.detector.result["encoding"]
if not result: # pragma: nocover
message = "Unable to determine encoding of content"
raise DecodingError(message, request=self.request)
raise ValueError("Unable to determine encoding of content")

return result

Expand Down
121 changes: 89 additions & 32 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ._exceptions import (
HTTPCORE_EXC_MAP,
CookieConflict,
DecodingError,
HTTPStatusError,
InvalidURL,
NotRedirectResponse,
Expand Down Expand Up @@ -689,7 +690,7 @@ def __init__(
self,
status_code: int,
*,
request: Request,
request: Request = None,
http_version: str = None,
headers: HeaderTypes = None,
stream: ContentStream = None,
Expand All @@ -700,7 +701,8 @@ def __init__(
self.http_version = http_version
self.headers = Headers(headers)

self.request = request
self._request: typing.Optional[Request] = request

self.call_next: typing.Optional[typing.Callable] = None

self.history = [] if history is None else list(history)
Expand All @@ -726,6 +728,19 @@ def elapsed(self) -> datetime.timedelta:
)
return self._elapsed

@property
def request(self) -> Request:
"""
Returns the request instance associated to the current response.
"""
if self._request is None:
raise RuntimeError("'.request' may only be accessed if initialized")
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
return self._request

@request.setter
def request(self, value: Request) -> None:
self._request = value

@property
def reason_phrase(self) -> str:
return codes.get_reason_phrase(self.status_code)
Expand Down Expand Up @@ -811,7 +826,7 @@ def decoder(self) -> Decoder:
value = value.strip().lower()
try:
decoder_cls = SUPPORTED_DECODERS[value]
decoders.append(decoder_cls(request=self.request))
decoders.append(decoder_cls())
except KeyError:
continue

Expand All @@ -820,7 +835,7 @@ def decoder(self) -> Decoder:
elif len(decoders) > 1:
self._decoder = MultiDecoder(children=decoders)
else:
self._decoder = IdentityDecoder(request=self.request)
self._decoder = IdentityDecoder()

return self._decoder

Expand All @@ -843,10 +858,14 @@ def raise_for_status(self) -> None:

if codes.is_client_error(self.status_code):
message = message.format(self, error_type="Client Error")
raise HTTPStatusError(message, request=self.request, response=self)
if self._request is None:
raise ValueError(message)
raise HTTPStatusError(message, request=self._request, response=self)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raising a ValueError inside raise_for_status seems a bit off ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest we start the method with something like...

if self._request is None:
    raise RuntimeError("Cannot call `raise_for_status` as the request instance has not been set on this response.")

elif codes.is_server_error(self.status_code):
message = message.format(self, error_type="Server Error")
raise HTTPStatusError(message, request=self.request, response=self)
if self._request is None:
raise ValueError(message)
raise HTTPStatusError(message, request=self._request, response=self)

def json(self, **kwargs: typing.Any) -> typing.Any:
if self.charset_encoding is None and self.content and len(self.content) > 3:
Expand Down Expand Up @@ -898,28 +917,46 @@ def iter_bytes(self) -> typing.Iterator[bytes]:
if hasattr(self, "_content"):
yield self._content
else:
for chunk in self.iter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
try:
for chunk in self.iter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
except ValueError as exc:
if self._request is None:
raise exc
else:
raise DecodingError(message=str(exc), request=self.request) from exc

def iter_text(self) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(request=self.request, encoding=self.charset_encoding)
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
decoder = TextDecoder(encoding=self.charset_encoding)
try:
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
except ValueError as exc:
if self._request is None:
raise exc
else:
raise DecodingError(message=str(exc), request=self.request) from exc

def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
for text in self.iter_text():
for line in decoder.decode(text):
try:
for text in self.iter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
for line in decoder.flush():
yield line
except ValueError as exc:
if self._request is None:
raise exc
else:
raise DecodingError(message=str(exc), request=self.request) from exc

def iter_raw(self) -> typing.Iterator[bytes]:
"""
Expand All @@ -931,7 +968,7 @@ def iter_raw(self) -> typing.Iterator[bytes]:
raise ResponseClosed()

self.is_stream_consumed = True
with map_exceptions(HTTPCORE_EXC_MAP, request=self.request):
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
for part in self._raw_stream:
yield part
self.close()
Expand All @@ -956,7 +993,8 @@ def close(self) -> None:
"""
if not self.is_closed:
self.is_closed = True
self._elapsed = self.request.timer.elapsed
if self._request is not None:
self._elapsed = self.request.timer.elapsed
self._raw_stream.close()

async def aread(self) -> bytes:
Expand All @@ -975,28 +1013,46 @@ async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
if hasattr(self, "_content"):
yield self._content
else:
async for chunk in self.aiter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
try:
async for chunk in self.aiter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
except ValueError as exc:
if self._request is None:
raise exc
else:
raise DecodingError(message=str(exc), request=self.request) from exc

async def aiter_text(self) -> typing.AsyncIterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(request=self.request, encoding=self.charset_encoding)
async for chunk in self.aiter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
decoder = TextDecoder(encoding=self.charset_encoding)
try:
async for chunk in self.aiter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
except ValueError as exc:
if self._request is None:
raise exc
else:
raise DecodingError(message=str(exc), request=self.request) from exc

async def aiter_lines(self) -> typing.AsyncIterator[str]:
decoder = LineDecoder()
async for text in self.aiter_text():
for line in decoder.decode(text):
try:
async for text in self.aiter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
for line in decoder.flush():
yield line
except ValueError as exc:
if self._request is None:
raise exc
else:
raise DecodingError(message=str(exc), request=self.request) from exc

async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
"""
Expand All @@ -1008,7 +1064,7 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
raise ResponseClosed()

self.is_stream_consumed = True
with map_exceptions(HTTPCORE_EXC_MAP, request=self.request):
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
async for part in self._raw_stream:
yield part
await self.aclose()
Expand All @@ -1032,7 +1088,8 @@ async def aclose(self) -> None:
"""
if not self.is_closed:
self.is_closed = True
self._elapsed = self.request.timer.elapsed
if self._request is not None:
self._elapsed = self.request.timer.elapsed
await self._raw_stream.aclose()


Expand Down
Loading