From e90b2ccf2a5a9eec63a26ba6784ac9a32b9bc13b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 7 Sep 2020 09:36:37 +0100 Subject: [PATCH 1/8] Support Response(content=) --- httpx/_content_streams.py | 35 +++++++++-------- httpx/_models.py | 11 ++++-- httpx/_types.py | 2 + tests/models/test_responses.py | 46 +++++----------------- tests/test_content_streams.py | 71 +++++++++++++++++++++++++++++++++- tests/test_decoders.py | 13 ++----- 6 files changed, 110 insertions(+), 68 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index 402fa959c8..3cd2196ab4 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -8,7 +8,7 @@ import httpcore from ._exceptions import StreamConsumed -from ._types import FileContent, FileTypes, RequestData, RequestFiles +from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent from ._utils import ( format_form_param, guess_content_type, @@ -72,11 +72,8 @@ class IteratorStream(ContentStream): Request content encoded as plain bytes, using an byte iterator. """ - def __init__( - self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None - ) -> None: + def __init__(self, iterator: typing.Iterator[bytes]) -> None: self.iterator = iterator - self.close_func = close_func self.is_stream_consumed = False def can_replay(self) -> bool: @@ -95,21 +92,14 @@ def __iter__(self) -> typing.Iterator[bytes]: def __aiter__(self) -> typing.AsyncIterator[bytes]: raise RuntimeError("Attempted to call a async iterator on an sync stream.") - def close(self) -> None: - if self.close_func is not None: - self.close_func() - class AsyncIteratorStream(ContentStream): """ Request content encoded as plain bytes, using an async byte iterator. """ - def __init__( - self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None - ) -> None: + def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None: self.aiterator = aiterator - self.close_func = close_func self.is_stream_consumed = False def can_replay(self) -> bool: @@ -128,10 +118,6 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: async for part in self.aiterator: yield part - async def aclose(self) -> None: - if self.close_func is not None: - await self.close_func() - class JSONStream(ContentStream): """ @@ -402,3 +388,18 @@ def encode( return IteratorStream(iterator=data) raise TypeError(f"Unexpected type for 'data', {type(data)!r}") + + +def encode_response(content: ResponseContent = None) -> ContentStream: + if content is None: + return ByteStream(b"") + elif isinstance(content, bytes): + return ByteStream(body=content) + elif hasattr(content, "__aiter__"): + content = typing.cast(typing.AsyncIterator[bytes], content) + return AsyncIteratorStream(aiterator=content) + elif hasattr(content, "__iter__"): + content = typing.cast(typing.Iterator[bytes], content) + return IteratorStream(iterator=content) + + raise TypeError(f"Unexpected type for 'content', {type(content)!r}") diff --git a/httpx/_models.py b/httpx/_models.py index 713281e662..694e520c2c 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -14,7 +14,7 @@ import rfc3986 import rfc3986.exceptions -from ._content_streams import ByteStream, ContentStream, encode +from ._content_streams import ByteStream, ContentStream, encode, encode_response from ._decoders import ( SUPPORTED_DECODERS, Decoder, @@ -44,6 +44,7 @@ QueryParamTypes, RequestData, RequestFiles, + ResponseContent, URLTypes, ) from ._utils import ( @@ -674,7 +675,7 @@ def __init__( http_version: str = None, headers: HeaderTypes = None, stream: ContentStream = None, - content: bytes = None, + content: ResponseContent = None, history: typing.List["Response"] = None, elapsed_func: typing.Callable = None, ): @@ -694,8 +695,10 @@ def __init__( if stream is not None: self._raw_stream = stream else: - self._raw_stream = ByteStream(body=content or b"") - self.read() + self._raw_stream = encode_response(content) + if content is None or isinstance(content, bytes): + # Load the response body, except for streaming content. + self.read() @property def elapsed(self) -> datetime.timedelta: diff --git a/httpx/_types.py b/httpx/_types.py index 3a90ee42e7..8989b2826c 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -63,6 +63,8 @@ None, ] +ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]] + RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]] FileContent = Union[IO[str], IO[bytes], str, bytes] diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 2b07a27040..9c4d285091 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -5,7 +5,6 @@ import pytest import httpx -from httpx._content_streams import AsyncIteratorStream, IteratorStream def streaming_body(): @@ -215,10 +214,9 @@ async def test_aread(): def test_iter_raw(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) raw = b"" @@ -229,10 +227,9 @@ def test_iter_raw(): @pytest.mark.asyncio async def test_aiter_raw(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) raw = b"" @@ -317,10 +314,9 @@ async def test_aiter_lines(): def test_sync_streaming_response(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) assert response.status_code == 200 @@ -335,10 +331,9 @@ def test_sync_streaming_response(): @pytest.mark.asyncio async def test_async_streaming_response(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) assert response.status_code == 200 @@ -352,10 +347,9 @@ async def test_async_streaming_response(): def test_cannot_read_after_stream_consumed(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) content = b"" @@ -368,10 +362,9 @@ def test_cannot_read_after_stream_consumed(): @pytest.mark.asyncio async def test_cannot_aread_after_stream_consumed(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) content = b"" @@ -383,54 +376,33 @@ async def test_cannot_aread_after_stream_consumed(): def test_cannot_read_after_response_closed(): - is_closed = False - - def close_func(): - nonlocal is_closed - is_closed = True - - stream = IteratorStream(iterator=streaming_body(), close_func=close_func) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) response.close() - assert is_closed - with pytest.raises(httpx.ResponseClosed): response.read() @pytest.mark.asyncio async def test_cannot_aread_after_response_closed(): - is_closed = False - - async def close_func(): - nonlocal is_closed - is_closed = True - - stream = AsyncIteratorStream( - aiterator=async_streaming_body(), close_func=close_func - ) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) await response.aclose() - assert is_closed - with pytest.raises(httpx.ResponseClosed): await response.aread() @pytest.mark.asyncio async def test_elapsed_not_available_until_closed(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) with pytest.raises(RuntimeError): diff --git a/tests/test_content_streams.py b/tests/test_content_streams.py index 140aa8d2af..2d1de1f1c0 100644 --- a/tests/test_content_streams.py +++ b/tests/test_content_streams.py @@ -3,7 +3,7 @@ import pytest from httpx import StreamConsumed -from httpx._content_streams import ContentStream, encode +from httpx._content_streams import ContentStream, encode, encode_response @pytest.mark.asyncio @@ -251,3 +251,72 @@ async def test_multipart_multiple_files_single_input_content(): b"--+++--\r\n", ] ) + + +@pytest.mark.asyncio +async def test_response_empty_content(): + stream = encode_response() + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + assert stream.can_replay() + assert stream.get_headers() == {} + assert sync_content == b"" + assert async_content == b"" + + +@pytest.mark.asyncio +async def test_response_bytes_content(): + stream = encode_response(content=b"Hello, world!") + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + assert stream.can_replay() + assert stream.get_headers() == {"Content-Length": "13"} + assert sync_content == b"Hello, world!" + assert async_content == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_response_iterator_content(): + def hello_world(): + yield b"Hello, " + yield b"world!" + + stream = encode_response(content=hello_world()) + content = b"".join([part for part in stream]) + + assert not stream.can_replay() + assert stream.get_headers() == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" + + with pytest.raises(RuntimeError): + [part async for part in stream] + + with pytest.raises(StreamConsumed): + [part for part in stream] + + +@pytest.mark.asyncio +async def test_response_aiterator_content(): + async def hello_world(): + yield b"Hello, " + yield b"world!" + + stream = encode_response(content=hello_world()) + content = b"".join([part async for part in stream]) + + assert not stream.can_replay() + assert stream.get_headers() == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" + + with pytest.raises(RuntimeError): + [part for part in stream] + + with pytest.raises(StreamConsumed): + [part async for part in stream] + + +def test_response_invalid_argument(): + with pytest.raises(TypeError): + encode_response(123) # type: ignore diff --git a/tests/test_decoders.py b/tests/test_decoders.py index dbbaac5450..7dfca9ef50 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -4,7 +4,6 @@ import pytest import httpx -from httpx._content_streams import AsyncIteratorStream from httpx._decoders import ( BrotliDecoder, DeflateDecoder, @@ -130,11 +129,10 @@ async def compress(body): yield compressor.flush() headers = [(b"Content-Encoding", b"gzip")] - stream = AsyncIteratorStream(aiterator=compress(body)) response = httpx.Response( 200, headers=headers, - stream=stream, + content=compress(body), ) assert not hasattr(response, "body") assert await response.aread() == body @@ -199,19 +197,17 @@ async def iterator(): yield chunk # Accessing `.text` on a read response. - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, - stream=stream, + content=iterator(), ) await response.aread() assert response.text == (b"".join(data)).decode(encoding) # Streaming `.aiter_text` iteratively. - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, - stream=stream, + content=iterator(), ) text = "".join([part async for part in response.aiter_text()]) assert text == (b"".join(data)).decode(encoding) @@ -224,11 +220,10 @@ async def iterator(): yield b"\x83" yield b"\x89\x83x\x83\x8b" - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, headers=[(b"Content-Type", b"text/html; charset=shift-jis")], - stream=stream, + content=iterator(), ) await response.aread() From 6af5e50bec87d8bf7cac00eb7c66cfdffd311378 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 10 Sep 2020 20:02:51 +0100 Subject: [PATCH 2/8] Update test for merged master --- tests/models/test_responses.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 5517ab8efc..b52e4846f3 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -226,12 +226,7 @@ def test_iter_raw(): def test_iter_raw_increments_updates_counter(): - stream = IteratorStream(iterator=streaming_body()) - - response = httpx.Response( - 200, - stream=stream, - ) + response = httpx.Response(200, content=streaming_body()) num_downloaded = response.num_bytes_downloaded for part in response.iter_raw(): @@ -241,10 +236,7 @@ def test_iter_raw_increments_updates_counter(): @pytest.mark.asyncio async def test_aiter_raw(): - response = httpx.Response( - 200, - content=async_streaming_body(), - ) + response = httpx.Response(200, content=async_streaming_body()) raw = b"" async for part in response.aiter_raw(): @@ -254,12 +246,7 @@ async def test_aiter_raw(): @pytest.mark.asyncio async def test_aiter_raw_increments_updates_counter(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) - - response = httpx.Response( - 200, - stream=stream, - ) + response = httpx.Response(200, content=async_streaming_body()) num_downloaded = response.num_bytes_downloaded async for part in response.aiter_raw(): From f930214c80845061f152973a3a6bb39e7c3203a3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 11 Sep 2020 10:56:07 +0100 Subject: [PATCH 3/8] Add MockTransport for test cases --- httpx/_models.py | 1 - tests/client/test_cookies.py | 62 ++++++++++---------------------- tests/client/test_headers.py | 44 ++++++++--------------- tests/client/test_queryparams.py | 28 ++++----------- tests/test_multipart.py | 30 ++++------------ tests/utils.py | 52 +++++++++++++++++++++++++++ 6 files changed, 100 insertions(+), 117 deletions(-) diff --git a/httpx/_models.py b/httpx/_models.py index 526ee2cebf..29b3e67ac0 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -605,7 +605,6 @@ def __init__( self.stream = stream else: self.stream = encode(data, files, json) - self.prepare() def prepare(self) -> None: diff --git a/tests/client/test_cookies.py b/tests/client/test_cookies.py index 8cd6be8394..af614effb6 100644 --- a/tests/client/test_cookies.py +++ b/tests/client/test_cookies.py @@ -1,43 +1,19 @@ -import typing +import json from http.cookiejar import Cookie, CookieJar -import httpcore - import httpx -from httpx._content_streams import ByteStream, ContentStream, JSONStream - - -def get_header_value(headers, key, default=None): - lookup = key.encode("ascii").lower() - for header_key, header_value in headers: - if header_key.lower() == lookup: - return header_value.decode("ascii") - return default - - -class MockTransport(httpcore.SyncHTTPTransport): - def request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]] = None, - stream: httpcore.SyncByteStream = None, - timeout: typing.Mapping[str, typing.Optional[float]] = None, - ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream - ]: - host, scheme, port, path = url - body: ContentStream - if path.startswith(b"/echo_cookies"): - cookie = get_header_value(headers, "cookie") - body = JSONStream({"cookies": cookie}) - return b"HTTP/1.1", 200, b"OK", [], body - elif path.startswith(b"/set_cookie"): - headers = [(b"set-cookie", b"example-name=example-value")] - body = ByteStream(b"") - return b"HTTP/1.1", 200, b"OK", headers, body - else: - raise NotImplementedError() # pragma: no cover +from tests.utils import MockTransport + + +def get_and_set_cookies(request: httpx.Request) -> httpx.Response: + if request.url.path == "/echo_cookies": + data = {"cookies": request.headers.get("cookie")} + content = json.dumps(data).encode("utf-8") + return httpx.Response(200, content=content) + elif request.url.path == "/set_cookie": + return httpx.Response(200, headers={"set-cookie": "example-name=example-value"}) + else: + raise NotImplementedError() # pragma: no cover def test_set_cookie() -> None: @@ -47,7 +23,7 @@ def test_set_cookie() -> None: url = "http://example.org/echo_cookies" cookies = {"example-name": "example-value"} - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(get_and_set_cookies)) response = client.get(url, cookies=cookies) assert response.status_code == 200 @@ -82,7 +58,7 @@ def test_set_cookie_with_cookiejar() -> None: ) cookies.set_cookie(cookie) - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(get_and_set_cookies)) response = client.get(url, cookies=cookies) assert response.status_code == 200 @@ -117,7 +93,7 @@ def test_setting_client_cookies_to_cookiejar() -> None: ) cookies.set_cookie(cookie) - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(get_and_set_cookies)) client.cookies = cookies # type: ignore response = client.get(url) @@ -134,7 +110,7 @@ def test_set_cookie_with_cookies_model() -> None: cookies = httpx.Cookies() cookies["example-name"] = "example-value" - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(get_and_set_cookies)) response = client.get(url, cookies=cookies) assert response.status_code == 200 @@ -144,7 +120,7 @@ def test_set_cookie_with_cookies_model() -> None: def test_get_cookie() -> None: url = "http://example.org/set_cookie" - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(get_and_set_cookies)) response = client.get(url) assert response.status_code == 200 @@ -156,7 +132,7 @@ def test_cookie_persistence() -> None: """ Ensure that Client instances persist cookies between requests. """ - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(get_and_set_cookies)) response = client.get("http://example.org/echo_cookies") assert response.status_code == 200 diff --git a/tests/client/test_headers.py b/tests/client/test_headers.py index c86eae33c1..d968616f4e 100755 --- a/tests/client/test_headers.py +++ b/tests/client/test_headers.py @@ -1,31 +1,17 @@ #!/usr/bin/env python3 -import typing +import json -import httpcore import pytest import httpx -from httpx._content_streams import ContentStream, JSONStream - - -class MockTransport(httpcore.SyncHTTPTransport): - def request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]] = None, - stream: httpcore.SyncByteStream = None, - timeout: typing.Mapping[str, typing.Optional[float]] = None, - ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream - ]: - assert headers is not None - headers_dict = { - key.decode("ascii"): value.decode("ascii") for key, value in headers - } - body = JSONStream({"headers": headers_dict}) - return b"HTTP/1.1", 200, b"OK", [], body +from tests.utils import MockTransport + + +def echo_headers(request: httpx.Request) -> httpx.Response: + data = {"headers": dict(request.headers)} + content = json.dumps(data).encode("utf-8") + return httpx.Response(200, content=content) def test_client_header(): @@ -35,7 +21,7 @@ def test_client_header(): url = "http://example.org/echo_headers" headers = {"Example-Header": "example-value"} - client = httpx.Client(transport=MockTransport(), headers=headers) + client = httpx.Client(transport=MockTransport(echo_headers), headers=headers) response = client.get(url) assert response.status_code == 200 @@ -55,7 +41,7 @@ def test_header_merge(): url = "http://example.org/echo_headers" client_headers = {"User-Agent": "python-myclient/0.2.1"} request_headers = {"X-Auth-Token": "FooBarBazToken"} - client = httpx.Client(transport=MockTransport(), headers=client_headers) + client = httpx.Client(transport=MockTransport(echo_headers), headers=client_headers) response = client.get(url, headers=request_headers) assert response.status_code == 200 @@ -75,7 +61,7 @@ def test_header_merge_conflicting_headers(): url = "http://example.org/echo_headers" client_headers = {"X-Auth-Token": "FooBar"} request_headers = {"X-Auth-Token": "BazToken"} - client = httpx.Client(transport=MockTransport(), headers=client_headers) + client = httpx.Client(transport=MockTransport(echo_headers), headers=client_headers) response = client.get(url, headers=request_headers) assert response.status_code == 200 @@ -93,7 +79,7 @@ def test_header_merge_conflicting_headers(): def test_header_update(): url = "http://example.org/echo_headers" - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_headers)) first_response = client.get(url) client.headers.update( {"User-Agent": "python-myclient/0.2.1", "Another-Header": "AThing"} @@ -130,7 +116,7 @@ def test_remove_default_header(): """ url = "http://example.org/echo_headers" - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_headers)) del client.headers["User-Agent"] response = client.get(url) @@ -160,7 +146,7 @@ def test_host_with_auth_and_port_in_url(): """ url = "http://username:password@example.org:80/echo_headers" - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_headers)) response = client.get(url) assert response.status_code == 200 @@ -183,7 +169,7 @@ def test_host_with_non_default_port_in_url(): """ url = "http://username:password@example.org:123/echo_headers" - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_headers)) response = client.get(url) assert response.status_code == 200 diff --git a/tests/client/test_queryparams.py b/tests/client/test_queryparams.py index 22f715dadc..a14d1f4e30 100644 --- a/tests/client/test_queryparams.py +++ b/tests/client/test_queryparams.py @@ -1,24 +1,5 @@ -import typing - -import httpcore - import httpx -from httpx._content_streams import ContentStream, JSONStream - - -class MockTransport(httpcore.SyncHTTPTransport): - def request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]] = None, - stream: httpcore.SyncByteStream = None, - timeout: typing.Mapping[str, typing.Optional[float]] = None, - ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream - ]: - body = JSONStream({"ok": "ok"}) - return b"HTTP/1.1", 200, b"OK", [], body +from tests.utils import MockTransport def test_client_queryparams(): @@ -39,10 +20,15 @@ def test_client_queryparams_string(): def test_client_queryparams_echo(): + def hello_world(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=b"Hello, world") + url = "http://example.org/echo_queryparams" client_queryparams = "first=str" request_queryparams = {"second": "dict"} - client = httpx.Client(transport=MockTransport(), params=client_queryparams) + client = httpx.Client( + transport=MockTransport(hello_world), params=client_queryparams + ) response = client.get(url, params=request_queryparams) assert response.status_code == 200 diff --git a/tests/test_multipart.py b/tests/test_multipart.py index f4962daba0..d10c39038d 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -4,37 +4,21 @@ import typing from unittest import mock -import httpcore import pytest import httpx from httpx._content_streams import MultipartStream, encode from httpx._utils import format_form_param +from tests.utils import MockTransport -class MockTransport(httpcore.SyncHTTPTransport): - def request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]] = None, - stream: httpcore.SyncByteStream = None, - timeout: typing.Mapping[str, typing.Optional[float]] = None, - ) -> typing.Tuple[ - bytes, - int, - bytes, - typing.List[typing.Tuple[bytes, bytes]], - httpcore.SyncByteStream, - ]: - assert stream is not None - content = httpcore.IteratorByteStream(iterator=(part for part in stream)) - return b"HTTP/1.1", 200, b"OK", [], content +def echo_request_content(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=request.content) @pytest.mark.parametrize(("value,output"), (("abc", b"abc"), (b"abc", b"abc"))) def test_multipart(value, output): - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_request_content)) # Test with a single-value 'data' argument, and a plain file 'files' argument. data = {"text": value} @@ -60,7 +44,7 @@ def test_multipart(value, output): @pytest.mark.parametrize(("key"), (b"abc", 1, 2.3, None)) def test_multipart_invalid_key(key): - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_request_content)) data = {key: "abc"} files = {"file": io.BytesIO(b"")} @@ -75,7 +59,7 @@ def test_multipart_invalid_key(key): @pytest.mark.parametrize(("value"), (1, 2.3, None, [None, "abc"], {None: "abc"})) def test_multipart_invalid_value(value): - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_request_content)) data = {"text": value} files = {"file": io.BytesIO(b"")} @@ -85,7 +69,7 @@ def test_multipart_invalid_value(value): def test_multipart_file_tuple(): - client = httpx.Client(transport=MockTransport()) + client = httpx.Client(transport=MockTransport(echo_request_content)) # Test with a list of values 'data' argument, # and a tuple style 'files' argument. diff --git a/tests/utils.py b/tests/utils.py index e2636a535c..5b10287523 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,11 @@ import contextlib import logging import os +from typing import Callable, List, Mapping, Optional, Tuple +import httpcore + +import httpx from httpx import _utils @@ -18,3 +22,51 @@ def override_log_level(log_level: str): finally: # Reset the logger so we don't have verbose output in all unit tests logging.getLogger("httpx").handlers = [] + + +class MockTransport(httpcore.SyncHTTPTransport): + def __init__(self, handler: Callable) -> None: + self.handler = handler + + def request( + self, + method: bytes, + url: Tuple[bytes, bytes, Optional[int], bytes], + headers: List[Tuple[bytes, bytes]] = None, + stream: httpcore.SyncByteStream = None, + timeout: Mapping[str, Optional[float]] = None, + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.SyncByteStream]: + raw_scheme, raw_host, port, raw_path = url + scheme = raw_scheme.decode("ascii") + host = raw_host.decode("ascii") + port_str = "" if port is None else f":{port}" + path = raw_path.decode("ascii") + + request_headers = httpx.Headers(headers) + data = ( + (item for item in stream) + if stream + and ( + "Content-Length" in request_headers + or "Transfer-Encoding" in request_headers + ) + else None + ) + + request = httpx.Request( + method=method.decode("ascii"), + url=f"{scheme}://{host}{port_str}{path}", + headers=request_headers, + data=data, + ) + request.read() + response = self.handler(request) + return ( + response.http_version.encode("ascii") + if response.http_version + else b"HTTP/1.1", + response.status_code, + response.reason_phrase.encode("ascii"), + response.headers.raw, + response._raw_stream, + ) From 6ae0c41a0b1e19f25abe21e110db4fd0a7773b0f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 11 Sep 2020 12:37:36 +0100 Subject: [PATCH 4/8] Use MockTransport for redirect tests --- tests/client/test_redirects.py | 408 ++++++++++++++------------------- tests/utils.py | 39 ++++ 2 files changed, 209 insertions(+), 238 deletions(-) diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 4b00133e31..cac94bf2aa 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -1,180 +1,123 @@ import json -import typing -from urllib.parse import parse_qs import httpcore import pytest import httpx -from httpx._content_streams import ByteStream, ContentStream, IteratorStream - - -def get_header_value(headers, key, default=None): - lookup = key.encode("ascii").lower() - for header_key, header_value in headers: - if header_key.lower() == lookup: - return header_value.decode("ascii") - return default - - -class MockTransport: - def _request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, int, bytes], - headers: typing.List[typing.Tuple[bytes, bytes]], - stream: ContentStream, - timeout: typing.Mapping[str, typing.Optional[float]] = None, - ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream - ]: - scheme, host, port, path = url - if scheme not in (b"http", b"https"): - raise httpcore.UnsupportedProtocol(f"Scheme {scheme!r} not supported.") - - path, _, query = path.partition(b"?") - if path == b"/no_redirect": - return b"HTTP/1.1", httpx.codes.OK, b"OK", [], ByteStream(b"") - - elif path == b"/redirect_301": - - def body(): - yield b"here" - - status_code = httpx.codes.MOVED_PERMANENTLY - headers = [(b"location", b"https://example.org/")] - stream = IteratorStream(iterator=body()) - return b"HTTP/1.1", status_code, b"Moved Permanently", headers, stream - - elif path == b"/redirect_302": - status_code = httpx.codes.FOUND - headers = [(b"location", b"https://example.org/")] - return b"HTTP/1.1", status_code, b"Found", headers, ByteStream(b"") - - elif path == b"/redirect_303": - status_code = httpx.codes.SEE_OTHER - headers = [(b"location", b"https://example.org/")] - return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"") - - elif path == b"/relative_redirect": - status_code = httpx.codes.SEE_OTHER - headers = [(b"location", b"/")] - return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"") - - elif path == b"/malformed_redirect": - status_code = httpx.codes.SEE_OTHER - headers = [(b"location", b"https://:443/")] - return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"") - - elif path == b"/invalid_redirect": - status_code = httpx.codes.SEE_OTHER - headers = [(b"location", "https://😇/".encode("utf-8"))] - return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"") - - elif path == b"/no_scheme_redirect": - status_code = httpx.codes.SEE_OTHER - headers = [(b"location", b"//example.org/")] - return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"") - - elif path == b"/multiple_redirects": - params = parse_qs(query.decode("ascii")) - count = int(params.get("count", "0")[0]) - redirect_count = count - 1 - code = httpx.codes.SEE_OTHER if count else httpx.codes.OK - phrase = b"See Other" if count else b"OK" - location = b"/multiple_redirects" +from tests.utils import AsyncMockTransport, MockTransport + + +def redirects(request: httpx.Request) -> httpx.Response: + if request.url.scheme not in ("http", "https"): + raise httpcore.UnsupportedProtocol( + f"Scheme {request.url.scheme!r} not supported." + ) + + if request.url.path == "/no_redirect": + return httpx.Response(200) + + elif request.url.path == "/redirect_301": + status_code = httpx.codes.MOVED_PERMANENTLY + content = b"here" + headers = {"location": "https://example.org/"} + return httpx.Response(status_code, headers=headers, content=content) + + elif request.url.path == "/redirect_302": + status_code = httpx.codes.FOUND + headers = {"location": "https://example.org/"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/redirect_303": + status_code = httpx.codes.SEE_OTHER + headers = {"location": "https://example.org/"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/relative_redirect": + status_code = httpx.codes.SEE_OTHER + headers = {"location": "/"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/malformed_redirect": + status_code = httpx.codes.SEE_OTHER + headers = {"location": "https://:443/"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/invalid_redirect": + status_code = httpx.codes.SEE_OTHER + raw_headers = [(b"location", "https://😇/".encode("utf-8"))] + return httpx.Response(status_code, headers=raw_headers) + + elif request.url.path == "/no_scheme_redirect": + status_code = httpx.codes.SEE_OTHER + headers = {"location": "//example.org/"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/multiple_redirects": + params = httpx.QueryParams(request.url.query) + count = int(params.get("count", "0")) + redirect_count = count - 1 + status_code = httpx.codes.SEE_OTHER if count else httpx.codes.OK + if count: + location = "/multiple_redirects" if redirect_count: - location += b"?count=" + str(redirect_count).encode("ascii") - headers = [(b"location", location)] if count else [] - return b"HTTP/1.1", code, phrase, headers, ByteStream(b"") - - if path == b"/redirect_loop": - code = httpx.codes.SEE_OTHER - headers = [(b"location", b"/redirect_loop")] - return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"") - - elif path == b"/cross_domain": - code = httpx.codes.SEE_OTHER - headers = [(b"location", b"https://example.org/cross_domain_target")] - return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"") - - elif path == b"/cross_domain_target": - headers_dict = { - key.decode("ascii"): value.decode("ascii") for key, value in headers - } - stream = ByteStream(json.dumps({"headers": headers_dict}).encode()) - return b"HTTP/1.1", 200, b"OK", [], stream - - elif path == b"/redirect_body": - code = httpx.codes.PERMANENT_REDIRECT - headers = [(b"location", b"/redirect_body_target")] - return b"HTTP/1.1", code, b"Permanent Redirect", headers, ByteStream(b"") - - elif path == b"/redirect_no_body": - code = httpx.codes.SEE_OTHER - headers = [(b"location", b"/redirect_body_target")] - return b"HTTP/1.1", code, b"See Other", headers, ByteStream(b"") - - elif path == b"/redirect_body_target": - content = b"".join(stream) - headers_dict = { - key.decode("ascii"): value.decode("ascii") for key, value in headers - } - stream = ByteStream( - json.dumps({"body": content.decode(), "headers": headers_dict}).encode() - ) - return b"HTTP/1.1", 200, b"OK", [], stream - - elif path == b"/cross_subdomain": - host = get_header_value(headers, "host") - if host != "www.example.org": - headers = [(b"location", b"https://www.example.org/cross_subdomain")] - return ( - b"HTTP/1.1", - httpx.codes.PERMANENT_REDIRECT, - b"Permanent Redirect", - headers, - ByteStream(b""), - ) - else: - return b"HTTP/1.1", 200, b"OK", [], ByteStream(b"Hello, world!") - - elif path == b"/redirect_custom_scheme": - status_code = httpx.codes.MOVED_PERMANENTLY - headers = [(b"location", b"market://details?id=42")] - return ( - b"HTTP/1.1", - status_code, - b"Moved Permanently", - headers, - ByteStream(b""), - ) - - stream = ByteStream(b"Hello, world!") if method != b"HEAD" else ByteStream(b"") - - return b"HTTP/1.1", 200, b"OK", [], stream - - -class AsyncMockTransport(MockTransport, httpcore.AsyncHTTPTransport): - async def request( - self, *args, **kwargs - ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream - ]: - return self._request(*args, **kwargs) - - -class SyncMockTransport(MockTransport, httpcore.SyncHTTPTransport): - def request( - self, *args, **kwargs - ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream - ]: - return self._request(*args, **kwargs) + location += f"?count={redirect_count}" + headers = {"location": location} + else: + headers = {} + return httpx.Response(status_code, headers=headers) + + if request.url.path == "/redirect_loop": + status_code = httpx.codes.SEE_OTHER + headers = {"location": "/redirect_loop"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/cross_domain": + status_code = httpx.codes.SEE_OTHER + headers = {"location": "https://example.org/cross_domain_target"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/cross_domain_target": + status_code = httpx.codes.OK + content = json.dumps({"headers": dict(request.headers)}).encode("utf-8") + return httpx.Response(status_code, content=content) + + elif request.url.path == "/redirect_body": + status_code = httpx.codes.PERMANENT_REDIRECT + headers = {"location": "/redirect_body_target"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/redirect_no_body": + status_code = httpx.codes.SEE_OTHER + headers = {"location": "/redirect_body_target"} + return httpx.Response(status_code, headers=headers) + + elif request.url.path == "/redirect_body_target": + content = json.dumps( + {"body": request.content.decode("ascii"), "headers": dict(request.headers)} + ).encode("utf-8") + return httpx.Response(200, content=content) + + elif request.url.path == "/cross_subdomain": + if request.headers["Host"] != "www.example.org": + status_code = httpx.codes.PERMANENT_REDIRECT + headers = {"location": "https://www.example.org/cross_subdomain"} + return httpx.Response(status_code, headers=headers) + else: + return httpx.Response(200, content=b"Hello, world!") + + elif request.url.path == "/redirect_custom_scheme": + status_code = httpx.codes.MOVED_PERMANENTLY + headers = {"location": "market://details?id=42"} + return httpx.Response(status_code, headers=headers) + + if request.method == "HEAD": + return httpx.Response(200) + + return httpx.Response(200, content=b"Hello, world!") def test_no_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.com/no_redirect" response = client.get(url) assert response.status_code == 200 @@ -183,7 +126,7 @@ def test_no_redirect(): def test_redirect_301(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.post("https://example.org/redirect_301") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" @@ -191,7 +134,7 @@ def test_redirect_301(): def test_redirect_302(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.post("https://example.org/redirect_302") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" @@ -199,7 +142,7 @@ def test_redirect_302(): def test_redirect_303(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.get("https://example.org/redirect_303") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" @@ -207,7 +150,7 @@ def test_redirect_303(): def test_disallow_redirects(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.post("https://example.org/redirect_303", allow_redirects=False) assert response.status_code == httpx.codes.SEE_OTHER assert response.url == "https://example.org/redirect_303" @@ -225,7 +168,7 @@ def test_head_redirect(): """ Contrary to Requests, redirects remain enabled by default for HEAD requests. """ - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.head("https://example.org/redirect_302") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" @@ -235,7 +178,7 @@ def test_head_redirect(): def test_relative_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.get("https://example.org/relative_redirect") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" @@ -244,7 +187,7 @@ def test_relative_redirect(): def test_malformed_redirect(): # https://github.com/encode/httpx/issues/771 - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.get("http://example.org/malformed_redirect") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org:443/" @@ -252,13 +195,13 @@ def test_malformed_redirect(): def test_invalid_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) with pytest.raises(httpx.RemoteProtocolError): client.get("http://example.org/invalid_redirect") def test_no_scheme_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.get("https://example.org/no_scheme_redirect") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/" @@ -266,7 +209,7 @@ def test_no_scheme_redirect(): def test_fragment_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.get("https://example.org/relative_redirect#fragment") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/#fragment" @@ -274,7 +217,7 @@ def test_fragment_redirect(): def test_multiple_redirects(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) response = client.get("https://example.org/multiple_redirects?count=20") assert response.status_code == httpx.codes.OK assert response.url == "https://example.org/multiple_redirects" @@ -287,14 +230,14 @@ def test_multiple_redirects(): @pytest.mark.usefixtures("async_environment") async def test_async_too_many_redirects(): - async with httpx.AsyncClient(transport=AsyncMockTransport()) as client: + async with httpx.AsyncClient(transport=AsyncMockTransport(redirects)) as client: with pytest.raises(httpx.TooManyRedirects): await client.get("https://example.org/multiple_redirects?count=21") @pytest.mark.usefixtures("async_environment") async def test_async_too_many_redirects_calling_next(): - async with httpx.AsyncClient(transport=AsyncMockTransport()) as client: + async with httpx.AsyncClient(transport=AsyncMockTransport(redirects)) as client: url = "https://example.org/multiple_redirects?count=21" response = await client.get(url, allow_redirects=False) with pytest.raises(httpx.TooManyRedirects): @@ -303,13 +246,13 @@ async def test_async_too_many_redirects_calling_next(): def test_sync_too_many_redirects(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) with pytest.raises(httpx.TooManyRedirects): client.get("https://example.org/multiple_redirects?count=21") def test_sync_too_many_redirects_calling_next(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.org/multiple_redirects?count=21" response = client.get(url, allow_redirects=False) with pytest.raises(httpx.TooManyRedirects): @@ -318,13 +261,13 @@ def test_sync_too_many_redirects_calling_next(): def test_redirect_loop(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) with pytest.raises(httpx.TooManyRedirects): client.get("https://example.org/redirect_loop") def test_cross_domain_redirect_with_auth_header(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.com/cross_domain" headers = {"Authorization": "abc"} response = client.get(url, headers=headers) @@ -333,7 +276,7 @@ def test_cross_domain_redirect_with_auth_header(): def test_cross_domain_redirect_with_auth(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.com/cross_domain" response = client.get(url, auth=("user", "pass")) assert response.url == "https://example.org/cross_domain_target" @@ -341,7 +284,7 @@ def test_cross_domain_redirect_with_auth(): def test_same_domain_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.org/cross_domain" headers = {"Authorization": "abc"} response = client.get(url, headers=headers) @@ -353,7 +296,7 @@ def test_body_redirect(): """ A 308 redirect should preserve the request body. """ - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.org/redirect_body" data = b"Example request body" response = client.post(url, data=data) @@ -366,7 +309,7 @@ def test_no_body_redirect(): """ A 303 redirect should remove the request body. """ - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.org/redirect_no_body" data = b"Example request body" response = client.post(url, data=data) @@ -376,7 +319,7 @@ def test_no_body_redirect(): def test_can_stream_if_no_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.org/redirect_301" with client.stream("GET", url, allow_redirects=False) as response: assert not response.is_closed @@ -385,7 +328,7 @@ def test_can_stream_if_no_redirect(): def test_cannot_redirect_streaming_body(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.org/redirect_body" def streaming_body(): @@ -396,64 +339,53 @@ def streaming_body(): def test_cross_subdomain_redirect(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) url = "https://example.com/cross_subdomain" response = client.get(url) assert response.url == "https://www.example.org/cross_subdomain" -class MockCookieTransport(httpcore.SyncHTTPTransport): - def request( - self, - method: bytes, - url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]] = None, - stream: httpcore.SyncByteStream = None, - timeout: typing.Mapping[str, typing.Optional[float]] = None, - ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream - ]: - scheme, host, port, path = url - if path == b"/": - cookie = get_header_value(headers, "Cookie") - if cookie is not None: - content = b"Logged in" - else: - content = b"Not logged in" - return b"HTTP/1.1", 200, b"OK", [], ByteStream(content) - - elif path == b"/login": - status_code = httpx.codes.SEE_OTHER - headers = [ - (b"location", b"/"), +def cookie_sessions(request: httpx.Request) -> httpx.Response: + if request.url.path == "/": + cookie = request.headers.get("Cookie") + if cookie is not None: + content = b"Logged in" + else: + content = b"Not logged in" + return httpx.Response(200, content=content) + + elif request.url.path == "/login": + status_code = httpx.codes.SEE_OTHER + headers = [ + (b"location", b"/"), + ( + b"set-cookie", ( - b"set-cookie", - ( - b"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; " - b"httponly; samesite=lax" - ), + b"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; " + b"httponly; samesite=lax" ), - ] - return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"") - - else: - assert path == b"/logout" - status_code = httpx.codes.SEE_OTHER - headers = [ - (b"location", b"/"), + ), + ] + return httpx.Response(status_code, headers=headers) + + else: + assert request.url.path == "/logout" + status_code = httpx.codes.SEE_OTHER + headers = [ + (b"location", b"/"), + ( + b"set-cookie", ( - b"set-cookie", - ( - b"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; " - b"httponly; samesite=lax" - ), + b"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; " + b"httponly; samesite=lax" ), - ] - return b"HTTP/1.1", status_code, b"See Other", headers, ByteStream(b"") + ), + ] + return httpx.Response(status_code, headers=headers) def test_redirect_cookie_behavior(): - client = httpx.Client(transport=MockCookieTransport()) + client = httpx.Client(transport=MockTransport(cookie_sessions)) # The client is not logged in. response = client.get("https://example.com/") @@ -482,7 +414,7 @@ def test_redirect_cookie_behavior(): def test_redirect_custom_scheme(): - client = httpx.Client(transport=SyncMockTransport()) + client = httpx.Client(transport=MockTransport(redirects)) with pytest.raises(httpx.UnsupportedProtocol) as e: client.post("https://example.org/redirect_custom_scheme") - assert str(e.value) == "Scheme b'market' not supported." + assert str(e.value) == "Scheme 'market' not supported." diff --git a/tests/utils.py b/tests/utils.py index 5b10287523..ee319e0010 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -70,3 +70,42 @@ def request( response.headers.raw, response._raw_stream, ) + + +class AsyncMockTransport(httpcore.AsyncHTTPTransport): + def __init__(self, handler: Callable) -> None: + self.impl = MockTransport(handler) + + async def request( + self, + method: bytes, + url: Tuple[bytes, bytes, Optional[int], bytes], + headers: List[Tuple[bytes, bytes]] = None, + stream: httpcore.AsyncByteStream = None, + timeout: Mapping[str, Optional[float]] = None, + ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]: + content = ( + httpcore.PlainByteStream(b"".join([part async for part in stream])) + if stream + else httpcore.PlainByteStream(b"") + ) + + ( + http_version, + status_code, + reason_phrase, + headers, + response_stream, + ) = self.impl.request( + method, url, headers=headers, stream=content, timeout=timeout + ) + + content = httpcore.PlainByteStream(b"".join([part for part in response_stream])) + + return ( + http_version, + status_code, + reason_phrase, + headers, + content, + ) From bab03ea94078bdad284d807f91e053cc92ca116d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 11 Sep 2020 12:39:59 +0100 Subject: [PATCH 5/8] Reduce change footprint --- httpx/_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/httpx/_models.py b/httpx/_models.py index 29b3e67ac0..526ee2cebf 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -605,6 +605,7 @@ def __init__( self.stream = stream else: self.stream = encode(data, files, json) + self.prepare() def prepare(self) -> None: From 7f31c2a50ea8727aafe5d34ab0f2af7896bea58a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 11 Sep 2020 12:41:06 +0100 Subject: [PATCH 6/8] Reduce change footprint --- tests/client/test_queryparams.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/client/test_queryparams.py b/tests/client/test_queryparams.py index a14d1f4e30..39731d5bb0 100644 --- a/tests/client/test_queryparams.py +++ b/tests/client/test_queryparams.py @@ -2,6 +2,10 @@ from tests.utils import MockTransport +def hello_world(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=b"Hello, world") + + def test_client_queryparams(): client = httpx.Client(params={"a": "b"}) assert isinstance(client.params, httpx.QueryParams) @@ -20,9 +24,6 @@ def test_client_queryparams_string(): def test_client_queryparams_echo(): - def hello_world(request: httpx.Request) -> httpx.Response: - return httpx.Response(200, content=b"Hello, world") - url = "http://example.org/echo_queryparams" client_queryparams = "first=str" request_queryparams = {"second": "dict"} From be04d4e024345dc90be3c97030a4477134af1d1b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 11 Sep 2020 12:46:05 +0100 Subject: [PATCH 7/8] Clean up headers slightly --- tests/client/test_redirects.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index cac94bf2aa..63fcd32087 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -356,31 +356,25 @@ def cookie_sessions(request: httpx.Request) -> httpx.Response: elif request.url.path == "/login": status_code = httpx.codes.SEE_OTHER - headers = [ - (b"location", b"/"), - ( - b"set-cookie", - ( - b"session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; " - b"httponly; samesite=lax" - ), + headers = { + "location": "/", + "set-cookie": ( + "session=eyJ1c2VybmFtZSI6ICJ0b21; path=/; Max-Age=1209600; " + "httponly; samesite=lax" ), - ] + } return httpx.Response(status_code, headers=headers) else: assert request.url.path == "/logout" status_code = httpx.codes.SEE_OTHER - headers = [ - (b"location", b"/"), - ( - b"set-cookie", - ( - b"session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; " - b"httponly; samesite=lax" - ), + headers = { + "location": "/", + "set-cookie": ( + "session=null; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT; " + "httponly; samesite=lax" ), - ] + } return httpx.Response(status_code, headers=headers) From b7b8a3464a8bb0949ddc41d98bb15765425f52f8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 12 Sep 2020 11:11:40 +0100 Subject: [PATCH 8/8] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b871b15cdb..037fb668c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,7 +18,7 @@ black==20.8b1 cryptography flake8 flake8-bugbear -flake8-pie +flake8-pie==0.5.* isort==5.* mypy pytest==5.*