diff --git a/httpx/__init__.py b/httpx/__init__.py index 8aca4f7cee..15f2223557 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -19,8 +19,8 @@ ProxyError, ReadError, ReadTimeout, - RedirectError, RequestBodyUnavailable, + RequestError, RequestNotRead, ResponseClosed, ResponseNotRead, @@ -28,6 +28,7 @@ StreamError, TimeoutException, TooManyRedirects, + TransportError, WriteError, WriteTimeout, ) @@ -76,7 +77,7 @@ "ProtocolError", "ReadError", "ReadTimeout", - "RedirectError", + "RequestError", "RequestBodyUnavailable", "ResponseClosed", "ResponseNotRead", @@ -87,6 +88,7 @@ "ProxyError", "TimeoutException", "TooManyRedirects", + "TransportError", "WriteError", "WriteTimeout", "URL", diff --git a/httpx/_auth.py b/httpx/_auth.py index 6940019e7e..571584593b 100644 --- a/httpx/_auth.py +++ b/httpx/_auth.py @@ -116,20 +116,24 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non # need to build an authenticated request. return - header = response.headers["www-authenticate"] - challenge = self._parse_challenge(header) + challenge = self._parse_challenge(request, response) request.headers["Authorization"] = self._build_auth_header(request, challenge) yield request - def _parse_challenge(self, header: str) -> "_DigestAuthChallenge": + def _parse_challenge( + self, request: Request, response: Response + ) -> "_DigestAuthChallenge": """ Returns a challenge from a Digest WWW-Authenticate header. These take the form of: `Digest realm="realm@host.com",qop="auth,auth-int",nonce="abc",opaque="xyz"` """ + header = response.headers["www-authenticate"] + scheme, _, fields = header.partition(" ") if scheme.lower() != "digest": - raise ProtocolError("Header does not start with 'Digest'") + message = "Header does not start with 'Digest'" + raise ProtocolError(message, request=request) header_dict: typing.Dict[str, str] = {} for field in parse_http_list(fields): @@ -146,7 +150,8 @@ def _parse_challenge(self, header: str) -> "_DigestAuthChallenge": realm=realm, nonce=nonce, qop=qop, opaque=opaque, algorithm=algorithm ) except KeyError as exc: - raise ProtocolError("Malformed Digest WWW-Authenticate header") from exc + message = "Malformed Digest WWW-Authenticate header" + raise ProtocolError(message, request=request) from exc def _build_auth_header( self, request: Request, challenge: "_DigestAuthChallenge" @@ -171,7 +176,7 @@ def digest(data: bytes) -> bytes: if challenge.algorithm.lower().endswith("-sess"): HA1 = digest(b":".join((HA1, challenge.nonce, cnonce))) - qop = self._resolve_qop(challenge.qop) + qop = self._resolve_qop(challenge.qop, request=request) if qop is None: digest_data = [HA1, challenge.nonce, HA2] else: @@ -221,7 +226,9 @@ def _get_header_value(self, header_fields: typing.Dict[str, bytes]) -> str: return header_value - def _resolve_qop(self, qop: typing.Optional[bytes]) -> typing.Optional[bytes]: + def _resolve_qop( + self, qop: typing.Optional[bytes], request: Request + ) -> typing.Optional[bytes]: if qop is None: return None qops = re.split(b", ?", qop) @@ -231,7 +238,8 @@ def _resolve_qop(self, qop: typing.Optional[bytes]) -> typing.Optional[bytes]: if qops == [b"auth-int"]: raise NotImplementedError("Digest auth-int support is not yet implemented") - raise ProtocolError(f'Unexpected qop value "{qop!r}" in digest auth') + message = f'Unexpected qop value "{qop!r}" in digest auth' + raise ProtocolError(message, request=request) class _DigestAuthChallenge: diff --git a/httpx/_client.py b/httpx/_client.py index 4a7fc030a0..a718649e4a 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -324,7 +324,8 @@ def _redirect_url(self, request: Request, response: Response) -> URL: # Check that we can handle the scheme if url.scheme and url.scheme not in ("http", "https"): - raise InvalidURL(f'Scheme "{url.scheme}" not supported.') + message = f'Scheme "{url.scheme}" not supported.' + raise InvalidURL(message, request=request) # Handle malformed 'Location' headers that are "absolute" form, have no host. # See: https://github.com/encode/httpx/issues/771 @@ -537,12 +538,13 @@ def _init_proxy_transport( http2=http2, ) - def _transport_for_url(self, url: URL) -> httpcore.SyncHTTPTransport: + def _transport_for_url(self, request: Request) -> httpcore.SyncHTTPTransport: """ Returns the transport instance that should be used for a given URL. This will either be the standard connection pool, or a proxy. """ - enforce_http_url(url) + url = request.url + enforce_http_url(request) if self._proxies and not should_not_be_proxied(url): for matcher, transport in self._proxies.items(): @@ -590,7 +592,8 @@ def send( timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, ) -> Response: if request.url.scheme not in ("http", "https"): - raise InvalidURL('URL scheme must be "http" or "https".') + message = 'URL scheme must be "http" or "https".' + raise InvalidURL(message, request=request) timeout = self.timeout if isinstance(timeout, UnsetType) else Timeout(timeout) @@ -682,7 +685,7 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response: """ Sends a single request, without handling any redirections. """ - transport = self._transport_for_url(request.url) + transport = self._transport_for_url(request) with map_exceptions(HTTPCORE_EXC_MAP, request=request): ( @@ -1059,12 +1062,13 @@ def _init_proxy_transport( http2=http2, ) - def _transport_for_url(self, url: URL) -> httpcore.AsyncHTTPTransport: + def _transport_for_url(self, request: Request) -> httpcore.AsyncHTTPTransport: """ Returns the transport instance that should be used for a given URL. This will either be the standard connection pool, or a proxy. """ - enforce_http_url(url) + url = request.url + enforce_http_url(request) if self._proxies and not should_not_be_proxied(url): for matcher, transport in self._proxies.items(): @@ -1204,7 +1208,7 @@ async def _send_single_request( """ Sends a single request, without handling any redirections. """ - transport = self._transport_for_url(request.url) + transport = self._transport_for_url(request) with map_exceptions(HTTPCORE_EXC_MAP, request=request): ( diff --git a/httpx/_decoders.py b/httpx/_decoders.py index d1c60fb267..c3959f8467 100644 --- a/httpx/_decoders.py +++ b/httpx/_decoders.py @@ -16,8 +16,14 @@ except ImportError: # pragma: nocover brotli = None +if typing.TYPE_CHECKING: # pragma: no cover + from ._models import Request + class Decoder: + def __init__(self, request: "Request") -> None: + self.request = request + def decode(self, data: bytes) -> bytes: raise NotImplementedError() # pragma: nocover @@ -44,7 +50,8 @@ class DeflateDecoder(Decoder): See: https://stackoverflow.com/questions/1838699 """ - def __init__(self) -> None: + def __init__(self, request: "Request") -> None: + self.request = request self.first_attempt = True self.decompressor = zlib.decompressobj() @@ -57,13 +64,13 @@ def decode(self, data: bytes) -> bytes: if was_first_attempt: self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) return self.decode(data) - raise DecodingError from exc + raise DecodingError(message=str(exc), request=self.request) def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: nocover - raise DecodingError from exc + raise DecodingError(message=str(exc), request=self.request) class GZipDecoder(Decoder): @@ -73,20 +80,21 @@ class GZipDecoder(Decoder): See: https://stackoverflow.com/questions/1838699 """ - def __init__(self) -> None: + def __init__(self, request: "Request") -> None: + self.request = request self.decompressor = zlib.decompressobj(zlib.MAX_WBITS | 16) def decode(self, data: bytes) -> bytes: try: return self.decompressor.decompress(data) except zlib.error as exc: - raise DecodingError from exc + raise DecodingError(message=str(exc), request=self.request) def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: nocover - raise DecodingError from exc + raise DecodingError(message=str(exc), request=self.request) class BrotliDecoder(Decoder): @@ -99,10 +107,11 @@ class BrotliDecoder(Decoder): name. The top branches are for 'brotlipy' and bottom branches for 'Brotli' """ - def __init__(self) -> None: + def __init__(self, request: "Request") -> None: assert ( brotli is not None ), "The 'brotlipy' or 'brotli' library must be installed to use 'BrotliDecoder'" + self.request = request self.decompressor = brotli.Decompressor() self.seen_data = False if hasattr(self.decompressor, "decompress"): @@ -117,7 +126,7 @@ def decode(self, data: bytes) -> bytes: try: return self._decompress(data) except brotli.error as exc: - raise DecodingError from exc + raise DecodingError(message=str(exc), request=self.request) def flush(self) -> bytes: if not self.seen_data: @@ -127,7 +136,7 @@ def flush(self) -> bytes: self.decompressor.finish() return b"" except brotli.error as exc: # pragma: nocover - raise DecodingError from exc + raise DecodingError(message=str(exc), request=self.request) class MultiDecoder(Decoder): @@ -160,7 +169,8 @@ class TextDecoder: Handles incrementally decoding bytes into text """ - def __init__(self, encoding: typing.Optional[str] = None): + def __init__(self, request: "Request", encoding: typing.Optional[str] = None): + self.request = request self.decoder: typing.Optional[codecs.IncrementalDecoder] = ( None if encoding is None else codecs.getincrementaldecoder(encoding)() ) @@ -194,8 +204,8 @@ def decode(self, data: bytes) -> str: self.buffer = None return text - except UnicodeDecodeError: # pragma: nocover - raise DecodingError() from None + except UnicodeDecodeError as exc: # pragma: nocover + raise DecodingError(message=str(exc), request=self.request) def flush(self) -> str: try: @@ -207,14 +217,15 @@ def flush(self) -> str: return bytes(self.buffer).decode(self._detector_result()) return self.decoder.decode(b"", True) - except UnicodeDecodeError: # pragma: nocover - raise DecodingError() from None + except UnicodeDecodeError as exc: # pragma: nocover + raise DecodingError(message=str(exc), request=self.request) def _detector_result(self) -> str: self.detector.close() result = self.detector.result["encoding"] if not result: # pragma: nocover - raise DecodingError("Unable to determine encoding of content") + message = "Unable to determine encoding of content" + raise DecodingError(message, request=self.request) return result diff --git a/httpx/_exceptions.py b/httpx/_exceptions.py index a5271c37c4..36fa7d1139 100644 --- a/httpx/_exceptions.py +++ b/httpx/_exceptions.py @@ -1,3 +1,33 @@ +""" +Our exception hierarchy: + +* RequestError + + TransportError + - TimeoutException + · ConnectTimeout + · ReadTimeout + · WriteTimeout + · PoolTimeout + - NetworkError + · ConnectError + · ReadError + · WriteError + · CloseError + - ProxyError + - ProtocolError + + DecodingError + + TooManyRedirects + + RequestBodyUnavailable + + InvalidURL +* HTTPStatusError +* NotRedirectResponse +* CookieConflict +* StreamError + + StreamConsumed + + ResponseNotRead + + RequestNotRead + + ResponseClosed +""" import contextlib import typing @@ -7,30 +37,26 @@ from ._models import Request, Response # pragma: nocover -class HTTPError(Exception): +class RequestError(Exception): """ - Base class for all HTTPX exceptions. + Base class for all exceptions that may occur when issuing a `.request()`. """ - def __init__( - self, *args: typing.Any, request: "Request" = None, response: "Response" = None - ) -> None: - super().__init__(*args) - self._request = request or (response.request if response is not None else None) - self.response = response + def __init__(self, message: str, *, request: "Request") -> None: + super().__init__(message) + self.request = request - @property - def request(self) -> "Request": - # NOTE: this property exists so that a `Request` is exposed to type - # checkers, instead of `Optional[Request]`. - assert self._request is not None # Populated by the client. - return self._request + +class TransportError(RequestError): + """ + Base class for all exceptions that are mapped from the httpcore API. + """ # Timeout exceptions... -class TimeoutException(HTTPError): +class TimeoutException(TransportError): """ The base class for timeout errors. @@ -65,7 +91,7 @@ class PoolTimeout(TimeoutException): # Core networking exceptions... -class NetworkError(HTTPError): +class NetworkError(TransportError): """ The base class for network-related errors. @@ -100,63 +126,94 @@ class CloseError(NetworkError): # Other transport exceptions... -class ProxyError(HTTPError): +class ProxyError(TransportError): """ An error occurred while proxying a request. """ -class ProtocolError(HTTPError): +class ProtocolError(TransportError): """ A protocol was violated by the server. """ -# HTTP exceptions... +# Other request exceptions... -class DecodingError(HTTPError): +class DecodingError(RequestError): """ Decoding of the response failed. """ -class HTTPStatusError(HTTPError): +class TooManyRedirects(RequestError): """ - Response sent an error HTTP status. + Too many redirects. """ - def __init__(self, *args: typing.Any, response: "Response") -> None: - super().__init__(*args) - self._request = response.request - self.response = response - -# Redirect exceptions... +class RequestBodyUnavailable(RequestError): + """ + Had to send the request again, but the request body was streaming, and is + no longer available. + """ -class RedirectError(HTTPError): +class InvalidURL(RequestError): """ - Base class for HTTP redirect errors. + URL was missing a hostname, or was not one of HTTP/HTTPS. """ -class TooManyRedirects(RedirectError): +# Client errors + + +class HTTPStatusError(Exception): """ - Too many redirects. + Response sent an error HTTP status. + + May be raised when calling `response.raise_for_status()` """ + def __init__( + self, message: str, *, request: "Request", response: "Response" + ) -> None: + super().__init__(message) + self.request = request + self.response = response -class NotRedirectResponse(RedirectError): + +class NotRedirectResponse(Exception): """ Response was not a redirect response. + + May be raised if `response.next()` is called without first + properly checking `response.is_redirect`. """ + def __init__(self, message: str) -> None: + super().__init__(message) + + +class CookieConflict(Exception): + """ + Attempted to lookup a cookie by name, but multiple cookies existed. + + Can occur when calling `response.cookies.get(...)`. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + # Stream exceptions... +# These may occur as the result of a programming error, by accessing +# the request/response stream in an invalid manner. -class StreamError(HTTPError): + +class StreamError(Exception): """ The base class for stream exceptions. @@ -164,12 +221,8 @@ class StreamError(HTTPError): an invalid way. """ - -class RequestBodyUnavailable(StreamError): - """ - Had to send the request again, but the request body was streaming, and is - no longer available. - """ + def __init__(self, message: str) -> None: + super().__init__(message) class StreamConsumed(StreamError): @@ -178,6 +231,13 @@ class StreamConsumed(StreamError): been streamed. """ + def __init__(self) -> None: + message = ( + "Attempted to read or stream response content, but the content has " + "already been streamed." + ) + super().__init__(message) + class ResponseNotRead(StreamError): """ @@ -185,12 +245,23 @@ class ResponseNotRead(StreamError): after a streaming response. """ + def __init__(self) -> None: + message = ( + "Attempted to access response content, without having called `read()` " + "after a streaming response." + ) + super().__init__(message) + class RequestNotRead(StreamError): """ Attempted to access request content, without having called `read()`. """ + def __init__(self) -> None: + message = "Attempted to access request content, without having called `read()`." + super().__init__(message) + class ResponseClosed(StreamError): """ @@ -198,20 +269,17 @@ class ResponseClosed(StreamError): closed. """ - -# Other cases... - - -class InvalidURL(HTTPError): - """ - URL was missing a hostname, or was not one of HTTP/HTTPS. - """ + def __init__(self) -> None: + message = ( + "Attempted to read or stream response content, but the request has " + "been closed." + ) + super().__init__(message) -class CookieConflict(HTTPError): - """ - Attempted to lookup a cookie by name, but multiple cookies existed. - """ +# We're continuing to expose this earlier naming at the moment. +# It is due to be deprecated. Don't use it. +HTTPError = RequestError @contextlib.contextmanager diff --git a/httpx/_models.py b/httpx/_models.py index 4a81e5965d..36345e9f78 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -798,16 +798,16 @@ def decoder(self) -> Decoder: value = value.strip().lower() try: decoder_cls = SUPPORTED_DECODERS[value] - decoders.append(decoder_cls()) + decoders.append(decoder_cls(request=self.request)) except KeyError: continue if len(decoders) == 1: self._decoder = decoders[0] elif len(decoders) > 1: - self._decoder = MultiDecoder(decoders) + self._decoder = MultiDecoder(children=decoders) else: - self._decoder = IdentityDecoder() + self._decoder = IdentityDecoder(request=self.request) return self._decoder @@ -830,10 +830,10 @@ def raise_for_status(self) -> None: if codes.is_client_error(self.status_code): message = message.format(self, error_type="Client Error") - raise HTTPStatusError(message, response=self) + raise HTTPStatusError(message, request=self.request, response=self) elif codes.is_server_error(self.status_code): message = message.format(self, error_type="Server Error") - raise HTTPStatusError(message, response=self) + raise HTTPStatusError(message, request=self.request, response=self) def json(self, **kwargs: typing.Any) -> typing.Any: if self.charset_encoding is None and self.content and len(self.content) > 3: @@ -895,7 +895,7 @@ def iter_text(self) -> typing.Iterator[str]: that handles both gzip, deflate, etc but also detects the content's string encoding. """ - decoder = TextDecoder(encoding=self.charset_encoding) + decoder = TextDecoder(request=self.request, encoding=self.charset_encoding) for chunk in self.iter_bytes(): yield decoder.decode(chunk) yield decoder.flush() @@ -927,7 +927,11 @@ def next(self) -> "Response": Get the next response from a redirect response. """ if not self.is_redirect: - raise NotRedirectResponse() + message = ( + "Called .next(), but the response was not a redirect. " + "Calling code should check `response.is_redirect` first." + ) + raise NotRedirectResponse(message) assert self.call_next is not None return self.call_next() @@ -968,7 +972,7 @@ async def aiter_text(self) -> typing.AsyncIterator[str]: that handles both gzip, deflate, etc but also detects the content's string encoding. """ - decoder = TextDecoder(encoding=self.charset_encoding) + decoder = TextDecoder(request=self.request, encoding=self.charset_encoding) async for chunk in self.aiter_bytes(): yield decoder.decode(chunk) yield decoder.flush() @@ -1000,7 +1004,10 @@ async def anext(self) -> "Response": Get the next response from a redirect response. """ if not self.is_redirect: - raise NotRedirectResponse() + raise NotRedirectResponse( + "Called .anext(), but the response was not a redirect. " + "Calling code should check `response.is_redirect` first." + ) assert self.call_next is not None return await self.call_next() diff --git a/httpx/_utils.py b/httpx/_utils.py index 2533a86d2a..41b71c248f 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -18,7 +18,7 @@ from ._types import PrimitiveData if typing.TYPE_CHECKING: # pragma: no cover - from ._models import URL + from ._models import URL, Request _HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} @@ -265,16 +265,21 @@ def trace(message: str, *args: typing.Any, **kwargs: typing.Any) -> None: return typing.cast(Logger, logger) -def enforce_http_url(url: "URL") -> None: +def enforce_http_url(request: "Request") -> None: """ Raise an appropriate InvalidURL for any non-HTTP URLs. """ + url = request.url + if not url.scheme: - raise InvalidURL("No scheme included in URL.") + message = "No scheme included in URL." + raise InvalidURL(message, request=request) if not url.host: - raise InvalidURL("No host included in URL.") + message = "No host included in URL." + raise InvalidURL(message, request=request) if url.scheme not in ("http", "https"): - raise InvalidURL('URL scheme must be "http" or "https".') + message = 'URL scheme must be "http" or "https".' + raise InvalidURL(message, request=request) def port_or_default(url: "URL") -> typing.Optional[int]: diff --git a/tests/client/test_proxies.py b/tests/client/test_proxies.py index d361372c1f..8d012fe668 100644 --- a/tests/client/test_proxies.py +++ b/tests/client/test_proxies.py @@ -94,7 +94,8 @@ def test_proxies_parameter(proxies, expected_proxies): ) def test_transport_for_request(url, proxies, expected): client = httpx.AsyncClient(proxies=proxies) - transport = client._transport_for_url(httpx.URL(url)) + request = httpx.Request(method="GET", url=url) + transport = client._transport_for_url(request) if expected is None: assert transport is client._transport @@ -141,7 +142,8 @@ def test_proxies_environ(monkeypatch, client_class, url, env, expected): monkeypatch.setenv(name, value) client = client_class() - transport = client._transport_for_url(httpx.URL(url)) + request = httpx.Request(method="GET", url=url) + transport = client._transport_for_url(request) if expected is None: assert transport == client._transport diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 3584b97573..c59ef7c30e 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -66,8 +66,8 @@ async def test_asgi_exc(): @pytest.mark.usefixtures("async_environment") async def test_asgi_http_error(): - client = httpx.AsyncClient(app=partial(raise_exc, exc=httpx.HTTPError)) - with pytest.raises(httpx.HTTPError): + client = httpx.AsyncClient(app=partial(raise_exc, exc=RuntimeError)) + with pytest.raises(RuntimeError): await client.get("http://www.example.org/") diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 6b7993109a..ec01d41e4c 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -135,7 +135,8 @@ def test_empty_content(header_value): "decoder", (BrotliDecoder, DeflateDecoder, GZipDecoder, IdentityDecoder) ) def test_decoders_empty_cases(decoder): - instance = decoder() + request = httpx.Request(method="GET", url="https://www.example.com") + instance = decoder(request) assert instance.decode(b"") == b"" assert instance.flush() == b"" @@ -206,10 +207,12 @@ async def iterator(): def test_text_decoder_empty_cases(): - decoder = TextDecoder() + request = httpx.Request(method="GET", url="https://www.example.com") + + decoder = TextDecoder(request=request) assert decoder.flush() == "" - decoder = TextDecoder() + decoder = TextDecoder(request=request) assert decoder.decode(b"") == "" assert decoder.flush() == "" diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 18a272c36f..e002133939 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -97,8 +97,8 @@ def test_wsgi_exc(): def test_wsgi_http_error(): - client = httpx.Client(app=partial(raise_exc, exc=httpx.HTTPError)) - with pytest.raises(httpx.HTTPError): + client = httpx.Client(app=partial(raise_exc, exc=RuntimeError)) + with pytest.raises(RuntimeError): client.get("http://www.example.org/")