Skip to content

Commit

Permalink
fix: auth
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguillaumegouv committed Dec 20, 2024
1 parent fac466e commit 5bd3ff1
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 21 deletions.
10 changes: 5 additions & 5 deletions app/clients/_authenticationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def check_api_key(self, key: str) -> Optional[User]:
Returns:
Optional[User]: User object if found, None otherwise
"""
user_id = self._api_key_to_user_id(input=key)
user_id = self.api_key_to_user_id(input=key)
ttl = -2

# fetch from Redis
Expand Down Expand Up @@ -121,15 +121,15 @@ async def check_api_key(self, key: str) -> Optional[User]:
return user

@staticmethod
def _api_key_to_user_id(input: str) -> str:
def api_key_to_user_id(input: str) -> str:
"""
Generate a 16 length unique code from an input string using salted SHA-256 hashing.
Generate a 16 length unique user id from an input string using SHA-256 hashing.
Args:
input_string (str): The input string to generate the code from.
input_string (str): The input string to generate the user id from.
Returns:
tuple[str, bytes]: A tuple containing the generated code and the salt used.
str: The generated user id.
"""
hash = hashlib.sha256((input).encode()).digest()
hash = base64.urlsafe_b64encode(hash).decode()
Expand Down
2 changes: 1 addition & 1 deletion app/helpers/_metricsmiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def dispatch(self, request: Request, call_next) -> Response:
)

if authorization and authorization.startswith("Bearer "):
user_id = AuthenticationClient._api_key_to_user_id(input=authorization.split(sep=" ")[1])
user_id = AuthenticationClient.api_key_to_user_id(input=authorization.split(sep=" ")[1])
self.http_requests_by_user.labels(user=user_id, endpoint=endpoint[3:], model=model).inc()

response = await call_next(request)
Expand Down
4 changes: 2 additions & 2 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def session_admin(args):

@pytest.fixture(scope="session")
def cleanup_collections(args, session_user, session_admin):
USER = AuthenticationClient._api_key_to_user_id(input=args["api_key_user"])
ADMIN = AuthenticationClient._api_key_to_user_id(input=args["api_key_admin"])
USER = AuthenticationClient.api_key_to_user_id(input=args["api_key_user"])
ADMIN = AuthenticationClient.api_key_to_user_id(input=args["api_key_admin"])

yield USER, ADMIN

Expand Down
4 changes: 2 additions & 2 deletions app/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

@pytest.fixture(scope="module")
def setup(args, session_user):
USER = AuthenticationClient._api_key_to_user_id(input=args["api_key_user"])
ADMIN = AuthenticationClient._api_key_to_user_id(input=args["api_key_admin"])
USER = AuthenticationClient.api_key_to_user_id(input=args["api_key_user"])
ADMIN = AuthenticationClient.api_key_to_user_id(input=args["api_key_admin"])
logging.info(f"test user ID: {USER}")
logging.info(f"test admin ID: {ADMIN}")

Expand Down
18 changes: 7 additions & 11 deletions app/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from fastapi import Depends, Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from app.schemas.security import User
from app.utils.settings import settings
from app.utils.exceptions import InvalidAPIKeyException, InvalidAuthenticationSchemeException, InsufficientRightsException
from app.clients import AuthenticationClient
from app.schemas.security import Role, User
from app.utils.exceptions import InsufficientRightsException, InvalidAPIKeyException, InvalidAuthenticationSchemeException
from app.utils.lifespan import clients
from app.schemas.security import Role

from app.utils.settings import settings

if settings.clients.auth:

Expand Down Expand Up @@ -57,7 +56,7 @@ def check_api_key(api_key: Optional[str] = None) -> User:
return User(id="no-auth", role=Role.ADMIN)


async def check_rate_limit(request: Request) -> Optional[str]:
def check_rate_limit(request: Request) -> Optional[str]:
"""
Check the rate limit for the user.
Expand All @@ -71,9 +70,6 @@ async def check_rate_limit(request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, credentials = authorization.split(" ") if authorization else ("", "")
api_key = HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
user = await check_api_key(api_key=api_key)
user_id = AuthenticationClient.api_key_to_user_id(input=api_key.credentials)

if user.role.value > Role.USER.value:
return None
else:
return user.id
return user_id

0 comments on commit 5bd3ff1

Please sign in to comment.