From 0d790dffc020027057002724599f507ae8096274 Mon Sep 17 00:00:00 2001 From: Ali Aliyev Date: Wed, 10 Feb 2021 21:30:07 +0400 Subject: [PATCH] validate_request for RefreshTokenGrantType was rewritten in general form (#20) --- src/aioauth/__version__.py | 2 +- src/aioauth/grant_type.py | 28 ++++++++++++++-------------- tests/test_grant_type.py | 6 +++--- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/aioauth/__version__.py b/src/aioauth/__version__.py index 6f09a32..7ed9e76 100644 --- a/src/aioauth/__version__.py +++ b/src/aioauth/__version__.py @@ -1,7 +1,7 @@ __title__ = "aioauth" __description__ = "Asynchronous OAuth 2.0 framework for Python 3." __url__ = "https://github.com/aliev/aioauth" -__version__ = "0.1.5" +__version__ = "0.1.6" __author__ = "Ali Aliyev" __author_email__ = "ali@aliev.me" __license__ = "The MIT License (MIT)" diff --git a/src/aioauth/grant_type.py b/src/aioauth/grant_type.py index 46311a2..31e4a52 100644 --- a/src/aioauth/grant_type.py +++ b/src/aioauth/grant_type.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional from .base.request_validator import BaseRequestValidator from .errors import ( @@ -9,7 +9,7 @@ UnauthorizedClientError, UnsupportedGrantTypeError, ) -from .models import Client, Token +from .models import Client from .requests import Request from .responses import TokenResponse from .types import GrantType, RequestMethod @@ -148,7 +148,16 @@ class RefreshTokenGrantType(GrantTypeBase): async def create_token_response(self, request: Request) -> TokenResponse: """ Validate token request and create token response. """ - client, old_token = await self.validate_request(request) + client = await self.validate_request(request) + + old_token = await self.db.get_token( + request=request, + client_id=client.client_id, + refresh_token=request.post.refresh_token, + ) + + if not old_token or old_token.revoked or old_token.refresh_token_expired: + raise InvalidGrantError(request=request) # Revoke old token await self.db.revoke_token( @@ -178,7 +187,7 @@ async def create_token_response(self, request: Request) -> TokenResponse: token_type=token.token_type, ) - async def validate_request(self, request: Request) -> Tuple[Client, Token]: + async def validate_request(self, request: Request) -> Client: client = await super().validate_request(request) if not request.post.refresh_token: @@ -186,16 +195,7 @@ async def validate_request(self, request: Request) -> Tuple[Client, Token]: request=request, description="Missing refresh token parameter." ) - token = await self.db.get_token( - request=request, - client_id=client.client_id, - refresh_token=request.post.refresh_token, - ) - - if not token or token.revoked or token.refresh_token_expired: - raise InvalidGrantError(request=request) - - return client, token + return client class ClientCredentialsGrantType(GrantTypeBase): diff --git a/tests/test_grant_type.py b/tests/test_grant_type.py index 2160ff1..6764c2e 100644 --- a/tests/test_grant_type.py +++ b/tests/test_grant_type.py @@ -83,7 +83,7 @@ async def test_refresh_token_grant_type( grant_type = RefreshTokenGrantType(db) - client, old_token = await grant_type.validate_request(request) + client = await grant_type.validate_request(request) assert client.client_id == client_id assert client.client_secret == client_secret @@ -92,10 +92,10 @@ async def test_refresh_token_grant_type( # Check that previous token was revoken token_in_db = await db.get_token( - request, client_id, old_token.access_token, old_token.refresh_token + request, client_id, defaults.access_token, defaults.refresh_token ) assert token_in_db.revoked assert token_response.scope == "read" with pytest.raises(InvalidGrantError): - await grant_type.validate_request(request) + token_response = await grant_type.create_token_response(request)