From 533540a0118e6e27d404b8e8a5fbfe52e7ac6f45 Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 12 Sep 2023 14:15:03 +0200 Subject: [PATCH] fix --- src/diracx/cli/__init__.py | 6 +++--- src/diracx/cli/jobs.py | 6 +++--- src/diracx/cli/utils.py | 5 +---- src/diracx/client/aio/_patch.py | 37 +++++++++++++++++++++++++++------ 4 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/diracx/cli/__init__.py b/src/diracx/cli/__init__.py index 18ad8483b..fffc404df 100644 --- a/src/diracx/cli/__init__.py +++ b/src/diracx/cli/__init__.py @@ -5,7 +5,7 @@ from typer import Option -from diracx.client.aio import Dirac +from diracx.client.aio import DiracClient from . import internal, jobs from .utils import AsyncTyper @@ -28,14 +28,14 @@ async def login( scopes += [f"property:{p}" for p in property] print(f"Logging in with scopes: {scopes}") - async with Dirac() as api: + async with DiracClient() as api: await api.login(scopes) print("\nLogin successful!") @app.async_command() async def logout(): - async with Dirac() as api: + async with DiracClient() as api: await api.logout() print("\nLogout successful!") diff --git a/src/diracx/cli/jobs.py b/src/diracx/cli/jobs.py index 57457ff13..3ece3b9a3 100644 --- a/src/diracx/cli/jobs.py +++ b/src/diracx/cli/jobs.py @@ -11,7 +11,7 @@ from rich.table import Table from typer import FileText, Option -from diracx.client.aio import Dirac +from diracx.client.aio import DiracClient from diracx.core.models import ScalarSearchOperator, SearchSpec, VectorSearchOperator from .utils import AsyncTyper @@ -53,7 +53,7 @@ async def search( condition: Annotated[list[SearchSpec], Option(parser=parse_condition)] = [], all: bool = False, ): - async with Dirac() as api: + async with DiracClient() as api: jobs = await api.jobs.search( parameters=None if all else parameter, search=condition if condition else None, @@ -102,7 +102,7 @@ def display_rich(data, unit: str) -> None: @app.async_command() async def submit(jdl: list[FileText]): - async with Dirac() as api: + async with DiracClient() as api: # api.valid(enforce_https=False) jobs = await api.jobs.submit_bulk_jobs([x.read() for x in jdl]) print( diff --git a/src/diracx/cli/utils.py b/src/diracx/cli/utils.py index e1abeb728..8014b0998 100644 --- a/src/diracx/cli/utils.py +++ b/src/diracx/cli/utils.py @@ -1,15 +1,12 @@ from __future__ import annotations -__all__ = ("AsyncTyper", "CREDENTIALS_PATH") +__all__ = ("AsyncTyper",) from asyncio import run from functools import wraps -from pathlib import Path import typer -CREDENTIALS_PATH = Path.home() / ".cache" / "diracx" / "credentials.json" - class AsyncTyper(typer.Typer): def async_command(self, *args, **kwargs): diff --git a/src/diracx/client/aio/_patch.py b/src/diracx/client/aio/_patch.py index c6167c221..b89a80e18 100644 --- a/src/diracx/client/aio/_patch.py +++ b/src/diracx/client/aio/_patch.py @@ -8,11 +8,11 @@ """ import asyncio import json -import logging +from types import TracebackType import requests from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Any, List +from typing import Any, List, Optional from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest @@ -24,7 +24,7 @@ from ._client import Dirac as DiracGenerated __all__: List[str] = [ - "Dirac", + "DiracClient", ] # Add all objects you want publicly available to users at this package level @@ -48,7 +48,13 @@ def __init__(self, token_endpoint, client_id) -> None: self.token_endpoint = token_endpoint self.client_id = client_id - async def get_token(self, **kwargs: Any) -> AccessToken: + async def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + **kwargs: Any, + ) -> AccessToken: """Refresh the access token using the refresh_token flow. :param str scopes: The type of access needed. :keyword str claims: Additional claims required in the token, such as those returned in a resource @@ -86,6 +92,20 @@ async def get_token(self, **kwargs: Any) -> AccessToken: credentials.get("access_token"), credentials.get("expires_on") ) + async def close(self) -> None: + pass + + async def __aenter__(self): + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): """Custom AsyncBearerTokenCredentialPolicy tailored for our use case. @@ -110,7 +130,7 @@ async def on_request( if self._need_new_token(): credentials = json.loads(CREDENTIALS_PATH.read_text()) self._token: AccessToken = await self._credential.get_token( - refresh_token=credentials["refresh_token"] + "", refresh_token=credentials["refresh_token"] ) request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" @@ -125,7 +145,7 @@ def _need_new_token(self) -> bool: ) -class Dirac(DiracGenerated): +class DiracClient(DiracGenerated): """This class inherits from the generated Dirac client and adds support for tokens, so that the caller does not need to configure it by itself. """ @@ -196,6 +216,11 @@ async def logout(self): CREDENTIALS_PATH.unlink(missing_ok=True) print(f"Removed credentials from {CREDENTIALS_PATH}") + async def __aenter__(self) -> "DiracClient": + """Redefined to provide the patched Dirac client in the managed context""" + await self._client.__aenter__() + return self + def write_credentials(token_response: TokenResponse): """Write credentials received in CREDENTIALS_PATH"""