diff --git a/src/diracx/cli/__init__.py b/src/diracx/cli/__init__.py index 18ad8483b..fffc404df 100644 --- a/src/diracx/cli/__init__.py +++ b/src/diracx/cli/__init__.py @@ -5,7 +5,7 @@ from typer import Option -from diracx.client.aio import Dirac +from diracx.client.aio import DiracClient from . import internal, jobs from .utils import AsyncTyper @@ -28,14 +28,14 @@ async def login( scopes += [f"property:{p}" for p in property] print(f"Logging in with scopes: {scopes}") - async with Dirac() as api: + async with DiracClient() as api: await api.login(scopes) print("\nLogin successful!") @app.async_command() async def logout(): - async with Dirac() as api: + async with DiracClient() as api: await api.logout() print("\nLogout successful!") diff --git a/src/diracx/cli/jobs.py b/src/diracx/cli/jobs.py index 57457ff13..3ece3b9a3 100644 --- a/src/diracx/cli/jobs.py +++ b/src/diracx/cli/jobs.py @@ -11,7 +11,7 @@ from rich.table import Table from typer import FileText, Option -from diracx.client.aio import Dirac +from diracx.client.aio import DiracClient from diracx.core.models import ScalarSearchOperator, SearchSpec, VectorSearchOperator from .utils import AsyncTyper @@ -53,7 +53,7 @@ async def search( condition: Annotated[list[SearchSpec], Option(parser=parse_condition)] = [], all: bool = False, ): - async with Dirac() as api: + async with DiracClient() as api: jobs = await api.jobs.search( parameters=None if all else parameter, search=condition if condition else None, @@ -102,7 +102,7 @@ def display_rich(data, unit: str) -> None: @app.async_command() async def submit(jdl: list[FileText]): - async with Dirac() as api: + async with DiracClient() as api: # api.valid(enforce_https=False) jobs = await api.jobs.submit_bulk_jobs([x.read() for x in jdl]) print( diff --git a/src/diracx/cli/utils.py b/src/diracx/cli/utils.py index e1abeb728..8014b0998 100644 --- a/src/diracx/cli/utils.py +++ b/src/diracx/cli/utils.py @@ -1,15 +1,12 @@ from __future__ import annotations -__all__ = ("AsyncTyper", "CREDENTIALS_PATH") +__all__ = ("AsyncTyper",) from asyncio import run from functools import wraps -from pathlib import Path import typer -CREDENTIALS_PATH = Path.home() / ".cache" / "diracx" / "credentials.json" - class AsyncTyper(typer.Typer): def async_command(self, *args, **kwargs): diff --git a/src/diracx/client/aio/_patch.py b/src/diracx/client/aio/_patch.py index c6167c221..dfeaa9a81 100644 --- a/src/diracx/client/aio/_patch.py +++ b/src/diracx/client/aio/_patch.py @@ -8,11 +8,11 @@ """ import asyncio import json -import logging +from types import TracebackType import requests from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, List +from typing import Any, List, Optional, cast from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest @@ -24,11 +24,10 @@ from ._client import Dirac as DiracGenerated __all__: List[str] = [ - "Dirac", + "DiracClient", ] # Add all objects you want publicly available to users at this package level -CREDENTIALS_PATH = Path.home() / ".cache" / "diracx" / "credentials.json" EXPIRES_GRACE_SECONDS = 15 @@ -44,11 +43,18 @@ def patch_sdk(): class DiracTokenCredential(AsyncTokenCredential): """Tailor get_token() for our context""" - def __init__(self, token_endpoint, client_id) -> None: + def __init__(self, location: Path, token_endpoint: str, client_id: str) -> None: + self.location = location self.token_endpoint = token_endpoint self.client_id = client_id - async def get_token(self, **kwargs: Any) -> AccessToken: + async def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + **kwargs: Any, + ) -> AccessToken: """Refresh the access token using the refresh_token flow. :param str scopes: The type of access needed. :keyword str claims: Additional claims required in the token, such as those returned in a resource @@ -81,11 +87,28 @@ async def get_token(self, **kwargs: Any) -> AccessToken: ) write_credentials(token_response) - credentials = json.loads(CREDENTIALS_PATH.read_text()) + credentials = json.loads(self.location.read_text()) return AccessToken( credentials.get("access_token"), credentials.get("expires_on") ) + async def close(self) -> None: + """AsyncTokenCredential is a protocol: we need to 'implement' close()""" + pass + + async def __aenter__(self): + """AsyncTokenCredential is a protocol: we need to 'implement' __aenter__()""" + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None = ..., + exc_value: BaseException | None = ..., + traceback: TracebackType | None = ..., + ) -> None: + """AsyncTokenCredential is a protocol: we need to 'implement' __aexit__()""" + pass + class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom AsyncBearerTokenCredentialPolicy tailored for our use case. @@ -94,6 +117,11 @@ class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): * It does not ensure that an access token is available. """ + def __init__( + self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any + ) -> None: + super().__init__(credential, *scopes, **kwargs) + async def on_request( self, request: "PipelineRequest" ) -> None: # pylint:disable=invalid-overridden-method @@ -102,15 +130,28 @@ async def on_request( :type request: ~azure.core.pipeline.PipelineRequest :raises: :class:`~azure.core.exceptions.ServiceRequestError` """ + self._token: AccessToken + self._credential: DiracTokenCredential + credentials: dict[str, Any] + # If the credentials path does not exist, we suppose it is not needed to perform the request - if not CREDENTIALS_PATH.exists(): + if not self._credential.location.exists(): return + # Load the existing credentials + if not self._token: + credentials = json.loads(self._credential.location.read_text()) + self._token = AccessToken( + cast(str, credentials.get("access_token")), + cast(int, credentials.get("expires_on")), + ) + # Else we check if we need a new access token if self._need_new_token(): - credentials = json.loads(CREDENTIALS_PATH.read_text()) - self._token: AccessToken = await self._credential.get_token( - refresh_token=credentials["refresh_token"] + if not credentials: + credentials = json.loads(self._credential.location.read_text()) + self._token = await self._credential.get_token( + "", refresh_token=credentials["refresh_token"] ) request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" @@ -125,28 +166,38 @@ def _need_new_token(self) -> bool: ) -class Dirac(DiracGenerated): +class DiracClient(DiracGenerated): """This class inherits from the generated Dirac client and adds support for tokens, so that the caller does not need to configure it by itself. """ - def __init__(self, **kwargs: Any) -> None: - endpoint = get_diracx_preferences().url - self._client_id = "myDIRACClientID" + def __init__( + self, endpoint: str | None = None, client_id: str | None = None, **kwargs: Any + ) -> None: + diracx_preferences = get_diracx_preferences() + self._credentials_path = diracx_preferences.credentials_path + + self._endpoint = endpoint or diracx_preferences.url + self._client_id = client_id or "myDIRACClientID" # Get .well-known configuration - response = requests.get(url=f"{endpoint}/.well-known/openid-configuration") + response = requests.get( + url=f"{self._endpoint}/.well-known/openid-configuration" + ) if not response.ok: + print(response.__dict__) raise RuntimeError( "Cannot fetch any information from the .well-known endpoint" ) # Initialize Dirac with a Dirac-specific token credential policy super().__init__( - endpoint=endpoint, + endpoint=self._endpoint, authentication_policy=DiracBearerTokenCredentialPolicy( DiracTokenCredential( - response.json()["token_endpoint"], self._client_id + location=self._credentials_path, + token_endpoint=response.json()["token_endpoint"], + client_id=self._client_id, ), ), **kwargs, @@ -176,15 +227,15 @@ async def login(self, scopes: list[str]): raise RuntimeError("Device authorization flow expired") # Save credentials - CREDENTIALS_PATH.parent.mkdir(parents=True, exist_ok=True) + self._credentials_path.parent.mkdir(parents=True, exist_ok=True) write_credentials(response) - print(f"Saved credentials to {CREDENTIALS_PATH}") + print(f"Saved credentials to {self._credentials_path}") async def logout(self): """Remove credentials""" - if not CREDENTIALS_PATH.exists(): + if not self._credentials_path.exists(): return - credentials = json.loads(CREDENTIALS_PATH.read_text()) + credentials = json.loads(self._credentials_path.read_text()) # Revoke refresh token try: @@ -193,8 +244,13 @@ async def logout(self): pass # Remove credentials - CREDENTIALS_PATH.unlink(missing_ok=True) - print(f"Removed credentials from {CREDENTIALS_PATH}") + self._credentials_path.unlink(missing_ok=True) + print(f"Removed credentials from {self._credentials_path}") + + async def __aenter__(self) -> "DiracClient": + """Redefined to provide the patched Dirac client in the managed context""" + await self._client.__aenter__() + return self def write_credentials(token_response: TokenResponse): @@ -207,4 +263,4 @@ def write_credentials(token_response: TokenResponse): "refresh_token": token_response.refresh_token, "expires_on": int(datetime.timestamp(expires)), } - CREDENTIALS_PATH.write_text(json.dumps(credential_data)) + get_diracx_preferences().credentials_path.write_text(json.dumps(credential_data)) diff --git a/src/diracx/core/preferences.py b/src/diracx/core/preferences.py index 2bfa17afc..f74cc7313 100644 --- a/src/diracx/core/preferences.py +++ b/src/diracx/core/preferences.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pathlib import Path + __all__ = ("DiracxPreferences", "OutputFormats", "get_diracx_preferences") import logging @@ -27,6 +29,7 @@ class DiracxPreferences(BaseSettings, env_prefix="DIRACX_"): url: AnyHttpUrl output_format: OutputFormats = OutputFormats.RICH log_level: LogLevels = LogLevels.INFO + credentials_path: Path = Path.home() / ".cache" / "diracx" / "credentials.json" @classmethod def from_env(cls):