diff --git a/httpx/_decoders.py b/httpx/_decoders.py index fa9f8124a2..6c0c4492f9 100644 --- a/httpx/_decoders.py +++ b/httpx/_decoders.py @@ -9,21 +9,13 @@ import chardet -from ._exceptions import DecodingError - try: import brotli except ImportError: # pragma: nocover brotli = None -if typing.TYPE_CHECKING: # pragma: no cover - from ._models import Request - class Decoder: - def __init__(self, request: "Request") -> None: - self.request = request - def decode(self, data: bytes) -> bytes: raise NotImplementedError() # pragma: nocover @@ -50,8 +42,7 @@ class DeflateDecoder(Decoder): See: https://stackoverflow.com/questions/1838699 """ - def __init__(self, request: "Request") -> None: - self.request = request + def __init__(self) -> None: self.first_attempt = True self.decompressor = zlib.decompressobj() @@ -64,13 +55,13 @@ def decode(self, data: bytes) -> bytes: if was_first_attempt: self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) return self.decode(data) - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: nocover - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) class GZipDecoder(Decoder): @@ -80,21 +71,20 @@ class GZipDecoder(Decoder): See: https://stackoverflow.com/questions/1838699 """ - def __init__(self, request: "Request") -> None: - self.request = request + def __init__(self) -> None: self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) def decode(self, data: bytes) -> bytes: try: return self.decompressor.decompress(data) except zlib.error as exc: - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: nocover - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) class BrotliDecoder(Decoder): @@ -107,7 +97,7 @@ class BrotliDecoder(Decoder): name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' """ - def __init__(self, request: "Request") -> None: + def __init__(self) -> None: if brotli is None: # pragma: nocover raise ImportError( "Using 'BrotliDecoder', but the 'brotlipy' or 'brotli' library " @@ -115,7 +105,6 @@ def __init__(self, request: "Request") -> None: "Make sure to install httpx using `pip install httpx[brotli]`." ) from None - self.request = request self.decompressor = brotli.Decompressor() self.seen_data = False if hasattr(self.decompressor, "decompress"): @@ -130,7 +119,7 @@ def decode(self, data: bytes) -> bytes: try: return self._decompress(data) except brotli.error as exc: - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) def flush(self) -> bytes: if not self.seen_data: @@ -140,7 +129,7 @@ def flush(self) -> bytes: self.decompressor.finish() return b"" except brotli.error as exc: # pragma: nocover - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) class MultiDecoder(Decoder): @@ -173,8 +162,7 @@ class TextDecoder: Handles incrementally decoding bytes into text """ - def __init__(self, request: "Request", encoding: typing.Optional[str] = None): - self.request = request + def __init__(self, encoding: typing.Optional[str] = None): self.decoder: typing.Optional[codecs.IncrementalDecoder] = ( None if encoding is None else codecs.getincrementaldecoder(encoding)() ) @@ -209,7 +197,7 @@ def decode(self, data: bytes) -> str: return text except UnicodeDecodeError as exc: # pragma: nocover - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) def flush(self) -> str: try: @@ -222,14 +210,13 @@ def flush(self) -> str: return self.decoder.decode(b"", True) except UnicodeDecodeError as exc: # pragma: nocover - raise DecodingError(message=str(exc), request=self.request) + raise ValueError(str(exc)) def _detector_result(self) -> str: self.detector.close() result = self.detector.result["encoding"] if not result: # pragma: nocover - message = "Unable to determine encoding of content" - raise DecodingError(message, request=self.request) + raise ValueError("Unable to determine encoding of content") return result diff --git a/httpx/_models.py b/httpx/_models.py index 67cebeb091..4a40263266 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -1,4 +1,5 @@ import cgi +import contextlib import datetime import email.message import json as jsonlib @@ -26,6 +27,7 @@ from ._exceptions import ( HTTPCORE_EXC_MAP, CookieConflict, + DecodingError, HTTPStatusError, InvalidURL, NotRedirectResponse, @@ -689,7 +691,7 @@ def __init__( self, status_code: int, *, - request: Request, + request: Request = None, http_version: str = None, headers: HeaderTypes = None, stream: ContentStream = None, @@ -700,7 +702,8 @@ def __init__( self.http_version = http_version self.headers = Headers(headers) - self.request = request + self._request: typing.Optional[Request] = request + self.call_next: typing.Optional[typing.Callable] = None self.history = [] if history is None else list(history) @@ -726,6 +729,21 @@ def elapsed(self) -> datetime.timedelta: ) return self._elapsed + @property + def request(self) -> Request: + """ + Returns the request instance associated to the current response. + """ + if self._request is None: + raise RuntimeError( + "The request instance has not been set on this response." + ) + return self._request + + @request.setter + def request(self, value: Request) -> None: + self._request = value + @property def reason_phrase(self) -> str: return codes.get_reason_phrase(self.status_code) @@ -811,7 +829,7 @@ def decoder(self) -> Decoder: value = value.strip().lower() try: decoder_cls = SUPPORTED_DECODERS[value] - decoders.append(decoder_cls(request=self.request)) + decoders.append(decoder_cls()) except KeyError: continue @@ -820,7 +838,7 @@ def decoder(self) -> Decoder: elif len(decoders) > 1: self._decoder = MultiDecoder(children=decoders) else: - self._decoder = IdentityDecoder(request=self.request) + self._decoder = IdentityDecoder() return self._decoder @@ -841,12 +859,19 @@ def raise_for_status(self) -> None: "For more information check: https://httpstatuses.com/{0.status_code}" ) + request = self._request + if request is None: + raise RuntimeError( + "Cannot call `raise_for_status` as the request " + "instance has not been set on this response." + ) + if codes.is_client_error(self.status_code): message = message.format(self, error_type="Client Error") - raise HTTPStatusError(message, request=self.request, response=self) + raise HTTPStatusError(message, request=request, response=self) elif codes.is_server_error(self.status_code): message = message.format(self, error_type="Server Error") - raise HTTPStatusError(message, request=self.request, response=self) + raise HTTPStatusError(message, request=request, response=self) def json(self, **kwargs: typing.Any) -> typing.Any: if self.charset_encoding is None and self.content and len(self.content) > 3: @@ -882,6 +907,17 @@ def links(self) -> typing.Dict[typing.Optional[str], typing.Dict[str, str]]: def __repr__(self) -> str: return f"" + @contextlib.contextmanager + def _wrap_decoder_errors(self) -> typing.Iterator[None]: + # If the response has an associated request instance, we want decoding + # errors to be raised as proper `httpx.DecodingError` exceptions. + try: + yield + except ValueError as exc: + if self._request is None: + raise exc + raise DecodingError(message=str(exc), request=self.request) from exc + def read(self) -> bytes: """ Read and return the response content. @@ -898,9 +934,10 @@ def iter_bytes(self) -> typing.Iterator[bytes]: if hasattr(self, "_content"): yield self._content else: - for chunk in self.iter_raw(): - yield self.decoder.decode(chunk) - yield self.decoder.flush() + with self._wrap_decoder_errors(): + for chunk in self.iter_raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() def iter_text(self) -> typing.Iterator[str]: """ @@ -908,18 +945,20 @@ def iter_text(self) -> typing.Iterator[str]: that handles both gzip, deflate, etc but also detects the content's string encoding. """ - decoder = TextDecoder(request=self.request, encoding=self.charset_encoding) - for chunk in self.iter_bytes(): - yield decoder.decode(chunk) - yield decoder.flush() + decoder = TextDecoder(encoding=self.charset_encoding) + with self._wrap_decoder_errors(): + for chunk in self.iter_bytes(): + yield decoder.decode(chunk) + yield decoder.flush() def iter_lines(self) -> typing.Iterator[str]: decoder = LineDecoder() - for text in self.iter_text(): - for line in decoder.decode(text): + with self._wrap_decoder_errors(): + for text in self.iter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): yield line - for line in decoder.flush(): - yield line def iter_raw(self) -> typing.Iterator[bytes]: """ @@ -931,7 +970,7 @@ def iter_raw(self) -> typing.Iterator[bytes]: raise ResponseClosed() self.is_stream_consumed = True - with map_exceptions(HTTPCORE_EXC_MAP, request=self.request): + with map_exceptions(HTTPCORE_EXC_MAP, request=self._request): for part in self._raw_stream: yield part self.close() @@ -956,7 +995,8 @@ def close(self) -> None: """ if not self.is_closed: self.is_closed = True - self._elapsed = self.request.timer.elapsed + if self._request is not None: + self._elapsed = self.request.timer.elapsed self._raw_stream.close() async def aread(self) -> bytes: @@ -975,9 +1015,10 @@ async def aiter_bytes(self) -> typing.AsyncIterator[bytes]: if hasattr(self, "_content"): yield self._content else: - async for chunk in self.aiter_raw(): - yield self.decoder.decode(chunk) - yield self.decoder.flush() + with self._wrap_decoder_errors(): + async for chunk in self.aiter_raw(): + yield self.decoder.decode(chunk) + yield self.decoder.flush() async def aiter_text(self) -> typing.AsyncIterator[str]: """ @@ -985,18 +1026,20 @@ async def aiter_text(self) -> typing.AsyncIterator[str]: that handles both gzip, deflate, etc but also detects the content's string encoding. """ - decoder = TextDecoder(request=self.request, encoding=self.charset_encoding) - async for chunk in self.aiter_bytes(): - yield decoder.decode(chunk) - yield decoder.flush() + decoder = TextDecoder(encoding=self.charset_encoding) + with self._wrap_decoder_errors(): + async for chunk in self.aiter_bytes(): + yield decoder.decode(chunk) + yield decoder.flush() async def aiter_lines(self) -> typing.AsyncIterator[str]: decoder = LineDecoder() - async for text in self.aiter_text(): - for line in decoder.decode(text): + with self._wrap_decoder_errors(): + async for text in self.aiter_text(): + for line in decoder.decode(text): + yield line + for line in decoder.flush(): yield line - for line in decoder.flush(): - yield line async def aiter_raw(self) -> typing.AsyncIterator[bytes]: """ @@ -1008,7 +1051,7 @@ async def aiter_raw(self) -> typing.AsyncIterator[bytes]: raise ResponseClosed() self.is_stream_consumed = True - with map_exceptions(HTTPCORE_EXC_MAP, request=self.request): + with map_exceptions(HTTPCORE_EXC_MAP, request=self._request): async for part in self._raw_stream: yield part await self.aclose() @@ -1032,7 +1075,8 @@ async def aclose(self) -> None: """ if not self.is_closed: self.is_closed = True - self._elapsed = self.request.timer.elapsed + if self._request is not None: + self._elapsed = self.request.timer.elapsed await self._raw_stream.aclose() diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 878bef072d..e9fbeca22d 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -2,6 +2,7 @@ import json from unittest import mock +import brotli import pytest import httpx @@ -31,6 +32,28 @@ def test_response(): assert not response.is_error +def test_raise_for_status(): + # 2xx status codes are not an error. + response = httpx.Response(200, request=REQUEST) + response.raise_for_status() + + # 4xx status codes are a client error. + response = httpx.Response(403, request=REQUEST) + with pytest.raises(httpx.HTTPStatusError): + response.raise_for_status() + + # 5xx status codes are a server error. + response = httpx.Response(500, request=REQUEST) + with pytest.raises(httpx.HTTPStatusError): + response.raise_for_status() + + # Calling .raise_for_status without setting a request instance is + # not valid. Should raise a runtime error. + response = httpx.Response(200) + with pytest.raises(RuntimeError): + response.raise_for_status() + + def test_response_repr(): response = httpx.Response(200, content=b"Hello, world!", request=REQUEST) assert repr(response) == "" @@ -372,7 +395,18 @@ def test_json_without_specified_encoding_decode_error(): response = httpx.Response( 200, content=content, headers=headers, request=REQUEST ) - with pytest.raises(json.JSONDecodeError): + with pytest.raises(json.decoder.JSONDecodeError): + response.json() + + +def test_json_without_specified_encoding_value_error(): + data = {"greeting": "hello", "recipient": "world"} + content = json.dumps(data).encode("utf-32-be") + headers = {"Content-Type": "application/json"} + # force incorrect guess from `guess_json_utf` to trigger error + with mock.patch("httpx._models.guess_json_utf", return_value="utf-32"): + response = httpx.Response(200, content=content, headers=headers) + with pytest.raises(ValueError): response.json() @@ -395,3 +429,45 @@ def test_json_without_specified_encoding_decode_error(): def test_link_headers(headers, expected): response = httpx.Response(200, content=None, headers=headers, request=REQUEST) assert response.links == expected + + +@pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br")) +def test_decode_error_with_request(header_value): + headers = [(b"Content-Encoding", header_value)] + body = b"test 123" + compressed_body = brotli.compress(body)[3:] + with pytest.raises(httpx.DecodingError): + httpx.Response(200, headers=headers, content=compressed_body, request=REQUEST) + + +@pytest.mark.parametrize("header_value", (b"deflate", b"gzip", b"br")) +def test_value_error_without_request(header_value): + headers = [(b"Content-Encoding", header_value)] + body = b"test 123" + compressed_body = brotli.compress(body)[3:] + with pytest.raises(ValueError): + httpx.Response(200, headers=headers, content=compressed_body) + + +def test_response_with_unset_request(): + response = httpx.Response(200, content=b"Hello, world!") + + assert response.status_code == 200 + assert response.reason_phrase == "OK" + assert response.text == "Hello, world!" + assert not response.is_error + + +def test_set_request_after_init(): + response = httpx.Response(200, content=b"Hello, world!") + + response.request = REQUEST + + assert response.request == REQUEST + + +def test_cannot_access_unset_request(): + response = httpx.Response(200, content=b"Hello, world!") + + with pytest.raises(RuntimeError): + assert response.request is not None diff --git a/tests/test_decoders.py b/tests/test_decoders.py index ec01d41e4c..abd478e400 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -135,9 +135,9 @@ def test_empty_content(header_value): "decoder", (BrotliDecoder, DeflateDecoder, GZipDecoder, IdentityDecoder) ) def test_decoders_empty_cases(decoder): - request = httpx.Request(method="GET", url="https://www.example.com") - instance = decoder(request) - assert instance.decode(b"") == b"" + response = httpx.Response(content=b"", status_code=200) + instance = decoder() + assert instance.decode(response.content) == b"" assert instance.flush() == b"" @@ -207,12 +207,10 @@ async def iterator(): def test_text_decoder_empty_cases(): - request = httpx.Request(method="GET", url="https://www.example.com") - - decoder = TextDecoder(request=request) + decoder = TextDecoder() assert decoder.flush() == "" - decoder = TextDecoder(request=request) + decoder = TextDecoder() assert decoder.decode(b"") == "" assert decoder.flush() == ""