Skip to content

Commit

Permalink
feat: move CREDENTIALS_PATH to preferences
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Sep 14, 2023
1 parent 4a04331 commit cee3c34
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 35 deletions.
6 changes: 3 additions & 3 deletions src/diracx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typer import Option

from diracx.client.aio import Dirac
from diracx.client.aio import DiracClient

from . import internal, jobs
from .utils import AsyncTyper
Expand All @@ -28,14 +28,14 @@ async def login(
scopes += [f"property:{p}" for p in property]

print(f"Logging in with scopes: {scopes}")
async with Dirac() as api:
async with DiracClient() as api:
await api.login(scopes)
print("\nLogin successful!")

Check warning on line 33 in src/diracx/cli/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/cli/__init__.py#L31-L33

Added lines #L31 - L33 were not covered by tests


@app.async_command()
async def logout():
async with Dirac() as api:
async with DiracClient() as api:
await api.logout()
print("\nLogout successful!")

Check warning on line 40 in src/diracx/cli/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/cli/__init__.py#L38-L40

Added lines #L38 - L40 were not covered by tests

Expand Down
6 changes: 3 additions & 3 deletions src/diracx/cli/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rich.table import Table
from typer import FileText, Option

from diracx.client.aio import Dirac
from diracx.client.aio import DiracClient
from diracx.core.models import ScalarSearchOperator, SearchSpec, VectorSearchOperator

from .utils import AsyncTyper
Expand Down Expand Up @@ -53,7 +53,7 @@ async def search(
condition: Annotated[list[SearchSpec], Option(parser=parse_condition)] = [],
all: bool = False,
):
async with Dirac() as api:
async with DiracClient() as api:

Check warning on line 56 in src/diracx/cli/jobs.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/cli/jobs.py#L56

Added line #L56 was not covered by tests
jobs = await api.jobs.search(
parameters=None if all else parameter,
search=condition if condition else None,
Expand Down Expand Up @@ -102,7 +102,7 @@ def display_rich(data, unit: str) -> None:

@app.async_command()
async def submit(jdl: list[FileText]):
async with Dirac() as api:
async with DiracClient() as api:

Check warning on line 105 in src/diracx/cli/jobs.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/cli/jobs.py#L105

Added line #L105 was not covered by tests
# api.valid(enforce_https=False)
jobs = await api.jobs.submit_bulk_jobs([x.read() for x in jdl])

Check warning on line 107 in src/diracx/cli/jobs.py

View check run for this annotation

Codecov / codecov/patch

src/diracx/cli/jobs.py#L107

Added line #L107 was not covered by tests
print(
Expand Down
5 changes: 1 addition & 4 deletions src/diracx/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from __future__ import annotations

__all__ = ("AsyncTyper", "CREDENTIALS_PATH")
__all__ = ("AsyncTyper",)

from asyncio import run
from functools import wraps
from pathlib import Path

import typer

CREDENTIALS_PATH = Path.home() / ".cache" / "diracx" / "credentials.json"


class AsyncTyper(typer.Typer):
def async_command(self, *args, **kwargs):
Expand Down
106 changes: 81 additions & 25 deletions src/diracx/client/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
"""
import asyncio
import json
import logging
from types import TracebackType
import requests
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, List
from typing import Any, List, Optional, cast
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline import PipelineRequest
Expand All @@ -24,11 +24,10 @@
from ._client import Dirac as DiracGenerated

__all__: List[str] = [
"Dirac",
"DiracClient",
] # Add all objects you want publicly available to users at this package level


CREDENTIALS_PATH = Path.home() / ".cache" / "diracx" / "credentials.json"
EXPIRES_GRACE_SECONDS = 15


Expand All @@ -44,11 +43,18 @@ def patch_sdk():
class DiracTokenCredential(AsyncTokenCredential):
"""Tailor get_token() for our context"""

def __init__(self, token_endpoint, client_id) -> None:
def __init__(self, location: Path, token_endpoint: str, client_id: str) -> None:
self.location = location
self.token_endpoint = token_endpoint
self.client_id = client_id

async def get_token(self, **kwargs: Any) -> AccessToken:
async def get_token(
self,
*scopes: str,
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
Expand Down Expand Up @@ -81,11 +87,28 @@ async def get_token(self, **kwargs: Any) -> AccessToken:
)

write_credentials(token_response)
credentials = json.loads(CREDENTIALS_PATH.read_text())
credentials = json.loads(self.location.read_text())
return AccessToken(
credentials.get("access_token"), credentials.get("expires_on")
)

async def close(self) -> None:
"""AsyncTokenCredential is a protocol: we need to 'implement' close()"""
pass

async def __aenter__(self):
"""AsyncTokenCredential is a protocol: we need to 'implement' __aenter__()"""
pass

async def __aexit__(
self,
exc_type: type[BaseException] | None = ...,
exc_value: BaseException | None = ...,
traceback: TracebackType | None = ...,
) -> None:
"""AsyncTokenCredential is a protocol: we need to 'implement' __aexit__()"""
pass


class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
"""Custom AsyncBearerTokenCredentialPolicy tailored for our use case.
Expand All @@ -94,6 +117,11 @@ class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
* It does not ensure that an access token is available.
"""

def __init__(
self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any
) -> None:
super().__init__(credential, *scopes, **kwargs)

async def on_request(
self, request: "PipelineRequest"
) -> None: # pylint:disable=invalid-overridden-method
Expand All @@ -102,15 +130,28 @@ async def on_request(
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._token: AccessToken
self._credential: DiracTokenCredential
credentials: dict[str, Any]

# If the credentials path does not exist, we suppose it is not needed to perform the request
if not CREDENTIALS_PATH.exists():
if not self._credential.location.exists():
return

# Load the existing credentials
if not self._token:
credentials = json.loads(self._credential.location.read_text())
self._token = AccessToken(
cast(str, credentials.get("access_token")),
cast(int, credentials.get("expires_on")),
)

# Else we check if we need a new access token
if self._need_new_token():
credentials = json.loads(CREDENTIALS_PATH.read_text())
self._token: AccessToken = await self._credential.get_token(
refresh_token=credentials["refresh_token"]
if not credentials:
credentials = json.loads(self._credential.location.read_text())
self._token = await self._credential.get_token(
"", refresh_token=credentials["refresh_token"]
)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
Expand All @@ -125,28 +166,38 @@ def _need_new_token(self) -> bool:
)


class Dirac(DiracGenerated):
class DiracClient(DiracGenerated):
"""This class inherits from the generated Dirac client and adds support for tokens,
so that the caller does not need to configure it by itself.
"""

def __init__(self, **kwargs: Any) -> None:
endpoint = get_diracx_preferences().url
self._client_id = "myDIRACClientID"
def __init__(
self, endpoint: str | None = None, client_id: str | None = None, **kwargs: Any
) -> None:
diracx_preferences = get_diracx_preferences()
self._credentials_path = diracx_preferences.credentials_path

self._endpoint = endpoint or diracx_preferences.url
self._client_id = client_id or "myDIRACClientID"

# Get .well-known configuration
response = requests.get(url=f"{endpoint}/.well-known/openid-configuration")
response = requests.get(
url=f"{self._endpoint}/.well-known/openid-configuration"
)
if not response.ok:
print(response.__dict__)
raise RuntimeError(
"Cannot fetch any information from the .well-known endpoint"
)

# Initialize Dirac with a Dirac-specific token credential policy
super().__init__(
endpoint=endpoint,
endpoint=self._endpoint,
authentication_policy=DiracBearerTokenCredentialPolicy(
DiracTokenCredential(
response.json()["token_endpoint"], self._client_id
location=self._credentials_path,
token_endpoint=response.json()["token_endpoint"],
client_id=self._client_id,
),
),
**kwargs,
Expand Down Expand Up @@ -176,15 +227,15 @@ async def login(self, scopes: list[str]):
raise RuntimeError("Device authorization flow expired")

# Save credentials
CREDENTIALS_PATH.parent.mkdir(parents=True, exist_ok=True)
self._credentials_path.parent.mkdir(parents=True, exist_ok=True)
write_credentials(response)
print(f"Saved credentials to {CREDENTIALS_PATH}")
print(f"Saved credentials to {self._credentials_path}")

async def logout(self):
"""Remove credentials"""
if not CREDENTIALS_PATH.exists():
if not self._credentials_path.exists():
return
credentials = json.loads(CREDENTIALS_PATH.read_text())
credentials = json.loads(self._credentials_path.read_text())

# Revoke refresh token
try:
Expand All @@ -193,8 +244,13 @@ async def logout(self):
pass

# Remove credentials
CREDENTIALS_PATH.unlink(missing_ok=True)
print(f"Removed credentials from {CREDENTIALS_PATH}")
self._credentials_path.unlink(missing_ok=True)
print(f"Removed credentials from {self._credentials_path}")

async def __aenter__(self) -> "DiracClient":
"""Redefined to provide the patched Dirac client in the managed context"""
await self._client.__aenter__()
return self


def write_credentials(token_response: TokenResponse):
Expand All @@ -207,4 +263,4 @@ def write_credentials(token_response: TokenResponse):
"refresh_token": token_response.refresh_token,
"expires_on": int(datetime.timestamp(expires)),
}
CREDENTIALS_PATH.write_text(json.dumps(credential_data))
get_diracx_preferences().credentials_path.write_text(json.dumps(credential_data))
3 changes: 3 additions & 0 deletions src/diracx/core/preferences.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from pathlib import Path

__all__ = ("DiracxPreferences", "OutputFormats", "get_diracx_preferences")

import logging
Expand Down Expand Up @@ -27,6 +29,7 @@ class DiracxPreferences(BaseSettings, env_prefix="DIRACX_"):
url: AnyHttpUrl
output_format: OutputFormats = OutputFormats.RICH
log_level: LogLevels = LogLevels.INFO
credentials_path: Path = Path.home() / ".cache" / "diracx" / "credentials.json"

@classmethod
def from_env(cls):
Expand Down

0 comments on commit cee3c34

Please sign in to comment.