Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#268 #269 Improve HTTPX Support #270

Merged
merged 8 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ jobs:
- version: 2.7
toxenv: py27,py27-flask
- version: 3.6
toxenv: py36,flask,django,async
toxenv: py36,flask,django,py3
- version: 3.7
toxenv: py37,flask,django,async
toxenv: py37,flask,django,py3
- version: 3.8
toxenv: py38,flask,django,async
toxenv: py38,flask,django,py3

steps:
- uses: actions/checkout@v2
Expand Down
6 changes: 3 additions & 3 deletions authlib/integrations/httpx_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
SIGNATURE_TYPE_QUERY,
SIGNATURE_TYPE_BODY,
)
from .oauth1_client import OAuth1Auth, AsyncOAuth1Client
from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client
from .oauth2_client import (
OAuth2Auth, OAuth2ClientAuth,
OAuth2Auth, OAuth2Client, OAuth2ClientAuth,
AsyncOAuth2Client,
)
from .assertion_client import AsyncAssertionClient
from .assertion_client import AssertionClient, AsyncAssertionClient
from ..base_client import OAuthError


Expand Down
37 changes: 35 additions & 2 deletions authlib/integrations/httpx_client/assertion_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from httpx import AsyncClient
from httpx import AsyncClient, Client
from httpx._config import UNSET
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from authlib.oauth2 import OAuth2Error
Expand Down Expand Up @@ -31,7 +32,7 @@ def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=No

async def request(self, method, url, withhold_token=False, auth=None, **kwargs):
"""Send request with auto refresh token feature."""
if not withhold_token and auth is None:
if not withhold_token and auth is UNSET:
if not self.token or self.token.is_expired():
await self.refresh_token()

Expand All @@ -51,3 +52,35 @@ async def _refresh_token(self, data):
)
self.token = token
return self.token


class AssertionClient(_AssertionClient, Client):
dustydecapod marked this conversation as resolved.
Show resolved Hide resolved
token_auth_class = OAuth2Auth
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE

def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):

client_kwargs = extract_client_kwargs(kwargs)
Client.__init__(self, **client_kwargs)

_AssertionClient.__init__(
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
)

def request(self, method, url, withhold_token=False, auth=None, **kwargs):
"""Send request with auto refresh token feature."""
if not withhold_token and auth is UNSET:
if not self.token or self.token.is_expired():
self.refresh_token()

auth = self.token_auth
return super(AssertionClient, self).request(
method, url, auth=auth, **kwargs)
28 changes: 27 additions & 1 deletion authlib/integrations/httpx_client/oauth1_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from httpx import AsyncClient, Auth, Request, Response
from httpx import AsyncClient, Auth, Client, Request, Response
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_TYPE_HEADER,
Expand All @@ -18,6 +18,7 @@ class OAuth1Auth(Auth, ClientAuth):
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content)
headers['Content-Length'] = str(len(body))
yield Request(method=request.method, url=url, headers=headers, data=body)


Expand Down Expand Up @@ -72,3 +73,28 @@ async def _fetch_token(self, url, **kwargs):
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

class OAuth1Client(_OAuth1Client, Client):
auth_class = OAuth1Auth

def __init__(self, client_id, client_secret=None,
token=None, token_secret=None,
redirect_uri=None, rsa_key=None, verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False, **kwargs):

_client_kwargs = extract_client_kwargs(kwargs)
Client.__init__(self, **_client_kwargs)

_OAuth1Client.__init__(
self, self,
client_id=client_id, client_secret=client_secret,
token=token, token_secret=token_secret,
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
signature_method=signature_method, signature_type=signature_type,
force_include_body=force_include_body, **kwargs)

@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
62 changes: 60 additions & 2 deletions authlib/integrations/httpx_client/oauth2_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import typing
from httpx import AsyncClient, Auth, Request, Response
from httpx import AsyncClient, Auth, Client, Request, Response
from httpx._config import UNSET
from authlib.common.urls import url_decode
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
from authlib.oauth2.auth import ClientAuth, TokenAuth
Expand Down Expand Up @@ -78,7 +79,7 @@ def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

async def request(self, method, url, withhold_token=False, auth=None, **kwargs):
if not withhold_token and auth is None:
if not withhold_token and auth is UNSET:
if not self.token:
raise MissingTokenError()

Expand Down Expand Up @@ -150,3 +151,60 @@ def _http_post(self, url, body=None, auth=None, headers=None, **kwargs):
return self.post(
url, data=dict(url_decode(body)),
headers=headers, auth=auth, **kwargs)

class OAuth2Client(_OAuth2Client, Client):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS

client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth

def __init__(self, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, **kwargs):

# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
Client.__init__(self, **client_kwargs)

_OAuth2Client.__init__(
self, session=self,
client_id=client_id, client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, **kwargs
)

@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

def request(self, method, url, withhold_token=False, auth=None, **kwargs):
if not withhold_token and auth is UNSET:
if not self.token:
raise MissingTokenError()

if self.token.is_expired():
self.ensure_active_token()

auth = self.token_auth

return super(OAuth2Client, self).request(
method, url, auth=auth, **kwargs)

def ensure_active_token(self):
refresh_token = self.token.get('refresh_token')
url = self.metadata.get('token_endpoint')
if refresh_token and url:
self.refresh_token(url, refresh_token=refresh_token)
elif self.metadata.get('grant_type') == 'client_credentials':
access_token = self.token['access_token']
token = self.fetch_token(url, grant_type='client_credentials')
if self.update_token:
self.update_token(token, access_token=access_token)
else:
raise InvalidTokenError()
3 changes: 3 additions & 0 deletions docs/client/httpx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ OAuth for HTTPX
HTTPX is a next-generation HTTP client for Python. Authlib enables OAuth 1.0
and OAuth 2.0 for HTTPX with its async versions:

* :class:`OAuth1Client`
* :class:`OAuth2Client`
* :class:`AssertionClient`
* :class:`AsyncOAuth1Client`
* :class:`AsyncOAuth2Client`
* :class:`AsyncAssertionClient`
Expand Down
65 changes: 65 additions & 0 deletions tests/py3/test_httpx_client/test_assertion_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import time
import pytest
from authlib.integrations.httpx_client import AssertionClient
from tests.py3.utils import MockDispatch


default_token = {
'token_type': 'Bearer',
'access_token': 'a',
'refresh_token': 'b',
'expires_in': '3600',
'expires_at': int(time.time()) + 3600,
}


@pytest.mark.asyncio
def test_refresh_token():
def verifier(request):
content = request.form
if str(request.url) == 'https://i.b/token':
assert 'assertion' in content

with AssertionClient(
'https://i.b/token',
grant_type=AssertionClient.JWT_BEARER_GRANT_TYPE,
issuer='foo',
subject='foo',
audience='foo',
alg='HS256',
key='secret',
app=MockDispatch(default_token, assert_func=verifier)
) as client:
client.get('https://i.b')

# trigger more case
now = int(time.time())
with AssertionClient(
'https://i.b/token',
issuer='foo',
subject=None,
audience='foo',
issued_at=now,
expires_at=now + 3600,
header={'alg': 'HS256'},
key='secret',
scope='email',
claims={'test_mode': 'true'},
app=MockDispatch(default_token, assert_func=verifier)
) as client:
client.get('https://i.b')
client.get('https://i.b')


@pytest.mark.asyncio
def test_without_alg():
with AssertionClient(
'https://i.b/token',
issuer='foo',
subject='foo',
audience='foo',
key='secret',
app=MockDispatch(default_token)
) as client:
with pytest.raises(ValueError):
client.get('https://i.b')
7 changes: 4 additions & 3 deletions tests/py3/test_httpx_client/test_async_assertion_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import pytest
from authlib.integrations.httpx_client import AsyncAssertionClient
from tests.py3.utils import MockDispatch
from tests.py3.utils import AsyncMockDispatch


default_token = {
Expand All @@ -28,7 +28,7 @@ async def verifier(request):
audience='foo',
alg='HS256',
key='secret',
app=MockDispatch(default_token, assert_func=verifier)
app=AsyncMockDispatch(default_token, assert_func=verifier)
) as client:
await client.get('https://i.b')

Expand All @@ -45,7 +45,7 @@ async def verifier(request):
key='secret',
scope='email',
claims={'test_mode': 'true'},
app=MockDispatch(default_token, assert_func=verifier)
app=AsyncMockDispatch(default_token, assert_func=verifier)
) as client:
await client.get('https://i.b')
await client.get('https://i.b')
Expand All @@ -59,6 +59,7 @@ async def test_without_alg():
subject='foo',
audience='foo',
key='secret',
app=AsyncMockDispatch()
) as client:
with pytest.raises(ValueError):
await client.get('https://i.b')
16 changes: 8 additions & 8 deletions tests/py3/test_httpx_client/test_async_oauth1_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
SIGNATURE_TYPE_BODY,
SIGNATURE_TYPE_QUERY,
)
from tests.py3.utils import MockDispatch
from tests.py3.utils import AsyncMockDispatch

oauth_url = 'https://example.com/oauth'

Expand All @@ -19,7 +19,7 @@ async def assert_func(request):
assert 'oauth_consumer_key="id"' in auth_header
assert 'oauth_signature=' in auth_header

app = MockDispatch(request_token, assert_func=assert_func)
app = AsyncMockDispatch(request_token, assert_func=assert_func)
async with AsyncOAuth1Client('id', 'secret', app=app) as client:
response = await client.fetch_request_token(oauth_url)

Expand All @@ -38,7 +38,7 @@ async def assert_func(request):
assert b'oauth_consumer_key=id' in content
assert b'&oauth_signature=' in content

mock_response = MockDispatch(request_token, assert_func=assert_func)
mock_response = AsyncMockDispatch(request_token, assert_func=assert_func)

async with AsyncOAuth1Client(
'id', 'secret', signature_type=SIGNATURE_TYPE_BODY,
Expand All @@ -61,7 +61,7 @@ async def assert_func(request):
assert 'oauth_consumer_key=id' in url
assert '&oauth_signature=' in url

mock_response = MockDispatch(request_token, assert_func=assert_func)
mock_response = AsyncMockDispatch(request_token, assert_func=assert_func)

async with AsyncOAuth1Client(
'id', 'secret', signature_type=SIGNATURE_TYPE_QUERY,
Expand All @@ -83,7 +83,7 @@ async def assert_func(request):
assert 'oauth_consumer_key="id"' in auth_header
assert 'oauth_signature=' in auth_header

mock_response = MockDispatch(request_token, assert_func=assert_func)
mock_response = AsyncMockDispatch(request_token, assert_func=assert_func)
async with AsyncOAuth1Client(
'id', 'secret', token='foo', token_secret='bar',
app=mock_response,
Expand All @@ -98,7 +98,7 @@ async def assert_func(request):

@pytest.mark.asyncio
async def test_get_via_header():
mock_response = MockDispatch(b'hello')
mock_response = AsyncMockDispatch(b'hello')
async with AsyncOAuth1Client(
'id', 'secret', token='foo', token_secret='bar',
app=mock_response,
Expand All @@ -121,7 +121,7 @@ async def assert_func(request):
assert b'oauth_consumer_key=id' in content
assert b'oauth_signature=' in content

mock_response = MockDispatch(b'hello', assert_func=assert_func)
mock_response = AsyncMockDispatch(b'hello', assert_func=assert_func)
async with AsyncOAuth1Client(
'id', 'secret', token='foo', token_secret='bar',
signature_type=SIGNATURE_TYPE_BODY,
Expand All @@ -138,7 +138,7 @@ async def assert_func(request):

@pytest.mark.asyncio
async def test_get_via_query():
mock_response = MockDispatch(b'hello')
mock_response = AsyncMockDispatch(b'hello')
async with AsyncOAuth1Client(
'id', 'secret', token='foo', token_secret='bar',
signature_type=SIGNATURE_TYPE_QUERY,
Expand Down
Loading