From bad200e799663deda7c007a16a4a57463680355a Mon Sep 17 00:00:00 2001 From: Ber Zoidberg Date: Wed, 9 Sep 2020 21:59:31 -0700 Subject: [PATCH 1/6] py3 is no longer just async --- .github/workflows/python.yml | 6 +++--- tox.ini | 13 +++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 42de65de..4db46eb9 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -24,11 +24,11 @@ jobs: - version: 2.7 toxenv: py27,py27-flask - version: 3.6 - toxenv: py36,flask,django,async + toxenv: py36,flask,django,py3 - version: 3.7 - toxenv: py37,flask,django,async + toxenv: py37,flask,django,py3 - version: 3.8 - toxenv: py38,flask,django,async + toxenv: py38,flask,django,py3 steps: - uses: actions/checkout@v2 diff --git a/tox.ini b/tox.ini index 587d38bc..bd37557e 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = py{27,36,37,38} - {py36,py37,py38}-async + {py36,py37,py38} {py27,py36,py37,py38}-flask {py36,py37,py38}-django coverage @@ -12,10 +12,11 @@ deps = py27: unittest2 flask: Flask flask: Flask-SQLAlchemy - async: httpx==0.14.1 - async: pytest-asyncio - async: starlette - async: itsdangerous + py3: httpx==0.14.1 + py3: pytest-asyncio + py3: starlette + py3: itsdangerous + py3: werkzeug django: Django django: pytest-django @@ -23,7 +24,7 @@ setenv = TESTPATH=tests/core RCFILE=setup.cfg py27: RCFILE=.py27conf - async: TESTPATH=tests/py3 + py3: TESTPATH=tests/py3 flask: TESTPATH=tests/flask django: TESTPATH=tests/django commands = From 12496deda44b8f048efb7bf4c014b5a5ad83553d Mon Sep 17 00:00:00 2001 From: Ber Zoidberg Date: Wed, 9 Sep 2020 22:01:04 -0700 Subject: [PATCH 2/6] add support for httpx sync APIs; fix bug in httpx oauth1 support where content-length header si not specific --- authlib/integrations/httpx_client/__init__.py | 6 +- .../httpx_client/assertion_client.py | 46 ++- .../httpx_client/oauth1_client.py | 55 ++- .../httpx_client/oauth2_client.py | 100 ++++- .../test_assertion_client.py | 64 +++ .../test_async_assertion_client.py | 6 +- .../test_async_oauth1_client.py | 16 +- .../test_async_oauth2_client.py | 32 +- .../test_httpx_client/test_oauth1_client.py | 157 ++++++++ .../test_httpx_client/test_oauth2_client.py | 374 ++++++++++++++++++ .../test_oauth_client.py | 16 +- .../test_starlette_client/test_user_mixin.py | 6 +- tests/py3/utils.py | 75 +++- 13 files changed, 901 insertions(+), 52 deletions(-) create mode 100644 tests/py3/test_httpx_client/test_assertion_client.py create mode 100644 tests/py3/test_httpx_client/test_oauth1_client.py create mode 100644 tests/py3/test_httpx_client/test_oauth2_client.py diff --git a/authlib/integrations/httpx_client/__init__.py b/authlib/integrations/httpx_client/__init__.py index 21d641e8..6b4b9d67 100644 --- a/authlib/integrations/httpx_client/__init__.py +++ b/authlib/integrations/httpx_client/__init__.py @@ -6,12 +6,12 @@ SIGNATURE_TYPE_QUERY, SIGNATURE_TYPE_BODY, ) -from .oauth1_client import OAuth1Auth, AsyncOAuth1Client +from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client from .oauth2_client import ( - OAuth2Auth, OAuth2ClientAuth, + OAuth2Auth, OAuth2Client, OAuth2ClientAuth, AsyncOAuth2Client, ) -from .assertion_client import AsyncAssertionClient +from .assertion_client import AssertionClient, AsyncAssertionClient from ..base_client import OAuthError diff --git a/authlib/integrations/httpx_client/assertion_client.py b/authlib/integrations/httpx_client/assertion_client.py index 17228d32..7e5a75be 100644 --- a/authlib/integrations/httpx_client/assertion_client.py +++ b/authlib/integrations/httpx_client/assertion_client.py @@ -1,4 +1,4 @@ -from httpx import AsyncClient +from httpx import AsyncClient, Client from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient from authlib.oauth2.rfc7523 import JWTBearerGrant from authlib.oauth2 import OAuth2Error @@ -51,3 +51,47 @@ async def _refresh_token(self, data): ) self.token = token return self.token + +class AssertionClient(_AssertionClient, Client): + token_auth_class = OAuth2Auth + JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE + ASSERTION_METHODS = { + JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign, + } + DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE + + def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, + claims=None, token_placement='header', scope=None, **kwargs): + + client_kwargs = extract_client_kwargs(kwargs) + Client.__init__(self, **client_kwargs) + + _AssertionClient.__init__( + self, session=None, + token_endpoint=token_endpoint, issuer=issuer, subject=subject, + audience=audience, grant_type=grant_type, claims=claims, + token_placement=token_placement, scope=scope, **kwargs + ) + + def request(self, method, url, withhold_token=False, auth=None, **kwargs): + """Send request with auto refresh token feature.""" + if not withhold_token and auth is None: + if not self.token or self.token.is_expired(): + self.refresh_token() + + auth = self.token_auth + return super(AssertionClient, self).request( + method, url, auth=auth, **kwargs) + + def _refresh_token(self, data): + resp = self.request( + 'POST', self.token_endpoint, data=data, withhold_token=True) + + token = resp.json() + if 'error' in token: + raise OAuth2Error( + error=token['error'], + description=token.get('error_description') + ) + self.token = token + return self.token diff --git a/authlib/integrations/httpx_client/oauth1_client.py b/authlib/integrations/httpx_client/oauth1_client.py index 7bb1ccb0..13483ec4 100644 --- a/authlib/integrations/httpx_client/oauth1_client.py +++ b/authlib/integrations/httpx_client/oauth1_client.py @@ -1,5 +1,5 @@ import typing -from httpx import AsyncClient, Auth, Request, Response +from httpx import AsyncClient, Auth, Client, Request, Response from authlib.oauth1 import ( SIGNATURE_HMAC_SHA1, SIGNATURE_TYPE_HEADER, @@ -18,6 +18,7 @@ class OAuth1Auth(Auth, ClientAuth): def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: url, headers, body = self.prepare( request.method, str(request.url), request.headers, request.content) + headers['Content-Length'] = str(len(body)) yield Request(method=request.method, url=url, headers=headers, data=body) @@ -72,3 +73,55 @@ async def _fetch_token(self, url, **kwargs): @staticmethod def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) + +class OAuth1Client(_OAuth1Client, Client): + auth_class = OAuth1Auth + + def __init__(self, client_id, client_secret=None, + token=None, token_secret=None, + redirect_uri=None, rsa_key=None, verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, **kwargs): + + _client_kwargs = extract_client_kwargs(kwargs) + Client.__init__(self, **_client_kwargs) + + _OAuth1Client.__init__( + self, None, + client_id=client_id, client_secret=client_secret, + token=token, token_secret=token_secret, + redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, + signature_method=signature_method, signature_type=signature_type, + force_include_body=force_include_body, **kwargs) + + def fetch_access_token(self, url, verifier=None, **kwargs): + """Method for fetching an access token from the token endpoint. + + This is the final step in the OAuth 1 workflow. An access token is + obtained using all previously obtained credentials, including the + verifier from the authorization step. + + :param url: Access Token endpoint. + :param verifier: A verifier string to prove authorization was granted. + :param kwargs: Extra parameters to include for fetching access token. + :return: A token dict. + """ + if verifier: + self.auth.verifier = verifier + if not self.auth.verifier: + self.handle_error('missing_verifier', 'Missing "verifier" value') + token = self._fetch_token(url, **kwargs) + self.auth.verifier = None + return token + + def _fetch_token(self, url, **kwargs): + resp = self.post(url, **kwargs) + text = resp.read() + token = self.parse_response_token(resp.status_code, to_unicode(text)) + self.token = token + return token + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) diff --git a/authlib/integrations/httpx_client/oauth2_client.py b/authlib/integrations/httpx_client/oauth2_client.py index 9d88f127..83277abf 100644 --- a/authlib/integrations/httpx_client/oauth2_client.py +++ b/authlib/integrations/httpx_client/oauth2_client.py @@ -1,5 +1,5 @@ import typing -from httpx import AsyncClient, Auth, Request, Response +from httpx import AsyncClient, Auth, Client, Request, Response from authlib.common.urls import url_decode from authlib.oauth2.client import OAuth2Client as _OAuth2Client from authlib.oauth2.auth import ClientAuth, TokenAuth @@ -137,3 +137,101 @@ def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): return self.post( url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs) + +class OAuth2Client(_OAuth2Client, Client): + SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS + + client_auth_class = OAuth2ClientAuth + token_auth_class = OAuth2Auth + + def __init__(self, client_id=None, client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, redirect_uri=None, + token=None, token_placement='header', + update_token=None, **kwargs): + + # extract httpx.Client kwargs + client_kwargs = self._extract_session_request_params(kwargs) + Client.__init__(self, **client_kwargs) + + _OAuth2Client.__init__( + self, session=None, + client_id=client_id, client_secret=client_secret, + token_endpoint_auth_method=token_endpoint_auth_method, + revocation_endpoint_auth_method=revocation_endpoint_auth_method, + scope=scope, redirect_uri=redirect_uri, + token=token, token_placement=token_placement, + update_token=update_token, **kwargs + ) + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) + + def request(self, method, url, withhold_token=False, auth=None, **kwargs): + if not withhold_token and auth is None: + if not self.token: + raise MissingTokenError() + + if self.token.is_expired(): + self.ensure_active_token() + + auth = self.token_auth + + return super(OAuth2Client, self).request( + method, url, auth=auth, **kwargs) + + def ensure_active_token(self): + refresh_token = self.token.get('refresh_token') + url = self.metadata.get('token_endpoint') + if refresh_token and url: + self.refresh_token(url, refresh_token=refresh_token) + elif self.metadata.get('grant_type') == 'client_credentials': + access_token = self.token['access_token'] + token = self.fetch_token(url, grant_type='client_credentials') + if self.update_token: + self.update_token(token, access_token=access_token) + else: + raise InvalidTokenError() + + def _fetch_token(self, url, body='', headers=None, auth=None, + method='POST', **kwargs): + if method.upper() == 'POST': + resp = self.post( + url, data=dict(url_decode(body)), headers=headers, + auth=auth, **kwargs) + else: + if '?' in url: + url = '&'.join([url, body]) + else: + url = '?'.join([url, body]) + resp = self.get(url, headers=headers, auth=auth, **kwargs) + + for hook in self.compliance_hook['access_token_response']: + resp = hook(resp) + + return self.parse_response_token(resp.json()) + + def _refresh_token(self, url, refresh_token=None, body='', + headers=None, auth=None, **kwargs): + resp = self.post( + url, data=dict(url_decode(body)), headers=headers, + auth=auth, **kwargs) + + for hook in self.compliance_hook['refresh_token_response']: + resp = hook(resp) + + token = self.parse_response_token(resp.json()) + if 'refresh_token' not in token: + self.token['refresh_token'] = refresh_token + + if self.update_token: + self.update_token(self.token, refresh_token=refresh_token) + + return self.token + + def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): + return self.post( + url, data=dict(url_decode(body)), + headers=headers, auth=auth, **kwargs) diff --git a/tests/py3/test_httpx_client/test_assertion_client.py b/tests/py3/test_httpx_client/test_assertion_client.py new file mode 100644 index 00000000..1d01cdb5 --- /dev/null +++ b/tests/py3/test_httpx_client/test_assertion_client.py @@ -0,0 +1,64 @@ +import time +import pytest +from authlib.integrations.httpx_client import AssertionClient +from tests.py3.utils import MockDispatch + + +default_token = { + 'token_type': 'Bearer', + 'access_token': 'a', + 'refresh_token': 'b', + 'expires_in': '3600', + 'expires_at': int(time.time()) + 3600, +} + + +@pytest.mark.asyncio +def test_refresh_token(): + def verifier(request): + content = request.form + if str(request.url) == 'https://i.b/token': + assert 'assertion' in content + + with AssertionClient( + 'https://i.b/token', + grant_type=AssertionClient.JWT_BEARER_GRANT_TYPE, + issuer='foo', + subject='foo', + audience='foo', + alg='HS256', + key='secret', + app=MockDispatch(default_token, assert_func=verifier) + ) as client: + client.get('https://i.b') + + # trigger more case + now = int(time.time()) + with AssertionClient( + 'https://i.b/token', + issuer='foo', + subject=None, + audience='foo', + issued_at=now, + expires_at=now + 3600, + header={'alg': 'HS256'}, + key='secret', + scope='email', + claims={'test_mode': 'true'}, + app=MockDispatch(default_token, assert_func=verifier) + ) as client: + client.get('https://i.b') + client.get('https://i.b') + + +@pytest.mark.asyncio +def test_without_alg(): + with AssertionClient( + 'https://i.b/token', + issuer='foo', + subject='foo', + audience='foo', + key='secret', + ) as client: + with pytest.raises(ValueError): + client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_assertion_client.py b/tests/py3/test_httpx_client/test_async_assertion_client.py index fa51e14a..fb25d9aa 100644 --- a/tests/py3/test_httpx_client/test_async_assertion_client.py +++ b/tests/py3/test_httpx_client/test_async_assertion_client.py @@ -1,7 +1,7 @@ import time import pytest from authlib.integrations.httpx_client import AsyncAssertionClient -from tests.py3.utils import MockDispatch +from tests.py3.utils import AsyncMockDispatch default_token = { @@ -28,7 +28,7 @@ async def verifier(request): audience='foo', alg='HS256', key='secret', - app=MockDispatch(default_token, assert_func=verifier) + app=AsyncMockDispatch(default_token, assert_func=verifier) ) as client: await client.get('https://i.b') @@ -45,7 +45,7 @@ async def verifier(request): key='secret', scope='email', claims={'test_mode': 'true'}, - app=MockDispatch(default_token, assert_func=verifier) + app=AsyncMockDispatch(default_token, assert_func=verifier) ) as client: await client.get('https://i.b') await client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_oauth1_client.py b/tests/py3/test_httpx_client/test_async_oauth1_client.py index bf8542d8..75703567 100644 --- a/tests/py3/test_httpx_client/test_async_oauth1_client.py +++ b/tests/py3/test_httpx_client/test_async_oauth1_client.py @@ -5,7 +5,7 @@ SIGNATURE_TYPE_BODY, SIGNATURE_TYPE_QUERY, ) -from tests.py3.utils import MockDispatch +from tests.py3.utils import AsyncMockDispatch oauth_url = 'https://example.com/oauth' @@ -19,7 +19,7 @@ async def assert_func(request): assert 'oauth_consumer_key="id"' in auth_header assert 'oauth_signature=' in auth_header - app = MockDispatch(request_token, assert_func=assert_func) + app = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client('id', 'secret', app=app) as client: response = await client.fetch_request_token(oauth_url) @@ -38,7 +38,7 @@ async def assert_func(request): assert b'oauth_consumer_key=id' in content assert b'&oauth_signature=' in content - mock_response = MockDispatch(request_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', signature_type=SIGNATURE_TYPE_BODY, @@ -61,7 +61,7 @@ async def assert_func(request): assert 'oauth_consumer_key=id' in url assert '&oauth_signature=' in url - mock_response = MockDispatch(request_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', signature_type=SIGNATURE_TYPE_QUERY, @@ -83,7 +83,7 @@ async def assert_func(request): assert 'oauth_consumer_key="id"' in auth_header assert 'oauth_signature=' in auth_header - mock_response = MockDispatch(request_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', app=mock_response, @@ -98,7 +98,7 @@ async def assert_func(request): @pytest.mark.asyncio async def test_get_via_header(): - mock_response = MockDispatch(b'hello') + mock_response = AsyncMockDispatch(b'hello') async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', app=mock_response, @@ -121,7 +121,7 @@ async def assert_func(request): assert b'oauth_consumer_key=id' in content assert b'oauth_signature=' in content - mock_response = MockDispatch(b'hello', assert_func=assert_func) + mock_response = AsyncMockDispatch(b'hello', assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', signature_type=SIGNATURE_TYPE_BODY, @@ -138,7 +138,7 @@ async def assert_func(request): @pytest.mark.asyncio async def test_get_via_query(): - mock_response = MockDispatch(b'hello') + mock_response = AsyncMockDispatch(b'hello') async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', signature_type=SIGNATURE_TYPE_QUERY, diff --git a/tests/py3/test_httpx_client/test_async_oauth2_client.py b/tests/py3/test_httpx_client/test_async_oauth2_client.py index 54b3a693..3c2613df 100644 --- a/tests/py3/test_httpx_client/test_async_oauth2_client.py +++ b/tests/py3/test_httpx_client/test_async_oauth2_client.py @@ -8,7 +8,7 @@ OAuthError, AsyncOAuth2Client, ) -from tests.py3.utils import MockDispatch +from tests.py3.utils import AsyncMockDispatch default_token = { @@ -27,7 +27,7 @@ async def assert_func(request): auth_header = request.headers.get('authorization') assert auth_header == token - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', token=default_token, @@ -45,7 +45,7 @@ async def assert_func(request): content = await request.body() assert default_token['access_token'] in content.decode() - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', token=default_token, @@ -63,7 +63,7 @@ async def test_add_token_to_uri(): async def assert_func(request): assert default_token['access_token'] in str(request.url) - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', token=default_token, @@ -123,7 +123,7 @@ async def assert_func(request): assert 'client_id=' in content assert 'grant_type=authorization_code' in content - mock_response = MockDispatch(default_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', app=mock_response) as client: token = await client.fetch_token(url, authorization_response='https://i.b/?code=v') assert token == default_token @@ -136,7 +136,7 @@ async def assert_func(request): token = await client.fetch_token(url, code='v') assert token == default_token - mock_response = MockDispatch({'error': 'invalid_request'}) + mock_response = AsyncMockDispatch({'error': 'invalid_request'}) async with AsyncOAuth2Client('foo', app=mock_response) as client: with pytest.raises(OAuthError): await client.fetch_token(url) @@ -152,7 +152,7 @@ async def assert_func(request): assert 'client_id=' in url assert 'grant_type=authorization_code' in url - mock_response = MockDispatch(default_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', app=mock_response) as client: authorization_response = 'https://i.b/?code=v' token = await client.fetch_token( @@ -183,7 +183,7 @@ async def assert_func(request): assert 'client_secret=bar' in content assert 'grant_type=authorization_code' in content - mock_response = MockDispatch(default_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', 'bar', token_endpoint_auth_method='client_secret_post', @@ -203,7 +203,7 @@ def _access_token_response_hook(resp): return resp access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client('foo', token=default_token, app=app) as sess: sess.register_compliance_hook( 'access_token_response', @@ -224,7 +224,7 @@ async def assert_func(request): assert 'scope=profile' in content assert 'grant_type=password' in content - app = MockDispatch(default_token, assert_func=assert_func) + app = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', scope='profile', app=app) as sess: token = await sess.fetch_token(url, username='v', password='v') assert token == default_token @@ -244,7 +244,7 @@ async def assert_func(request): assert 'scope=profile' in content assert 'grant_type=client_credentials' in content - app = MockDispatch(default_token, assert_func=assert_func) + app = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', scope='profile', app=app) as sess: token = await sess.fetch_token(url) assert token == default_token @@ -262,7 +262,7 @@ async def test_cleans_previous_token_before_fetching_new_one(): new_token['expires_at'] = now + 3600 url = 'https://example.com/token' - app = MockDispatch(new_token) + app = AsyncMockDispatch(new_token) with mock.patch('time.time', lambda: now): async with AsyncOAuth2Client('foo', token=default_token, app=app) as sess: assert await sess.fetch_token(url) == new_token @@ -288,7 +288,7 @@ async def _update_token(token, refresh_token=None, access_token=None): token_type='bearer', expires_at=100 ) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client( 'foo', token=old_token, token_endpoint='https://i.b/token', update_token=update_token, app=app @@ -324,7 +324,7 @@ async def _update_token(token, refresh_token=None, access_token=None): expires_at=100 ) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client( 'foo', token=old_token, @@ -358,7 +358,7 @@ async def _update_token(token, refresh_token=None, access_token=None): expires_at=100 ) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client( 'foo', token=old_token, token_endpoint='https://i.b/token', @@ -372,7 +372,7 @@ async def _update_token(token, refresh_token=None, access_token=None): @pytest.mark.asyncio async def test_revoke_token(): answer = {'status': 'ok'} - app = MockDispatch(answer) + app = AsyncMockDispatch(answer) async with AsyncOAuth2Client('a', app=app) as sess: resp = await sess.revoke_token('https://i.b/token', 'hi') diff --git a/tests/py3/test_httpx_client/test_oauth1_client.py b/tests/py3/test_httpx_client/test_oauth1_client.py new file mode 100644 index 00000000..a5f34df3 --- /dev/null +++ b/tests/py3/test_httpx_client/test_oauth1_client.py @@ -0,0 +1,157 @@ +import pytest +from authlib.integrations.httpx_client import ( + OAuthError, + OAuth1Client, + SIGNATURE_TYPE_BODY, + SIGNATURE_TYPE_QUERY, +) +from tests.py3.utils import MockDispatch + +oauth_url = 'https://example.com/oauth' + + +@pytest.mark.asyncio +def test_fetch_request_token_via_header(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert 'oauth_consumer_key="id"' in auth_header + assert 'oauth_signature=' in auth_header + + app = MockDispatch(request_token, assert_func=assert_func) + with OAuth1Client('id', 'secret', app=app) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +def test_fetch_request_token_via_body(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert auth_header is None + + content = request.form + assert content.get('oauth_consumer_key') == 'id' + assert 'oauth_signature' in content + + mock_response = MockDispatch(request_token, assert_func=assert_func) + + with OAuth1Client( + 'id', 'secret', signature_type=SIGNATURE_TYPE_BODY, + app=mock_response, + ) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +def test_fetch_request_token_via_query(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert auth_header is None + + url = str(request.url) + assert 'oauth_consumer_key=id' in url + assert '&oauth_signature=' in url + + mock_response = MockDispatch(request_token, assert_func=assert_func) + + with OAuth1Client( + 'id', 'secret', signature_type=SIGNATURE_TYPE_QUERY, + app=mock_response, + ) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +def test_fetch_access_token(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert 'oauth_verifier="d"' in auth_header + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert 'oauth_signature=' in auth_header + + mock_response = MockDispatch(request_token, assert_func=assert_func) + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + app=mock_response, + ) as client: + with pytest.raises(OAuthError): + client.fetch_access_token(oauth_url) + + response = client.fetch_access_token(oauth_url, verifier='d') + + assert response == request_token + + +@pytest.mark.asyncio +def test_get_via_header(): + mock_response = MockDispatch(b'hello') + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + app=mock_response, + ) as client: + response = client.get('https://example.com/') + + assert response.content == b'hello' + request = response.request + auth_header = request.headers.get('authorization') + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert 'oauth_signature=' in auth_header + + +@pytest.mark.asyncio +def test_get_via_body(): + def assert_func(request): + content = request.form + assert content.get('oauth_token') == 'foo' + assert content.get('oauth_consumer_key') == 'id' + assert 'oauth_signature' in content + + mock_response = MockDispatch(b'hello', assert_func=assert_func) + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + signature_type=SIGNATURE_TYPE_BODY, + app=mock_response, + ) as client: + response = client.post('https://example.com/') + + assert response.content == b'hello' + + request = response.request + auth_header = request.headers.get('authorization') + assert auth_header is None + + +@pytest.mark.asyncio +def test_get_via_query(): + mock_response = MockDispatch(b'hello') + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + signature_type=SIGNATURE_TYPE_QUERY, + app=mock_response, + ) as client: + response = client.get('https://example.com/') + + assert response.content == b'hello' + request = response.request + auth_header = request.headers.get('authorization') + assert auth_header is None + + url = str(request.url) + assert 'oauth_token=foo' in url + assert 'oauth_consumer_key=id' in url + assert 'oauth_signature=' in url diff --git a/tests/py3/test_httpx_client/test_oauth2_client.py b/tests/py3/test_httpx_client/test_oauth2_client.py new file mode 100644 index 00000000..e5bc6b35 --- /dev/null +++ b/tests/py3/test_httpx_client/test_oauth2_client.py @@ -0,0 +1,374 @@ +import mock +import time +import pytest +from copy import deepcopy +from authlib.common.security import generate_token +from authlib.common.urls import url_encode +from authlib.integrations.httpx_client import ( + OAuthError, + OAuth2Client, +) +from tests.py3.utils import MockDispatch + + +default_token = { + 'token_type': 'Bearer', + 'access_token': 'a', + 'refresh_token': 'b', + 'expires_in': '3600', + 'expires_at': int(time.time()) + 3600, +} + + +def test_add_token_to_header(): + def assert_func(request): + token = 'Bearer ' + default_token['access_token'] + auth_header = request.headers.get('authorization') + assert auth_header == token + + mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + with OAuth2Client( + 'foo', + token=default_token, + app=mock_response + ) as client: + resp = client.get('https://i.b') + + data = resp.json() + assert data['a'] == 'a' + + +def test_add_token_to_body(): + def assert_func(request): + content = request.data + content = content.decode() + assert content == 'access_token=%s' % default_token['access_token'] + + mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + with OAuth2Client( + 'foo', + token=default_token, + token_placement='body', + app=mock_response + ) as client: + resp = client.get('https://i.b') + + data = resp.json() + assert data['a'] == 'a' + + +def test_add_token_to_uri(): + def assert_func(request): + assert default_token['access_token'] in str(request.url) + + mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + with OAuth2Client( + 'foo', + token=default_token, + token_placement='uri', + app=mock_response + ) as client: + resp = client.get('https://i.b') + + data = resp.json() + assert data['a'] == 'a' + + +def test_create_authorization_url(): + url = 'https://example.com/authorize?foo=bar' + + sess = OAuth2Client(client_id='foo') + auth_url, state = sess.create_authorization_url(url) + assert state in auth_url + assert 'client_id=foo' in auth_url + assert 'response_type=code' in auth_url + + sess = OAuth2Client(client_id='foo', prompt='none') + auth_url, state = sess.create_authorization_url( + url, state='foo', redirect_uri='https://i.b', scope='profile') + assert state == 'foo' + assert 'i.b' in auth_url + assert 'profile' in auth_url + assert 'prompt=none' in auth_url + + +def test_code_challenge(): + sess = OAuth2Client('foo', code_challenge_method='S256') + + url = 'https://example.com/authorize' + auth_url, _ = sess.create_authorization_url( + url, code_verifier=generate_token(48)) + assert 'code_challenge=' in auth_url + assert 'code_challenge_method=S256' in auth_url + + +def test_token_from_fragment(): + sess = OAuth2Client('foo') + response_url = 'https://i.b/callback#' + url_encode(default_token.items()) + assert sess.token_from_fragment(response_url) == default_token + token = sess.fetch_token(authorization_response=response_url) + assert token == default_token + + +def test_fetch_token_post(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('code') == 'v' + assert content.get('client_id') == 'foo' + assert content.get('grant_type') == 'authorization_code' + + mock_response = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', app=mock_response) as client: + token = client.fetch_token(url, authorization_response='https://i.b/?code=v') + assert token == default_token + + with OAuth2Client( + 'foo', + token_endpoint_auth_method='none', + app=mock_response + ) as client: + token = client.fetch_token(url, code='v') + assert token == default_token + + mock_response = MockDispatch({'error': 'invalid_request'}) + with OAuth2Client('foo', app=mock_response) as client: + with pytest.raises(OAuthError): + client.fetch_token(url) + + +def test_fetch_token_get(): + url = 'https://example.com/token' + + def assert_func(request): + url = str(request.url) + assert 'code=v' in url + assert 'client_id=' in url + assert 'grant_type=authorization_code' in url + + mock_response = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', app=mock_response) as client: + authorization_response = 'https://i.b/?code=v' + token = client.fetch_token( + url, authorization_response=authorization_response, method='GET') + assert token == default_token + + with OAuth2Client( + 'foo', + token_endpoint_auth_method='none', + app=mock_response + ) as client: + token = client.fetch_token(url, code='v', method='GET') + assert token == default_token + + token = client.fetch_token(url + '?q=a', code='v', method='GET') + assert token == default_token + + +def test_token_auth_method_client_secret_post(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('code') == 'v' + assert content.get('client_id') == 'foo' + assert content.get('client_secret') == 'bar' + assert content.get('grant_type') == 'authorization_code' + + mock_response = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client( + 'foo', 'bar', + token_endpoint_auth_method='client_secret_post', + app=mock_response + ) as client: + token = client.fetch_token(url, code='v') + + assert token == default_token + + +def test_access_token_response_hook(): + url = 'https://example.com/token' + + def _access_token_response_hook(resp): + assert resp.json() == default_token + return resp + + access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) + app = MockDispatch(default_token) + with OAuth2Client('foo', token=default_token, app=app) as sess: + sess.register_compliance_hook( + 'access_token_response', + access_token_response_hook + ) + assert sess.fetch_token(url) == default_token + assert access_token_response_hook.called is True + + +def test_password_grant_type(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('username') == 'v' + assert content.get('scope') == 'profile' + assert content.get('grant_type') == 'password' + + app = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', scope='profile', app=app) as sess: + token = sess.fetch_token(url, username='v', password='v') + assert token == default_token + + token = sess.fetch_token( + url, username='v', password='v', grant_type='password') + assert token == default_token + + +def test_client_credentials_type(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('scope') == 'profile' + assert content.get('grant_type') == 'client_credentials' + + app = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', scope='profile', app=app) as sess: + token = sess.fetch_token(url) + assert token == default_token + + token = sess.fetch_token(url, grant_type='client_credentials') + assert token == default_token + + +def test_cleans_previous_token_before_fetching_new_one(): + now = int(time.time()) + new_token = deepcopy(default_token) + past = now - 7200 + default_token['expires_at'] = past + new_token['expires_at'] = now + 3600 + url = 'https://example.com/token' + + app = MockDispatch(new_token) + with mock.patch('time.time', lambda: now): + with OAuth2Client('foo', token=default_token, app=app) as sess: + assert sess.fetch_token(url) == new_token + + +def test_token_status(): + token = dict(access_token='a', token_type='bearer', expires_at=100) + sess = OAuth2Client('foo', token=token) + assert sess.token.is_expired() is True + + +def test_auto_refresh_token(): + + def _update_token(token, refresh_token=None, access_token=None): + assert refresh_token == 'b' + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token='a', refresh_token='b', + token_type='bearer', expires_at=100 + ) + + app = MockDispatch(default_token) + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, app=app + ) as sess: + sess.get('https://i.b/user') + assert update_token.called is True + + old_token = dict( + access_token='a', + token_type='bearer', + expires_at=100 + ) + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, app=app + ) as sess: + with pytest.raises(OAuthError): + sess.get('https://i.b/user') + + +def test_auto_refresh_token2(): + + def _update_token(token, refresh_token=None, access_token=None): + assert access_token == 'a' + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token='a', + token_type='bearer', + expires_at=100 + ) + + app = MockDispatch(default_token) + + with OAuth2Client( + 'foo', token=old_token, + token_endpoint='https://i.b/token', + grant_type='client_credentials', + app=app, + ) as client: + client.get('https://i.b/user') + assert update_token.called is False + + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, grant_type='client_credentials', + app=app, + ) as client: + client.get('https://i.b/user') + assert update_token.called is True + + +def test_auto_refresh_token3(): + def _update_token(token, refresh_token=None, access_token=None): + assert access_token == 'a' + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token='a', + token_type='bearer', + expires_at=100 + ) + + app = MockDispatch(default_token) + + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, grant_type='client_credentials', + app=app, + ) as client: + client.post('https://i.b/user', json={'foo': 'bar'}) + assert update_token.called is True + + +def test_revoke_token(): + answer = {'status': 'ok'} + app = MockDispatch(answer) + + with OAuth2Client('a', app=app) as sess: + resp = sess.revoke_token('https://i.b/token', 'hi') + assert resp.json() == answer + + resp = sess.revoke_token( + 'https://i.b/token', 'hi', + token_type_hint='access_token' + ) + assert resp.json() == answer + + +def test_request_without_token(): + with OAuth2Client('a') as client: + with pytest.raises(OAuthError): + client.get('https://i.b/token') diff --git a/tests/py3/test_starlette_client/test_oauth_client.py b/tests/py3/test_starlette_client/test_oauth_client.py index 4722e4ed..68654fc8 100644 --- a/tests/py3/test_starlette_client/test_oauth_client.py +++ b/tests/py3/test_starlette_client/test_oauth_client.py @@ -2,7 +2,7 @@ from starlette.config import Config from starlette.requests import Request from authlib.integrations.starlette_client import OAuth -from tests.py3.utils import PathMapDispatch +from tests.py3.utils import AsyncPathMapDispatch from tests.client_base import get_bearer_token @@ -39,7 +39,7 @@ def test_register_with_overwrite(): @pytest.mark.asyncio async def test_oauth1_authorize(): oauth = OAuth() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/request-token': {'body': 'oauth_token=foo&oauth_verifier=baz'}, '/token': {'body': 'oauth_token=a&oauth_token_secret=b'}, }) @@ -74,7 +74,7 @@ async def test_oauth1_authorize(): @pytest.mark.asyncio async def test_oauth2_authorize(): oauth = OAuth() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/token': {'body': get_bearer_token()} }) client = oauth.register( @@ -113,7 +113,7 @@ async def test_oauth2_authorize(): @pytest.mark.asyncio async def test_oauth2_authorize_code_challenge(): - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/token': {'body': get_bearer_token()} }) oauth = OAuth() @@ -163,7 +163,7 @@ async def test_with_fetch_token_in_register(): async def fetch_token(request): return {'access_token': 'dev', 'token_type': 'bearer'} - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/user': {'body': {'sub': '123'}} }) oauth = OAuth() @@ -191,7 +191,7 @@ async def test_with_fetch_token_in_oauth(): async def fetch_token(name, request): return {'access_token': 'dev', 'token_type': 'bearer'} - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/user': {'body': {'sub': '123'}} }) oauth = OAuth(fetch_token=fetch_token) @@ -216,7 +216,7 @@ async def fetch_token(name, request): @pytest.mark.asyncio async def test_request_withhold_token(): oauth = OAuth() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/user': {'body': {'sub': '123'}} }) client = oauth.register( @@ -252,7 +252,7 @@ async def test_oauth2_authorize_with_metadata(): await client.create_authorization_url(req) - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/.well-known/openid-configuration': {'body': { 'authorization_endpoint': 'https://i.b/authorize' }} diff --git a/tests/py3/test_starlette_client/test_user_mixin.py b/tests/py3/test_starlette_client/test_user_mixin.py index 2e348015..f9e32b56 100644 --- a/tests/py3/test_starlette_client/test_user_mixin.py +++ b/tests/py3/test_starlette_client/test_user_mixin.py @@ -5,7 +5,7 @@ from authlib.jose.errors import InvalidClaimError from authlib.oidc.core.grants.util import generate_id_token from tests.util import read_file_path -from tests.py3.utils import PathMapDispatch +from tests.py3.utils import AsyncPathMapDispatch from tests.client_base import get_bearer_token @@ -15,7 +15,7 @@ async def run_fetch_userinfo(payload, compliance_fix=None): async def fetch_token(request): return get_bearer_token() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/userinfo': {'body': payload} }) @@ -125,7 +125,7 @@ async def test_force_fetch_jwks_uri(): aud='dev', exp=3600, nonce='n', ) - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/jwks': {'body': read_file_path('jwks_public.json')} }) diff --git a/tests/py3/utils.py b/tests/py3/utils.py index 807cea07..9416e100 100644 --- a/tests/py3/utils.py +++ b/tests/py3/utils.py @@ -1,9 +1,11 @@ import json -from starlette.requests import Request -from starlette.responses import Response +from starlette.requests import Request as ASGIRequest +from starlette.responses import Response as ASGIResponse +from werkzeug.wrappers import Request as WSGIRequest +from werkzeug.wrappers import Response as WSGIResponse -class MockDispatch: +class AsyncMockDispatch: def __init__(self, body=b'', status_code=200, headers=None, assert_func=None): if headers is None: @@ -22,12 +24,12 @@ def __init__(self, body=b'', status_code=200, headers=None, self.assert_func = assert_func async def __call__(self, scope, receive, send): - request = Request(scope, receive=receive) + request = ASGIRequest(scope, receive=receive) if self.assert_func: await self.assert_func(request) - response = Response( + response = ASGIResponse( status_code=self.status_code, content=self.body, headers=self.headers, @@ -35,12 +37,12 @@ async def __call__(self, scope, receive, send): await response(scope, receive, send) -class PathMapDispatch: +class AsyncPathMapDispatch: def __init__(self, path_maps): self.path_maps = path_maps async def __call__(self, scope, receive, send): - request = Request(scope, receive=receive) + request = ASGIRequest(scope, receive=receive) rv = self.path_maps[request.url.path] status_code = rv.get('status_code', 200) @@ -54,9 +56,66 @@ async def __call__(self, scope, receive, send): body = body.encode() headers['Content-Type'] = 'application/x-www-form-urlencoded' - response = Response( + response = ASGIResponse( status_code=status_code, content=body, headers=headers, ) await response(scope, receive, send) + +class MockDispatch: + def __init__(self, body=b'', status_code=200, headers=None, + assert_func=None): + if headers is None: + headers = {} + if isinstance(body, dict): + body = json.dumps(body).encode() + headers['Content-Type'] = 'application/json' + else: + if isinstance(body, str): + body = body.encode() + headers['Content-Type'] = 'application/x-www-form-urlencoded' + + self.body = body + self.status_code = status_code + self.headers = headers + self.assert_func = assert_func + + def __call__(self, environ, start_response): + request = WSGIRequest(environ) + + if self.assert_func: + self.assert_func(request) + + response = WSGIResponse( + status=self.status_code, + response=self.body, + headers=self.headers, + ) + return response(environ, start_response) + + +class PathMapDispatch: + def __init__(self, path_maps): + self.path_maps = path_maps + + def __call__(self, environ, start_response): + request = WSGIRequest(environ) + + rv = self.path_maps[request.url.path] + status_code = rv.get('status_code', 200) + body = rv.get('body', b'') + headers = rv.get('headers', {}) + if isinstance(body, dict): + body = json.dumps(body).encode() + headers['Content-Type'] = 'application/json' + else: + if isinstance(body, str): + body = body.encode() + headers['Content-Type'] = 'application/x-www-form-urlencoded' + response = WSGIResponse( + status=status_code, + response=body, + headers=headers, + ) + return response(environ, start_response) From 31ec53768cb60215f83e281fda2efd01d9d7e260 Mon Sep 17 00:00:00 2001 From: Ber Zoidberg Date: Thu, 10 Sep 2020 00:15:32 -0700 Subject: [PATCH 3/6] update to httpx 0.14.3 --- authlib/integrations/httpx_client/assertion_client.py | 5 +++-- authlib/integrations/httpx_client/oauth2_client.py | 5 +++-- tests/py3/test_httpx_client/test_assertion_client.py | 1 + tests/py3/test_httpx_client/test_async_assertion_client.py | 1 + tests/py3/test_httpx_client/test_async_oauth2_client.py | 2 +- tests/py3/test_httpx_client/test_oauth2_client.py | 2 +- tox.ini | 2 +- 7 files changed, 11 insertions(+), 7 deletions(-) diff --git a/authlib/integrations/httpx_client/assertion_client.py b/authlib/integrations/httpx_client/assertion_client.py index 7e5a75be..144f685f 100644 --- a/authlib/integrations/httpx_client/assertion_client.py +++ b/authlib/integrations/httpx_client/assertion_client.py @@ -1,4 +1,5 @@ from httpx import AsyncClient, Client +from httpx._config import UnsetType from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient from authlib.oauth2.rfc7523 import JWTBearerGrant from authlib.oauth2 import OAuth2Error @@ -31,7 +32,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No async def request(self, method, url, withhold_token=False, auth=None, **kwargs): """Send request with auto refresh token feature.""" - if not withhold_token and auth is None: + if not withhold_token and isinstance(auth, UnsetType): if not self.token or self.token.is_expired(): await self.refresh_token() @@ -75,7 +76,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No def request(self, method, url, withhold_token=False, auth=None, **kwargs): """Send request with auto refresh token feature.""" - if not withhold_token and auth is None: + if not withhold_token and isinstance(auth, UnsetType): if not self.token or self.token.is_expired(): self.refresh_token() diff --git a/authlib/integrations/httpx_client/oauth2_client.py b/authlib/integrations/httpx_client/oauth2_client.py index 83277abf..09769995 100644 --- a/authlib/integrations/httpx_client/oauth2_client.py +++ b/authlib/integrations/httpx_client/oauth2_client.py @@ -1,5 +1,6 @@ import typing from httpx import AsyncClient, Auth, Client, Request, Response +from httpx._config import UnsetType from authlib.common.urls import url_decode from authlib.oauth2.client import OAuth2Client as _OAuth2Client from authlib.oauth2.auth import ClientAuth, TokenAuth @@ -72,7 +73,7 @@ def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) async def request(self, method, url, withhold_token=False, auth=None, **kwargs): - if not withhold_token and auth is None: + if not withhold_token and isinstance(auth, UnsetType): if not self.token: raise MissingTokenError() @@ -170,7 +171,7 @@ def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) def request(self, method, url, withhold_token=False, auth=None, **kwargs): - if not withhold_token and auth is None: + if not withhold_token and isinstance(auth, UnsetType): if not self.token: raise MissingTokenError() diff --git a/tests/py3/test_httpx_client/test_assertion_client.py b/tests/py3/test_httpx_client/test_assertion_client.py index 1d01cdb5..91b05297 100644 --- a/tests/py3/test_httpx_client/test_assertion_client.py +++ b/tests/py3/test_httpx_client/test_assertion_client.py @@ -59,6 +59,7 @@ def test_without_alg(): subject='foo', audience='foo', key='secret', + app=MockDispatch(default_token) ) as client: with pytest.raises(ValueError): client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_assertion_client.py b/tests/py3/test_httpx_client/test_async_assertion_client.py index fb25d9aa..46286bff 100644 --- a/tests/py3/test_httpx_client/test_async_assertion_client.py +++ b/tests/py3/test_httpx_client/test_async_assertion_client.py @@ -59,6 +59,7 @@ async def test_without_alg(): subject='foo', audience='foo', key='secret', + app=AsyncMockDispatch() ) as client: with pytest.raises(ValueError): await client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_oauth2_client.py b/tests/py3/test_httpx_client/test_async_oauth2_client.py index 3c2613df..b4aa9b7e 100644 --- a/tests/py3/test_httpx_client/test_async_oauth2_client.py +++ b/tests/py3/test_httpx_client/test_async_oauth2_client.py @@ -387,6 +387,6 @@ async def test_revoke_token(): @pytest.mark.asyncio async def test_request_without_token(): - async with AsyncOAuth2Client('a') as client: + async with AsyncOAuth2Client('a', app=AsyncMockDispatch()) as client: with pytest.raises(OAuthError): await client.get('https://i.b/token') diff --git a/tests/py3/test_httpx_client/test_oauth2_client.py b/tests/py3/test_httpx_client/test_oauth2_client.py index e5bc6b35..7bd39387 100644 --- a/tests/py3/test_httpx_client/test_oauth2_client.py +++ b/tests/py3/test_httpx_client/test_oauth2_client.py @@ -369,6 +369,6 @@ def test_revoke_token(): def test_request_without_token(): - with OAuth2Client('a') as client: + with OAuth2Client('a', app=MockDispatch()) as client: with pytest.raises(OAuthError): client.get('https://i.b/token') diff --git a/tox.ini b/tox.ini index bd37557e..a8c5a354 100644 --- a/tox.ini +++ b/tox.ini @@ -12,7 +12,7 @@ deps = py27: unittest2 flask: Flask flask: Flask-SQLAlchemy - py3: httpx==0.14.1 + py3: httpx==0.14.3 py3: pytest-asyncio py3: starlette py3: itsdangerous From 99f6cb5dc8e0c42455504ff38488dba8e9a23d55 Mon Sep 17 00:00:00 2001 From: Ber Zoidberg Date: Thu, 10 Sep 2020 00:24:43 -0700 Subject: [PATCH 4/6] update to work with latest master --- docs/client/httpx.rst | 3 +++ tests/py3/test_httpx_client/test_async_oauth2_client.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/client/httpx.rst b/docs/client/httpx.rst index d97f6bd2..48412f1f 100644 --- a/docs/client/httpx.rst +++ b/docs/client/httpx.rst @@ -15,6 +15,9 @@ OAuth for HTTPX HTTPX is a next-generation HTTP client for Python. Authlib enables OAuth 1.0 and OAuth 2.0 for HTTPX with its async versions: +* :class:`OAuth1Client` +* :class:`OAuth2Client` +* :class:`AssertionClient` * :class:`AsyncOAuth1Client` * :class:`AsyncOAuth2Client` * :class:`AsyncAssertionClient` diff --git a/tests/py3/test_httpx_client/test_async_oauth2_client.py b/tests/py3/test_httpx_client/test_async_oauth2_client.py index e977067d..2333d2e5 100644 --- a/tests/py3/test_httpx_client/test_async_oauth2_client.py +++ b/tests/py3/test_httpx_client/test_async_oauth2_client.py @@ -382,7 +382,7 @@ async def _update_token(token, refresh_token=None, access_token=None): expires_at=100 ) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client( 'foo', token=old_token, token_endpoint='https://i.b/token', From f0250428775ea75a8ac52012e04fe77233750e44 Mon Sep 17 00:00:00 2001 From: Ber Zoidberg Date: Wed, 9 Sep 2020 21:59:31 -0700 Subject: [PATCH 5/6] py3 is no longer just async add support for httpx sync APIs; fix bug in httpx oauth1 support where content-length header si not specific update to httpx 0.14.3 Remove deprecated and purely cosmetical argument providing_args in Django integrations add test and fix for unintentional parallel token refreshes update to work with latest master --- .github/workflows/python.yml | 6 +- .../integrations/django_client/integration.py | 2 +- authlib/integrations/django_oauth2/signals.py | 6 +- authlib/integrations/httpx_client/__init__.py | 6 +- .../httpx_client/assertion_client.py | 49 ++- .../httpx_client/oauth1_client.py | 55 ++- .../httpx_client/oauth2_client.py | 136 ++++++- docs/client/httpx.rst | 3 + .../test_assertion_client.py | 65 +++ .../test_async_assertion_client.py | 7 +- .../test_async_oauth1_client.py | 16 +- .../test_async_oauth2_client.py | 58 ++- .../test_httpx_client/test_oauth1_client.py | 157 ++++++++ .../test_httpx_client/test_oauth2_client.py | 374 ++++++++++++++++++ .../test_oauth_client.py | 16 +- .../test_starlette_client/test_user_mixin.py | 6 +- tests/py3/utils.py | 75 +++- tox.ini | 13 +- 18 files changed, 972 insertions(+), 78 deletions(-) create mode 100644 tests/py3/test_httpx_client/test_assertion_client.py create mode 100644 tests/py3/test_httpx_client/test_oauth1_client.py create mode 100644 tests/py3/test_httpx_client/test_oauth2_client.py diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 42de65de..4db46eb9 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -24,11 +24,11 @@ jobs: - version: 2.7 toxenv: py27,py27-flask - version: 3.6 - toxenv: py36,flask,django,async + toxenv: py36,flask,django,py3 - version: 3.7 - toxenv: py37,flask,django,async + toxenv: py37,flask,django,py3 - version: 3.8 - toxenv: py38,flask,django,async + toxenv: py38,flask,django,py3 steps: - uses: actions/checkout@v2 diff --git a/authlib/integrations/django_client/integration.py b/authlib/integrations/django_client/integration.py index ea1dc4be..79d7dbde 100644 --- a/authlib/integrations/django_client/integration.py +++ b/authlib/integrations/django_client/integration.py @@ -5,7 +5,7 @@ from ..requests_client import OAuth1Session, OAuth2Session -token_update = Signal(providing_args=['name', 'token', 'refresh_token', 'access_token']) +token_update = Signal() class DjangoIntegration(FrameworkIntegration): diff --git a/authlib/integrations/django_oauth2/signals.py b/authlib/integrations/django_oauth2/signals.py index 76f448cf..0e9c2659 100644 --- a/authlib/integrations/django_oauth2/signals.py +++ b/authlib/integrations/django_oauth2/signals.py @@ -2,10 +2,10 @@ #: signal when client is authenticated -client_authenticated = Signal(providing_args=['client', 'grant']) +client_authenticated = Signal() #: signal when token is revoked -token_revoked = Signal(providing_args=['token', 'client']) +token_revoked = Signal() #: signal when token is authenticated -token_authenticated = Signal(providing_args=['token']) +token_authenticated = Signal() diff --git a/authlib/integrations/httpx_client/__init__.py b/authlib/integrations/httpx_client/__init__.py index 21d641e8..6b4b9d67 100644 --- a/authlib/integrations/httpx_client/__init__.py +++ b/authlib/integrations/httpx_client/__init__.py @@ -6,12 +6,12 @@ SIGNATURE_TYPE_QUERY, SIGNATURE_TYPE_BODY, ) -from .oauth1_client import OAuth1Auth, AsyncOAuth1Client +from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client from .oauth2_client import ( - OAuth2Auth, OAuth2ClientAuth, + OAuth2Auth, OAuth2Client, OAuth2ClientAuth, AsyncOAuth2Client, ) -from .assertion_client import AsyncAssertionClient +from .assertion_client import AssertionClient, AsyncAssertionClient from ..base_client import OAuthError diff --git a/authlib/integrations/httpx_client/assertion_client.py b/authlib/integrations/httpx_client/assertion_client.py index 17228d32..144f685f 100644 --- a/authlib/integrations/httpx_client/assertion_client.py +++ b/authlib/integrations/httpx_client/assertion_client.py @@ -1,4 +1,5 @@ -from httpx import AsyncClient +from httpx import AsyncClient, Client +from httpx._config import UnsetType from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient from authlib.oauth2.rfc7523 import JWTBearerGrant from authlib.oauth2 import OAuth2Error @@ -31,7 +32,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No async def request(self, method, url, withhold_token=False, auth=None, **kwargs): """Send request with auto refresh token feature.""" - if not withhold_token and auth is None: + if not withhold_token and isinstance(auth, UnsetType): if not self.token or self.token.is_expired(): await self.refresh_token() @@ -51,3 +52,47 @@ async def _refresh_token(self, data): ) self.token = token return self.token + +class AssertionClient(_AssertionClient, Client): + token_auth_class = OAuth2Auth + JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE + ASSERTION_METHODS = { + JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign, + } + DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE + + def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None, + claims=None, token_placement='header', scope=None, **kwargs): + + client_kwargs = extract_client_kwargs(kwargs) + Client.__init__(self, **client_kwargs) + + _AssertionClient.__init__( + self, session=None, + token_endpoint=token_endpoint, issuer=issuer, subject=subject, + audience=audience, grant_type=grant_type, claims=claims, + token_placement=token_placement, scope=scope, **kwargs + ) + + def request(self, method, url, withhold_token=False, auth=None, **kwargs): + """Send request with auto refresh token feature.""" + if not withhold_token and isinstance(auth, UnsetType): + if not self.token or self.token.is_expired(): + self.refresh_token() + + auth = self.token_auth + return super(AssertionClient, self).request( + method, url, auth=auth, **kwargs) + + def _refresh_token(self, data): + resp = self.request( + 'POST', self.token_endpoint, data=data, withhold_token=True) + + token = resp.json() + if 'error' in token: + raise OAuth2Error( + error=token['error'], + description=token.get('error_description') + ) + self.token = token + return self.token diff --git a/authlib/integrations/httpx_client/oauth1_client.py b/authlib/integrations/httpx_client/oauth1_client.py index 7bb1ccb0..13483ec4 100644 --- a/authlib/integrations/httpx_client/oauth1_client.py +++ b/authlib/integrations/httpx_client/oauth1_client.py @@ -1,5 +1,5 @@ import typing -from httpx import AsyncClient, Auth, Request, Response +from httpx import AsyncClient, Auth, Client, Request, Response from authlib.oauth1 import ( SIGNATURE_HMAC_SHA1, SIGNATURE_TYPE_HEADER, @@ -18,6 +18,7 @@ class OAuth1Auth(Auth, ClientAuth): def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]: url, headers, body = self.prepare( request.method, str(request.url), request.headers, request.content) + headers['Content-Length'] = str(len(body)) yield Request(method=request.method, url=url, headers=headers, data=body) @@ -72,3 +73,55 @@ async def _fetch_token(self, url, **kwargs): @staticmethod def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) + +class OAuth1Client(_OAuth1Client, Client): + auth_class = OAuth1Auth + + def __init__(self, client_id, client_secret=None, + token=None, token_secret=None, + redirect_uri=None, rsa_key=None, verifier=None, + signature_method=SIGNATURE_HMAC_SHA1, + signature_type=SIGNATURE_TYPE_HEADER, + force_include_body=False, **kwargs): + + _client_kwargs = extract_client_kwargs(kwargs) + Client.__init__(self, **_client_kwargs) + + _OAuth1Client.__init__( + self, None, + client_id=client_id, client_secret=client_secret, + token=token, token_secret=token_secret, + redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, + signature_method=signature_method, signature_type=signature_type, + force_include_body=force_include_body, **kwargs) + + def fetch_access_token(self, url, verifier=None, **kwargs): + """Method for fetching an access token from the token endpoint. + + This is the final step in the OAuth 1 workflow. An access token is + obtained using all previously obtained credentials, including the + verifier from the authorization step. + + :param url: Access Token endpoint. + :param verifier: A verifier string to prove authorization was granted. + :param kwargs: Extra parameters to include for fetching access token. + :return: A token dict. + """ + if verifier: + self.auth.verifier = verifier + if not self.auth.verifier: + self.handle_error('missing_verifier', 'Missing "verifier" value') + token = self._fetch_token(url, **kwargs) + self.auth.verifier = None + return token + + def _fetch_token(self, url, **kwargs): + resp = self.post(url, **kwargs) + text = resp.read() + token = self.parse_response_token(resp.status_code, to_unicode(text)) + self.token = token + return token + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) diff --git a/authlib/integrations/httpx_client/oauth2_client.py b/authlib/integrations/httpx_client/oauth2_client.py index 9d88f127..6d1411a2 100644 --- a/authlib/integrations/httpx_client/oauth2_client.py +++ b/authlib/integrations/httpx_client/oauth2_client.py @@ -1,5 +1,7 @@ +import asyncio import typing -from httpx import AsyncClient, Auth, Request, Response +from httpx import AsyncClient, Auth, Client, Request, Response +from httpx._config import UnsetType from authlib.common.urls import url_decode from authlib.oauth2.client import OAuth2Client as _OAuth2Client from authlib.oauth2.auth import ClientAuth, TokenAuth @@ -57,6 +59,11 @@ def __init__(self, client_id=None, client_secret=None, client_kwargs = self._extract_session_request_params(kwargs) AsyncClient.__init__(self, **client_kwargs) + # We use a "reverse" Event to synchronize coroutines to prevent + # multiple concurrent attempts to refresh the same token + self._token_refresh_event = asyncio.Event() + self._token_refresh_event.set() + _OAuth2Client.__init__( self, session=None, client_id=client_id, client_secret=client_secret, @@ -72,7 +79,7 @@ def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) async def request(self, method, url, withhold_token=False, auth=None, **kwargs): - if not withhold_token and auth is None: + if not withhold_token and isinstance(auth, UnsetType): if not self.token: raise MissingTokenError() @@ -85,22 +92,127 @@ async def request(self, method, url, withhold_token=False, auth=None, **kwargs): method, url, auth=auth, **kwargs) async def ensure_active_token(self): + if self._token_refresh_event.is_set(): + # Unset the event so other coroutines don't try to update the token + self._token_refresh_event.clear() + refresh_token = self.token.get('refresh_token') + url = self.metadata.get('token_endpoint') + if refresh_token and url: + await self.refresh_token(url, refresh_token=refresh_token) + elif self.metadata.get('grant_type') == 'client_credentials': + access_token = self.token['access_token'] + token = await self.fetch_token(url, grant_type='client_credentials') + if self.update_token: + await self.update_token(token, access_token=access_token) + else: + raise InvalidTokenError() + # Notify coroutines that token is refreshed + self._token_refresh_event.set() + return + await self._token_refresh_event.wait() # wait until the token is ready + + async def _fetch_token(self, url, body='', headers=None, auth=None, + method='POST', **kwargs): + if method.upper() == 'POST': + resp = await self.post( + url, data=dict(url_decode(body)), headers=headers, + auth=auth, **kwargs) + else: + if '?' in url: + url = '&'.join([url, body]) + else: + url = '?'.join([url, body]) + resp = await self.get(url, headers=headers, auth=auth, **kwargs) + + for hook in self.compliance_hook['access_token_response']: + resp = hook(resp) + + return self.parse_response_token(resp.json()) + + async def _refresh_token(self, url, refresh_token=None, body='', + headers=None, auth=None, **kwargs): + resp = await self.post( + url, data=dict(url_decode(body)), headers=headers, + auth=auth, **kwargs) + + for hook in self.compliance_hook['refresh_token_response']: + resp = hook(resp) + + token = self.parse_response_token(resp.json()) + if 'refresh_token' not in token: + self.token['refresh_token'] = refresh_token + + if self.update_token: + await self.update_token(self.token, refresh_token=refresh_token) + + return self.token + + def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): + return self.post( + url, data=dict(url_decode(body)), + headers=headers, auth=auth, **kwargs) + +class OAuth2Client(_OAuth2Client, Client): + SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS + + client_auth_class = OAuth2ClientAuth + token_auth_class = OAuth2Auth + + def __init__(self, client_id=None, client_secret=None, + token_endpoint_auth_method=None, + revocation_endpoint_auth_method=None, + scope=None, redirect_uri=None, + token=None, token_placement='header', + update_token=None, **kwargs): + + # extract httpx.Client kwargs + client_kwargs = self._extract_session_request_params(kwargs) + Client.__init__(self, **client_kwargs) + + _OAuth2Client.__init__( + self, session=None, + client_id=client_id, client_secret=client_secret, + token_endpoint_auth_method=token_endpoint_auth_method, + revocation_endpoint_auth_method=revocation_endpoint_auth_method, + scope=scope, redirect_uri=redirect_uri, + token=token, token_placement=token_placement, + update_token=update_token, **kwargs + ) + + @staticmethod + def handle_error(error_type, error_description): + raise OAuthError(error_type, error_description) + + def request(self, method, url, withhold_token=False, auth=None, **kwargs): + if not withhold_token and isinstance(auth, UnsetType): + if not self.token: + raise MissingTokenError() + + if self.token.is_expired(): + self.ensure_active_token() + + auth = self.token_auth + + return super(OAuth2Client, self).request( + method, url, auth=auth, **kwargs) + + def ensure_active_token(self): refresh_token = self.token.get('refresh_token') url = self.metadata.get('token_endpoint') if refresh_token and url: - await self.refresh_token(url, refresh_token=refresh_token) + self.refresh_token(url, refresh_token=refresh_token) elif self.metadata.get('grant_type') == 'client_credentials': access_token = self.token['access_token'] - token = await self.fetch_token(url, grant_type='client_credentials') + token = self.fetch_token(url, grant_type='client_credentials') if self.update_token: - await self.update_token(token, access_token=access_token) + self.update_token(token, access_token=access_token) else: raise InvalidTokenError() - async def _fetch_token(self, url, body='', headers=None, auth=None, - method='POST', **kwargs): + def _fetch_token(self, url, body='', headers=None, auth=None, + method='POST', **kwargs): if method.upper() == 'POST': - resp = await self.post( + resp = self.post( url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs) else: @@ -108,16 +220,16 @@ async def _fetch_token(self, url, body='', headers=None, auth=None, url = '&'.join([url, body]) else: url = '?'.join([url, body]) - resp = await self.get(url, headers=headers, auth=auth, **kwargs) + resp = self.get(url, headers=headers, auth=auth, **kwargs) for hook in self.compliance_hook['access_token_response']: resp = hook(resp) return self.parse_response_token(resp.json()) - async def _refresh_token(self, url, refresh_token=None, body='', + def _refresh_token(self, url, refresh_token=None, body='', headers=None, auth=None, **kwargs): - resp = await self.post( + resp = self.post( url, data=dict(url_decode(body)), headers=headers, auth=auth, **kwargs) @@ -129,7 +241,7 @@ async def _refresh_token(self, url, refresh_token=None, body='', self.token['refresh_token'] = refresh_token if self.update_token: - await self.update_token(self.token, refresh_token=refresh_token) + self.update_token(self.token, refresh_token=refresh_token) return self.token diff --git a/docs/client/httpx.rst b/docs/client/httpx.rst index d97f6bd2..48412f1f 100644 --- a/docs/client/httpx.rst +++ b/docs/client/httpx.rst @@ -15,6 +15,9 @@ OAuth for HTTPX HTTPX is a next-generation HTTP client for Python. Authlib enables OAuth 1.0 and OAuth 2.0 for HTTPX with its async versions: +* :class:`OAuth1Client` +* :class:`OAuth2Client` +* :class:`AssertionClient` * :class:`AsyncOAuth1Client` * :class:`AsyncOAuth2Client` * :class:`AsyncAssertionClient` diff --git a/tests/py3/test_httpx_client/test_assertion_client.py b/tests/py3/test_httpx_client/test_assertion_client.py new file mode 100644 index 00000000..91b05297 --- /dev/null +++ b/tests/py3/test_httpx_client/test_assertion_client.py @@ -0,0 +1,65 @@ +import time +import pytest +from authlib.integrations.httpx_client import AssertionClient +from tests.py3.utils import MockDispatch + + +default_token = { + 'token_type': 'Bearer', + 'access_token': 'a', + 'refresh_token': 'b', + 'expires_in': '3600', + 'expires_at': int(time.time()) + 3600, +} + + +@pytest.mark.asyncio +def test_refresh_token(): + def verifier(request): + content = request.form + if str(request.url) == 'https://i.b/token': + assert 'assertion' in content + + with AssertionClient( + 'https://i.b/token', + grant_type=AssertionClient.JWT_BEARER_GRANT_TYPE, + issuer='foo', + subject='foo', + audience='foo', + alg='HS256', + key='secret', + app=MockDispatch(default_token, assert_func=verifier) + ) as client: + client.get('https://i.b') + + # trigger more case + now = int(time.time()) + with AssertionClient( + 'https://i.b/token', + issuer='foo', + subject=None, + audience='foo', + issued_at=now, + expires_at=now + 3600, + header={'alg': 'HS256'}, + key='secret', + scope='email', + claims={'test_mode': 'true'}, + app=MockDispatch(default_token, assert_func=verifier) + ) as client: + client.get('https://i.b') + client.get('https://i.b') + + +@pytest.mark.asyncio +def test_without_alg(): + with AssertionClient( + 'https://i.b/token', + issuer='foo', + subject='foo', + audience='foo', + key='secret', + app=MockDispatch(default_token) + ) as client: + with pytest.raises(ValueError): + client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_assertion_client.py b/tests/py3/test_httpx_client/test_async_assertion_client.py index fa51e14a..46286bff 100644 --- a/tests/py3/test_httpx_client/test_async_assertion_client.py +++ b/tests/py3/test_httpx_client/test_async_assertion_client.py @@ -1,7 +1,7 @@ import time import pytest from authlib.integrations.httpx_client import AsyncAssertionClient -from tests.py3.utils import MockDispatch +from tests.py3.utils import AsyncMockDispatch default_token = { @@ -28,7 +28,7 @@ async def verifier(request): audience='foo', alg='HS256', key='secret', - app=MockDispatch(default_token, assert_func=verifier) + app=AsyncMockDispatch(default_token, assert_func=verifier) ) as client: await client.get('https://i.b') @@ -45,7 +45,7 @@ async def verifier(request): key='secret', scope='email', claims={'test_mode': 'true'}, - app=MockDispatch(default_token, assert_func=verifier) + app=AsyncMockDispatch(default_token, assert_func=verifier) ) as client: await client.get('https://i.b') await client.get('https://i.b') @@ -59,6 +59,7 @@ async def test_without_alg(): subject='foo', audience='foo', key='secret', + app=AsyncMockDispatch() ) as client: with pytest.raises(ValueError): await client.get('https://i.b') diff --git a/tests/py3/test_httpx_client/test_async_oauth1_client.py b/tests/py3/test_httpx_client/test_async_oauth1_client.py index bf8542d8..75703567 100644 --- a/tests/py3/test_httpx_client/test_async_oauth1_client.py +++ b/tests/py3/test_httpx_client/test_async_oauth1_client.py @@ -5,7 +5,7 @@ SIGNATURE_TYPE_BODY, SIGNATURE_TYPE_QUERY, ) -from tests.py3.utils import MockDispatch +from tests.py3.utils import AsyncMockDispatch oauth_url = 'https://example.com/oauth' @@ -19,7 +19,7 @@ async def assert_func(request): assert 'oauth_consumer_key="id"' in auth_header assert 'oauth_signature=' in auth_header - app = MockDispatch(request_token, assert_func=assert_func) + app = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client('id', 'secret', app=app) as client: response = await client.fetch_request_token(oauth_url) @@ -38,7 +38,7 @@ async def assert_func(request): assert b'oauth_consumer_key=id' in content assert b'&oauth_signature=' in content - mock_response = MockDispatch(request_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', signature_type=SIGNATURE_TYPE_BODY, @@ -61,7 +61,7 @@ async def assert_func(request): assert 'oauth_consumer_key=id' in url assert '&oauth_signature=' in url - mock_response = MockDispatch(request_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', signature_type=SIGNATURE_TYPE_QUERY, @@ -83,7 +83,7 @@ async def assert_func(request): assert 'oauth_consumer_key="id"' in auth_header assert 'oauth_signature=' in auth_header - mock_response = MockDispatch(request_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(request_token, assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', app=mock_response, @@ -98,7 +98,7 @@ async def assert_func(request): @pytest.mark.asyncio async def test_get_via_header(): - mock_response = MockDispatch(b'hello') + mock_response = AsyncMockDispatch(b'hello') async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', app=mock_response, @@ -121,7 +121,7 @@ async def assert_func(request): assert b'oauth_consumer_key=id' in content assert b'oauth_signature=' in content - mock_response = MockDispatch(b'hello', assert_func=assert_func) + mock_response = AsyncMockDispatch(b'hello', assert_func=assert_func) async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', signature_type=SIGNATURE_TYPE_BODY, @@ -138,7 +138,7 @@ async def assert_func(request): @pytest.mark.asyncio async def test_get_via_query(): - mock_response = MockDispatch(b'hello') + mock_response = AsyncMockDispatch(b'hello') async with AsyncOAuth1Client( 'id', 'secret', token='foo', token_secret='bar', signature_type=SIGNATURE_TYPE_QUERY, diff --git a/tests/py3/test_httpx_client/test_async_oauth2_client.py b/tests/py3/test_httpx_client/test_async_oauth2_client.py index 54b3a693..2333d2e5 100644 --- a/tests/py3/test_httpx_client/test_async_oauth2_client.py +++ b/tests/py3/test_httpx_client/test_async_oauth2_client.py @@ -1,3 +1,4 @@ +import asyncio import mock import time import pytest @@ -8,7 +9,7 @@ OAuthError, AsyncOAuth2Client, ) -from tests.py3.utils import MockDispatch +from tests.py3.utils import AsyncMockDispatch default_token = { @@ -27,7 +28,7 @@ async def assert_func(request): auth_header = request.headers.get('authorization') assert auth_header == token - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', token=default_token, @@ -45,7 +46,7 @@ async def assert_func(request): content = await request.body() assert default_token['access_token'] in content.decode() - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', token=default_token, @@ -63,7 +64,7 @@ async def test_add_token_to_uri(): async def assert_func(request): assert default_token['access_token'] in str(request.url) - mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + mock_response = AsyncMockDispatch({'a': 'a'}, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', token=default_token, @@ -123,7 +124,7 @@ async def assert_func(request): assert 'client_id=' in content assert 'grant_type=authorization_code' in content - mock_response = MockDispatch(default_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', app=mock_response) as client: token = await client.fetch_token(url, authorization_response='https://i.b/?code=v') assert token == default_token @@ -136,7 +137,7 @@ async def assert_func(request): token = await client.fetch_token(url, code='v') assert token == default_token - mock_response = MockDispatch({'error': 'invalid_request'}) + mock_response = AsyncMockDispatch({'error': 'invalid_request'}) async with AsyncOAuth2Client('foo', app=mock_response) as client: with pytest.raises(OAuthError): await client.fetch_token(url) @@ -152,7 +153,7 @@ async def assert_func(request): assert 'client_id=' in url assert 'grant_type=authorization_code' in url - mock_response = MockDispatch(default_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', app=mock_response) as client: authorization_response = 'https://i.b/?code=v' token = await client.fetch_token( @@ -183,7 +184,7 @@ async def assert_func(request): assert 'client_secret=bar' in content assert 'grant_type=authorization_code' in content - mock_response = MockDispatch(default_token, assert_func=assert_func) + mock_response = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client( 'foo', 'bar', token_endpoint_auth_method='client_secret_post', @@ -203,7 +204,7 @@ def _access_token_response_hook(resp): return resp access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client('foo', token=default_token, app=app) as sess: sess.register_compliance_hook( 'access_token_response', @@ -224,7 +225,7 @@ async def assert_func(request): assert 'scope=profile' in content assert 'grant_type=password' in content - app = MockDispatch(default_token, assert_func=assert_func) + app = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', scope='profile', app=app) as sess: token = await sess.fetch_token(url, username='v', password='v') assert token == default_token @@ -244,7 +245,7 @@ async def assert_func(request): assert 'scope=profile' in content assert 'grant_type=client_credentials' in content - app = MockDispatch(default_token, assert_func=assert_func) + app = AsyncMockDispatch(default_token, assert_func=assert_func) async with AsyncOAuth2Client('foo', scope='profile', app=app) as sess: token = await sess.fetch_token(url) assert token == default_token @@ -262,7 +263,7 @@ async def test_cleans_previous_token_before_fetching_new_one(): new_token['expires_at'] = now + 3600 url = 'https://example.com/token' - app = MockDispatch(new_token) + app = AsyncMockDispatch(new_token) with mock.patch('time.time', lambda: now): async with AsyncOAuth2Client('foo', token=default_token, app=app) as sess: assert await sess.fetch_token(url) == new_token @@ -288,7 +289,7 @@ async def _update_token(token, refresh_token=None, access_token=None): token_type='bearer', expires_at=100 ) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client( 'foo', token=old_token, token_endpoint='https://i.b/token', update_token=update_token, app=app @@ -324,7 +325,7 @@ async def _update_token(token, refresh_token=None, access_token=None): expires_at=100 ) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client( 'foo', token=old_token, @@ -358,7 +359,7 @@ async def _update_token(token, refresh_token=None, access_token=None): expires_at=100 ) - app = MockDispatch(default_token) + app = AsyncMockDispatch(default_token) async with AsyncOAuth2Client( 'foo', token=old_token, token_endpoint='https://i.b/token', @@ -368,11 +369,34 @@ async def _update_token(token, refresh_token=None, access_token=None): await client.post('https://i.b/user', json={'foo': 'bar'}) assert update_token.called is True +@pytest.mark.asyncio +async def test_auto_refresh_token4(): + async def _update_token(token, refresh_token=None, access_token=None): + await asyncio.sleep(0.1) # artificial sleep to force other coroutines to wake + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token='a', + token_type='bearer', + expires_at=100 + ) + + app = AsyncMockDispatch(default_token) + + async with AsyncOAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, grant_type='client_credentials', + app=app, + ) as client: + coroutines = [client.get('https://i.b/user') for x in range(10)] + await asyncio.gather(*coroutines) + update_token.assert_called_once() @pytest.mark.asyncio async def test_revoke_token(): answer = {'status': 'ok'} - app = MockDispatch(answer) + app = AsyncMockDispatch(answer) async with AsyncOAuth2Client('a', app=app) as sess: resp = await sess.revoke_token('https://i.b/token', 'hi') @@ -387,6 +411,6 @@ async def test_revoke_token(): @pytest.mark.asyncio async def test_request_without_token(): - async with AsyncOAuth2Client('a') as client: + async with AsyncOAuth2Client('a', app=AsyncMockDispatch()) as client: with pytest.raises(OAuthError): await client.get('https://i.b/token') diff --git a/tests/py3/test_httpx_client/test_oauth1_client.py b/tests/py3/test_httpx_client/test_oauth1_client.py new file mode 100644 index 00000000..a5f34df3 --- /dev/null +++ b/tests/py3/test_httpx_client/test_oauth1_client.py @@ -0,0 +1,157 @@ +import pytest +from authlib.integrations.httpx_client import ( + OAuthError, + OAuth1Client, + SIGNATURE_TYPE_BODY, + SIGNATURE_TYPE_QUERY, +) +from tests.py3.utils import MockDispatch + +oauth_url = 'https://example.com/oauth' + + +@pytest.mark.asyncio +def test_fetch_request_token_via_header(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert 'oauth_consumer_key="id"' in auth_header + assert 'oauth_signature=' in auth_header + + app = MockDispatch(request_token, assert_func=assert_func) + with OAuth1Client('id', 'secret', app=app) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +def test_fetch_request_token_via_body(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert auth_header is None + + content = request.form + assert content.get('oauth_consumer_key') == 'id' + assert 'oauth_signature' in content + + mock_response = MockDispatch(request_token, assert_func=assert_func) + + with OAuth1Client( + 'id', 'secret', signature_type=SIGNATURE_TYPE_BODY, + app=mock_response, + ) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +def test_fetch_request_token_via_query(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert auth_header is None + + url = str(request.url) + assert 'oauth_consumer_key=id' in url + assert '&oauth_signature=' in url + + mock_response = MockDispatch(request_token, assert_func=assert_func) + + with OAuth1Client( + 'id', 'secret', signature_type=SIGNATURE_TYPE_QUERY, + app=mock_response, + ) as client: + response = client.fetch_request_token(oauth_url) + + assert response == request_token + + +@pytest.mark.asyncio +def test_fetch_access_token(): + request_token = {'oauth_token': '1', 'oauth_token_secret': '2'} + + def assert_func(request): + auth_header = request.headers.get('authorization') + assert 'oauth_verifier="d"' in auth_header + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert 'oauth_signature=' in auth_header + + mock_response = MockDispatch(request_token, assert_func=assert_func) + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + app=mock_response, + ) as client: + with pytest.raises(OAuthError): + client.fetch_access_token(oauth_url) + + response = client.fetch_access_token(oauth_url, verifier='d') + + assert response == request_token + + +@pytest.mark.asyncio +def test_get_via_header(): + mock_response = MockDispatch(b'hello') + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + app=mock_response, + ) as client: + response = client.get('https://example.com/') + + assert response.content == b'hello' + request = response.request + auth_header = request.headers.get('authorization') + assert 'oauth_token="foo"' in auth_header + assert 'oauth_consumer_key="id"' in auth_header + assert 'oauth_signature=' in auth_header + + +@pytest.mark.asyncio +def test_get_via_body(): + def assert_func(request): + content = request.form + assert content.get('oauth_token') == 'foo' + assert content.get('oauth_consumer_key') == 'id' + assert 'oauth_signature' in content + + mock_response = MockDispatch(b'hello', assert_func=assert_func) + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + signature_type=SIGNATURE_TYPE_BODY, + app=mock_response, + ) as client: + response = client.post('https://example.com/') + + assert response.content == b'hello' + + request = response.request + auth_header = request.headers.get('authorization') + assert auth_header is None + + +@pytest.mark.asyncio +def test_get_via_query(): + mock_response = MockDispatch(b'hello') + with OAuth1Client( + 'id', 'secret', token='foo', token_secret='bar', + signature_type=SIGNATURE_TYPE_QUERY, + app=mock_response, + ) as client: + response = client.get('https://example.com/') + + assert response.content == b'hello' + request = response.request + auth_header = request.headers.get('authorization') + assert auth_header is None + + url = str(request.url) + assert 'oauth_token=foo' in url + assert 'oauth_consumer_key=id' in url + assert 'oauth_signature=' in url diff --git a/tests/py3/test_httpx_client/test_oauth2_client.py b/tests/py3/test_httpx_client/test_oauth2_client.py new file mode 100644 index 00000000..7bd39387 --- /dev/null +++ b/tests/py3/test_httpx_client/test_oauth2_client.py @@ -0,0 +1,374 @@ +import mock +import time +import pytest +from copy import deepcopy +from authlib.common.security import generate_token +from authlib.common.urls import url_encode +from authlib.integrations.httpx_client import ( + OAuthError, + OAuth2Client, +) +from tests.py3.utils import MockDispatch + + +default_token = { + 'token_type': 'Bearer', + 'access_token': 'a', + 'refresh_token': 'b', + 'expires_in': '3600', + 'expires_at': int(time.time()) + 3600, +} + + +def test_add_token_to_header(): + def assert_func(request): + token = 'Bearer ' + default_token['access_token'] + auth_header = request.headers.get('authorization') + assert auth_header == token + + mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + with OAuth2Client( + 'foo', + token=default_token, + app=mock_response + ) as client: + resp = client.get('https://i.b') + + data = resp.json() + assert data['a'] == 'a' + + +def test_add_token_to_body(): + def assert_func(request): + content = request.data + content = content.decode() + assert content == 'access_token=%s' % default_token['access_token'] + + mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + with OAuth2Client( + 'foo', + token=default_token, + token_placement='body', + app=mock_response + ) as client: + resp = client.get('https://i.b') + + data = resp.json() + assert data['a'] == 'a' + + +def test_add_token_to_uri(): + def assert_func(request): + assert default_token['access_token'] in str(request.url) + + mock_response = MockDispatch({'a': 'a'}, assert_func=assert_func) + with OAuth2Client( + 'foo', + token=default_token, + token_placement='uri', + app=mock_response + ) as client: + resp = client.get('https://i.b') + + data = resp.json() + assert data['a'] == 'a' + + +def test_create_authorization_url(): + url = 'https://example.com/authorize?foo=bar' + + sess = OAuth2Client(client_id='foo') + auth_url, state = sess.create_authorization_url(url) + assert state in auth_url + assert 'client_id=foo' in auth_url + assert 'response_type=code' in auth_url + + sess = OAuth2Client(client_id='foo', prompt='none') + auth_url, state = sess.create_authorization_url( + url, state='foo', redirect_uri='https://i.b', scope='profile') + assert state == 'foo' + assert 'i.b' in auth_url + assert 'profile' in auth_url + assert 'prompt=none' in auth_url + + +def test_code_challenge(): + sess = OAuth2Client('foo', code_challenge_method='S256') + + url = 'https://example.com/authorize' + auth_url, _ = sess.create_authorization_url( + url, code_verifier=generate_token(48)) + assert 'code_challenge=' in auth_url + assert 'code_challenge_method=S256' in auth_url + + +def test_token_from_fragment(): + sess = OAuth2Client('foo') + response_url = 'https://i.b/callback#' + url_encode(default_token.items()) + assert sess.token_from_fragment(response_url) == default_token + token = sess.fetch_token(authorization_response=response_url) + assert token == default_token + + +def test_fetch_token_post(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('code') == 'v' + assert content.get('client_id') == 'foo' + assert content.get('grant_type') == 'authorization_code' + + mock_response = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', app=mock_response) as client: + token = client.fetch_token(url, authorization_response='https://i.b/?code=v') + assert token == default_token + + with OAuth2Client( + 'foo', + token_endpoint_auth_method='none', + app=mock_response + ) as client: + token = client.fetch_token(url, code='v') + assert token == default_token + + mock_response = MockDispatch({'error': 'invalid_request'}) + with OAuth2Client('foo', app=mock_response) as client: + with pytest.raises(OAuthError): + client.fetch_token(url) + + +def test_fetch_token_get(): + url = 'https://example.com/token' + + def assert_func(request): + url = str(request.url) + assert 'code=v' in url + assert 'client_id=' in url + assert 'grant_type=authorization_code' in url + + mock_response = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', app=mock_response) as client: + authorization_response = 'https://i.b/?code=v' + token = client.fetch_token( + url, authorization_response=authorization_response, method='GET') + assert token == default_token + + with OAuth2Client( + 'foo', + token_endpoint_auth_method='none', + app=mock_response + ) as client: + token = client.fetch_token(url, code='v', method='GET') + assert token == default_token + + token = client.fetch_token(url + '?q=a', code='v', method='GET') + assert token == default_token + + +def test_token_auth_method_client_secret_post(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('code') == 'v' + assert content.get('client_id') == 'foo' + assert content.get('client_secret') == 'bar' + assert content.get('grant_type') == 'authorization_code' + + mock_response = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client( + 'foo', 'bar', + token_endpoint_auth_method='client_secret_post', + app=mock_response + ) as client: + token = client.fetch_token(url, code='v') + + assert token == default_token + + +def test_access_token_response_hook(): + url = 'https://example.com/token' + + def _access_token_response_hook(resp): + assert resp.json() == default_token + return resp + + access_token_response_hook = mock.Mock(side_effect=_access_token_response_hook) + app = MockDispatch(default_token) + with OAuth2Client('foo', token=default_token, app=app) as sess: + sess.register_compliance_hook( + 'access_token_response', + access_token_response_hook + ) + assert sess.fetch_token(url) == default_token + assert access_token_response_hook.called is True + + +def test_password_grant_type(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('username') == 'v' + assert content.get('scope') == 'profile' + assert content.get('grant_type') == 'password' + + app = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', scope='profile', app=app) as sess: + token = sess.fetch_token(url, username='v', password='v') + assert token == default_token + + token = sess.fetch_token( + url, username='v', password='v', grant_type='password') + assert token == default_token + + +def test_client_credentials_type(): + url = 'https://example.com/token' + + def assert_func(request): + content = request.form + assert content.get('scope') == 'profile' + assert content.get('grant_type') == 'client_credentials' + + app = MockDispatch(default_token, assert_func=assert_func) + with OAuth2Client('foo', scope='profile', app=app) as sess: + token = sess.fetch_token(url) + assert token == default_token + + token = sess.fetch_token(url, grant_type='client_credentials') + assert token == default_token + + +def test_cleans_previous_token_before_fetching_new_one(): + now = int(time.time()) + new_token = deepcopy(default_token) + past = now - 7200 + default_token['expires_at'] = past + new_token['expires_at'] = now + 3600 + url = 'https://example.com/token' + + app = MockDispatch(new_token) + with mock.patch('time.time', lambda: now): + with OAuth2Client('foo', token=default_token, app=app) as sess: + assert sess.fetch_token(url) == new_token + + +def test_token_status(): + token = dict(access_token='a', token_type='bearer', expires_at=100) + sess = OAuth2Client('foo', token=token) + assert sess.token.is_expired() is True + + +def test_auto_refresh_token(): + + def _update_token(token, refresh_token=None, access_token=None): + assert refresh_token == 'b' + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token='a', refresh_token='b', + token_type='bearer', expires_at=100 + ) + + app = MockDispatch(default_token) + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, app=app + ) as sess: + sess.get('https://i.b/user') + assert update_token.called is True + + old_token = dict( + access_token='a', + token_type='bearer', + expires_at=100 + ) + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, app=app + ) as sess: + with pytest.raises(OAuthError): + sess.get('https://i.b/user') + + +def test_auto_refresh_token2(): + + def _update_token(token, refresh_token=None, access_token=None): + assert access_token == 'a' + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token='a', + token_type='bearer', + expires_at=100 + ) + + app = MockDispatch(default_token) + + with OAuth2Client( + 'foo', token=old_token, + token_endpoint='https://i.b/token', + grant_type='client_credentials', + app=app, + ) as client: + client.get('https://i.b/user') + assert update_token.called is False + + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, grant_type='client_credentials', + app=app, + ) as client: + client.get('https://i.b/user') + assert update_token.called is True + + +def test_auto_refresh_token3(): + def _update_token(token, refresh_token=None, access_token=None): + assert access_token == 'a' + assert token == default_token + + update_token = mock.Mock(side_effect=_update_token) + + old_token = dict( + access_token='a', + token_type='bearer', + expires_at=100 + ) + + app = MockDispatch(default_token) + + with OAuth2Client( + 'foo', token=old_token, token_endpoint='https://i.b/token', + update_token=update_token, grant_type='client_credentials', + app=app, + ) as client: + client.post('https://i.b/user', json={'foo': 'bar'}) + assert update_token.called is True + + +def test_revoke_token(): + answer = {'status': 'ok'} + app = MockDispatch(answer) + + with OAuth2Client('a', app=app) as sess: + resp = sess.revoke_token('https://i.b/token', 'hi') + assert resp.json() == answer + + resp = sess.revoke_token( + 'https://i.b/token', 'hi', + token_type_hint='access_token' + ) + assert resp.json() == answer + + +def test_request_without_token(): + with OAuth2Client('a', app=MockDispatch()) as client: + with pytest.raises(OAuthError): + client.get('https://i.b/token') diff --git a/tests/py3/test_starlette_client/test_oauth_client.py b/tests/py3/test_starlette_client/test_oauth_client.py index 4722e4ed..68654fc8 100644 --- a/tests/py3/test_starlette_client/test_oauth_client.py +++ b/tests/py3/test_starlette_client/test_oauth_client.py @@ -2,7 +2,7 @@ from starlette.config import Config from starlette.requests import Request from authlib.integrations.starlette_client import OAuth -from tests.py3.utils import PathMapDispatch +from tests.py3.utils import AsyncPathMapDispatch from tests.client_base import get_bearer_token @@ -39,7 +39,7 @@ def test_register_with_overwrite(): @pytest.mark.asyncio async def test_oauth1_authorize(): oauth = OAuth() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/request-token': {'body': 'oauth_token=foo&oauth_verifier=baz'}, '/token': {'body': 'oauth_token=a&oauth_token_secret=b'}, }) @@ -74,7 +74,7 @@ async def test_oauth1_authorize(): @pytest.mark.asyncio async def test_oauth2_authorize(): oauth = OAuth() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/token': {'body': get_bearer_token()} }) client = oauth.register( @@ -113,7 +113,7 @@ async def test_oauth2_authorize(): @pytest.mark.asyncio async def test_oauth2_authorize_code_challenge(): - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/token': {'body': get_bearer_token()} }) oauth = OAuth() @@ -163,7 +163,7 @@ async def test_with_fetch_token_in_register(): async def fetch_token(request): return {'access_token': 'dev', 'token_type': 'bearer'} - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/user': {'body': {'sub': '123'}} }) oauth = OAuth() @@ -191,7 +191,7 @@ async def test_with_fetch_token_in_oauth(): async def fetch_token(name, request): return {'access_token': 'dev', 'token_type': 'bearer'} - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/user': {'body': {'sub': '123'}} }) oauth = OAuth(fetch_token=fetch_token) @@ -216,7 +216,7 @@ async def fetch_token(name, request): @pytest.mark.asyncio async def test_request_withhold_token(): oauth = OAuth() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/user': {'body': {'sub': '123'}} }) client = oauth.register( @@ -252,7 +252,7 @@ async def test_oauth2_authorize_with_metadata(): await client.create_authorization_url(req) - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/.well-known/openid-configuration': {'body': { 'authorization_endpoint': 'https://i.b/authorize' }} diff --git a/tests/py3/test_starlette_client/test_user_mixin.py b/tests/py3/test_starlette_client/test_user_mixin.py index 2e348015..f9e32b56 100644 --- a/tests/py3/test_starlette_client/test_user_mixin.py +++ b/tests/py3/test_starlette_client/test_user_mixin.py @@ -5,7 +5,7 @@ from authlib.jose.errors import InvalidClaimError from authlib.oidc.core.grants.util import generate_id_token from tests.util import read_file_path -from tests.py3.utils import PathMapDispatch +from tests.py3.utils import AsyncPathMapDispatch from tests.client_base import get_bearer_token @@ -15,7 +15,7 @@ async def run_fetch_userinfo(payload, compliance_fix=None): async def fetch_token(request): return get_bearer_token() - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/userinfo': {'body': payload} }) @@ -125,7 +125,7 @@ async def test_force_fetch_jwks_uri(): aud='dev', exp=3600, nonce='n', ) - app = PathMapDispatch({ + app = AsyncPathMapDispatch({ '/jwks': {'body': read_file_path('jwks_public.json')} }) diff --git a/tests/py3/utils.py b/tests/py3/utils.py index 807cea07..9416e100 100644 --- a/tests/py3/utils.py +++ b/tests/py3/utils.py @@ -1,9 +1,11 @@ import json -from starlette.requests import Request -from starlette.responses import Response +from starlette.requests import Request as ASGIRequest +from starlette.responses import Response as ASGIResponse +from werkzeug.wrappers import Request as WSGIRequest +from werkzeug.wrappers import Response as WSGIResponse -class MockDispatch: +class AsyncMockDispatch: def __init__(self, body=b'', status_code=200, headers=None, assert_func=None): if headers is None: @@ -22,12 +24,12 @@ def __init__(self, body=b'', status_code=200, headers=None, self.assert_func = assert_func async def __call__(self, scope, receive, send): - request = Request(scope, receive=receive) + request = ASGIRequest(scope, receive=receive) if self.assert_func: await self.assert_func(request) - response = Response( + response = ASGIResponse( status_code=self.status_code, content=self.body, headers=self.headers, @@ -35,12 +37,12 @@ async def __call__(self, scope, receive, send): await response(scope, receive, send) -class PathMapDispatch: +class AsyncPathMapDispatch: def __init__(self, path_maps): self.path_maps = path_maps async def __call__(self, scope, receive, send): - request = Request(scope, receive=receive) + request = ASGIRequest(scope, receive=receive) rv = self.path_maps[request.url.path] status_code = rv.get('status_code', 200) @@ -54,9 +56,66 @@ async def __call__(self, scope, receive, send): body = body.encode() headers['Content-Type'] = 'application/x-www-form-urlencoded' - response = Response( + response = ASGIResponse( status_code=status_code, content=body, headers=headers, ) await response(scope, receive, send) + +class MockDispatch: + def __init__(self, body=b'', status_code=200, headers=None, + assert_func=None): + if headers is None: + headers = {} + if isinstance(body, dict): + body = json.dumps(body).encode() + headers['Content-Type'] = 'application/json' + else: + if isinstance(body, str): + body = body.encode() + headers['Content-Type'] = 'application/x-www-form-urlencoded' + + self.body = body + self.status_code = status_code + self.headers = headers + self.assert_func = assert_func + + def __call__(self, environ, start_response): + request = WSGIRequest(environ) + + if self.assert_func: + self.assert_func(request) + + response = WSGIResponse( + status=self.status_code, + response=self.body, + headers=self.headers, + ) + return response(environ, start_response) + + +class PathMapDispatch: + def __init__(self, path_maps): + self.path_maps = path_maps + + def __call__(self, environ, start_response): + request = WSGIRequest(environ) + + rv = self.path_maps[request.url.path] + status_code = rv.get('status_code', 200) + body = rv.get('body', b'') + headers = rv.get('headers', {}) + if isinstance(body, dict): + body = json.dumps(body).encode() + headers['Content-Type'] = 'application/json' + else: + if isinstance(body, str): + body = body.encode() + headers['Content-Type'] = 'application/x-www-form-urlencoded' + response = WSGIResponse( + status=status_code, + response=body, + headers=headers, + ) + return response(environ, start_response) diff --git a/tox.ini b/tox.ini index 587d38bc..a8c5a354 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = py{27,36,37,38} - {py36,py37,py38}-async + {py36,py37,py38} {py27,py36,py37,py38}-flask {py36,py37,py38}-django coverage @@ -12,10 +12,11 @@ deps = py27: unittest2 flask: Flask flask: Flask-SQLAlchemy - async: httpx==0.14.1 - async: pytest-asyncio - async: starlette - async: itsdangerous + py3: httpx==0.14.3 + py3: pytest-asyncio + py3: starlette + py3: itsdangerous + py3: werkzeug django: Django django: pytest-django @@ -23,7 +24,7 @@ setenv = TESTPATH=tests/core RCFILE=setup.cfg py27: RCFILE=.py27conf - async: TESTPATH=tests/py3 + py3: TESTPATH=tests/py3 flask: TESTPATH=tests/flask django: TESTPATH=tests/django commands = From 340168b5c015562a15e472f5b41f50b9c2c109b1 Mon Sep 17 00:00:00 2001 From: Ber Zoidberg Date: Thu, 10 Sep 2020 01:03:00 -0700 Subject: [PATCH 6/6] remove unnecessary code duplication, use UNSET instead of UnsetType for cleaner code --- .../httpx_client/assertion_client.py | 22 ++------- .../httpx_client/oauth1_client.py | 29 +---------- .../httpx_client/oauth2_client.py | 49 ++----------------- 3 files changed, 10 insertions(+), 90 deletions(-) diff --git a/authlib/integrations/httpx_client/assertion_client.py b/authlib/integrations/httpx_client/assertion_client.py index 144f685f..62f81b79 100644 --- a/authlib/integrations/httpx_client/assertion_client.py +++ b/authlib/integrations/httpx_client/assertion_client.py @@ -1,5 +1,5 @@ from httpx import AsyncClient, Client -from httpx._config import UnsetType +from httpx._config import UNSET from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient from authlib.oauth2.rfc7523 import JWTBearerGrant from authlib.oauth2 import OAuth2Error @@ -32,7 +32,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No async def request(self, method, url, withhold_token=False, auth=None, **kwargs): """Send request with auto refresh token feature.""" - if not withhold_token and isinstance(auth, UnsetType): + if not withhold_token and auth is UNSET: if not self.token or self.token.is_expired(): await self.refresh_token() @@ -53,6 +53,7 @@ async def _refresh_token(self, data): self.token = token return self.token + class AssertionClient(_AssertionClient, Client): token_auth_class = OAuth2Auth JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE @@ -68,7 +69,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No Client.__init__(self, **client_kwargs) _AssertionClient.__init__( - self, session=None, + self, session=self, token_endpoint=token_endpoint, issuer=issuer, subject=subject, audience=audience, grant_type=grant_type, claims=claims, token_placement=token_placement, scope=scope, **kwargs @@ -76,23 +77,10 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No def request(self, method, url, withhold_token=False, auth=None, **kwargs): """Send request with auto refresh token feature.""" - if not withhold_token and isinstance(auth, UnsetType): + if not withhold_token and auth is UNSET: if not self.token or self.token.is_expired(): self.refresh_token() auth = self.token_auth return super(AssertionClient, self).request( method, url, auth=auth, **kwargs) - - def _refresh_token(self, data): - resp = self.request( - 'POST', self.token_endpoint, data=data, withhold_token=True) - - token = resp.json() - if 'error' in token: - raise OAuth2Error( - error=token['error'], - description=token.get('error_description') - ) - self.token = token - return self.token diff --git a/authlib/integrations/httpx_client/oauth1_client.py b/authlib/integrations/httpx_client/oauth1_client.py index 13483ec4..6d755bb1 100644 --- a/authlib/integrations/httpx_client/oauth1_client.py +++ b/authlib/integrations/httpx_client/oauth1_client.py @@ -88,40 +88,13 @@ def __init__(self, client_id, client_secret=None, Client.__init__(self, **_client_kwargs) _OAuth1Client.__init__( - self, None, + self, self, client_id=client_id, client_secret=client_secret, token=token, token_secret=token_secret, redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier, signature_method=signature_method, signature_type=signature_type, force_include_body=force_include_body, **kwargs) - def fetch_access_token(self, url, verifier=None, **kwargs): - """Method for fetching an access token from the token endpoint. - - This is the final step in the OAuth 1 workflow. An access token is - obtained using all previously obtained credentials, including the - verifier from the authorization step. - - :param url: Access Token endpoint. - :param verifier: A verifier string to prove authorization was granted. - :param kwargs: Extra parameters to include for fetching access token. - :return: A token dict. - """ - if verifier: - self.auth.verifier = verifier - if not self.auth.verifier: - self.handle_error('missing_verifier', 'Missing "verifier" value') - token = self._fetch_token(url, **kwargs) - self.auth.verifier = None - return token - - def _fetch_token(self, url, **kwargs): - resp = self.post(url, **kwargs) - text = resp.read() - token = self.parse_response_token(resp.status_code, to_unicode(text)) - self.token = token - return token - @staticmethod def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) diff --git a/authlib/integrations/httpx_client/oauth2_client.py b/authlib/integrations/httpx_client/oauth2_client.py index 6d1411a2..387560b9 100644 --- a/authlib/integrations/httpx_client/oauth2_client.py +++ b/authlib/integrations/httpx_client/oauth2_client.py @@ -1,7 +1,7 @@ import asyncio import typing from httpx import AsyncClient, Auth, Client, Request, Response -from httpx._config import UnsetType +from httpx._config import UNSET from authlib.common.urls import url_decode from authlib.oauth2.client import OAuth2Client as _OAuth2Client from authlib.oauth2.auth import ClientAuth, TokenAuth @@ -79,7 +79,7 @@ def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) async def request(self, method, url, withhold_token=False, auth=None, **kwargs): - if not withhold_token and isinstance(auth, UnsetType): + if not withhold_token and auth is UNSET: if not self.token: raise MissingTokenError() @@ -170,7 +170,7 @@ def __init__(self, client_id=None, client_secret=None, Client.__init__(self, **client_kwargs) _OAuth2Client.__init__( - self, session=None, + self, session=self, client_id=client_id, client_secret=client_secret, token_endpoint_auth_method=token_endpoint_auth_method, revocation_endpoint_auth_method=revocation_endpoint_auth_method, @@ -184,7 +184,7 @@ def handle_error(error_type, error_description): raise OAuthError(error_type, error_description) def request(self, method, url, withhold_token=False, auth=None, **kwargs): - if not withhold_token and isinstance(auth, UnsetType): + if not withhold_token and auth is UNSET: if not self.token: raise MissingTokenError() @@ -208,44 +208,3 @@ def ensure_active_token(self): self.update_token(token, access_token=access_token) else: raise InvalidTokenError() - - def _fetch_token(self, url, body='', headers=None, auth=None, - method='POST', **kwargs): - if method.upper() == 'POST': - resp = self.post( - url, data=dict(url_decode(body)), headers=headers, - auth=auth, **kwargs) - else: - if '?' in url: - url = '&'.join([url, body]) - else: - url = '?'.join([url, body]) - resp = self.get(url, headers=headers, auth=auth, **kwargs) - - for hook in self.compliance_hook['access_token_response']: - resp = hook(resp) - - return self.parse_response_token(resp.json()) - - def _refresh_token(self, url, refresh_token=None, body='', - headers=None, auth=None, **kwargs): - resp = self.post( - url, data=dict(url_decode(body)), headers=headers, - auth=auth, **kwargs) - - for hook in self.compliance_hook['refresh_token_response']: - resp = hook(resp) - - token = self.parse_response_token(resp.json()) - if 'refresh_token' not in token: - self.token['refresh_token'] = refresh_token - - if self.update_token: - self.update_token(self.token, refresh_token=refresh_token) - - return self.token - - def _http_post(self, url, body=None, auth=None, headers=None, **kwargs): - return self.post( - url, data=dict(url_decode(body)), - headers=headers, auth=auth, **kwargs)