From be809fba517d3b813c9d83b6a7effdb555727644 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Tue, 25 Aug 2020 16:49:17 +0200 Subject: [PATCH 1/7] Add support for async auth flows --- httpx/_auth.py | 26 ++++++++++++++++++++++ httpx/_client.py | 8 +++---- tests/client/test_auth.py | 45 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/httpx/_auth.py b/httpx/_auth.py index 571584593b..ab2082a79a 100644 --- a/httpx/_auth.py +++ b/httpx/_auth.py @@ -17,6 +17,12 @@ class Auth: To implement a custom authentication scheme, subclass `Auth` and override the `.auth_flow()` method. + + If the authentication scheme does I/O, such as disk access or network calls, or uses + synchronization primitives such as locks, you should override `.async_auth_flow()` + to provide an async-friendly implementation that will be used by the `AsyncClient`. + Usage of sync I/O within an async codebase would block the event loop, and could + cause performance issues. """ requires_request_body = False @@ -46,6 +52,26 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non """ yield request + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + """ + Execute the authentication flow asynchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O, such as disk access or network calls, + or uses concurrency primitives such as locks. + """ + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + try: + request = flow.send(response) + except StopIteration: + break + class FunctionAuth(Auth): """ diff --git a/httpx/_client.py b/httpx/_client.py index 2d2ca9ac16..939d50ae39 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -1365,15 +1365,15 @@ async def _send_handling_auth( if auth.requires_request_body: await request.aread() - auth_flow = auth.auth_flow(request) - request = next(auth_flow) + auth_flow = auth.async_auth_flow(request) + request = await auth_flow.__anext__() while True: response = await self._send_single_request(request, timeout) if auth.requires_response_body: await response.aread() try: - next_request = auth_flow.send(response) - except StopIteration: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: return response except BaseException as exc: await response.aclose() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index edfccf0a70..52f18f0ad7 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,5 +1,7 @@ +import asyncio import hashlib import os +import threading import typing import httpcore @@ -184,6 +186,29 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non yield request +class SyncOrAsyncAuth(Auth): + """ + A mock authentication scheme that uses a different implementation for the + sync and async cases. + """ + + def __init__(self): + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + + def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + with self._lock: + request.headers["Authorization"] = "sync-auth" + yield request + + async def async_auth_flow( + self, request: Request + ) -> typing.AsyncGenerator[Request, Response]: + async with self._async_lock: + request.headers["Authorization"] = "async-auth" + yield request + + @pytest.mark.asyncio async def test_basic_auth() -> None: url = "https://example.org/" @@ -641,3 +666,23 @@ def test_sync_auth_reads_response_body() -> None: response = client.get(url, auth=auth) assert response.status_code == 200 assert response.json() == {"auth": '{"auth": "xyz"}'} + + +@pytest.mark.asyncio +async def test_sync_async_auth() -> None: + """ + Test that we can use a different auth flow implementation in the async case, to + support cases that require performing I/O or using concurrency primitives (such + as checking a disk-based cache or fetching a token from a remote auth server). + """ + url = "https://example.org/" + auth = SyncOrAsyncAuth() + + client = AsyncClient(transport=AsyncMockTransport()) + response = await client.get(url, auth=auth) + assert response.status_code == 200 + assert response.json() == {"auth": "async-auth"} + + response = Client(transport=SyncMockTransport()).get(url, auth=auth) + assert response.status_code == 200 + assert response.json() == {"auth": "sync-auth"} From 3432c142ac480f30dbbd0840ac1518264fa6d6cf Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 27 Aug 2020 15:35:44 +0200 Subject: [PATCH 2/7] Move body logic to Auth, add sync_auth_flow, add NoAuth --- httpx/_auth.py | 43 ++++++++++++++++++++++++++++++++++++------- httpx/_client.py | 18 ++++++------------ 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/httpx/_auth.py b/httpx/_auth.py index ab2082a79a..40a0e0f199 100644 --- a/httpx/_auth.py +++ b/httpx/_auth.py @@ -18,11 +18,10 @@ class Auth: To implement a custom authentication scheme, subclass `Auth` and override the `.auth_flow()` method. - If the authentication scheme does I/O, such as disk access or network calls, or uses - synchronization primitives such as locks, you should override `.async_auth_flow()` - to provide an async-friendly implementation that will be used by the `AsyncClient`. - Usage of sync I/O within an async codebase would block the event loop, and could - cause performance issues. + If the authentication scheme does I/O such as disk access or network calls, or uses + synchronization primitives such as locks, you should override `.sync_auth_flow()` + and/or `.async_auth_flow()` instead of `.auth_flow()` to provide specialized + implementations that will be used by `Client` and `AsyncClient` respectively. """ requires_request_body = False @@ -52,6 +51,31 @@ def auth_flow(self, request: Request) -> typing.Generator[Request, Response, Non """ yield request + def sync_auth_flow( + self, request: Request + ) -> typing.Generator[Request, Response, None]: + """ + Execute the authentication flow synchronously. + + By default, this defers to `.auth_flow()`. You should override this method + when the authentication scheme does I/O and/or uses concurrency primitives. + """ + if self.requires_request_body: + request.read() + + flow = self.auth_flow(request) + request = next(flow) + + while True: + response = yield request + if self.requires_response_body: + response.read() + + try: + request = flow.send(response) + except StopIteration: + break + async def async_auth_flow( self, request: Request ) -> typing.AsyncGenerator[Request, Response]: @@ -59,14 +83,19 @@ async def async_auth_flow( Execute the authentication flow asynchronously. By default, this defers to `.auth_flow()`. You should override this method - when the authentication scheme does I/O, such as disk access or network calls, - or uses concurrency primitives such as locks. + when the authentication scheme does I/O and/or uses concurrency primitives. """ + if self.requires_request_body: + await request.aread() + flow = self.auth_flow(request) request = next(flow) while True: response = yield request + if self.requires_response_body: + await response.aread() + try: request = flow.send(response) except StopIteration: diff --git a/httpx/_client.py b/httpx/_client.py index 939d50ae39..4737401596 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -760,15 +760,12 @@ def _send_handling_auth( auth: Auth, timeout: Timeout, ) -> Response: - if auth.requires_request_body: - request.read() + auth_flow = auth.sync_auth_flow(request) + request = auth_flow.send(None) # type: ignore - auth_flow = auth.auth_flow(request) - request = next(auth_flow) while True: response = self._send_single_request(request, timeout) - if auth.requires_response_body: - response.read() + try: next_request = auth_flow.send(response) except StopIteration: @@ -1362,15 +1359,12 @@ async def _send_handling_auth( auth: Auth, timeout: Timeout, ) -> Response: - if auth.requires_request_body: - await request.aread() - auth_flow = auth.async_auth_flow(request) - request = await auth_flow.__anext__() + request = await auth_flow.asend(None) # type: ignore + while True: response = await self._send_single_request(request, timeout) - if auth.requires_response_body: - await response.aread() + try: next_request = await auth_flow.asend(response) except StopAsyncIteration: From d672659ed43860a1ccc8fc5339aa26e2f5047de8 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 29 Aug 2020 20:29:12 +0200 Subject: [PATCH 3/7] Update tests --- tests/client/test_auth.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 4533291c8e..e14d8b2229 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -192,11 +192,13 @@ class SyncOrAsyncAuth(Auth): sync and async cases. """ - def __init__(self): + def __init__(self) -> None: self._lock = threading.Lock() self._async_lock = asyncio.Lock() - def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: + def sync_auth_flow( + self, request: Request + ) -> typing.Generator[Request, Response, None]: with self._lock: request.headers["Authorization"] = "sync-auth" yield request @@ -669,20 +671,31 @@ def test_sync_auth_reads_response_body() -> None: @pytest.mark.asyncio -async def test_sync_async_auth() -> None: +async def test_async_auth() -> None: """ - Test that we can use a different auth flow implementation in the async case, to + Test that we can use an auth implementation specific to the async case, to support cases that require performing I/O or using concurrency primitives (such as checking a disk-based cache or fetching a token from a remote auth server). """ url = "https://example.org/" auth = SyncOrAsyncAuth() - client = AsyncClient(transport=AsyncMockTransport()) - response = await client.get(url, auth=auth) + async with AsyncClient(transport=AsyncMockTransport()) as client: + response = await client.get(url, auth=auth) + assert response.status_code == 200 assert response.json() == {"auth": "async-auth"} - response = Client(transport=SyncMockTransport()).get(url, auth=auth) + +def test_sync_auth() -> None: + """ + Test that we can use an auth implementation specific to the sync case. + """ + url = "https://example.org/" + auth = SyncOrAsyncAuth() + + with Client(transport=SyncMockTransport()) as client: + response = client.get(url, auth=auth) + assert response.status_code == 200 assert response.json() == {"auth": "sync-auth"} From 8965feaa46ed1186c0d3ed4f1ff7ef3fcbf53fb5 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 29 Aug 2020 20:30:37 +0200 Subject: [PATCH 4/7] Stick to next() / __anext__() --- httpx/_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/httpx/_client.py b/httpx/_client.py index 41eaf56303..eaa80fb3b3 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -761,7 +761,7 @@ def _send_handling_auth( timeout: Timeout, ) -> Response: auth_flow = auth.sync_auth_flow(request) - request = auth_flow.send(None) # type: ignore + request = next(auth_flow) while True: response = self._send_single_request(request, timeout) @@ -1367,7 +1367,7 @@ async def _send_handling_auth( timeout: Timeout, ) -> Response: auth_flow = auth.async_auth_flow(request) - request = await auth_flow.asend(None) # type: ignore + request = await auth_flow.__anext__() while True: response = await self._send_single_request(request, timeout) From 6321d8f2bbcf56c77a963a26fa785a94a4c7d1ca Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Wed, 2 Sep 2020 00:01:40 +0200 Subject: [PATCH 5/7] Fix undefined name errors --- tests/client/test_auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 9b26be5885..e8daf9353c 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -703,7 +703,7 @@ async def test_async_auth() -> None: url = "https://example.org/" auth = SyncOrAsyncAuth() - async with AsyncClient(transport=AsyncMockTransport()) as client: + async with httpx.AsyncClient(transport=AsyncMockTransport()) as client: response = await client.get(url, auth=auth) assert response.status_code == 200 @@ -717,7 +717,7 @@ def test_sync_auth() -> None: url = "https://example.org/" auth = SyncOrAsyncAuth() - with Client(transport=SyncMockTransport()) as client: + with httpx.Client(transport=SyncMockTransport()) as client: response = client.get(url, auth=auth) assert response.status_code == 200 From 1ace944cc6068392d4fc6f0aa1b54448beda7bb3 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Fri, 4 Sep 2020 23:53:03 +0200 Subject: [PATCH 6/7] Add docs --- docs/advanced.md | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/docs/advanced.md b/docs/advanced.md index b2a07df371..0f0b2ddf72 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -724,6 +724,55 @@ class MyCustomAuth(httpx.Auth): ... ``` +If you _do_ need to perform I/O other than HTTP requests, such as accessing a disk-based cache, or you need to use concurrency primitives, such as locks, then you should override `.sync_auth_flow()` and `.async_auth_flow()` (instead of `.auth_flow()`). The former will be used by `httpx.Client`, while the latter will be used by `httpx.AsyncClient`. + +```python +import asyncio +import threading +import httpx + + +class MyCustomAuth(httpx.Auth): + def __init__(self): + self._sync_lock = threading.RLock() + self._async_lock = asyncio.Lock() + + def sync_get_token(self): + with self._sync_lock: + ... + + def sync_auth_flow(self, request): + token = self.sync_get_token() + request.headers["Authorization"] = f"Token {token}" + yield request + + async def async_get_token(self): + async with self._async_lock: + ... + + async def async_auth_flow(self, request): + token = await self.async_get_token() + request.headers["Authorization"] = f"Token {token}" + yield request +``` + +If you only want to support one of the two methods, then you should still override it, but raise an explicit `RuntimeError`. + +```python +import httpx +import sync_only_library + + +class MyCustomAuth(httpx.Auth): + def sync_auth_flow(self, request): + token = sync_only_library.get_token(...) + request.headers["Authorization"] = f"Token {token}" + yield request + + async def async_auth_flow(self, request): + raise RuntimeError("Cannot use a sync authentication class with httpx.AsyncClient") +``` + ## SSL certificates When making a request over HTTPS, HTTPX needs to verify the identity of the requested host. To do this, it uses a bundle of SSL certificates (a.k.a. CA bundle) delivered by a trusted certificate authority (CA). From 5eec7a827cc45d8b94f18cb35a05ecb7af2e9c0d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 9 Sep 2020 14:23:07 +0100 Subject: [PATCH 7/7] Add unit tests for auth classes --- tests/client/test_auth.py | 5 ++++ tests/test_auth.py | 63 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/test_auth.py diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index e8daf9353c..c6c6d979ac 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,3 +1,8 @@ +""" +Integration tests for authentication. + +Unit tests for auth classes also exist in tests/test_auth.py +""" import asyncio import hashlib import os diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000000..20c666a88c --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,63 @@ +""" +Unit tests for auth classes. + +Integration tests also exist in tests/client/test_auth.py +""" +import pytest + +import httpx + + +def test_basic_auth(): + auth = httpx.BasicAuth(username="user", password="pass") + request = httpx.Request("GET", "https://www.example.com") + + # The initial request should include a basic auth header. + flow = auth.sync_auth_flow(request) + request = next(flow) + assert request.headers["Authorization"].startswith("Basic") + + # No other requests are made. + response = httpx.Response(content=b"Hello, world!", status_code=200) + with pytest.raises(StopIteration): + flow.send(response) + + +def test_digest_auth_with_200(): + auth = httpx.DigestAuth(username="user", password="pass") + request = httpx.Request("GET", "https://www.example.com") + + # The initial request should not include an auth header. + flow = auth.sync_auth_flow(request) + request = next(flow) + assert "Authorization" not in request.headers + + # If a 200 response is returned, then no other requests are made. + response = httpx.Response(content=b"Hello, world!", status_code=200) + with pytest.raises(StopIteration): + flow.send(response) + + +def test_digest_auth_with_401(): + auth = httpx.DigestAuth(username="user", password="pass") + request = httpx.Request("GET", "https://www.example.com") + + # The initial request should not include an auth header. + flow = auth.sync_auth_flow(request) + request = next(flow) + assert "Authorization" not in request.headers + + # If a 401 response is returned, then a digest auth request is made. + headers = { + "WWW-Authenticate": 'Digest realm="...", qop="auth", nonce="...", opaque="..."' + } + response = httpx.Response( + content=b"Auth required", status_code=401, headers=headers + ) + request = flow.send(response) + assert request.headers["Authorization"].startswith("Digest") + + # No other requests are made. + response = httpx.Response(content=b"Hello, world!", status_code=200) + with pytest.raises(StopIteration): + flow.send(response)