From 188699eaeb1c1c7d1e4223ccd64641dcfc7037f9 Mon Sep 17 00:00:00 2001 From: Dany Gielow Date: Tue, 2 Mar 2021 16:08:33 +0100 Subject: [PATCH] allow client_id and client_secret in POST, check for credentials in POST and auth header (#24) --- src/aioauth/grant_type.py | 13 +++++- src/aioauth/requests.py | 2 + tests/test_flow.py | 83 +++++++++++++++++++++++++++++++++++++++ tests/utils.py | 30 ++++++++++++++ 4 files changed, 126 insertions(+), 2 deletions(-) diff --git a/src/aioauth/grant_type.py b/src/aioauth/grant_type.py index 31e4a52..2aa9144 100644 --- a/src/aioauth/grant_type.py +++ b/src/aioauth/grant_type.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple from .base.request_validator import BaseRequestValidator from .errors import ( @@ -41,7 +41,7 @@ async def create_token_response(self, request: Request) -> TokenResponse: async def validate_request(self, request: Request) -> Client: await super().validate_request(request) - client_id, client_secret = decode_auth_headers(request) + client_id, client_secret = self.get_client_credentials(request) client = await self.db.get_client( request, client_id=client_id, client_secret=client_secret @@ -68,6 +68,15 @@ async def validate_request(self, request: Request) -> Client: return client + def get_client_credentials(self, request: Request) -> Tuple[str, str]: + client_id = request.post.client_id + client_secret = request.post.client_secret + + if client_id is None or client_secret is None: + client_id, client_secret = decode_auth_headers(request) + + return client_id, client_secret + class AuthorizationCodeGrantType(GrantTypeBase): grant_type: GrantType = GrantType.TYPE_AUTHORIZATION_CODE diff --git a/src/aioauth/requests.py b/src/aioauth/requests.py index 6da85bc..b55823a 100644 --- a/src/aioauth/requests.py +++ b/src/aioauth/requests.py @@ -17,6 +17,8 @@ class Query(NamedTuple): class Post(NamedTuple): grant_type: Optional[GrantType] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None redirect_uri: Optional[str] = None scope: str = "" username: Optional[str] = None diff --git a/tests/test_flow.py b/tests/test_flow.py index dfcf9ce..7f3d5d2 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -287,3 +287,86 @@ async def test_authorization_code_flow(server: AuthorizationServer, defaults: De response = await server.create_token_response(request) assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_authorization_code_flow_credentials_in_post( + server: AuthorizationServer, defaults: Defaults +): + client_id = defaults.client_id + client_secret = defaults.client_secret + request_url = "https://localhost" + user = "username" + + query = Query( + client_id=defaults.client_id, + response_type=ResponseType.TYPE_CODE, + redirect_uri=defaults.redirect_uri, + scope=defaults.scope, + state=generate_token(10), + ) + + request = Request( + url=request_url, query=query, method=RequestMethod.GET, user=user, + ) + + response = await server.create_authorization_response(request) + assert response.status_code == HTTPStatus.FOUND + + location = response.headers["location"] + location = urlparse(location) + query = dict(parse_qsl(location.query)) + code = query["code"] + + post = Post( + grant_type=GrantType.TYPE_AUTHORIZATION_CODE, + client_id=client_id, + client_secret=client_secret, + redirect_uri=defaults.redirect_uri, + code=code, + ) + + request = Request(url=request_url, post=post, method=RequestMethod.POST,) + + response = await server.create_token_response(request) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_client_credentials_flow_post_data( + server: AuthorizationServer, defaults: Defaults +): + request_url = "https://localhost" + + post = Post( + grant_type=GrantType.TYPE_CLIENT_CREDENTIALS, + client_id=defaults.client_id, + client_secret=defaults.client_secret, + scope=defaults.scope, + ) + + request = Request(url=request_url, post=post, method=RequestMethod.POST) + + response = await server.create_token_response(request) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_client_credentials_flow_auth_header( + server: AuthorizationServer, defaults: Defaults +): + request_url = "https://localhost" + + post = Post(grant_type=GrantType.TYPE_CLIENT_CREDENTIALS, scope=defaults.scope,) + + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers( + client_id=defaults.client_id, client_secret=defaults.client_secret + ), + ) + + response = await server.create_token_response(request) + assert response.status_code == HTTPStatus.OK diff --git a/tests/utils.py b/tests/utils.py index aa383b0..17951d0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -79,6 +79,20 @@ status_code=HTTPStatus.BAD_REQUEST, headers=default_headers, ), + "client_id": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, description="Invalid credentials given.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "client_secret": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, description="Invalid credentials given.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), "username": Response( content=ErrorResponse( error=ErrorType.INVALID_GRANT, description="Invalid credentials given.", @@ -167,6 +181,22 @@ status_code=HTTPStatus.BAD_REQUEST, headers=default_headers, ), + "client_id": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, + description="Invalid client_id parameter value.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "client_secret": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, + description="Invalid client_secret parameter value.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), "username": Response( content=ErrorResponse( error=ErrorType.INVALID_GRANT, description="Invalid credentials given.",