Skip to content

Commit

Permalink
Merge pull request DIRACGrid#299 from natthan-pigoux/feat/lock_refres…
Browse files Browse the repository at this point in the history
…h_token

feat: lock file while read, write and refresh token
  • Loading branch information
chaen authored Dec 6, 2024
2 parents 5421e3f + 8e5fea6 commit ad49f49
Show file tree
Hide file tree
Showing 6 changed files with 576 additions and 114 deletions.
6 changes: 3 additions & 3 deletions diracx-cli/src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from diracx.client.models import DeviceFlowErrorResponse
from diracx.core.extensions import select_from_extension
from diracx.core.preferences import get_diracx_preferences
from diracx.core.utils import write_credentials
from diracx.core.utils import read_credentials, write_credentials

from .utils import AsyncTyper

Expand Down Expand Up @@ -116,11 +116,11 @@ async def logout():
async with DiracClient() as api:
credentials_path = get_diracx_preferences().credentials_path
if credentials_path.exists():
credentials = json.loads(credentials_path.read_text())
credentials = read_credentials(credentials_path)

# Revoke refresh token
try:
await api.auth.revoke_refresh_token(credentials["refresh_token"])
await api.auth.revoke_refresh_token(credentials.refresh_token)
except Exception as e:
print(f"Error revoking the refresh token {e!r}")
pass
Expand Down
53 changes: 19 additions & 34 deletions diracx-client/src/diracx/client/patches/aio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from __future__ import annotations

import abc
import json
from importlib.metadata import PackageNotFoundError, distribution
from types import TracebackType
from pathlib import Path
from typing import Any, List, Optional, Self

from typing import Any, List, Optional, cast

from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest
Expand All @@ -24,8 +25,6 @@
from ..utils import (
get_openid_configuration,
get_token,
refresh_token,
is_refresh_token_valid,
)

__all__: List[str] = [
Expand Down Expand Up @@ -56,20 +55,12 @@ async def get_token(
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
"""
return refresh_token(
return get_token(
self.location,
kwargs.get("token"),
self.token_endpoint,
self.client_id,
kwargs["refresh_token"],
verify=self.verify,
self.verify,
)

async def close(self) -> None:
Expand Down Expand Up @@ -97,6 +88,9 @@ class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
* It does not ensure that an access token is available.
"""

# Make mypy happy
_token: Optional[AccessToken] = None

def __init__(
self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any
) -> None:
Expand All @@ -110,28 +104,19 @@ async def on_request(
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._token: AccessToken | None
self._credential: DiracTokenCredential
credentials: dict[str, Any]
try:
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
# Make mypy happy
if not isinstance(self._credential, AsyncTokenCredential):
return

self._token = await self._credential.get_token("", token=self._token)
if not self._token.token:
# If we are here, it means the token is not available
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
refresh_token = credentials["refresh_token"]
if not is_refresh_token_valid(refresh_token):
# If we are here, it means the refresh token is not valid anymore
# we suppose it is not needed to perform the request
return
self._token = await self._credential.get_token(
"", refresh_token=refresh_token
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
request.http_request.headers["Authorization"] = (
"Bearer " + cast(AccessToken, self._token).token
)


class DiracClientMixin(metaclass=abc.ABCMeta):
Expand Down
Loading

0 comments on commit ad49f49

Please sign in to comment.