Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the response's request parameter optional #1238

Merged
merged 9 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions httpx/_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,13 @@

import chardet

from ._exceptions import DecodingError

try:
import brotli
except ImportError: # pragma: nocover
brotli = None

if typing.TYPE_CHECKING: # pragma: no cover
from ._models import Request


class Decoder:
def __init__(self, request: "Request") -> None:
self.request = request

def decode(self, data: bytes) -> bytes:
raise NotImplementedError() # pragma: nocover

Expand All @@ -50,8 +42,7 @@ class DeflateDecoder(Decoder):
See: https://stackoverflow.com/questions/1838699
"""

def __init__(self, request: "Request") -> None:
self.request = request
def __init__(self) -> None:
self.first_attempt = True
self.decompressor = zlib.decompressobj()

Expand All @@ -64,13 +55,13 @@ def decode(self, data: bytes) -> bytes:
if was_first_attempt:
self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
return self.decode(data)
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))


class GZipDecoder(Decoder):
Expand All @@ -80,21 +71,20 @@ class GZipDecoder(Decoder):
See: https://stackoverflow.com/questions/1838699
"""

def __init__(self, request: "Request") -> None:
self.request = request
def __init__(self) -> None:
self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16)

def decode(self, data: bytes) -> bytes:
try:
return self.decompressor.decompress(data)
except zlib.error as exc:
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))


class BrotliDecoder(Decoder):
Expand All @@ -107,15 +97,14 @@ class BrotliDecoder(Decoder):
name. The top branches are for 'brotlipy' and bottom branches for 'Brotli'
"""

def __init__(self, request: "Request") -> None:
def __init__(self) -> None:
if brotli is None: # pragma: nocover
raise ImportError(
"Using 'BrotliDecoder', but the 'brotlipy' or 'brotli' library "
"is not installed."
"Make sure to install httpx using `pip install httpx[brotli]`."
) from None

self.request = request
self.decompressor = brotli.Decompressor()
self.seen_data = False
if hasattr(self.decompressor, "decompress"):
Expand All @@ -130,7 +119,7 @@ def decode(self, data: bytes) -> bytes:
try:
return self._decompress(data)
except brotli.error as exc:
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> bytes:
if not self.seen_data:
Expand All @@ -140,7 +129,7 @@ def flush(self) -> bytes:
self.decompressor.finish()
return b""
except brotli.error as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))


class MultiDecoder(Decoder):
Expand Down Expand Up @@ -173,8 +162,7 @@ class TextDecoder:
Handles incrementally decoding bytes into text
"""

def __init__(self, request: "Request", encoding: typing.Optional[str] = None):
self.request = request
def __init__(self, encoding: typing.Optional[str] = None):
self.decoder: typing.Optional[codecs.IncrementalDecoder] = (
None if encoding is None else codecs.getincrementaldecoder(encoding)()
)
Expand Down Expand Up @@ -209,7 +197,7 @@ def decode(self, data: bytes) -> str:

return text
except UnicodeDecodeError as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def flush(self) -> str:
try:
Expand All @@ -222,14 +210,13 @@ def flush(self) -> str:

return self.decoder.decode(b"", True)
except UnicodeDecodeError as exc: # pragma: nocover
raise DecodingError(message=str(exc), request=self.request)
raise ValueError(str(exc))

def _detector_result(self) -> str:
self.detector.close()
result = self.detector.result["encoding"]
if not result: # pragma: nocover
message = "Unable to determine encoding of content"
raise DecodingError(message, request=self.request)
raise ValueError("Unable to determine encoding of content")

return result

Expand Down
108 changes: 76 additions & 32 deletions httpx/_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import cgi
import contextlib
import datetime
import email.message
import json as jsonlib
Expand Down Expand Up @@ -26,6 +27,7 @@
from ._exceptions import (
HTTPCORE_EXC_MAP,
CookieConflict,
DecodingError,
HTTPStatusError,
InvalidURL,
NotRedirectResponse,
Expand Down Expand Up @@ -689,7 +691,7 @@ def __init__(
self,
status_code: int,
*,
request: Request,
request: Request = None,
http_version: str = None,
headers: HeaderTypes = None,
stream: ContentStream = None,
Expand All @@ -700,7 +702,8 @@ def __init__(
self.http_version = http_version
self.headers = Headers(headers)

self.request = request
self._request: typing.Optional[Request] = request

self.call_next: typing.Optional[typing.Callable] = None

self.history = [] if history is None else list(history)
Expand All @@ -726,6 +729,21 @@ def elapsed(self) -> datetime.timedelta:
)
return self._elapsed

@property
def request(self) -> Request:
"""
Returns the request instance associated to the current response.
"""
if self._request is None:
raise RuntimeError(
"The request instance has not been set on this response."
)
return self._request

@request.setter
def request(self, value: Request) -> None:
self._request = value

@property
def reason_phrase(self) -> str:
return codes.get_reason_phrase(self.status_code)
Expand Down Expand Up @@ -811,7 +829,7 @@ def decoder(self) -> Decoder:
value = value.strip().lower()
try:
decoder_cls = SUPPORTED_DECODERS[value]
decoders.append(decoder_cls(request=self.request))
decoders.append(decoder_cls())
except KeyError:
continue

Expand All @@ -820,7 +838,7 @@ def decoder(self) -> Decoder:
elif len(decoders) > 1:
self._decoder = MultiDecoder(children=decoders)
else:
self._decoder = IdentityDecoder(request=self.request)
self._decoder = IdentityDecoder()

return self._decoder

Expand All @@ -841,12 +859,19 @@ def raise_for_status(self) -> None:
"For more information check: https://httpstatuses.com/{0.status_code}"
)

request = self._request
if request is None:
raise RuntimeError(
"Cannot call `raise_for_status` as the request "
"instance has not been set on this response."
)

if codes.is_client_error(self.status_code):
message = message.format(self, error_type="Client Error")
raise HTTPStatusError(message, request=self.request, response=self)
raise HTTPStatusError(message, request=request, response=self)
elif codes.is_server_error(self.status_code):
message = message.format(self, error_type="Server Error")
raise HTTPStatusError(message, request=self.request, response=self)
raise HTTPStatusError(message, request=request, response=self)

def json(self, **kwargs: typing.Any) -> typing.Any:
if self.charset_encoding is None and self.content and len(self.content) > 3:
Expand Down Expand Up @@ -882,6 +907,17 @@ def links(self) -> typing.Dict[typing.Optional[str], typing.Dict[str, str]]:
def __repr__(self) -> str:
return f"<Response [{self.status_code} {self.reason_phrase}]>"

@contextlib.contextmanager
def _wrap_decoder_errors(self) -> typing.Iterator[None]:
# If the response has an associated request instance, we want decoding
# errors to be raised as proper `httpx.DecodingError` exceptions.
try:
yield
except ValueError as exc:
if self._request is None:
raise exc
raise DecodingError(message=str(exc), request=self.request) from exc

def read(self) -> bytes:
"""
Read and return the response content.
Expand All @@ -898,28 +934,31 @@ def iter_bytes(self) -> typing.Iterator[bytes]:
if hasattr(self, "_content"):
yield self._content
else:
for chunk in self.iter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
with self._wrap_decoder_errors():
for chunk in self.iter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()

def iter_text(self) -> typing.Iterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(request=self.request, encoding=self.charset_encoding)
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
decoder = TextDecoder(encoding=self.charset_encoding)
with self._wrap_decoder_errors():
for chunk in self.iter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()

def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
for text in self.iter_text():
for line in decoder.decode(text):
with self._wrap_decoder_errors():
for text in self.iter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
for line in decoder.flush():
yield line

def iter_raw(self) -> typing.Iterator[bytes]:
"""
Expand All @@ -931,7 +970,7 @@ def iter_raw(self) -> typing.Iterator[bytes]:
raise ResponseClosed()

self.is_stream_consumed = True
with map_exceptions(HTTPCORE_EXC_MAP, request=self.request):
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
for part in self._raw_stream:
yield part
self.close()
Expand All @@ -956,7 +995,8 @@ def close(self) -> None:
"""
if not self.is_closed:
self.is_closed = True
self._elapsed = self.request.timer.elapsed
if self._request is not None:
self._elapsed = self.request.timer.elapsed
self._raw_stream.close()

async def aread(self) -> bytes:
Expand All @@ -975,28 +1015,31 @@ async def aiter_bytes(self) -> typing.AsyncIterator[bytes]:
if hasattr(self, "_content"):
yield self._content
else:
async for chunk in self.aiter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()
with self._wrap_decoder_errors():
async for chunk in self.aiter_raw():
yield self.decoder.decode(chunk)
yield self.decoder.flush()

async def aiter_text(self) -> typing.AsyncIterator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
string encoding.
"""
decoder = TextDecoder(request=self.request, encoding=self.charset_encoding)
async for chunk in self.aiter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()
decoder = TextDecoder(encoding=self.charset_encoding)
with self._wrap_decoder_errors():
async for chunk in self.aiter_bytes():
yield decoder.decode(chunk)
yield decoder.flush()

async def aiter_lines(self) -> typing.AsyncIterator[str]:
decoder = LineDecoder()
async for text in self.aiter_text():
for line in decoder.decode(text):
with self._wrap_decoder_errors():
async for text in self.aiter_text():
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line
for line in decoder.flush():
yield line

async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
"""
Expand All @@ -1008,7 +1051,7 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]:
raise ResponseClosed()

self.is_stream_consumed = True
with map_exceptions(HTTPCORE_EXC_MAP, request=self.request):
with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
async for part in self._raw_stream:
yield part
await self.aclose()
Expand All @@ -1032,7 +1075,8 @@ async def aclose(self) -> None:
"""
if not self.is_closed:
self.is_closed = True
self._elapsed = self.request.timer.elapsed
if self._request is not None:
self._elapsed = self.request.timer.elapsed
await self._raw_stream.aclose()


Expand Down
Loading