From 406c0dce8c43873cd6ab608ffa9f04ea4a926818 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 14:22:27 +0100 Subject: [PATCH 1/8] encode -> encode_request_body --- httpx/_content_streams.py | 2 +- httpx/_models.py | 4 ++-- tests/test_content_streams.py | 24 ++++++++++++------------ tests/test_multipart.py | 12 ++++++------ 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index 402fa959c8..6b8d47a2be 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -370,7 +370,7 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: yield chunk -def encode( +def encode_request_body( data: RequestData = None, files: RequestFiles = None, json: typing.Any = None, diff --git a/httpx/_models.py b/httpx/_models.py index 4a40263266..c73063bc09 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -15,7 +15,7 @@ import rfc3986.exceptions from .__version__ import __version__ -from ._content_streams import ByteStream, ContentStream, encode +from ._content_streams import ByteStream, ContentStream, encode_request_body from ._decoders import ( SUPPORTED_DECODERS, Decoder, @@ -609,7 +609,7 @@ def __init__( if stream is not None: self.stream = stream else: - self.stream = encode(data, files, json) + self.stream = encode_request_body(data, files, json) self.timer = ElapsedTimer() self.prepare() diff --git a/tests/test_content_streams.py b/tests/test_content_streams.py index 140aa8d2af..c72e01d2a9 100644 --- a/tests/test_content_streams.py +++ b/tests/test_content_streams.py @@ -3,7 +3,7 @@ import pytest from httpx import StreamConsumed -from httpx._content_streams import ContentStream, encode +from httpx._content_streams import ContentStream, encode_request_body @pytest.mark.asyncio @@ -20,7 +20,7 @@ async def test_base_content(): @pytest.mark.asyncio async def test_empty_content(): - stream = encode() + stream = encode_request_body() sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -32,7 +32,7 @@ async def test_empty_content(): @pytest.mark.asyncio async def test_bytes_content(): - stream = encode(data=b"Hello, world!") + stream = encode_request_body(data=b"Hello, world!") sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -48,7 +48,7 @@ def hello_world(): yield b"Hello, " yield b"world!" - stream = encode(data=hello_world()) + stream = encode_request_body(data=hello_world()) content = b"".join([part for part in stream]) assert not stream.can_replay() @@ -68,7 +68,7 @@ async def hello_world(): yield b"Hello, " yield b"world!" - stream = encode(data=hello_world()) + stream = encode_request_body(data=hello_world()) content = b"".join([part async for part in stream]) assert not stream.can_replay() @@ -84,7 +84,7 @@ async def hello_world(): @pytest.mark.asyncio async def test_json_content(): - stream = encode(json={"Hello": "world!"}) + stream = encode_request_body(json={"Hello": "world!"}) sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -99,7 +99,7 @@ async def test_json_content(): @pytest.mark.asyncio async def test_urlencoded_content(): - stream = encode(data={"Hello": "world!"}) + stream = encode_request_body(data={"Hello": "world!"}) sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -115,7 +115,7 @@ async def test_urlencoded_content(): @pytest.mark.asyncio async def test_multipart_files_content(): files = {"file": io.BytesIO(b"")} - stream = encode(files=files, boundary=b"+++") + stream = encode_request_body(files=files, boundary=b"+++") sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -150,7 +150,7 @@ async def test_multipart_files_content(): async def test_multipart_data_and_files_content(): data = {"message": "Hello, world!"} files = {"file": io.BytesIO(b"")} - stream = encode(data=data, files=files, boundary=b"+++") + stream = encode_request_body(data=data, files=files, boundary=b"+++") sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -191,7 +191,7 @@ async def test_multipart_data_and_files_content(): @pytest.mark.asyncio async def test_empty_request(): - stream = encode(data={}, files={}) + stream = encode_request_body(data={}, files={}) sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -203,7 +203,7 @@ async def test_empty_request(): def test_invalid_argument(): with pytest.raises(TypeError): - encode(123) # type: ignore + encode_request_body(123) # type: ignore @pytest.mark.asyncio @@ -212,7 +212,7 @@ async def test_multipart_multiple_files_single_input_content(): ("file", io.BytesIO(b"")), ("file", io.BytesIO(b"")), ] - stream = encode(files=files, boundary=b"+++") + stream = encode_request_body(files=files, boundary=b"+++") sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index f4962daba0..824a0c071a 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -8,7 +8,7 @@ import pytest import httpx -from httpx._content_streams import MultipartStream, encode +from httpx._content_streams import MultipartStream, encode_request_body from httpx._utils import format_form_param @@ -126,7 +126,7 @@ def test_multipart_encode(tmp_path: typing.Any) -> None: with mock.patch("os.urandom", return_value=os.urandom(16)): boundary = os.urandom(16).hex() - stream = encode(data=data, files=files) + stream = encode_request_body(data=data, files=files) assert isinstance(stream, MultipartStream) assert stream.can_replay() @@ -153,7 +153,7 @@ def test_multipart_encode_files_allows_filenames_as_none() -> None: with mock.patch("os.urandom", return_value=os.urandom(16)): boundary = os.urandom(16).hex() - stream = encode(data={}, files=files) + stream = encode_request_body(data={}, files=files) assert isinstance(stream, MultipartStream) assert stream.can_replay() @@ -180,7 +180,7 @@ def test_multipart_encode_files_guesses_correct_content_type( with mock.patch("os.urandom", return_value=os.urandom(16)): boundary = os.urandom(16).hex() - stream = encode(data={}, files=files) + stream = encode_request_body(data={}, files=files) assert isinstance(stream, MultipartStream) assert stream.can_replay() @@ -204,7 +204,7 @@ def test_multipart_encode_files_allows_bytes_or_str_content( with mock.patch("os.urandom", return_value=os.urandom(16)): boundary = os.urandom(16).hex() - stream = encode(data={}, files=files) + stream = encode_request_body(data={}, files=files) assert isinstance(stream, MultipartStream) assert stream.can_replay() @@ -242,7 +242,7 @@ def data() -> typing.Iterator[bytes]: fileobj: typing.Any = IteratorIO(data()) files = {"file": fileobj} - stream = encode(files=files, boundary=b"+++") + stream = encode_request_body(files=files, boundary=b"+++") assert not stream.can_replay() content = ( From 786258d6ca0a411b6c85d0d87539000577f82361 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 14:26:30 +0100 Subject: [PATCH 2/8] Use encode_response_body for Response(content=...) case --- httpx/_content_streams.py | 4 ++++ httpx/_models.py | 9 +++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index 6b8d47a2be..3fc55fd3a5 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -402,3 +402,7 @@ def encode_request_body( return IteratorStream(iterator=data) raise TypeError(f"Unexpected type for 'data', {type(data)!r}") + + +def encode_response_body(content: bytes = None) -> ContentStream: + return ByteStream(body=content or b"") diff --git a/httpx/_models.py b/httpx/_models.py index c73063bc09..206f0022b1 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -15,7 +15,12 @@ import rfc3986.exceptions from .__version__ import __version__ -from ._content_streams import ByteStream, ContentStream, encode_request_body +from ._content_streams import ( + ByteStream, + ContentStream, + encode_request_body, + encode_response_body, +) from ._decoders import ( SUPPORTED_DECODERS, Decoder, @@ -713,7 +718,7 @@ def __init__( if stream is not None: self._raw_stream = stream else: - self._raw_stream = ByteStream(body=content or b"") + self._raw_stream = encode_response_body(content) self.read() @property From bd1438614030d7acf754bf1c11b4a2902ef96693 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 14:36:00 +0100 Subject: [PATCH 3/8] Auto set appropriate response headers --- httpx/_models.py | 2 ++ tests/models/test_responses.py | 1 + 2 files changed, 3 insertions(+) diff --git a/httpx/_models.py b/httpx/_models.py index 206f0022b1..fb00bc041f 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -719,6 +719,8 @@ def __init__( self._raw_stream = stream else: self._raw_stream = encode_response_body(content) + for key, value in self._raw_stream.get_headers().items(): + self.headers.setdefault(key, value) self.read() @property diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 32163a6fc8..770a98b80b 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -32,6 +32,7 @@ def test_response(): assert response.request.method == "GET" assert response.request.url == "https://example.org" assert response.elapsed >= datetime.timedelta(0) + assert response.headers == httpx.Headers({"Content-Length": "13"}) assert not response.is_error From f457af0a52343fb64c46cd35fbcac61c92a7036b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 14:40:39 +0100 Subject: [PATCH 4/8] Add support for json=... --- httpx/_content_streams.py | 6 +++++- httpx/_models.py | 7 ++++--- tests/models/test_responses.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index 3fc55fd3a5..99461adc35 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -404,5 +404,9 @@ def encode_request_body( raise TypeError(f"Unexpected type for 'data', {type(data)!r}") -def encode_response_body(content: bytes = None) -> ContentStream: +def encode_response_body( + content: bytes = None, json: typing.Any = None +) -> ContentStream: + if json is not None: + return JSONStream(json=json) return ByteStream(body=content or b"") diff --git a/httpx/_models.py b/httpx/_models.py index fb00bc041f..2c91cac373 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -696,11 +696,12 @@ def __init__( self, status_code: int, *, - request: Request = None, http_version: str = None, headers: HeaderTypes = None, - stream: ContentStream = None, content: bytes = None, + json: typing.Any = None, + stream: ContentStream = None, + request: Request = None, history: typing.List["Response"] = None, ): self.status_code = status_code @@ -718,7 +719,7 @@ def __init__( if stream is not None: self._raw_stream = stream else: - self._raw_stream = encode_response_body(content) + self._raw_stream = encode_response_body(content=content, json=json) for key, value in self._raw_stream.get_headers().items(): self.headers.setdefault(key, value) self.read() diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 770a98b80b..9bc68160d0 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -36,6 +36,16 @@ def test_response(): assert not response.is_error +def test_json_response(): + response = httpx.Response(200, json={"Hello": "World!"}) + + assert response.status_code == 200 + assert response.text == '{"Hello": "World!"}' + assert response.headers == httpx.Headers( + {"Content-Type": "application/json", "Content-Length": "19"} + ) + + def test_raise_for_status(): request = httpx.Request("GET", "https://example.org") From b34da1842dcec10178081f7da2ee39e6b677f271 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 15:00:35 +0100 Subject: [PATCH 5/8] Add text/plain and text/html responses --- httpx/_content_streams.py | 48 ++++++++++++++++++++++++++++++++-- httpx/_models.py | 6 ++++- tests/models/test_responses.py | 20 ++++++++++++++ tests/test_content_streams.py | 36 ++++++++++++++++++++++++- 4 files changed, 106 insertions(+), 4 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index 99461adc35..ac2d7901e8 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -370,6 +370,46 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: yield chunk +class TextStream(ContentStream): + """ + Response content as plain text. + """ + + def __init__(self, text: str) -> None: + self.body = text.encode("utf-8") + + def get_headers(self) -> typing.Dict[str, str]: + content_length = str(len(self.body)) + content_type = "text/plain; charset=utf-8" + return {"Content-Length": content_length, "Content-Type": content_type} + + def __iter__(self) -> typing.Iterator[bytes]: + yield self.body + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + yield self.body + + +class HTMLStream(ContentStream): + """ + Response content as HTML. + """ + + def __init__(self, html: str) -> None: + self.body = html.encode("utf-8") + + def get_headers(self) -> typing.Dict[str, str]: + content_length = str(len(self.body)) + content_type = "text/html; charset=utf-8" + return {"Content-Length": content_length, "Content-Type": content_type} + + def __iter__(self) -> typing.Iterator[bytes]: + yield self.body + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + yield self.body + + def encode_request_body( data: RequestData = None, files: RequestFiles = None, @@ -405,8 +445,12 @@ def encode_request_body( def encode_response_body( - content: bytes = None, json: typing.Any = None + content: bytes = None, text: str = None, html: str = None, json: typing.Any = None ) -> ContentStream: - if json is not None: + if text is not None: + return TextStream(text=text) + elif html is not None: + return HTMLStream(html=html) + elif json is not None: return JSONStream(json=json) return ByteStream(body=content or b"") diff --git a/httpx/_models.py b/httpx/_models.py index 2c91cac373..d952ea12d2 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -699,6 +699,8 @@ def __init__( http_version: str = None, headers: HeaderTypes = None, content: bytes = None, + text: str = None, + html: str = None, json: typing.Any = None, stream: ContentStream = None, request: Request = None, @@ -719,7 +721,9 @@ def __init__( if stream is not None: self._raw_stream = stream else: - self._raw_stream = encode_response_body(content=content, json=json) + self._raw_stream = encode_response_body( + content=content, text=text, html=html, json=json + ) for key, value in self._raw_stream.get_headers().items(): self.headers.setdefault(key, value) self.read() diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 9bc68160d0..93798c2787 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -36,6 +36,26 @@ def test_response(): assert not response.is_error +def test_text_response(): + response = httpx.Response(200, text="Hello, world!") + + assert response.status_code == 200 + assert response.text == "Hello, world!" + assert response.headers == httpx.Headers( + {"Content-Type": "text/plain; charset=utf-8", "Content-Length": "13"} + ) + + +def test_html_response(): + response = httpx.Response(200, html="

Hello, world!

") + + assert response.status_code == 200 + assert response.text == "

Hello, world!

" + assert response.headers == httpx.Headers( + {"Content-Type": "text/html; charset=utf-8", "Content-Length": "35"} + ) + + def test_json_response(): response = httpx.Response(200, json={"Hello": "World!"}) diff --git a/tests/test_content_streams.py b/tests/test_content_streams.py index c72e01d2a9..bedb24ca04 100644 --- a/tests/test_content_streams.py +++ b/tests/test_content_streams.py @@ -3,7 +3,11 @@ import pytest from httpx import StreamConsumed -from httpx._content_streams import ContentStream, encode_request_body +from httpx._content_streams import ( + ContentStream, + encode_request_body, + encode_response_body, +) @pytest.mark.asyncio @@ -251,3 +255,33 @@ async def test_multipart_multiple_files_single_input_content(): b"--+++--\r\n", ] ) + + +@pytest.mark.asyncio +async def test_text_content(): + stream = encode_response_body(text="Hello, world!") + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + assert stream.can_replay() + assert stream.get_headers() == { + "Content-Length": "13", + "Content-Type": "text/plain; charset=utf-8", + } + assert sync_content == b"Hello, world!" + assert async_content == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_html_content(): + stream = encode_response_body(html="

Hello, world!

") + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + assert stream.can_replay() + assert stream.get_headers() == { + "Content-Length": "35", + "Content-Type": "text/html; charset=utf-8", + } + assert sync_content == b"

Hello, world!

" + assert async_content == b"

Hello, world!

" From d03b972a2a157b4e0f7b0fb4447d8d8bae4929bd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 15:15:09 +0100 Subject: [PATCH 6/8] Support Response(content=) --- httpx/_content_streams.py | 40 +++++++++++++++++++++---------- httpx/_models.py | 8 +++++-- httpx/_types.py | 1 + tests/test_content_streams.py | 45 ++++++++++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 15 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index ac2d7901e8..fa0df50478 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -8,7 +8,7 @@ import httpcore from ._exceptions import StreamConsumed -from ._types import FileContent, FileTypes, RequestData, RequestFiles +from ._types import FileContent, FileTypes, RequestData, RequestFiles, ResponseContent from ._utils import ( format_form_param, guess_content_type, @@ -420,13 +420,12 @@ def encode_request_body( Handles encoding the given `data`, `files`, and `json`, returning a `ContentStream` implementation. """ - if not data: + if data is None: if json is not None: return JSONStream(json=json) elif files: return MultipartStream(data={}, files=files, boundary=boundary) - else: - return ByteStream(body=b"") + return ByteStream(body=b"") elif isinstance(data, dict): if files: return MultipartStream(data=data, files=files, boundary=boundary) @@ -445,12 +444,29 @@ def encode_request_body( def encode_response_body( - content: bytes = None, text: str = None, html: str = None, json: typing.Any = None + content: ResponseContent = None, + text: str = None, + html: str = None, + json: typing.Any = None, ) -> ContentStream: - if text is not None: - return TextStream(text=text) - elif html is not None: - return HTMLStream(html=html) - elif json is not None: - return JSONStream(json=json) - return ByteStream(body=content or b"") + if content is None: + if text is not None: + return TextStream(text=text) + elif html is not None: + return HTMLStream(html=html) + elif json is not None: + return JSONStream(json=json) + return ByteStream(b"") + elif isinstance(content, bytes): + return ByteStream(body=content) + elif hasattr(content, "__aiter__"): + content = typing.cast(typing.AsyncIterator[bytes], content) + return AsyncIteratorStream(aiterator=content) + elif hasattr(content, "__iter__"): + content = typing.cast(typing.Iterator[bytes], content) + return IteratorStream(iterator=content) + + raise TypeError( + f"Unexpected type for 'content', should be bytes or " + f"byte iterator {type(content)!r}" + ) diff --git a/httpx/_models.py b/httpx/_models.py index d952ea12d2..8c87465405 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -50,6 +50,7 @@ QueryParamTypes, RequestData, RequestFiles, + ResponseContent, URLTypes, ) from ._utils import ( @@ -698,7 +699,7 @@ def __init__( *, http_version: str = None, headers: HeaderTypes = None, - content: bytes = None, + content: ResponseContent = None, text: str = None, html: str = None, json: typing.Any = None, @@ -726,7 +727,10 @@ def __init__( ) for key, value in self._raw_stream.get_headers().items(): self.headers.setdefault(key, value) - self.read() + + if content is None or isinstance(content, bytes): + # Load the response body, except for streaming content. + self.read() @property def elapsed(self) -> datetime.timedelta: diff --git a/httpx/_types.py b/httpx/_types.py index 3a90ee42e7..9cd85a04b6 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -64,6 +64,7 @@ ] RequestData = Union[dict, str, bytes, Iterator[bytes], AsyncIterator[bytes]] +ResponseContent = Union[bytes, Iterator[bytes], AsyncIterator[bytes]] FileContent = Union[IO[str], IO[bytes], str, bytes] FileTypes = Union[ diff --git a/tests/test_content_streams.py b/tests/test_content_streams.py index bedb24ca04..c93a370a53 100644 --- a/tests/test_content_streams.py +++ b/tests/test_content_streams.py @@ -195,7 +195,7 @@ async def test_multipart_data_and_files_content(): @pytest.mark.asyncio async def test_empty_request(): - stream = encode_request_body(data={}, files={}) + stream = encode_request_body() sync_content = b"".join([part for part in stream]) async_content = b"".join([part async for part in stream]) @@ -209,6 +209,9 @@ def test_invalid_argument(): with pytest.raises(TypeError): encode_request_body(123) # type: ignore + with pytest.raises(TypeError): + encode_response_body(123) # type: ignore + @pytest.mark.asyncio async def test_multipart_multiple_files_single_input_content(): @@ -285,3 +288,43 @@ async def test_html_content(): } assert sync_content == b"

Hello, world!

" assert async_content == b"

Hello, world!

" + + +@pytest.mark.asyncio +async def test_iterator_response_content(): + def hello_world(): + yield b"Hello, " + yield b"world!" + + stream = encode_response_body(content=hello_world()) + content = b"".join([part for part in stream]) + + assert not stream.can_replay() + assert stream.get_headers() == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" + + with pytest.raises(RuntimeError): + [part async for part in stream] + + with pytest.raises(StreamConsumed): + [part for part in stream] + + +@pytest.mark.asyncio +async def test_aiterator_response_content(): + async def hello_world(): + yield b"Hello, " + yield b"world!" + + stream = encode_response_body(content=hello_world()) + content = b"".join([part async for part in stream]) + + assert not stream.can_replay() + assert stream.get_headers() == {"Transfer-Encoding": "chunked"} + assert content == b"Hello, world!" + + with pytest.raises(RuntimeError): + [part for part in stream] + + with pytest.raises(StreamConsumed): + [part async for part in stream] From 76d360405094bbb03451516928e69c60d41c9265 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 15:37:34 +0100 Subject: [PATCH 7/8] Drop unneeded close_func from iterator content streams --- httpx/_content_streams.py | 18 ++----------- httpx/_transports/urllib3.py | 26 +++++++++++------- tests/client/test_queryparams.py | 8 ++++-- tests/models/test_responses.py | 46 ++++++++------------------------ 4 files changed, 36 insertions(+), 62 deletions(-) diff --git a/httpx/_content_streams.py b/httpx/_content_streams.py index fa0df50478..0ff15f2298 100644 --- a/httpx/_content_streams.py +++ b/httpx/_content_streams.py @@ -72,11 +72,8 @@ class IteratorStream(ContentStream): Request content encoded as plain bytes, using an byte iterator. """ - def __init__( - self, iterator: typing.Iterator[bytes], close_func: typing.Callable = None - ) -> None: + def __init__(self, iterator: typing.Iterator[bytes]) -> None: self.iterator = iterator - self.close_func = close_func self.is_stream_consumed = False def can_replay(self) -> bool: @@ -95,21 +92,14 @@ def __iter__(self) -> typing.Iterator[bytes]: def __aiter__(self) -> typing.AsyncIterator[bytes]: raise RuntimeError("Attempted to call a async iterator on an sync stream.") - def close(self) -> None: - if self.close_func is not None: - self.close_func() - class AsyncIteratorStream(ContentStream): """ Request content encoded as plain bytes, using an async byte iterator. """ - def __init__( - self, aiterator: typing.AsyncIterator[bytes], close_func: typing.Callable = None - ) -> None: + def __init__(self, aiterator: typing.AsyncIterator[bytes]) -> None: self.aiterator = aiterator - self.close_func = close_func self.is_stream_consumed = False def can_replay(self) -> bool: @@ -128,10 +118,6 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: async for part in self.aiterator: yield part - async def aclose(self) -> None: - if self.close_func is not None: - await self.close_func() - class JSONStream(ContentStream): """ diff --git a/httpx/_transports/urllib3.py b/httpx/_transports/urllib3.py index c5b7af6cc2..bc8dfde664 100644 --- a/httpx/_transports/urllib3.py +++ b/httpx/_transports/urllib3.py @@ -4,7 +4,7 @@ import httpcore from .._config import create_ssl_context -from .._content_streams import ByteStream, IteratorStream +from .._content_streams import ByteStream from .._exceptions import NetworkError, map_exceptions from .._types import CertTypes, VerifyTypes @@ -15,6 +15,21 @@ urllib3 = None +class URLLib3ByteStream(httpcore.SyncByteStream): + def __init__(self, conn: urllib3.HTTPResponse) -> None: + self._conn = conn + + def __iter__(self) -> Iterator[bytes]: + try: + for chunk in self._conn.stream(4096, decode_content=False): + yield chunk + except socket.error as exc: + raise httpcore.NetworkError(exc) from exc + + def close(self) -> None: + self._conn.release_conn() + + class URLLib3Transport(httpcore.SyncHTTPTransport): def __init__( self, @@ -104,16 +119,9 @@ def request( pool_timeout=timeout.get("pool"), ) - def response_bytes() -> Iterator[bytes]: - with map_exceptions({socket.error: NetworkError}): - for chunk in conn.stream(4096, decode_content=False): - yield chunk - status_code = conn.status headers = list(conn.headers.items()) - response_stream = IteratorStream( - iterator=response_bytes(), close_func=conn.release_conn - ) + response_stream = URLLib3ByteStream(conn=conn) return (b"HTTP/1.1", status_code, conn.reason, headers, response_stream) def close(self) -> None: diff --git a/tests/client/test_queryparams.py b/tests/client/test_queryparams.py index 22f715dadc..0daa393ae4 100644 --- a/tests/client/test_queryparams.py +++ b/tests/client/test_queryparams.py @@ -3,7 +3,7 @@ import httpcore import httpx -from httpx._content_streams import ContentStream, JSONStream +from httpx._content_streams import JSONStream class MockTransport(httpcore.SyncHTTPTransport): @@ -15,7 +15,11 @@ def request( stream: httpcore.SyncByteStream = None, timeout: typing.Mapping[str, typing.Optional[float]] = None, ) -> typing.Tuple[ - bytes, int, bytes, typing.List[typing.Tuple[bytes, bytes]], ContentStream + bytes, + int, + bytes, + typing.List[typing.Tuple[bytes, bytes]], + httpcore.SyncByteStream, ]: body = JSONStream({"ok": "ok"}) return b"HTTP/1.1", 200, b"OK", [], body diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index 93798c2787..1db122181b 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -6,7 +6,6 @@ import pytest import httpx -from httpx._content_streams import AsyncIteratorStream, IteratorStream def streaming_body(): @@ -248,10 +247,9 @@ async def test_aread(): def test_iter_raw(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) raw = b"" @@ -262,10 +260,9 @@ def test_iter_raw(): @pytest.mark.asyncio async def test_aiter_raw(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) raw = b"" @@ -350,10 +347,9 @@ async def test_aiter_lines(): def test_sync_streaming_response(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) assert response.status_code == 200 @@ -368,10 +364,9 @@ def test_sync_streaming_response(): @pytest.mark.asyncio async def test_async_streaming_response(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) assert response.status_code == 200 @@ -385,10 +380,9 @@ async def test_async_streaming_response(): def test_cannot_read_after_stream_consumed(): - stream = IteratorStream(iterator=streaming_body()) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) content = b"" @@ -401,10 +395,9 @@ def test_cannot_read_after_stream_consumed(): @pytest.mark.asyncio async def test_cannot_aread_after_stream_consumed(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) content = b"" @@ -416,20 +409,13 @@ async def test_cannot_aread_after_stream_consumed(): def test_cannot_read_after_response_closed(): - is_closed = False - - def close_func(): - nonlocal is_closed - is_closed = True - - stream = IteratorStream(iterator=streaming_body(), close_func=close_func) response = httpx.Response( 200, - stream=stream, + content=streaming_body(), ) response.close() - assert is_closed + assert response.is_closed with pytest.raises(httpx.ResponseClosed): response.read() @@ -437,22 +423,13 @@ def close_func(): @pytest.mark.asyncio async def test_cannot_aread_after_response_closed(): - is_closed = False - - async def close_func(): - nonlocal is_closed - is_closed = True - - stream = AsyncIteratorStream( - aiterator=async_streaming_body(), close_func=close_func - ) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) await response.aclose() - assert is_closed + assert response.is_closed with pytest.raises(httpx.ResponseClosed): await response.aread() @@ -460,10 +437,9 @@ async def close_func(): @pytest.mark.asyncio async def test_elapsed_not_available_until_closed(): - stream = AsyncIteratorStream(aiterator=async_streaming_body()) response = httpx.Response( 200, - stream=stream, + content=async_streaming_body(), ) with pytest.raises(RuntimeError): From 8e7b8d1810c6e083ee3e96ab4f157f6d212ad212 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 2 Sep 2020 16:07:42 +0100 Subject: [PATCH 8/8] Use Response(content=...) where possible --- tests/test_decoders.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/test_decoders.py b/tests/test_decoders.py index dbbaac5450..7dfca9ef50 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -4,7 +4,6 @@ import pytest import httpx -from httpx._content_streams import AsyncIteratorStream from httpx._decoders import ( BrotliDecoder, DeflateDecoder, @@ -130,11 +129,10 @@ async def compress(body): yield compressor.flush() headers = [(b"Content-Encoding", b"gzip")] - stream = AsyncIteratorStream(aiterator=compress(body)) response = httpx.Response( 200, headers=headers, - stream=stream, + content=compress(body), ) assert not hasattr(response, "body") assert await response.aread() == body @@ -199,19 +197,17 @@ async def iterator(): yield chunk # Accessing `.text` on a read response. - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, - stream=stream, + content=iterator(), ) await response.aread() assert response.text == (b"".join(data)).decode(encoding) # Streaming `.aiter_text` iteratively. - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, - stream=stream, + content=iterator(), ) text = "".join([part async for part in response.aiter_text()]) assert text == (b"".join(data)).decode(encoding) @@ -224,11 +220,10 @@ async def iterator(): yield b"\x83" yield b"\x89\x83x\x83\x8b" - stream = AsyncIteratorStream(aiterator=iterator()) response = httpx.Response( 200, headers=[(b"Content-Type", b"text/html; charset=shift-jis")], - stream=stream, + content=iterator(), ) await response.aread()