Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Sep 12, 2023
1 parent 4a04331 commit 533540a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 16 deletions.
6 changes: 3 additions & 3 deletions src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typer import Option

from diracx.client.aio import Dirac
from diracx.client.aio import DiracClient

from . import internal, jobs
from .utils import AsyncTyper
Expand All @@ -28,14 +28,14 @@ async def login(
scopes += [f"property:{p}" for p in property]

print(f"Logging in with scopes: {scopes}")
async with Dirac() as api:
async with DiracClient() as api:
await api.login(scopes)
print("\nLogin successful!")


@app.async_command()
async def logout():
async with Dirac() as api:
async with DiracClient() as api:
await api.logout()
print("\nLogout successful!")

Expand Down
6 changes: 3 additions & 3 deletions src/diracx/cli/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich.table import Table
from typer import FileText, Option

from diracx.client.aio import Dirac
from diracx.client.aio import DiracClient
from diracx.core.models import ScalarSearchOperator, SearchSpec, VectorSearchOperator

from .utils import AsyncTyper
Expand Down Expand Up @@ -53,7 +53,7 @@ async def search(
condition: Annotated[list[SearchSpec], Option(parser=parse_condition)] = [],
all: bool = False,
):
async with Dirac() as api:
async with DiracClient() as api:
jobs = await api.jobs.search(
parameters=None if all else parameter,
search=condition if condition else None,
Expand Down Expand Up @@ -102,7 +102,7 @@ def display_rich(data, unit: str) -> None:

@app.async_command()
async def submit(jdl: list[FileText]):
async with Dirac() as api:
async with DiracClient() as api:
# api.valid(enforce_https=False)
jobs = await api.jobs.submit_bulk_jobs([x.read() for x in jdl])
print(
Expand Down
5 changes: 1 addition & 4 deletions src/diracx/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from __future__ import annotations

__all__ = ("AsyncTyper", "CREDENTIALS_PATH")
__all__ = ("AsyncTyper",)

from asyncio import run
from functools import wraps
from pathlib import Path

import typer

CREDENTIALS_PATH = Path.home() / ".cache" / "diracx" / "credentials.json"


class AsyncTyper(typer.Typer):
def async_command(self, *args, **kwargs):
Expand Down
37 changes: 31 additions & 6 deletions src/diracx/client/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
"""
import asyncio
import json
import logging
from types import TracebackType
import requests
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, List
from typing import Any, List, Optional
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest
Expand All @@ -24,7 +24,7 @@
from ._client import Dirac as DiracGenerated

__all__: List[str] = [
"Dirac",
"DiracClient",
] # Add all objects you want publicly available to users at this package level


Expand All @@ -48,7 +48,13 @@ def __init__(self, token_endpoint, client_id) -> None:
self.token_endpoint = token_endpoint
self.client_id = client_id

async def get_token(self, **kwargs: Any) -> AccessToken:
async def get_token(
self,
*scopes: str,
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
Expand Down Expand Up @@ -86,6 +92,20 @@ async def get_token(self, **kwargs: Any) -> AccessToken:
credentials.get("access_token"), credentials.get("expires_on")
)

async def close(self) -> None:
pass

async def __aenter__(self):
pass

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass


class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
"""Custom AsyncBearerTokenCredentialPolicy tailored for our use case.
Expand All @@ -110,7 +130,7 @@ async def on_request(
if self._need_new_token():
credentials = json.loads(CREDENTIALS_PATH.read_text())
self._token: AccessToken = await self._credential.get_token(
refresh_token=credentials["refresh_token"]
"", refresh_token=credentials["refresh_token"]
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
Expand All @@ -125,7 +145,7 @@ def _need_new_token(self) -> bool:
)


class Dirac(DiracGenerated):
class DiracClient(DiracGenerated):
"""This class inherits from the generated Dirac client and adds support for tokens,
so that the caller does not need to configure it by itself.
"""
Expand Down Expand Up @@ -196,6 +216,11 @@ async def logout(self):
CREDENTIALS_PATH.unlink(missing_ok=True)
print(f"Removed credentials from {CREDENTIALS_PATH}")

async def __aenter__(self) -> "DiracClient":
"""Redefined to provide the patched Dirac client in the managed context"""
await self._client.__aenter__()
return self


def write_credentials(token_response: TokenResponse):
"""Write credentials received in CREDENTIALS_PATH"""
Expand Down

0 comments on commit 533540a

Please sign in to comment.