From bd4bbc65a2dc2e801649b0f49cb94eb68c206abb Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Mon, 25 Sep 2023 07:22:11 +0200 Subject: [PATCH 1/7] Add CLI tests --- tests/cli/__init__.py | 0 tests/cli/test_jobs.py | 13 ++++++ tests/cli/test_login.py | 94 +++++++++++++++++++++++++++++++++++++++++ tests/conftest.py | 58 +++++++++++++++++++++++-- 4 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 tests/cli/__init__.py create mode 100644 tests/cli/test_jobs.py create mode 100644 tests/cli/test_login.py diff --git a/tests/cli/__init__.py b/tests/cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cli/test_jobs.py b/tests/cli/test_jobs.py new file mode 100644 index 00000000..3875fa7f --- /dev/null +++ b/tests/cli/test_jobs.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import json + +from diracx import cli + + +async def test_search(with_cli_login, capfd): + await cli.jobs.search() + cap = capfd.readouterr() + assert cap.err == "" + # By default the output should be in JSON format as capfd is not a TTY + json.loads(cap.out) diff --git a/tests/cli/test_login.py b/tests/cli/test_login.py new file mode 100644 index 00000000..554766be --- /dev/null +++ b/tests/cli/test_login.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import asyncio +import re +from html.parser import HTMLParser +from pathlib import Path +from urllib.parse import urljoin + +import requests + +from diracx import cli + + +def do_device_flow_with_dex(url: str) -> None: + """Do the device flow with dex""" + + class DexLoginFormParser(HTMLParser): + def handle_starttag(self, tag, attrs): + nonlocal action_url + if "form" in str(tag): + assert action_url is None + action_url = urljoin(login_page_url, dict(attrs)["action"]) + + # Get the login page + r = requests.get(url) + r.raise_for_status() + login_page_url = r.url # This is not the same as URL as we redirect to dex + login_page_body = r.text + + # Search the page for the login form so we know where to post the credentials + action_url = None + DexLoginFormParser().feed(login_page_body) + assert action_url is not None, login_page_body + + # Do the actual login + r = requests.post( + action_url, + data={"login": "admin@example.com", "password": "password"}, + ) + r.raise_for_status() + # This should have redirected to the DiracX page that shows the login is complete + assert "Please close the window" in r.text + + +async def test_login(monkeypatch, capfd, cli_env): + poll_attempts = 0 + + def fake_sleep(*args, **kwargs): + nonlocal poll_attempts + + # Keep track of the number of times this is called + poll_attempts += 1 + + # After polling 5 times, do the actual login + if poll_attempts == 5: + # The login URL should have been printed to stdout + captured = capfd.readouterr() + match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out) + assert match, captured + + do_device_flow_with_dex(match.group()) + + # Ensure we don't poll forever + assert poll_attempts <= 100 + + # Reduce the sleep duration to zero to speed up the test + return unpatched_sleep(0) + + # We monkeypatch asyncio.sleep to provide a hook to run the actions that + # would normally be done by a user. This includes capturing the login URL + # and doing the actual device flow with dex. + unpatched_sleep = asyncio.sleep + + expected_credentials_path = Path( + cli_env["HOME"], ".cache", "diracx", "credentials.json" + ) + + # Ensure the credentials file does not exist before logging in + assert not expected_credentials_path.exists() + + # Run the login command + with monkeypatch.context() as m: + m.setattr("asyncio.sleep", fake_sleep) + await cli.login(vo="diracAdmin", group=None, property=None) + captured = capfd.readouterr() + assert "Login successful!" in captured.out + assert captured.err == "" + + # Ensure the credentials file exists after logging in + assert expected_credentials_path.exists() + + # Return the credentials so this test can also be used by the + # "with_cli_login" fixture + return expected_credentials_path.read_text() diff --git a/tests/conftest.py b/tests/conftest.py index b3a74318..9b4d224c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,12 +8,15 @@ from uuid import uuid4 import pytest +import requests +import yaml from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from fastapi.testclient import TestClient from git import Repo from diracx.core.config import Config, ConfigSource +from diracx.core.preferences import get_diracx_preferences from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.routers import create_app_inner from diracx.routers.auth import AuthSettings, create_token @@ -82,7 +85,9 @@ def with_app(test_auth_settings, with_config_repo): """ app = create_app_inner( enabled_systems={".well-known", "auth", "config", "jobs"}, - all_service_settings=[test_auth_settings], + all_service_settings=[ + test_auth_settings, + ], database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", "JobLoggingDB": "sqlite+aiosqlite:///:memory:", @@ -222,13 +227,23 @@ def admin_user_client(test_client, test_auth_settings): @pytest.fixture(scope="session") -def demo_kubectl_env(request): - """Get the dictionary of environment variables for kubectl to control the demo""" +def demo_dir(request) -> Path: demo_dir = request.config.getoption("--demo-dir") if demo_dir is None: pytest.skip("Requires a running instance of the DiracX demo") demo_dir = (demo_dir / ".demo").resolve() + yield demo_dir + + +@pytest.fixture(scope="session") +def demo_urls(demo_dir): + helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text()) + yield helm_values["developer"]["urls"] + +@pytest.fixture(scope="session") +def demo_kubectl_env(demo_dir): + """Get the dictionary of environment variables for kubectl to control the demo""" kube_conf = demo_dir / "kube.conf" if not kube_conf.exists(): raise RuntimeError(f"Could not find {kube_conf}, is the demo running?") @@ -246,3 +261,40 @@ def demo_kubectl_env(request): assert "diracx" in pods_result yield env + + +@pytest.fixture +def cli_env(monkeypatch, tmp_path, demo_urls): + """Set up the environment for the CLI""" + diracx_url = demo_urls["diracx"] + + # Ensure the demo is working + r = requests.get(f"{diracx_url}/openapi.json") + r.raise_for_status() + assert r.json()["info"]["title"] == "Dirac" + + env = { + "DIRACX_URL": diracx_url, + "HOME": str(tmp_path), + } + for key, value in env.items(): + monkeypatch.setenv(key, value) + yield env + + # The DiracX preferences are cached however when testing this cache is invalid + get_diracx_preferences.cache_clear() + + +@pytest.fixture +async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path): + from .cli.test_login import test_login + + try: + credentials = await test_login(monkeypatch, capfd, cli_env) + except Exception: + pytest.skip("Login failed, fix test_login to re-enable this test") + + credentials_path = tmp_path / "credentials.json" + credentials_path.write_text(credentials) + monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path)) + yield From 7397619dc8b10945000cbcbb73a0b745bd3f3f28 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Wed, 27 Sep 2023 11:58:03 +0200 Subject: [PATCH 2/7] Add diracx.core.s3 with utility functions --- environment.yml | 4 + src/diracx/core/models.py | 4 + src/diracx/core/s3.py | 75 +++++++++++++++++++ tests/core/test_s3.py | 154 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 237 insertions(+) create mode 100644 src/diracx/core/s3.py create mode 100644 tests/core/test_s3.py diff --git a/environment.yml b/environment.yml index d5e11411..a10cfad2 100644 --- a/environment.yml +++ b/environment.yml @@ -50,5 +50,9 @@ dependencies: - types-PyYAML - types-requests - uvicorn + - moto + - mypy-boto3-s3 + - botocore + - boto3-stubs # - pip: # - git+https://github.com/DIRACGrid/DIRAC.git@integration diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index 330ed1f2..930b3d5f 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -112,3 +112,7 @@ class UserInfo(BaseModel): preferred_username: str dirac_group: str vo: str + + +class ChecksumAlgorithm(StrEnum): + SHA256 = "sha256" diff --git a/src/diracx/core/s3.py b/src/diracx/core/s3.py new file mode 100644 index 00000000..34cf7990 --- /dev/null +++ b/src/diracx/core/s3.py @@ -0,0 +1,75 @@ +"""Utilities for interacting with S3-compatible storage.""" +from __future__ import annotations + +__all__ = ("s3_bucket_exists", "s3_object_exists", "generate_presigned_upload") + +import base64 +from typing import TYPE_CHECKING, TypedDict, cast + +from botocore.errorfactory import ClientError + +from .models import ChecksumAlgorithm + +if TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client + + +class S3PresignedPostInfo(TypedDict): + url: str + fields: dict[str, str] + + +def s3_bucket_exists(s3_client: S3Client, bucket_name: str) -> bool: + """Check if a bucket exists in S3.""" + return _s3_exists(s3_client.head_bucket, Bucket=bucket_name) + + +def s3_object_exists(s3_client: S3Client, bucket_name: str, key: str) -> bool: + """Check if an object exists in an S3 bucket.""" + return _s3_exists(s3_client.head_object, Bucket=bucket_name, Key=key) + + +def _s3_exists(method, **kwargs: str) -> bool: + try: + method(**kwargs) + except ClientError as e: + if e.response["Error"]["Code"] != "404": + raise + return False + else: + return True + + +def generate_presigned_upload( + s3_client: S3Client, + bucket_name: str, + key: str, + checksum_algorithm: ChecksumAlgorithm, + checksum: str, + size: int, + validity_seconds: int, +) -> S3PresignedPostInfo: + """Generate a presigned URL and fields for uploading a file to S3 + + The signature is restricted to only accept data with the given checksum and size. + """ + fields = { + "x-amz-checksum-algorithm": checksum_algorithm, + f"x-amz-checksum-{checksum_algorithm}": b16_to_b64(checksum), + } + conditions = [["content-length-range", size, size]] + [ + {k: v} for k, v in fields.items() + ] + result = s3_client.generate_presigned_post( + Bucket=bucket_name, + Key=key, + Fields=fields, + Conditions=conditions, + ExpiresIn=validity_seconds, + ) + return cast(S3PresignedPostInfo, result) + + +def b16_to_b64(hex_string: str) -> str: + """Convert hexadecimal encoded data to base64 encoded data""" + return base64.b64encode(base64.b16decode(hex_string.upper())).decode() diff --git a/tests/core/test_s3.py b/tests/core/test_s3.py new file mode 100644 index 00000000..c6acb2a3 --- /dev/null +++ b/tests/core/test_s3.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import base64 +import hashlib +import secrets + +import botocore.exceptions +import pytest +import requests +from moto import mock_s3 + +from diracx.core.s3 import ( + b16_to_b64, + generate_presigned_upload, + s3_bucket_exists, + s3_object_exists, +) + +BUCKET_NAME = "test_bucket" +OTHER_BUCKET_NAME = "other_bucket" +MISSING_BUCKET_NAME = "missing_bucket" +INVALID_BUCKET_NAME = ".." + + +def _random_file(size_bytes: int): + file_content = secrets.token_bytes(size_bytes) + checksum = hashlib.sha256(file_content).hexdigest() + return file_content, checksum + + +def test_b16_to_b64_hardcoded(): + assert b16_to_b64("25") == "JQ==", "%" + # Make sure we're using the URL-safe variant of base64 + assert b16_to_b64("355b3e51473f") == "NVs+UUc/", "5[>QG?" + + +def test_b16_to_b64_random(): + data = secrets.token_bytes() + input_hex = data.hex() + expected = base64.b64encode(data).decode() + actual = b16_to_b64(input_hex) + assert actual == expected, data.hex() + + +@pytest.fixture(scope="function") +def moto_s3(): + """Very basic moto-based S3 backend. + + This is a fixture that can be used to test S3 interactions using moto. + Note that this is not a complete S3 backend, in particular authentication + and validation of requests is not implemented. + """ + with mock_s3(): + client = botocore.session.get_session().create_client("s3") + client.create_bucket(Bucket=BUCKET_NAME) + client.create_bucket(Bucket=OTHER_BUCKET_NAME) + yield client + + +def test_s3_bucket_exists(moto_s3): + assert s3_bucket_exists(moto_s3, BUCKET_NAME) + assert not s3_bucket_exists(moto_s3, MISSING_BUCKET_NAME) + + +def test_s3_object_exists(moto_s3): + with pytest.raises(botocore.exceptions.ClientError): + s3_object_exists(moto_s3, MISSING_BUCKET_NAME, "key") + + assert not s3_object_exists(moto_s3, BUCKET_NAME, "key") + moto_s3.put_object(Bucket=BUCKET_NAME, Key="key", Body=b"hello") + assert s3_object_exists(moto_s3, BUCKET_NAME, "key") + + +def test_presigned_upload_moto(moto_s3): + """Test the presigned upload with moto + + This doesn't actually test the signature, see test_presigned_upload_minio + """ + file_content, checksum = _random_file(128) + key = f"{checksum}.dat" + upload_info = generate_presigned_upload( + moto_s3, BUCKET_NAME, key, "sha256", checksum, len(file_content), 60 + ) + + # Upload the file + r = requests.post( + upload_info["url"], data=upload_info["fields"], files={"file": file_content} + ) + assert r.status_code == 204, r.text + + # Make sure the object is actually there + obj = moto_s3.get_object(Bucket=BUCKET_NAME, Key=key) + assert obj["Body"].read() == file_content + + +@pytest.fixture(scope="session") +def minio_client(demo_urls): + """Create a S3 client that uses minio from the demo as backend""" + yield botocore.session.get_session().create_client( + "s3", + endpoint_url=demo_urls["minio"], + aws_access_key_id="console", + aws_secret_access_key="console123", + ) + + +@pytest.fixture(scope="session") +def test_bucket(minio_client): + """Create a test bucket that is cleaned up after the test session""" + bucket_name = f"dirac-test-{secrets.token_hex(8)}" + minio_client.create_bucket(Bucket=bucket_name) + yield bucket_name + for obj in minio_client.list_objects(Bucket=bucket_name)["Contents"]: + minio_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) + minio_client.delete_bucket(Bucket=bucket_name) + + +@pytest.mark.parametrize( + "content,checksum,size,expected_error", + [ + # Make sure a valid request works + [*_random_file(128), 128, None], + # Check with invalid sizes + [*_random_file(128), 127, "exceeds the maximum"], + [*_random_file(128), 129, "smaller than the minimum"], + # Check with invalid checksum + [_random_file(128)[0], _random_file(128)[1], 128, "ContentChecksumMismatch"], + ], +) +def test_presigned_upload_minio( + minio_client, test_bucket, content, checksum, size, expected_error +): + """Test the presigned upload with Minio + + This is a more complete test that checks that the presigned upload works + and is properly validated by Minio. This is not possible with moto as it + doesn't actually validate the signature. + """ + key = f"{checksum}.dat" + # Prepare the signed URL + upload_info = generate_presigned_upload( + minio_client, test_bucket, key, "sha256", checksum, size, 60 + ) + # Ensure the URL doesn't work + r = requests.post( + upload_info["url"], data=upload_info["fields"], files={"file": content} + ) + if expected_error is None: + assert r.status_code == 204, r.text + assert s3_object_exists(minio_client, test_bucket, key) + else: + assert r.status_code == 400, r.text + assert expected_error in r.text + assert not s3_object_exists(minio_client, test_bucket, key) From cdbbc4466d17d06711e7706c712213fef70c59e5 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Mon, 2 Oct 2023 15:44:39 +0200 Subject: [PATCH 3/7] Update SandboxMetadataDB interface --- src/diracx/core/models.py | 11 +++ src/diracx/db/sql/sandbox_metadata/db.py | 108 ++++++++++++----------- tests/db/test_sandboxMetadataDB.py | 84 ------------------ tests/db/test_sandbox_metadata.py | 92 +++++++++++++++++++ 4 files changed, 161 insertions(+), 134 deletions(-) delete mode 100644 tests/db/test_sandboxMetadataDB.py create mode 100644 tests/db/test_sandbox_metadata.py diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index 930b3d5f..21fc53a9 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -116,3 +116,14 @@ class UserInfo(BaseModel): class ChecksumAlgorithm(StrEnum): SHA256 = "sha256" + + +class SandboxFormat(StrEnum): + TAR_BZ2 = "tar.bz2" + + +class SandboxInfo(BaseModel): + checksum_algorithm: ChecksumAlgorithm + checksum: str = Field(pattern=r"^[0-f]{64}$") + size: int = Field(ge=1) + format: SandboxFormat diff --git a/src/diracx/db/sql/sandbox_metadata/db.py b/src/diracx/db/sql/sandbox_metadata/db.py index 6900f58a..a95dd93f 100644 --- a/src/diracx/db/sql/sandbox_metadata/db.py +++ b/src/diracx/db/sql/sandbox_metadata/db.py @@ -1,80 +1,88 @@ -""" SandboxMetadataDB frontend -""" - from __future__ import annotations -import datetime - import sqlalchemy -from diracx.db.sql.utils import BaseSQLDB +from diracx.core.models import SandboxInfo, UserInfo +from diracx.db.sql.utils import BaseSQLDB, utcnow from .schema import Base as SandboxMetadataDBBase from .schema import sb_Owners, sb_SandBoxes +# In legacy DIRAC the SEName column was used to support multiple different +# storage backends. This is no longer the case, so we hardcode the value to +# S3 to represent the new DiracX system. +SE_NAME = "ProductionSandboxSE" +PFN_PREFIX = "/S3/" + class SandboxMetadataDB(BaseSQLDB): metadata = SandboxMetadataDBBase.metadata - async def _get_put_owner(self, owner: str, owner_group: str) -> int: - """adds a new owner/ownerGroup pairs, while returning their ID if already existing - - Args: - owner (str): user name - owner_group (str): group of the owner - """ + async def upsert_owner(self, user: UserInfo) -> int: + """Get the id of the owner from the database""" + # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 stmt = sqlalchemy.select(sb_Owners.OwnerID).where( - sb_Owners.Owner == owner, sb_Owners.OwnerGroup == owner_group + sb_Owners.Owner == user.preferred_username, + sb_Owners.OwnerGroup == user.dirac_group, + # TODO: Add VO ) result = await self.conn.execute(stmt) if owner_id := result.scalar_one_or_none(): return owner_id - stmt = sqlalchemy.insert(sb_Owners).values(Owner=owner, OwnerGroup=owner_group) + stmt = sqlalchemy.insert(sb_Owners).values( + Owner=user.preferred_username, + OwnerGroup=user.dirac_group, + ) result = await self.conn.execute(stmt) return result.lastrowid - async def insert( - self, owner: str, owner_group: str, sb_SE: str, se_PFN: str, size: int = 0 - ) -> tuple[int, bool]: - """inserts a new sandbox in SandboxMetadataDB - this is "equivalent" of DIRAC registerAndGetSandbox + @staticmethod + def get_pfn(bucket_name: str, user: UserInfo, sandbox_info: SandboxInfo) -> str: + """Get the sandbox's user namespaced and content addressed PFN""" + parts = [ + "S3", + bucket_name, + user.vo, + user.dirac_group, + user.preferred_username, + f"{sandbox_info.checksum_algorithm}:{sandbox_info.checksum}.{sandbox_info.format}", + ] + return "/" + "/".join(parts) - Args: - owner (str): user name_ - owner_group (str): groupd of the owner - sb_SE (str): _description_ - sb_PFN (str): _description_ - size (int, optional): _description_. Defaults to 0. - """ - owner_id = await self._get_put_owner(owner, owner_group) + async def insert_sandbox(self, user: UserInfo, pfn: str, size: int) -> None: + """Add a new sandbox in SandboxMetadataDB""" + # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 + owner_id = await self.upsert_owner(user) stmt = sqlalchemy.insert(sb_SandBoxes).values( - OwnerId=owner_id, SEName=sb_SE, SEPFN=se_PFN, Bytes=size + OwnerId=owner_id, + SEName=SE_NAME, + SEPFN=pfn, + Bytes=size, + RegistrationTime=utcnow(), + LastAccessTime=utcnow(), ) try: result = await self.conn.execute(stmt) - return result.lastrowid except sqlalchemy.exc.IntegrityError: - # it is a duplicate, try to retrieve SBiD - stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.SBId).where( # type: ignore[no-redef] - sb_SandBoxes.SEPFN == se_PFN, - sb_SandBoxes.SEName == sb_SE, - sb_SandBoxes.OwnerId == owner_id, - ) - result = await self.conn.execute(stmt) - sb_ID = result.scalar_one() - stmt: sqlalchemy.Executable = ( # type: ignore[no-redef] - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SBId == sb_ID) - .values(LastAccessTime=datetime.datetime.utcnow()) - ) - await self.conn.execute(stmt) - return sb_ID + await self.update_sandbox_last_access_time(pfn) + else: + assert result.rowcount == 1 - async def delete(self, sandbox_ids: list[int]) -> bool: - stmt: sqlalchemy.Executable = sqlalchemy.delete(sb_SandBoxes).where( - sb_SandBoxes.SBId.in_(sandbox_ids) + async def update_sandbox_last_access_time(self, pfn: str) -> None: + stmt = ( + sqlalchemy.update(sb_SandBoxes) + .where(sb_SandBoxes.SEName == SE_NAME, sb_SandBoxes.SEPFN == pfn) + .values(LastAccessTime=utcnow()) ) - await self.conn.execute(stmt) + result = await self.conn.execute(stmt) + assert result.rowcount == 1 - return True + async def sandbox_is_assigned(self, pfn: str) -> bool: + """Checks if a sandbox exists and has been assigned.""" + stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where( + sb_SandBoxes.SEName == SE_NAME, sb_SandBoxes.SEPFN == pfn + ) + result = await self.conn.execute(stmt) + is_assigned = result.scalar_one() + return is_assigned diff --git a/tests/db/test_sandboxMetadataDB.py b/tests/db/test_sandboxMetadataDB.py deleted file mode 100644 index ffcf91ec..00000000 --- a/tests/db/test_sandboxMetadataDB.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -import pytest -import sqlalchemy - -from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB - - -@pytest.fixture -async def sandbox_metadata_db(tmp_path): - sandbox_metadata_db = SandboxMetadataDB("sqlite+aiosqlite:///:memory:") - async with sandbox_metadata_db.engine_context(): - async with sandbox_metadata_db.engine.begin() as conn: - await conn.run_sync(sandbox_metadata_db.metadata.create_all) - yield sandbox_metadata_db - - -async def test__get_put_owner(sandbox_metadata_db): - async with sandbox_metadata_db as sandbox_metadata_db: - result = await sandbox_metadata_db._get_put_owner("owner", "owner_group") - assert result == 1 - result = await sandbox_metadata_db._get_put_owner("owner_2", "owner_group") - assert result == 2 - result = await sandbox_metadata_db._get_put_owner("owner", "owner_group") - assert result == 1 - result = await sandbox_metadata_db._get_put_owner("owner_2", "owner_group") - assert result == 2 - result = await sandbox_metadata_db._get_put_owner("owner_2", "owner_group_2") - assert result == 3 - - -async def test_insert(sandbox_metadata_db): - async with sandbox_metadata_db as sandbox_metadata_db: - result = await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN", - 123, - ) - assert result == 1 - - result = await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN", - 123, - ) - assert result == 1 - - result = await sandbox_metadata_db.insert( - "owner_2", - "owner_group", - "sbSE", - "sbPFN_2", - 123, - ) - assert result == 2 - - # This would be incorrect - with pytest.raises(sqlalchemy.exc.NoResultFound): - await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN_2", - 123, - ) - - -async def test_delete(sandbox_metadata_db): - async with sandbox_metadata_db as sandbox_metadata_db: - result = await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN", - 123, - ) - assert result == 1 - - result = await sandbox_metadata_db.delete([1]) - assert result diff --git a/tests/db/test_sandbox_metadata.py b/tests/db/test_sandbox_metadata.py new file mode 100644 index 00000000..6757af13 --- /dev/null +++ b/tests/db/test_sandbox_metadata.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import asyncio +import secrets +from datetime import datetime + +import pytest +import sqlalchemy + +from diracx.core.models import SandboxInfo, UserInfo +from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB +from diracx.db.sql.sandbox_metadata.schema import sb_SandBoxes + + +@pytest.fixture +async def sandbox_metadata_db(tmp_path): + sandbox_metadata_db = SandboxMetadataDB("sqlite+aiosqlite:///:memory:") + async with sandbox_metadata_db.engine_context(): + async with sandbox_metadata_db.engine.begin() as conn: + await conn.run_sync(sandbox_metadata_db.metadata.create_all) + yield sandbox_metadata_db + + +def test_get_pfn(sandbox_metadata_db: SandboxMetadataDB): + user_info = UserInfo( + sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo" + ) + sandbox_info = SandboxInfo( + checksum="checksum", + checksum_algorithm="sha256", + format="tar.bz2", + size=100, + ) + pfn = sandbox_metadata_db.get_pfn("bucket1", user_info, sandbox_info) + assert pfn == "/S3/bucket1/vo/group1/user1/sha256:checksum.tar.bz2" + + +async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB): + user_info = UserInfo( + sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo" + ) + pfn1 = secrets.token_hex() + + # Make sure the sandbox doesn't already exist + db_contents = await _dump_db(sandbox_metadata_db) + assert pfn1 not in db_contents + async with sandbox_metadata_db: + with pytest.raises(sqlalchemy.exc.NoResultFound): + await sandbox_metadata_db.sandbox_is_assigned(pfn1) + + # Insert the sandbox + async with sandbox_metadata_db: + await sandbox_metadata_db.insert_sandbox(user_info, pfn1, 100) + db_contents = await _dump_db(sandbox_metadata_db) + owner_id1, last_access_time1 = db_contents[pfn1] + + # Inserting again should update the last access time + await asyncio.sleep(1) # The timestamp only has second precision + async with sandbox_metadata_db: + await sandbox_metadata_db.insert_sandbox(user_info, pfn1, 100) + db_contents = await _dump_db(sandbox_metadata_db) + owner_id2, last_access_time2 = db_contents[pfn1] + assert owner_id1 == owner_id2 + assert last_access_time2 > last_access_time1 + + # The sandbox still hasn't been assigned + async with sandbox_metadata_db: + assert not await sandbox_metadata_db.sandbox_is_assigned(pfn1) + + # Inserting again should update the last access time + await asyncio.sleep(1) # The timestamp only has second precision + last_access_time3 = (await _dump_db(sandbox_metadata_db))[pfn1][1] + assert last_access_time2 == last_access_time3 + async with sandbox_metadata_db: + await sandbox_metadata_db.update_sandbox_last_access_time(pfn1) + last_access_time4 = (await _dump_db(sandbox_metadata_db))[pfn1][1] + assert last_access_time2 < last_access_time4 + + +async def _dump_db( + sandbox_metadata_db: SandboxMetadataDB, +) -> dict[str, tuple[int, datetime]]: + """Dump the contents of the sandbox metadata database + + Returns a dict[pfn: str, (owner_id: int, last_access_time: datetime)] + """ + async with sandbox_metadata_db: + stmt = sqlalchemy.select( + sb_SandBoxes.SEPFN, sb_SandBoxes.OwnerId, sb_SandBoxes.LastAccessTime + ) + res = await sandbox_metadata_db.conn.execute(stmt) + return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res} From a3b0ebbfe4b749bfb8c68e9736308905ca9c851a Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Wed, 27 Sep 2023 11:59:27 +0200 Subject: [PATCH 4/7] Add routes for uploading/downloading sandboxes --- src/diracx/routers/dependencies.py | 4 + src/diracx/routers/job_manager/__init__.py | 2 + src/diracx/routers/job_manager/sandboxes.py | 176 ++++++++++++++++++++ tests/conftest.py | 15 +- tests/routers/jobs/__init__.py | 0 tests/routers/jobs/test_sandboxes.py | 63 +++++++ 6 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 src/diracx/routers/job_manager/sandboxes.py create mode 100644 tests/routers/jobs/__init__.py create mode 100644 tests/routers/jobs/test_sandboxes.py diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index 3ab40361..deb55167 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -19,6 +19,7 @@ from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB +from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB T = TypeVar("T") @@ -32,6 +33,9 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +SandboxMetadataDB = Annotated[ + _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) +] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index df36e15b..39592c52 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -29,12 +29,14 @@ from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token from ..dependencies import JobDB, JobLoggingDB from ..fastapi_classes import DiracxRouter +from .sandboxes import router as sandboxes_router MAX_PARAMETRIC_JOBS = 20 logger = logging.getLogger(__name__) router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) +router.include_router(sandboxes_router) class JobSummaryParams(BaseModel): diff --git a/src/diracx/routers/job_manager/sandboxes.py b/src/diracx/routers/job_manager/sandboxes.py new file mode 100644 index 00000000..b47af93f --- /dev/null +++ b/src/diracx/routers/job_manager/sandboxes.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import TYPE_CHECKING, Annotated + +import botocore.session +from botocore.config import Config +from botocore.errorfactory import ClientError +from fastapi import Depends, HTTPException, Query +from pydantic import BaseModel, PrivateAttr +from sqlalchemy.exc import NoResultFound + +from diracx.core.models import ( + SandboxInfo, +) +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER +from diracx.core.s3 import ( + generate_presigned_upload, + s3_bucket_exists, + s3_object_exists, +) +from diracx.core.settings import ServiceSettingsBase + +if TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client + +from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token +from ..dependencies import SandboxMetadataDB, add_settings_annotation +from ..fastapi_classes import DiracxRouter + +MAX_SANDBOX_SIZE_BYTES = 100 * 1024 * 1024 +router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) + + +@add_settings_annotation +class SandboxStoreSettings(ServiceSettingsBase, env_prefix="DIRACX_SANDBOX_STORE_"): + """Settings for the sandbox store.""" + + bucket_name: str + s3_client_kwargs: dict[str, str] + auto_create_bucket: bool = False + url_validity_seconds: int = 5 * 60 + _client: S3Client = PrivateAttr(None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # TODO: Use async + session = botocore.session.get_session() + self._client = session.create_client( + "s3", + # endpoint_url=s3_cred["endpoint"], + # aws_access_key_id=s3_cred["access_key_id"], + # aws_secret_access_key=s3_cred["secret_access_key"], + **self.s3_client_kwargs, + config=Config(signature_version="v4"), + ) + if not s3_bucket_exists(self._client, self.bucket_name): + if not self.auto_create_bucket: + raise ValueError( + f"Bucket {self.bucket_name} does not exist and auto_create_bucket is disabled" + ) + try: + self._client.create_bucket(Bucket=self.bucket_name) + except ClientError as e: + raise ValueError(f"Failed to create bucket {self.bucket_name}") from e + + @property + def s3_client(self) -> S3Client: + return self._client + + +class SandboxUploadResponse(BaseModel): + pfn: str + url: str | None = None + fields: dict[str, str] = {} + + +@router.post("/sandbox") +async def initiate_sandbox_upload( + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + sandbox_info: SandboxInfo, + sandbox_metadata_db: SandboxMetadataDB, + settings: SandboxStoreSettings, +) -> SandboxUploadResponse: + """Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + """ + if sandbox_info.size > MAX_SANDBOX_SIZE_BYTES: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Sandbox too large. Max size is {MAX_SANDBOX_SIZE_BYTES} bytes", + ) + + pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info) + + try: + exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned(pfn) + except NoResultFound: + # The sandbox doesn't exist in the database + pass + else: + # As sandboxes are registered in the DB before uploading to the storage + # backend we can't on their existence in the database to determine if + # they have been uploaded. Instead we check if the sandbox has been + # assigned to a job. If it has then we know it has been uploaded and we + # can avoid communicating with the storage backend. + if exists_and_assigned or s3_object_exists( + settings.s3_client, settings.bucket_name, pfn_to_key(pfn) + ): + await sandbox_metadata_db.update_sandbox_last_access_time(pfn) + return SandboxUploadResponse(pfn=pfn) + + upload_info = generate_presigned_upload( + settings.s3_client, + settings.bucket_name, + pfn_to_key(pfn), + sandbox_info.checksum_algorithm, + sandbox_info.checksum, + sandbox_info.size, + settings.url_validity_seconds, + ) + await sandbox_metadata_db.insert_sandbox(user_info, pfn, sandbox_info.size) + + return SandboxUploadResponse(**upload_info, pfn=pfn) + + +class SandboxDownloadResponse(BaseModel): + url: str + expires_in: int + + +def pfn_to_key(pfn: str) -> str: + """Convert a PFN to a key for S3 + + This removes the leading "/S3/" from the PFN. + """ + return "/".join(pfn.split("/")[3:]) + + +SANDBOX_PFN_REGEX = ( + # Starts with /S3/ + r"^/S3/[a-z0-9\.\-]{3,63}" + # Followed ////:. + r"(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\.[a-z0-9\.]+$" +) + + +@router.get("/sandbox") +async def get_sandbox_file( + pfn: Annotated[str, Query(max_length=256, pattern=SANDBOX_PFN_REGEX)], + settings: SandboxStoreSettings, +) -> SandboxDownloadResponse: + """Get a presigned URL to download a sandbox file + + This route cannot use a redirect response most clients will also send the + authorization header when following a redirect. This is not desirable as + it would leak the authorization token to the storage backend. Additionally, + most storage backends return an error when they receive an authorization + header for a presigned URL. + """ + # TODO: Prevent people from downloading other people's sandboxes? + # TODO: Support by name and by job id? + presigned_url = settings.s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": settings.bucket_name, "Key": pfn_to_key(pfn)}, + ExpiresIn=settings.url_validity_seconds, + ) + return SandboxDownloadResponse( + url=presigned_url, expires_in=settings.url_validity_seconds + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9b4d224c..60ec907f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,12 +14,14 @@ from cryptography.hazmat.primitives.asymmetric import rsa from fastapi.testclient import TestClient from git import Repo +from moto import mock_s3 from diracx.core.config import Config, ConfigSource from diracx.core.preferences import get_diracx_preferences from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.routers import create_app_inner from diracx.routers.auth import AuthSettings, create_token +from diracx.routers.job_manager.sandboxes import SandboxStoreSettings # to get a string like this run: # openssl rand -hex 32 @@ -78,8 +80,18 @@ def test_auth_settings() -> AuthSettings: ) +@pytest.fixture(scope="function") +def test_sandbox_settings() -> SandboxStoreSettings: + with mock_s3(): + yield SandboxStoreSettings( + bucket_name="sandboxes", + s3_client_kwargs={}, + auto_create_bucket=True, + ) + + @pytest.fixture -def with_app(test_auth_settings, with_config_repo): +def with_app(test_auth_settings, test_sandbox_settings, with_config_repo): """ Create a DiracxApp with hard coded configuration for test """ @@ -87,6 +99,7 @@ def with_app(test_auth_settings, with_config_repo): enabled_systems={".well-known", "auth", "config", "jobs"}, all_service_settings=[ test_auth_settings, + test_sandbox_settings, ], database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", diff --git a/tests/routers/jobs/__init__.py b/tests/routers/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/routers/jobs/test_sandboxes.py b/tests/routers/jobs/test_sandboxes.py new file mode 100644 index 00000000..a683c2b7 --- /dev/null +++ b/tests/routers/jobs/test_sandboxes.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import hashlib +import secrets +from io import BytesIO + +import requests +from fastapi.testclient import TestClient + + +def test_upload_then_download(normal_user_client: TestClient): + data = secrets.token_bytes(512) + checksum = hashlib.sha256(data).hexdigest() + + # Initiate the upload + r = normal_user_client.post( + "/jobs/sandbox", + json={ + "checksum_algorithm": "sha256", + "checksum": checksum, + "size": len(data), + "format": "tar.bz2", + }, + ) + assert r.status_code == 200, r.text + upload_info = r.json() + assert upload_info["url"] + sandbox_pfn = upload_info["pfn"] + assert sandbox_pfn.startswith("/S3/") + + # Actually upload the file + files = {"file": ("file", BytesIO(data))} + r = requests.post(upload_info["url"], data=upload_info["fields"], files=files) + assert r.status_code == 204, r.text + + # Make sure we can download it and get the same data back + r = normal_user_client.get("/jobs/sandbox", params={"pfn": sandbox_pfn}) + assert r.status_code == 200, r.text + download_info = r.json() + assert download_info["expires_in"] > 5 + r = requests.get(download_info["url"]) + assert r.status_code == 200, r.text + assert r.content == data + + +def test_upload_oversized(normal_user_client: TestClient): + data = secrets.token_bytes(512) + checksum = hashlib.sha256(data).hexdigest() + + # Initiate the upload + r = normal_user_client.post( + "/jobs/sandbox", + json={ + "checksum_algorithm": "sha256", + "checksum": checksum, + # We can forge the size here to be larger than the actual data as + # we should get an error and never actually upload the data + "size": 1024 * 1024 * 1024, + "format": "tar.bz2", + }, + ) + assert r.status_code == 400, r.text + assert "Sandbox too large" in r.json()["detail"], r.text From ad27b5188b10128d29462a4c97a166511c9aaeea Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 26 Sep 2023 17:54:01 +0200 Subject: [PATCH 5/7] Regenerate client --- .../client/aio/operations/_operations.py | 194 ++++++++++++++ src/diracx/client/models/__init__.py | 18 +- src/diracx/client/models/_enums.py | 36 ++- src/diracx/client/models/_models.py | 135 +++++++++- src/diracx/client/operations/_operations.py | 241 ++++++++++++++++++ 5 files changed, 607 insertions(+), 17 deletions(-) diff --git a/src/diracx/client/aio/operations/_operations.py b/src/diracx/client/aio/operations/_operations.py index 9c3ec627..a1c87d68 100644 --- a/src/diracx/client/aio/operations/_operations.py +++ b/src/diracx/client/aio/operations/_operations.py @@ -37,9 +37,11 @@ build_jobs_delete_bulk_jobs_request, build_jobs_get_job_status_bulk_request, build_jobs_get_job_status_history_bulk_request, + build_jobs_get_sandbox_file_request, build_jobs_get_single_job_request, build_jobs_get_single_job_status_history_request, build_jobs_get_single_job_status_request, + build_jobs_initiate_sandbox_upload_request, build_jobs_kill_bulk_jobs_request, build_jobs_search_request, build_jobs_set_job_status_bulk_request, @@ -819,6 +821,198 @@ def __init__(self, *args, **kwargs) -> None: input_args.pop(0) if input_args else kwargs.pop("deserializer") ) + @distributed_trace_async + async def get_sandbox_file( + self, *, pfn: str, **kwargs: Any + ) -> _models.SandboxDownloadResponse: + """Get Sandbox File. + + Get a presigned URL to download a sandbox file + + This route cannot use a redirect response most clients will also send the + authorization header when following a redirect. This is not desirable as + it would leak the authorization token to the storage backend. Additionally, + most storage backends return an error when they receive an authorization + header for a presigned URL. + + :keyword pfn: Required. + :paramtype pfn: str + :return: SandboxDownloadResponse + :rtype: ~client.models.SandboxDownloadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.SandboxDownloadResponse] = kwargs.pop("cls", None) + + request = build_jobs_get_sandbox_file_request( + pfn=pfn, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxDownloadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @overload + async def initiate_sandbox_upload( + self, + body: _models.SandboxInfo, + *, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: ~client.models.SandboxInfo + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def initiate_sandbox_upload( + self, body: IO, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: IO + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def initiate_sandbox_upload( + self, body: Union[_models.SandboxInfo, IO], **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Is either a SandboxInfo type or a IO type. Required. + :type body: ~client.models.SandboxInfo or IO + :keyword content_type: Body Parameter content-type. Known values are: 'application/json'. + Default value is None. + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + cls: ClsType[_models.SandboxUploadResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SandboxInfo") + + request = build_jobs_initiate_sandbox_upload_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxUploadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @overload async def submit_bulk_jobs( self, body: List[str], *, content_type: str = "application/json", **kwargs: Any diff --git a/src/diracx/client/models/__init__.py b/src/diracx/client/models/__init__.py index 893d7c15..205f5098 100644 --- a/src/diracx/client/models/__init__.py +++ b/src/diracx/client/models/__init__.py @@ -16,6 +16,9 @@ from ._models import JobSummaryParams from ._models import JobSummaryParamsSearchItem from ._models import LimitedJobStatusReturn +from ._models import SandboxDownloadResponse +from ._models import SandboxInfo +from ._models import SandboxUploadResponse from ._models import ScalarSearchSpec from ._models import SetJobStatusReturn from ._models import SortSpec @@ -26,14 +29,16 @@ from ._models import ValidationErrorLocItem from ._models import VectorSearchSpec +from ._enums import ChecksumAlgorithm from ._enums import Enum0 from ._enums import Enum1 +from ._enums import Enum10 +from ._enums import Enum11 from ._enums import Enum2 from ._enums import Enum3 from ._enums import Enum4 -from ._enums import Enum8 -from ._enums import Enum9 from ._enums import JobStatus +from ._enums import SandboxFormat from ._enums import ScalarSearchOperator from ._enums import VectorSearchOperator from ._patch import __all__ as _patch_all @@ -53,6 +58,9 @@ "JobSummaryParams", "JobSummaryParamsSearchItem", "LimitedJobStatusReturn", + "SandboxDownloadResponse", + "SandboxInfo", + "SandboxUploadResponse", "ScalarSearchSpec", "SetJobStatusReturn", "SortSpec", @@ -62,14 +70,16 @@ "ValidationError", "ValidationErrorLocItem", "VectorSearchSpec", + "ChecksumAlgorithm", "Enum0", "Enum1", + "Enum10", + "Enum11", "Enum2", "Enum3", "Enum4", - "Enum8", - "Enum9", "JobStatus", + "SandboxFormat", "ScalarSearchOperator", "VectorSearchOperator", ] diff --git a/src/diracx/client/models/_enums.py b/src/diracx/client/models/_enums.py index ccabc77c..60628b59 100644 --- a/src/diracx/client/models/_enums.py +++ b/src/diracx/client/models/_enums.py @@ -8,6 +8,12 @@ from azure.core import CaseInsensitiveEnumMeta +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """An enumeration.""" + + SHA256 = "sha256" + + class Enum0(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Enum0.""" @@ -22,6 +28,18 @@ class Enum1(str, Enum, metaclass=CaseInsensitiveEnumMeta): ) +class Enum10(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Enum10.""" + + ASC = "asc" + + +class Enum11(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Enum11.""" + + DSC = "dsc" + + class Enum2(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Enum2.""" @@ -40,18 +58,6 @@ class Enum4(str, Enum, metaclass=CaseInsensitiveEnumMeta): S256 = "S256" -class Enum8(str, Enum, metaclass=CaseInsensitiveEnumMeta): - """Enum8.""" - - ASC = "asc" - - -class Enum9(str, Enum, metaclass=CaseInsensitiveEnumMeta): - """Enum9.""" - - DSC = "dsc" - - class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): """An enumeration.""" @@ -72,6 +78,12 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """An enumeration.""" + + TAR_BZ2 = "tar.bz2" + + class ScalarSearchOperator(str, Enum, metaclass=CaseInsensitiveEnumMeta): """An enumeration.""" diff --git a/src/diracx/client/models/_models.py b/src/diracx/client/models/_models.py index 5578ad22..716f06f7 100644 --- a/src/diracx/client/models/_models.py +++ b/src/diracx/client/models/_models.py @@ -6,7 +6,7 @@ # -------------------------------------------------------------------------- import datetime -from typing import Any, List, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union from .. import _serialization @@ -509,6 +509,139 @@ def __init__( self.application_status = application_status +class SandboxDownloadResponse(_serialization.Model): + """SandboxDownloadResponse. + + All required parameters must be populated in order to send to Azure. + + :ivar url: Url. Required. + :vartype url: str + :ivar expires_in: Expires In. Required. + :vartype expires_in: int + """ + + _validation = { + "url": {"required": True}, + "expires_in": {"required": True}, + } + + _attribute_map = { + "url": {"key": "url", "type": "str"}, + "expires_in": {"key": "expires_in", "type": "int"}, + } + + def __init__(self, *, url: str, expires_in: int, **kwargs: Any) -> None: + """ + :keyword url: Url. Required. + :paramtype url: str + :keyword expires_in: Expires In. Required. + :paramtype expires_in: int + """ + super().__init__(**kwargs) + self.url = url + self.expires_in = expires_in + + +class SandboxInfo(_serialization.Model): + """SandboxInfo. + + All required parameters must be populated in order to send to Azure. + + :ivar checksum_algorithm: An enumeration. Required. "sha256" + :vartype checksum_algorithm: str or ~client.models.ChecksumAlgorithm + :ivar checksum: Checksum. Required. + :vartype checksum: str + :ivar size: Size. Required. + :vartype size: int + :ivar format: An enumeration. Required. "tar.bz2" + :vartype format: str or ~client.models.SandboxFormat + """ + + _validation = { + "checksum_algorithm": {"required": True}, + "checksum": {"required": True, "pattern": r"^[0-f]{64}$"}, + "size": {"required": True, "minimum": 1}, + "format": {"required": True}, + } + + _attribute_map = { + "checksum_algorithm": {"key": "checksum_algorithm", "type": "str"}, + "checksum": {"key": "checksum", "type": "str"}, + "size": {"key": "size", "type": "int"}, + "format": {"key": "format", "type": "str"}, + } + + def __init__( + self, + *, + checksum_algorithm: Union[str, "_models.ChecksumAlgorithm"], + checksum: str, + size: int, + format: Union[str, "_models.SandboxFormat"], + **kwargs: Any + ) -> None: + """ + :keyword checksum_algorithm: An enumeration. Required. "sha256" + :paramtype checksum_algorithm: str or ~client.models.ChecksumAlgorithm + :keyword checksum: Checksum. Required. + :paramtype checksum: str + :keyword size: Size. Required. + :paramtype size: int + :keyword format: An enumeration. Required. "tar.bz2" + :paramtype format: str or ~client.models.SandboxFormat + """ + super().__init__(**kwargs) + self.checksum_algorithm = checksum_algorithm + self.checksum = checksum + self.size = size + self.format = format + + +class SandboxUploadResponse(_serialization.Model): + """SandboxUploadResponse. + + All required parameters must be populated in order to send to Azure. + + :ivar pfn: Pfn. Required. + :vartype pfn: str + :ivar url: Url. + :vartype url: str + :ivar fields: Fields. + :vartype fields: dict[str, str] + """ + + _validation = { + "pfn": {"required": True}, + } + + _attribute_map = { + "pfn": {"key": "pfn", "type": "str"}, + "url": {"key": "url", "type": "str"}, + "fields": {"key": "fields", "type": "{str}"}, + } + + def __init__( + self, + *, + pfn: str, + url: Optional[str] = None, + fields: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pfn: Pfn. Required. + :paramtype pfn: str + :keyword url: Url. + :paramtype url: str + :keyword fields: Fields. + :paramtype fields: dict[str, str] + """ + super().__init__(**kwargs) + self.pfn = pfn + self.url = url + self.fields = fields + + class ScalarSearchSpec(_serialization.Model): """ScalarSearchSpec. diff --git a/src/diracx/client/operations/_operations.py b/src/diracx/client/operations/_operations.py index 72dea1c1..bb2d7807 100644 --- a/src/diracx/client/operations/_operations.py +++ b/src/diracx/client/operations/_operations.py @@ -280,6 +280,55 @@ def build_config_serve_config_request( return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_jobs_get_sandbox_file_request(*, pfn: str, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/jobs/sandbox" + + # Construct parameters + _params["pfn"] = _SERIALIZER.query( + "pfn", + pfn, + "str", + max_length=256, + pattern=r"^/S3/[a-z0-9\.\-]{3,63}(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\.[a-z0-9\.]+$", + ) + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest( + method="GET", url=_url, params=_params, headers=_headers, **kwargs + ) + + +def build_jobs_initiate_sandbox_upload_request( + **kwargs: Any, +) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/jobs/sandbox" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header( + "content_type", content_type, "str" + ) + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + def build_jobs_submit_bulk_jobs_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) @@ -1324,6 +1373,198 @@ def __init__(self, *args, **kwargs): input_args.pop(0) if input_args else kwargs.pop("deserializer") ) + @distributed_trace + def get_sandbox_file( + self, *, pfn: str, **kwargs: Any + ) -> _models.SandboxDownloadResponse: + """Get Sandbox File. + + Get a presigned URL to download a sandbox file + + This route cannot use a redirect response most clients will also send the + authorization header when following a redirect. This is not desirable as + it would leak the authorization token to the storage backend. Additionally, + most storage backends return an error when they receive an authorization + header for a presigned URL. + + :keyword pfn: Required. + :paramtype pfn: str + :return: SandboxDownloadResponse + :rtype: ~client.models.SandboxDownloadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.SandboxDownloadResponse] = kwargs.pop("cls", None) + + request = build_jobs_get_sandbox_file_request( + pfn=pfn, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxDownloadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @overload + def initiate_sandbox_upload( + self, + body: _models.SandboxInfo, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: ~client.models.SandboxInfo + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def initiate_sandbox_upload( + self, body: IO, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: IO + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def initiate_sandbox_upload( + self, body: Union[_models.SandboxInfo, IO], **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Is either a SandboxInfo type or a IO type. Required. + :type body: ~client.models.SandboxInfo or IO + :keyword content_type: Body Parameter content-type. Known values are: 'application/json'. + Default value is None. + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + cls: ClsType[_models.SandboxUploadResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SandboxInfo") + + request = build_jobs_initiate_sandbox_upload_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxUploadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @overload def submit_bulk_jobs( self, body: List[str], *, content_type: str = "application/json", **kwargs: Any From 87b0997db36b1aed1f2a3c1e6ab493d58d7e5e44 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Wed, 27 Sep 2023 09:05:28 +0200 Subject: [PATCH 6/7] Add API for uploading/downloading sandboxes --- src/diracx/api/__init__.py | 5 +++ src/diracx/api/jobs.py | 92 ++++++++++++++++++++++++++++++++++++++ src/diracx/api/utils.py | 25 +++++++++++ src/diracx/core/utils.py | 2 +- tests/api/__init__.py | 0 tests/api/test_jobs.py | 47 +++++++++++++++++++ tests/api/test_utils.py | 24 ++++++++++ 7 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 src/diracx/api/jobs.py create mode 100644 src/diracx/api/utils.py create mode 100644 tests/api/__init__.py create mode 100644 tests/api/test_jobs.py create mode 100644 tests/api/test_utils.py diff --git a/src/diracx/api/__init__.py b/src/diracx/api/__init__.py index e69de29b..329dd573 100644 --- a/src/diracx/api/__init__.py +++ b/src/diracx/api/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +__all__ = ("jobs",) + +from . import jobs diff --git a/src/diracx/api/jobs.py b/src/diracx/api/jobs.py new file mode 100644 index 00000000..a5bd5e45 --- /dev/null +++ b/src/diracx/api/jobs.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +__all__ = ("create_sandbox", "download_sandbox") + +import hashlib +import logging +import os +import tarfile +import tempfile +from pathlib import Path + +import httpx + +from diracx.client.aio import DiracClient +from diracx.client.models import SandboxInfo + +from .utils import with_client + +logger = logging.getLogger(__name__) + +SANDBOX_CHECKSUM_ALGORITHM = "sha256" +SANDBOX_COMPRESSION = "bz2" + + +@with_client +async def create_sandbox(paths: list[Path], *, client: DiracClient) -> str: + """Create a sandbox from the given paths and upload it to the storage backend. + + Any paths that are directories will be added recursively. + The returned value is the PFN of the sandbox in the storage backend and can + be used to submit jobs. + """ + with tempfile.TemporaryFile(mode="w+b") as tar_fh: + with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf: + for path in paths: + logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name) + tf.add(path.resolve(), path.name, recursive=True) + tar_fh.seek(0) + + hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)() + while data := tar_fh.read(512 * 1024): + hasher.update(data) + checksum = hasher.hexdigest() + tar_fh.seek(0) + logger.debug("Sandbox checksum is %s", checksum) + + sandbox_info = SandboxInfo( + checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM, + checksum=checksum, + size=os.stat(tar_fh.fileno()).st_size, + format=f"tar.{SANDBOX_COMPRESSION}", + ) + + res = await client.jobs.initiate_sandbox_upload(sandbox_info) + if res.url: + logger.debug("Uploading sandbox for %s", res.pfn) + files = {"file": ("file", tar_fh)} + async with httpx.AsyncClient() as httpx_client: + response = await httpx_client.post( + res.url, data=res.fields, files=files + ) + # TODO: Handle this error better + response.raise_for_status() + + logger.debug( + "Sandbox uploaded for %s with status code %s", + res.pfn, + response.status_code, + ) + else: + logger.debug("%s already exists in storage backend", res.pfn) + return res.pfn + + +@with_client +async def download_sandbox(pfn: str, destination: Path, *, client: DiracClient): + """Download a sandbox from the storage backend to the given destination.""" + res = await client.jobs.get_sandbox_file(pfn=pfn) + logger.debug("Downloading sandbox for %s", pfn) + with tempfile.TemporaryFile(mode="w+b") as fh: + async with httpx.AsyncClient() as http_client: + response = await http_client.get(res.url) + # TODO: Handle this error better + response.raise_for_status() + async for chunk in response.aiter_bytes(): + fh.write(chunk) + fh.seek(0) + logger.debug("Sandbox downloaded for %s", pfn) + + with tarfile.open(fileobj=fh) as tf: + tf.extractall(path=destination, filter="data") + logger.debug("Extracted %s to %s", pfn, destination) diff --git a/src/diracx/api/utils.py b/src/diracx/api/utils.py new file mode 100644 index 00000000..b53338da --- /dev/null +++ b/src/diracx/api/utils.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +__all__ = ("with_client",) + +from functools import wraps + +from diracx.client.aio import DiracClient + + +def with_client(func): + """Decorator to provide a DiracClient to a function. + + If the function already has a `client` keyword argument, it will be used. + Otherwise, a new DiracClient will be created and passed as a keyword argument. + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + if "client" in kwargs: + return await func(*args, **kwargs) + + async with DiracClient() as client: + return await func(*args, **kwargs, client=client) + + return wrapper diff --git a/src/diracx/core/utils.py b/src/diracx/core/utils.py index 2e81ea81..39fb23de 100644 --- a/src/diracx/core/utils.py +++ b/src/diracx/core/utils.py @@ -20,7 +20,7 @@ def dotenv_files_from_environment(prefix: str) -> list[str]: return [v for _, v in sorted(env_files.items())] -def write_credentials(token_response: TokenResponse, location: Path | None = None): +def write_credentials(token_response: TokenResponse, *, location: Path | None = None): """Write credentials received in dirax_preferences.credentials_path""" from diracx.core.preferences import get_diracx_preferences diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/test_jobs.py b/tests/api/test_jobs.py new file mode 100644 index 00000000..cf47595c --- /dev/null +++ b/tests/api/test_jobs.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import logging +import secrets + +from diracx.api.jobs import create_sandbox, download_sandbox + + +async def test_upload_download_sandbox(tmp_path, with_cli_login, caplog): + caplog.set_level(logging.DEBUG) + + input_directory = tmp_path / "input" + input_directory.mkdir() + input_files = [] + + input_file = input_directory / "input.dat" + input_file.write_bytes(secrets.token_bytes(512)) + input_files.append(input_file) + + input_file = input_directory / "a" / "b" / "c" / "nested.dat" + input_file.parent.mkdir(parents=True) + input_file.write_bytes(secrets.token_bytes(512)) + input_files.append(input_file) + + # Upload the sandbox + caplog.clear() + pfn = await create_sandbox(input_files) + assert has_record(caplog.records, "diracx.api.jobs", "Uploading sandbox for") + + # Uploading the same sandbox again should return the same PFN + caplog.clear() + pfn2 = await create_sandbox(input_files) + assert pfn == pfn2 + assert has_record(caplog.records, "diracx.api.jobs", "already exists in storage") + + # Download the sandbox + destination = tmp_path / "output" + await download_sandbox(pfn, destination) + assert (destination / "input.dat").is_file() + assert (destination / "nested.dat").is_file() + + +def has_record(records: list[logging.LogRecord], logger_name: str, message: str): + for record in records: + if record.name == logger_name and message in record.message: + return True + return False diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py new file mode 100644 index 00000000..a9778b45 --- /dev/null +++ b/tests/api/test_utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from diracx.api.utils import with_client +from diracx.client.aio import DiracClient + + +async def test_with_client_default(with_cli_login): + """Ensure that the with_client decorator provides a DiracClient.""" + + @with_client + async def test_func(*, client): + assert isinstance(client, DiracClient) + + await test_func() + + +async def test_with_client_override(): + """Ensure that the with_client can be overridden by providing a client kwarg.""" + + @with_client + async def test_func(*, client): + assert client == "foobar" + + await test_func(client="foobar") From 1d6f1084a2e54f96c45c797176d716db30afc6a0 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Mon, 2 Oct 2023 17:41:06 +0200 Subject: [PATCH 7/7] Add lifetime_function method to ServiceSettingsBase --- src/diracx/core/settings.py | 8 +++++++- src/diracx/routers/__init__.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diracx/core/settings.py b/src/diracx/core/settings.py index d98d0d9c..dcc36e61 100644 --- a/src/diracx/core/settings.py +++ b/src/diracx/core/settings.py @@ -6,8 +6,9 @@ "ServiceSettingsBase", ) +import contextlib from pathlib import Path -from typing import TYPE_CHECKING, Any, Self, TypeVar +from typing import TYPE_CHECKING, Any, AsyncIterator, Self, TypeVar from authlib.jose import JsonWebKey from pydantic import AnyUrl, BaseSettings, SecretStr, parse_obj_as @@ -58,3 +59,8 @@ class ServiceSettingsBase(BaseSettings, allow_mutation=False): @classmethod def create(cls) -> Self: raise NotImplementedError("This should never be called") + + @contextlib.asynccontextmanager + async def lifetime_function(self) -> AsyncIterator[None]: + """A context manager that can be used to run code at startup and shutdown.""" + yield diff --git a/src/diracx/routers/__init__.py b/src/diracx/routers/__init__.py index c3b49239..437833d5 100644 --- a/src/diracx/routers/__init__.py +++ b/src/diracx/routers/__init__.py @@ -52,6 +52,7 @@ def create_app_inner( cls = type(service_settings) assert cls not in available_settings_classes available_settings_classes.add(cls) + app.lifetime_functions.append(service_settings.lifetime_function) app.dependency_overrides[cls.create] = partial(lambda x: x, service_settings) # Override the configuration source