diff --git a/app/clients/_authenticationclient.py b/app/clients/_authenticationclient.py index f18ec97..e369d10 100644 --- a/app/clients/_authenticationclient.py +++ b/app/clients/_authenticationclient.py @@ -86,7 +86,7 @@ async def check_api_key(self, key: str) -> Optional[User]: Returns: Optional[User]: User object if found, None otherwise """ - user_id = self._api_key_to_user_id(input=key) + user_id = self.api_key_to_user_id(input=key) ttl = -2 # fetch from Redis @@ -121,15 +121,15 @@ async def check_api_key(self, key: str) -> Optional[User]: return user @staticmethod - def _api_key_to_user_id(input: str) -> str: + def api_key_to_user_id(input: str) -> str: """ - Generate a 16 length unique code from an input string using salted SHA-256 hashing. + Generate a 16 length unique user id from an input string using SHA-256 hashing. Args: - input_string (str): The input string to generate the code from. + input_string (str): The input string to generate the user id from. Returns: - tuple[str, bytes]: A tuple containing the generated code and the salt used. + str: The generated user id. """ hash = hashlib.sha256((input).encode()).digest() hash = base64.urlsafe_b64encode(hash).decode() diff --git a/app/helpers/_metricsmiddleware.py b/app/helpers/_metricsmiddleware.py index d4920b4..68ecdb0 100644 --- a/app/helpers/_metricsmiddleware.py +++ b/app/helpers/_metricsmiddleware.py @@ -48,7 +48,7 @@ async def dispatch(self, request: Request, call_next) -> Response: ) if authorization and authorization.startswith("Bearer "): - user_id = AuthenticationClient._api_key_to_user_id(input=authorization.split(sep=" ")[1]) + user_id = AuthenticationClient.api_key_to_user_id(input=authorization.split(sep=" ")[1]) self.http_requests_by_user.labels(user=user_id, endpoint=endpoint[3:], model=model).inc() response = await call_next(request) diff --git a/app/tests/conftest.py b/app/tests/conftest.py index d9c4fce..7e3edef 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -45,8 +45,8 @@ def session_admin(args): @pytest.fixture(scope="session") def cleanup_collections(args, session_user, session_admin): - USER = AuthenticationClient._api_key_to_user_id(input=args["api_key_user"]) - ADMIN = AuthenticationClient._api_key_to_user_id(input=args["api_key_admin"]) + USER = AuthenticationClient.api_key_to_user_id(input=args["api_key_user"]) + ADMIN = AuthenticationClient.api_key_to_user_id(input=args["api_key_admin"]) yield USER, ADMIN diff --git a/app/tests/test_collections.py b/app/tests/test_collections.py index 4a2f480..3604942 100644 --- a/app/tests/test_collections.py +++ b/app/tests/test_collections.py @@ -14,8 +14,8 @@ @pytest.fixture(scope="module") def setup(args, session_user): - USER = AuthenticationClient._api_key_to_user_id(input=args["api_key_user"]) - ADMIN = AuthenticationClient._api_key_to_user_id(input=args["api_key_admin"]) + USER = AuthenticationClient.api_key_to_user_id(input=args["api_key_user"]) + ADMIN = AuthenticationClient.api_key_to_user_id(input=args["api_key_admin"]) logging.info(f"test user ID: {USER}") logging.info(f"test admin ID: {ADMIN}") diff --git a/app/utils/security.py b/app/utils/security.py index a3a94be..34a34cf 100644 --- a/app/utils/security.py +++ b/app/utils/security.py @@ -3,12 +3,11 @@ from fastapi import Depends, Request from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer -from app.schemas.security import User -from app.utils.settings import settings -from app.utils.exceptions import InvalidAPIKeyException, InvalidAuthenticationSchemeException, InsufficientRightsException +from app.clients import AuthenticationClient +from app.schemas.security import Role, User +from app.utils.exceptions import InsufficientRightsException, InvalidAPIKeyException, InvalidAuthenticationSchemeException from app.utils.lifespan import clients -from app.schemas.security import Role - +from app.utils.settings import settings if settings.clients.auth: @@ -57,7 +56,7 @@ def check_api_key(api_key: Optional[str] = None) -> User: return User(id="no-auth", role=Role.ADMIN) -async def check_rate_limit(request: Request) -> Optional[str]: +def check_rate_limit(request: Request) -> Optional[str]: """ Check the rate limit for the user. @@ -71,9 +70,6 @@ async def check_rate_limit(request: Request) -> Optional[str]: authorization = request.headers.get("Authorization") scheme, credentials = authorization.split(" ") if authorization else ("", "") api_key = HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) - user = await check_api_key(api_key=api_key) + user_id = AuthenticationClient.api_key_to_user_id(input=api_key.credentials) - if user.role.value > Role.USER.value: - return None - else: - return user.id + return user_id