From 3d69d65e8842a8bf4cc9f6c14790e5c014af0eea Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Mon, 2 Oct 2023 15:44:39 +0200 Subject: [PATCH 1/8] Update SandboxMetadataDB interface --- src/diracx/core/models.py | 11 +++ src/diracx/db/sql/sandbox_metadata/db.py | 108 ++++++++++++----------- tests/db/test_sandboxMetadataDB.py | 84 ------------------ tests/db/test_sandbox_metadata.py | 92 +++++++++++++++++++ 4 files changed, 161 insertions(+), 134 deletions(-) delete mode 100644 tests/db/test_sandboxMetadataDB.py create mode 100644 tests/db/test_sandbox_metadata.py diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index 930b3d5f..21fc53a9 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -116,3 +116,14 @@ class UserInfo(BaseModel): class ChecksumAlgorithm(StrEnum): SHA256 = "sha256" + + +class SandboxFormat(StrEnum): + TAR_BZ2 = "tar.bz2" + + +class SandboxInfo(BaseModel): + checksum_algorithm: ChecksumAlgorithm + checksum: str = Field(pattern=r"^[0-f]{64}$") + size: int = Field(ge=1) + format: SandboxFormat diff --git a/src/diracx/db/sql/sandbox_metadata/db.py b/src/diracx/db/sql/sandbox_metadata/db.py index 6900f58a..a95dd93f 100644 --- a/src/diracx/db/sql/sandbox_metadata/db.py +++ b/src/diracx/db/sql/sandbox_metadata/db.py @@ -1,80 +1,88 @@ -""" SandboxMetadataDB frontend -""" - from __future__ import annotations -import datetime - import sqlalchemy -from diracx.db.sql.utils import BaseSQLDB +from diracx.core.models import SandboxInfo, UserInfo +from diracx.db.sql.utils import BaseSQLDB, utcnow from .schema import Base as SandboxMetadataDBBase from .schema import sb_Owners, sb_SandBoxes +# In legacy DIRAC the SEName column was used to support multiple different +# storage backends. This is no longer the case, so we hardcode the value to +# S3 to represent the new DiracX system. +SE_NAME = "ProductionSandboxSE" +PFN_PREFIX = "/S3/" + class SandboxMetadataDB(BaseSQLDB): metadata = SandboxMetadataDBBase.metadata - async def _get_put_owner(self, owner: str, owner_group: str) -> int: - """adds a new owner/ownerGroup pairs, while returning their ID if already existing - - Args: - owner (str): user name - owner_group (str): group of the owner - """ + async def upsert_owner(self, user: UserInfo) -> int: + """Get the id of the owner from the database""" + # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 stmt = sqlalchemy.select(sb_Owners.OwnerID).where( - sb_Owners.Owner == owner, sb_Owners.OwnerGroup == owner_group + sb_Owners.Owner == user.preferred_username, + sb_Owners.OwnerGroup == user.dirac_group, + # TODO: Add VO ) result = await self.conn.execute(stmt) if owner_id := result.scalar_one_or_none(): return owner_id - stmt = sqlalchemy.insert(sb_Owners).values(Owner=owner, OwnerGroup=owner_group) + stmt = sqlalchemy.insert(sb_Owners).values( + Owner=user.preferred_username, + OwnerGroup=user.dirac_group, + ) result = await self.conn.execute(stmt) return result.lastrowid - async def insert( - self, owner: str, owner_group: str, sb_SE: str, se_PFN: str, size: int = 0 - ) -> tuple[int, bool]: - """inserts a new sandbox in SandboxMetadataDB - this is "equivalent" of DIRAC registerAndGetSandbox + @staticmethod + def get_pfn(bucket_name: str, user: UserInfo, sandbox_info: SandboxInfo) -> str: + """Get the sandbox's user namespaced and content addressed PFN""" + parts = [ + "S3", + bucket_name, + user.vo, + user.dirac_group, + user.preferred_username, + f"{sandbox_info.checksum_algorithm}:{sandbox_info.checksum}.{sandbox_info.format}", + ] + return "/" + "/".join(parts) - Args: - owner (str): user name_ - owner_group (str): groupd of the owner - sb_SE (str): _description_ - sb_PFN (str): _description_ - size (int, optional): _description_. Defaults to 0. - """ - owner_id = await self._get_put_owner(owner, owner_group) + async def insert_sandbox(self, user: UserInfo, pfn: str, size: int) -> None: + """Add a new sandbox in SandboxMetadataDB""" + # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 + owner_id = await self.upsert_owner(user) stmt = sqlalchemy.insert(sb_SandBoxes).values( - OwnerId=owner_id, SEName=sb_SE, SEPFN=se_PFN, Bytes=size + OwnerId=owner_id, + SEName=SE_NAME, + SEPFN=pfn, + Bytes=size, + RegistrationTime=utcnow(), + LastAccessTime=utcnow(), ) try: result = await self.conn.execute(stmt) - return result.lastrowid except sqlalchemy.exc.IntegrityError: - # it is a duplicate, try to retrieve SBiD - stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.SBId).where( # type: ignore[no-redef] - sb_SandBoxes.SEPFN == se_PFN, - sb_SandBoxes.SEName == sb_SE, - sb_SandBoxes.OwnerId == owner_id, - ) - result = await self.conn.execute(stmt) - sb_ID = result.scalar_one() - stmt: sqlalchemy.Executable = ( # type: ignore[no-redef] - sqlalchemy.update(sb_SandBoxes) - .where(sb_SandBoxes.SBId == sb_ID) - .values(LastAccessTime=datetime.datetime.utcnow()) - ) - await self.conn.execute(stmt) - return sb_ID + await self.update_sandbox_last_access_time(pfn) + else: + assert result.rowcount == 1 - async def delete(self, sandbox_ids: list[int]) -> bool: - stmt: sqlalchemy.Executable = sqlalchemy.delete(sb_SandBoxes).where( - sb_SandBoxes.SBId.in_(sandbox_ids) + async def update_sandbox_last_access_time(self, pfn: str) -> None: + stmt = ( + sqlalchemy.update(sb_SandBoxes) + .where(sb_SandBoxes.SEName == SE_NAME, sb_SandBoxes.SEPFN == pfn) + .values(LastAccessTime=utcnow()) ) - await self.conn.execute(stmt) + result = await self.conn.execute(stmt) + assert result.rowcount == 1 - return True + async def sandbox_is_assigned(self, pfn: str) -> bool: + """Checks if a sandbox exists and has been assigned.""" + stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where( + sb_SandBoxes.SEName == SE_NAME, sb_SandBoxes.SEPFN == pfn + ) + result = await self.conn.execute(stmt) + is_assigned = result.scalar_one() + return is_assigned diff --git a/tests/db/test_sandboxMetadataDB.py b/tests/db/test_sandboxMetadataDB.py deleted file mode 100644 index ffcf91ec..00000000 --- a/tests/db/test_sandboxMetadataDB.py +++ /dev/null @@ -1,84 +0,0 @@ -from __future__ import annotations - -import pytest -import sqlalchemy - -from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB - - -@pytest.fixture -async def sandbox_metadata_db(tmp_path): - sandbox_metadata_db = SandboxMetadataDB("sqlite+aiosqlite:///:memory:") - async with sandbox_metadata_db.engine_context(): - async with sandbox_metadata_db.engine.begin() as conn: - await conn.run_sync(sandbox_metadata_db.metadata.create_all) - yield sandbox_metadata_db - - -async def test__get_put_owner(sandbox_metadata_db): - async with sandbox_metadata_db as sandbox_metadata_db: - result = await sandbox_metadata_db._get_put_owner("owner", "owner_group") - assert result == 1 - result = await sandbox_metadata_db._get_put_owner("owner_2", "owner_group") - assert result == 2 - result = await sandbox_metadata_db._get_put_owner("owner", "owner_group") - assert result == 1 - result = await sandbox_metadata_db._get_put_owner("owner_2", "owner_group") - assert result == 2 - result = await sandbox_metadata_db._get_put_owner("owner_2", "owner_group_2") - assert result == 3 - - -async def test_insert(sandbox_metadata_db): - async with sandbox_metadata_db as sandbox_metadata_db: - result = await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN", - 123, - ) - assert result == 1 - - result = await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN", - 123, - ) - assert result == 1 - - result = await sandbox_metadata_db.insert( - "owner_2", - "owner_group", - "sbSE", - "sbPFN_2", - 123, - ) - assert result == 2 - - # This would be incorrect - with pytest.raises(sqlalchemy.exc.NoResultFound): - await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN_2", - 123, - ) - - -async def test_delete(sandbox_metadata_db): - async with sandbox_metadata_db as sandbox_metadata_db: - result = await sandbox_metadata_db.insert( - "owner", - "owner_group", - "sbSE", - "sbPFN", - 123, - ) - assert result == 1 - - result = await sandbox_metadata_db.delete([1]) - assert result diff --git a/tests/db/test_sandbox_metadata.py b/tests/db/test_sandbox_metadata.py new file mode 100644 index 00000000..6757af13 --- /dev/null +++ b/tests/db/test_sandbox_metadata.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import asyncio +import secrets +from datetime import datetime + +import pytest +import sqlalchemy + +from diracx.core.models import SandboxInfo, UserInfo +from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB +from diracx.db.sql.sandbox_metadata.schema import sb_SandBoxes + + +@pytest.fixture +async def sandbox_metadata_db(tmp_path): + sandbox_metadata_db = SandboxMetadataDB("sqlite+aiosqlite:///:memory:") + async with sandbox_metadata_db.engine_context(): + async with sandbox_metadata_db.engine.begin() as conn: + await conn.run_sync(sandbox_metadata_db.metadata.create_all) + yield sandbox_metadata_db + + +def test_get_pfn(sandbox_metadata_db: SandboxMetadataDB): + user_info = UserInfo( + sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo" + ) + sandbox_info = SandboxInfo( + checksum="checksum", + checksum_algorithm="sha256", + format="tar.bz2", + size=100, + ) + pfn = sandbox_metadata_db.get_pfn("bucket1", user_info, sandbox_info) + assert pfn == "/S3/bucket1/vo/group1/user1/sha256:checksum.tar.bz2" + + +async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB): + user_info = UserInfo( + sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo" + ) + pfn1 = secrets.token_hex() + + # Make sure the sandbox doesn't already exist + db_contents = await _dump_db(sandbox_metadata_db) + assert pfn1 not in db_contents + async with sandbox_metadata_db: + with pytest.raises(sqlalchemy.exc.NoResultFound): + await sandbox_metadata_db.sandbox_is_assigned(pfn1) + + # Insert the sandbox + async with sandbox_metadata_db: + await sandbox_metadata_db.insert_sandbox(user_info, pfn1, 100) + db_contents = await _dump_db(sandbox_metadata_db) + owner_id1, last_access_time1 = db_contents[pfn1] + + # Inserting again should update the last access time + await asyncio.sleep(1) # The timestamp only has second precision + async with sandbox_metadata_db: + await sandbox_metadata_db.insert_sandbox(user_info, pfn1, 100) + db_contents = await _dump_db(sandbox_metadata_db) + owner_id2, last_access_time2 = db_contents[pfn1] + assert owner_id1 == owner_id2 + assert last_access_time2 > last_access_time1 + + # The sandbox still hasn't been assigned + async with sandbox_metadata_db: + assert not await sandbox_metadata_db.sandbox_is_assigned(pfn1) + + # Inserting again should update the last access time + await asyncio.sleep(1) # The timestamp only has second precision + last_access_time3 = (await _dump_db(sandbox_metadata_db))[pfn1][1] + assert last_access_time2 == last_access_time3 + async with sandbox_metadata_db: + await sandbox_metadata_db.update_sandbox_last_access_time(pfn1) + last_access_time4 = (await _dump_db(sandbox_metadata_db))[pfn1][1] + assert last_access_time2 < last_access_time4 + + +async def _dump_db( + sandbox_metadata_db: SandboxMetadataDB, +) -> dict[str, tuple[int, datetime]]: + """Dump the contents of the sandbox metadata database + + Returns a dict[pfn: str, (owner_id: int, last_access_time: datetime)] + """ + async with sandbox_metadata_db: + stmt = sqlalchemy.select( + sb_SandBoxes.SEPFN, sb_SandBoxes.OwnerId, sb_SandBoxes.LastAccessTime + ) + res = await sandbox_metadata_db.conn.execute(stmt) + return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res} From 3e7f3a7abf9358256bcc4c6f92d66797855bcdf4 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Wed, 27 Sep 2023 11:59:27 +0200 Subject: [PATCH 2/8] Add routes for uploading/downloading sandboxes --- src/diracx/routers/dependencies.py | 4 + src/diracx/routers/job_manager/__init__.py | 2 + src/diracx/routers/job_manager/sandboxes.py | 176 ++++++++++++++++++++ tests/conftest.py | 15 +- tests/routers/jobs/__init__.py | 0 tests/routers/jobs/test_sandboxes.py | 63 +++++++ 6 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 src/diracx/routers/job_manager/sandboxes.py create mode 100644 tests/routers/jobs/__init__.py create mode 100644 tests/routers/jobs/test_sandboxes.py diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index 3ab40361..deb55167 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -19,6 +19,7 @@ from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB +from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB T = TypeVar("T") @@ -32,6 +33,9 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +SandboxMetadataDB = Annotated[ + _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) +] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index df36e15b..39592c52 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -29,12 +29,14 @@ from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token from ..dependencies import JobDB, JobLoggingDB from ..fastapi_classes import DiracxRouter +from .sandboxes import router as sandboxes_router MAX_PARAMETRIC_JOBS = 20 logger = logging.getLogger(__name__) router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) +router.include_router(sandboxes_router) class JobSummaryParams(BaseModel): diff --git a/src/diracx/routers/job_manager/sandboxes.py b/src/diracx/routers/job_manager/sandboxes.py new file mode 100644 index 00000000..b47af93f --- /dev/null +++ b/src/diracx/routers/job_manager/sandboxes.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import TYPE_CHECKING, Annotated + +import botocore.session +from botocore.config import Config +from botocore.errorfactory import ClientError +from fastapi import Depends, HTTPException, Query +from pydantic import BaseModel, PrivateAttr +from sqlalchemy.exc import NoResultFound + +from diracx.core.models import ( + SandboxInfo, +) +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER +from diracx.core.s3 import ( + generate_presigned_upload, + s3_bucket_exists, + s3_object_exists, +) +from diracx.core.settings import ServiceSettingsBase + +if TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client + +from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token +from ..dependencies import SandboxMetadataDB, add_settings_annotation +from ..fastapi_classes import DiracxRouter + +MAX_SANDBOX_SIZE_BYTES = 100 * 1024 * 1024 +router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) + + +@add_settings_annotation +class SandboxStoreSettings(ServiceSettingsBase, env_prefix="DIRACX_SANDBOX_STORE_"): + """Settings for the sandbox store.""" + + bucket_name: str + s3_client_kwargs: dict[str, str] + auto_create_bucket: bool = False + url_validity_seconds: int = 5 * 60 + _client: S3Client = PrivateAttr(None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # TODO: Use async + session = botocore.session.get_session() + self._client = session.create_client( + "s3", + # endpoint_url=s3_cred["endpoint"], + # aws_access_key_id=s3_cred["access_key_id"], + # aws_secret_access_key=s3_cred["secret_access_key"], + **self.s3_client_kwargs, + config=Config(signature_version="v4"), + ) + if not s3_bucket_exists(self._client, self.bucket_name): + if not self.auto_create_bucket: + raise ValueError( + f"Bucket {self.bucket_name} does not exist and auto_create_bucket is disabled" + ) + try: + self._client.create_bucket(Bucket=self.bucket_name) + except ClientError as e: + raise ValueError(f"Failed to create bucket {self.bucket_name}") from e + + @property + def s3_client(self) -> S3Client: + return self._client + + +class SandboxUploadResponse(BaseModel): + pfn: str + url: str | None = None + fields: dict[str, str] = {} + + +@router.post("/sandbox") +async def initiate_sandbox_upload( + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + sandbox_info: SandboxInfo, + sandbox_metadata_db: SandboxMetadataDB, + settings: SandboxStoreSettings, +) -> SandboxUploadResponse: + """Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + """ + if sandbox_info.size > MAX_SANDBOX_SIZE_BYTES: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Sandbox too large. Max size is {MAX_SANDBOX_SIZE_BYTES} bytes", + ) + + pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info) + + try: + exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned(pfn) + except NoResultFound: + # The sandbox doesn't exist in the database + pass + else: + # As sandboxes are registered in the DB before uploading to the storage + # backend we can't on their existence in the database to determine if + # they have been uploaded. Instead we check if the sandbox has been + # assigned to a job. If it has then we know it has been uploaded and we + # can avoid communicating with the storage backend. + if exists_and_assigned or s3_object_exists( + settings.s3_client, settings.bucket_name, pfn_to_key(pfn) + ): + await sandbox_metadata_db.update_sandbox_last_access_time(pfn) + return SandboxUploadResponse(pfn=pfn) + + upload_info = generate_presigned_upload( + settings.s3_client, + settings.bucket_name, + pfn_to_key(pfn), + sandbox_info.checksum_algorithm, + sandbox_info.checksum, + sandbox_info.size, + settings.url_validity_seconds, + ) + await sandbox_metadata_db.insert_sandbox(user_info, pfn, sandbox_info.size) + + return SandboxUploadResponse(**upload_info, pfn=pfn) + + +class SandboxDownloadResponse(BaseModel): + url: str + expires_in: int + + +def pfn_to_key(pfn: str) -> str: + """Convert a PFN to a key for S3 + + This removes the leading "/S3/" from the PFN. + """ + return "/".join(pfn.split("/")[3:]) + + +SANDBOX_PFN_REGEX = ( + # Starts with /S3/ + r"^/S3/[a-z0-9\.\-]{3,63}" + # Followed ////:. + r"(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\.[a-z0-9\.]+$" +) + + +@router.get("/sandbox") +async def get_sandbox_file( + pfn: Annotated[str, Query(max_length=256, pattern=SANDBOX_PFN_REGEX)], + settings: SandboxStoreSettings, +) -> SandboxDownloadResponse: + """Get a presigned URL to download a sandbox file + + This route cannot use a redirect response most clients will also send the + authorization header when following a redirect. This is not desirable as + it would leak the authorization token to the storage backend. Additionally, + most storage backends return an error when they receive an authorization + header for a presigned URL. + """ + # TODO: Prevent people from downloading other people's sandboxes? + # TODO: Support by name and by job id? + presigned_url = settings.s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": settings.bucket_name, "Key": pfn_to_key(pfn)}, + ExpiresIn=settings.url_validity_seconds, + ) + return SandboxDownloadResponse( + url=presigned_url, expires_in=settings.url_validity_seconds + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9b4d224c..60ec907f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,12 +14,14 @@ from cryptography.hazmat.primitives.asymmetric import rsa from fastapi.testclient import TestClient from git import Repo +from moto import mock_s3 from diracx.core.config import Config, ConfigSource from diracx.core.preferences import get_diracx_preferences from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.routers import create_app_inner from diracx.routers.auth import AuthSettings, create_token +from diracx.routers.job_manager.sandboxes import SandboxStoreSettings # to get a string like this run: # openssl rand -hex 32 @@ -78,8 +80,18 @@ def test_auth_settings() -> AuthSettings: ) +@pytest.fixture(scope="function") +def test_sandbox_settings() -> SandboxStoreSettings: + with mock_s3(): + yield SandboxStoreSettings( + bucket_name="sandboxes", + s3_client_kwargs={}, + auto_create_bucket=True, + ) + + @pytest.fixture -def with_app(test_auth_settings, with_config_repo): +def with_app(test_auth_settings, test_sandbox_settings, with_config_repo): """ Create a DiracxApp with hard coded configuration for test """ @@ -87,6 +99,7 @@ def with_app(test_auth_settings, with_config_repo): enabled_systems={".well-known", "auth", "config", "jobs"}, all_service_settings=[ test_auth_settings, + test_sandbox_settings, ], database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", diff --git a/tests/routers/jobs/__init__.py b/tests/routers/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/routers/jobs/test_sandboxes.py b/tests/routers/jobs/test_sandboxes.py new file mode 100644 index 00000000..a683c2b7 --- /dev/null +++ b/tests/routers/jobs/test_sandboxes.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import hashlib +import secrets +from io import BytesIO + +import requests +from fastapi.testclient import TestClient + + +def test_upload_then_download(normal_user_client: TestClient): + data = secrets.token_bytes(512) + checksum = hashlib.sha256(data).hexdigest() + + # Initiate the upload + r = normal_user_client.post( + "/jobs/sandbox", + json={ + "checksum_algorithm": "sha256", + "checksum": checksum, + "size": len(data), + "format": "tar.bz2", + }, + ) + assert r.status_code == 200, r.text + upload_info = r.json() + assert upload_info["url"] + sandbox_pfn = upload_info["pfn"] + assert sandbox_pfn.startswith("/S3/") + + # Actually upload the file + files = {"file": ("file", BytesIO(data))} + r = requests.post(upload_info["url"], data=upload_info["fields"], files=files) + assert r.status_code == 204, r.text + + # Make sure we can download it and get the same data back + r = normal_user_client.get("/jobs/sandbox", params={"pfn": sandbox_pfn}) + assert r.status_code == 200, r.text + download_info = r.json() + assert download_info["expires_in"] > 5 + r = requests.get(download_info["url"]) + assert r.status_code == 200, r.text + assert r.content == data + + +def test_upload_oversized(normal_user_client: TestClient): + data = secrets.token_bytes(512) + checksum = hashlib.sha256(data).hexdigest() + + # Initiate the upload + r = normal_user_client.post( + "/jobs/sandbox", + json={ + "checksum_algorithm": "sha256", + "checksum": checksum, + # We can forge the size here to be larger than the actual data as + # we should get an error and never actually upload the data + "size": 1024 * 1024 * 1024, + "format": "tar.bz2", + }, + ) + assert r.status_code == 400, r.text + assert "Sandbox too large" in r.json()["detail"], r.text From ea886a19534df3ab9cc32f3f17c8c71eb0a48cf8 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 26 Sep 2023 17:54:01 +0200 Subject: [PATCH 3/8] Regenerate client --- .../client/aio/operations/_operations.py | 194 ++++++++++++++ src/diracx/client/models/__init__.py | 18 +- src/diracx/client/models/_enums.py | 36 ++- src/diracx/client/models/_models.py | 135 +++++++++- src/diracx/client/operations/_operations.py | 241 ++++++++++++++++++ 5 files changed, 607 insertions(+), 17 deletions(-) diff --git a/src/diracx/client/aio/operations/_operations.py b/src/diracx/client/aio/operations/_operations.py index 9c3ec627..a1c87d68 100644 --- a/src/diracx/client/aio/operations/_operations.py +++ b/src/diracx/client/aio/operations/_operations.py @@ -37,9 +37,11 @@ build_jobs_delete_bulk_jobs_request, build_jobs_get_job_status_bulk_request, build_jobs_get_job_status_history_bulk_request, + build_jobs_get_sandbox_file_request, build_jobs_get_single_job_request, build_jobs_get_single_job_status_history_request, build_jobs_get_single_job_status_request, + build_jobs_initiate_sandbox_upload_request, build_jobs_kill_bulk_jobs_request, build_jobs_search_request, build_jobs_set_job_status_bulk_request, @@ -819,6 +821,198 @@ def __init__(self, *args, **kwargs) -> None: input_args.pop(0) if input_args else kwargs.pop("deserializer") ) + @distributed_trace_async + async def get_sandbox_file( + self, *, pfn: str, **kwargs: Any + ) -> _models.SandboxDownloadResponse: + """Get Sandbox File. + + Get a presigned URL to download a sandbox file + + This route cannot use a redirect response most clients will also send the + authorization header when following a redirect. This is not desirable as + it would leak the authorization token to the storage backend. Additionally, + most storage backends return an error when they receive an authorization + header for a presigned URL. + + :keyword pfn: Required. + :paramtype pfn: str + :return: SandboxDownloadResponse + :rtype: ~client.models.SandboxDownloadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.SandboxDownloadResponse] = kwargs.pop("cls", None) + + request = build_jobs_get_sandbox_file_request( + pfn=pfn, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxDownloadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @overload + async def initiate_sandbox_upload( + self, + body: _models.SandboxInfo, + *, + content_type: str = "application/json", + **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: ~client.models.SandboxInfo + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def initiate_sandbox_upload( + self, body: IO, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: IO + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def initiate_sandbox_upload( + self, body: Union[_models.SandboxInfo, IO], **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Is either a SandboxInfo type or a IO type. Required. + :type body: ~client.models.SandboxInfo or IO + :keyword content_type: Body Parameter content-type. Known values are: 'application/json'. + Default value is None. + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + cls: ClsType[_models.SandboxUploadResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SandboxInfo") + + request = build_jobs_initiate_sandbox_upload_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxUploadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @overload async def submit_bulk_jobs( self, body: List[str], *, content_type: str = "application/json", **kwargs: Any diff --git a/src/diracx/client/models/__init__.py b/src/diracx/client/models/__init__.py index 893d7c15..205f5098 100644 --- a/src/diracx/client/models/__init__.py +++ b/src/diracx/client/models/__init__.py @@ -16,6 +16,9 @@ from ._models import JobSummaryParams from ._models import JobSummaryParamsSearchItem from ._models import LimitedJobStatusReturn +from ._models import SandboxDownloadResponse +from ._models import SandboxInfo +from ._models import SandboxUploadResponse from ._models import ScalarSearchSpec from ._models import SetJobStatusReturn from ._models import SortSpec @@ -26,14 +29,16 @@ from ._models import ValidationErrorLocItem from ._models import VectorSearchSpec +from ._enums import ChecksumAlgorithm from ._enums import Enum0 from ._enums import Enum1 +from ._enums import Enum10 +from ._enums import Enum11 from ._enums import Enum2 from ._enums import Enum3 from ._enums import Enum4 -from ._enums import Enum8 -from ._enums import Enum9 from ._enums import JobStatus +from ._enums import SandboxFormat from ._enums import ScalarSearchOperator from ._enums import VectorSearchOperator from ._patch import __all__ as _patch_all @@ -53,6 +58,9 @@ "JobSummaryParams", "JobSummaryParamsSearchItem", "LimitedJobStatusReturn", + "SandboxDownloadResponse", + "SandboxInfo", + "SandboxUploadResponse", "ScalarSearchSpec", "SetJobStatusReturn", "SortSpec", @@ -62,14 +70,16 @@ "ValidationError", "ValidationErrorLocItem", "VectorSearchSpec", + "ChecksumAlgorithm", "Enum0", "Enum1", + "Enum10", + "Enum11", "Enum2", "Enum3", "Enum4", - "Enum8", - "Enum9", "JobStatus", + "SandboxFormat", "ScalarSearchOperator", "VectorSearchOperator", ] diff --git a/src/diracx/client/models/_enums.py b/src/diracx/client/models/_enums.py index ccabc77c..60628b59 100644 --- a/src/diracx/client/models/_enums.py +++ b/src/diracx/client/models/_enums.py @@ -8,6 +8,12 @@ from azure.core import CaseInsensitiveEnumMeta +class ChecksumAlgorithm(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """An enumeration.""" + + SHA256 = "sha256" + + class Enum0(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Enum0.""" @@ -22,6 +28,18 @@ class Enum1(str, Enum, metaclass=CaseInsensitiveEnumMeta): ) +class Enum10(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Enum10.""" + + ASC = "asc" + + +class Enum11(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Enum11.""" + + DSC = "dsc" + + class Enum2(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Enum2.""" @@ -40,18 +58,6 @@ class Enum4(str, Enum, metaclass=CaseInsensitiveEnumMeta): S256 = "S256" -class Enum8(str, Enum, metaclass=CaseInsensitiveEnumMeta): - """Enum8.""" - - ASC = "asc" - - -class Enum9(str, Enum, metaclass=CaseInsensitiveEnumMeta): - """Enum9.""" - - DSC = "dsc" - - class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): """An enumeration.""" @@ -72,6 +78,12 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """An enumeration.""" + + TAR_BZ2 = "tar.bz2" + + class ScalarSearchOperator(str, Enum, metaclass=CaseInsensitiveEnumMeta): """An enumeration.""" diff --git a/src/diracx/client/models/_models.py b/src/diracx/client/models/_models.py index 5578ad22..716f06f7 100644 --- a/src/diracx/client/models/_models.py +++ b/src/diracx/client/models/_models.py @@ -6,7 +6,7 @@ # -------------------------------------------------------------------------- import datetime -from typing import Any, List, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union from .. import _serialization @@ -509,6 +509,139 @@ def __init__( self.application_status = application_status +class SandboxDownloadResponse(_serialization.Model): + """SandboxDownloadResponse. + + All required parameters must be populated in order to send to Azure. + + :ivar url: Url. Required. + :vartype url: str + :ivar expires_in: Expires In. Required. + :vartype expires_in: int + """ + + _validation = { + "url": {"required": True}, + "expires_in": {"required": True}, + } + + _attribute_map = { + "url": {"key": "url", "type": "str"}, + "expires_in": {"key": "expires_in", "type": "int"}, + } + + def __init__(self, *, url: str, expires_in: int, **kwargs: Any) -> None: + """ + :keyword url: Url. Required. + :paramtype url: str + :keyword expires_in: Expires In. Required. + :paramtype expires_in: int + """ + super().__init__(**kwargs) + self.url = url + self.expires_in = expires_in + + +class SandboxInfo(_serialization.Model): + """SandboxInfo. + + All required parameters must be populated in order to send to Azure. + + :ivar checksum_algorithm: An enumeration. Required. "sha256" + :vartype checksum_algorithm: str or ~client.models.ChecksumAlgorithm + :ivar checksum: Checksum. Required. + :vartype checksum: str + :ivar size: Size. Required. + :vartype size: int + :ivar format: An enumeration. Required. "tar.bz2" + :vartype format: str or ~client.models.SandboxFormat + """ + + _validation = { + "checksum_algorithm": {"required": True}, + "checksum": {"required": True, "pattern": r"^[0-f]{64}$"}, + "size": {"required": True, "minimum": 1}, + "format": {"required": True}, + } + + _attribute_map = { + "checksum_algorithm": {"key": "checksum_algorithm", "type": "str"}, + "checksum": {"key": "checksum", "type": "str"}, + "size": {"key": "size", "type": "int"}, + "format": {"key": "format", "type": "str"}, + } + + def __init__( + self, + *, + checksum_algorithm: Union[str, "_models.ChecksumAlgorithm"], + checksum: str, + size: int, + format: Union[str, "_models.SandboxFormat"], + **kwargs: Any + ) -> None: + """ + :keyword checksum_algorithm: An enumeration. Required. "sha256" + :paramtype checksum_algorithm: str or ~client.models.ChecksumAlgorithm + :keyword checksum: Checksum. Required. + :paramtype checksum: str + :keyword size: Size. Required. + :paramtype size: int + :keyword format: An enumeration. Required. "tar.bz2" + :paramtype format: str or ~client.models.SandboxFormat + """ + super().__init__(**kwargs) + self.checksum_algorithm = checksum_algorithm + self.checksum = checksum + self.size = size + self.format = format + + +class SandboxUploadResponse(_serialization.Model): + """SandboxUploadResponse. + + All required parameters must be populated in order to send to Azure. + + :ivar pfn: Pfn. Required. + :vartype pfn: str + :ivar url: Url. + :vartype url: str + :ivar fields: Fields. + :vartype fields: dict[str, str] + """ + + _validation = { + "pfn": {"required": True}, + } + + _attribute_map = { + "pfn": {"key": "pfn", "type": "str"}, + "url": {"key": "url", "type": "str"}, + "fields": {"key": "fields", "type": "{str}"}, + } + + def __init__( + self, + *, + pfn: str, + url: Optional[str] = None, + fields: Optional[Dict[str, str]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pfn: Pfn. Required. + :paramtype pfn: str + :keyword url: Url. + :paramtype url: str + :keyword fields: Fields. + :paramtype fields: dict[str, str] + """ + super().__init__(**kwargs) + self.pfn = pfn + self.url = url + self.fields = fields + + class ScalarSearchSpec(_serialization.Model): """ScalarSearchSpec. diff --git a/src/diracx/client/operations/_operations.py b/src/diracx/client/operations/_operations.py index 72dea1c1..bb2d7807 100644 --- a/src/diracx/client/operations/_operations.py +++ b/src/diracx/client/operations/_operations.py @@ -280,6 +280,55 @@ def build_config_serve_config_request( return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_jobs_get_sandbox_file_request(*, pfn: str, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/jobs/sandbox" + + # Construct parameters + _params["pfn"] = _SERIALIZER.query( + "pfn", + pfn, + "str", + max_length=256, + pattern=r"^/S3/[a-z0-9\.\-]{3,63}(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\.[a-z0-9\.]+$", + ) + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest( + method="GET", url=_url, params=_params, headers=_headers, **kwargs + ) + + +def build_jobs_initiate_sandbox_upload_request( + **kwargs: Any, +) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/jobs/sandbox" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header( + "content_type", content_type, "str" + ) + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + def build_jobs_submit_bulk_jobs_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) @@ -1324,6 +1373,198 @@ def __init__(self, *args, **kwargs): input_args.pop(0) if input_args else kwargs.pop("deserializer") ) + @distributed_trace + def get_sandbox_file( + self, *, pfn: str, **kwargs: Any + ) -> _models.SandboxDownloadResponse: + """Get Sandbox File. + + Get a presigned URL to download a sandbox file + + This route cannot use a redirect response most clients will also send the + authorization header when following a redirect. This is not desirable as + it would leak the authorization token to the storage backend. Additionally, + most storage backends return an error when they receive an authorization + header for a presigned URL. + + :keyword pfn: Required. + :paramtype pfn: str + :return: SandboxDownloadResponse + :rtype: ~client.models.SandboxDownloadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[_models.SandboxDownloadResponse] = kwargs.pop("cls", None) + + request = build_jobs_get_sandbox_file_request( + pfn=pfn, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxDownloadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @overload + def initiate_sandbox_upload( + self, + body: _models.SandboxInfo, + *, + content_type: str = "application/json", + **kwargs: Any, + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: ~client.models.SandboxInfo + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def initiate_sandbox_upload( + self, body: IO, *, content_type: str = "application/json", **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Required. + :type body: IO + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def initiate_sandbox_upload( + self, body: Union[_models.SandboxInfo, IO], **kwargs: Any + ) -> _models.SandboxUploadResponse: + """Initiate Sandbox Upload. + + Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + + :param body: Is either a SandboxInfo type or a IO type. Required. + :type body: ~client.models.SandboxInfo or IO + :keyword content_type: Body Parameter content-type. Known values are: 'application/json'. + Default value is None. + :paramtype content_type: str + :return: SandboxUploadResponse + :rtype: ~client.models.SandboxUploadResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop( + "content_type", _headers.pop("Content-Type", None) + ) + cls: ClsType[_models.SandboxUploadResponse] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SandboxInfo") + + request = build_jobs_initiate_sandbox_upload_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("SandboxUploadResponse", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @overload def submit_bulk_jobs( self, body: List[str], *, content_type: str = "application/json", **kwargs: Any From 4cbd4d0cb8fc42392382213c01e8a4adef5c8200 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Wed, 27 Sep 2023 09:05:28 +0200 Subject: [PATCH 4/8] Add API for uploading/downloading sandboxes --- src/diracx/api/__init__.py | 5 +++ src/diracx/api/jobs.py | 92 ++++++++++++++++++++++++++++++++++++++ src/diracx/api/utils.py | 25 +++++++++++ src/diracx/core/utils.py | 2 +- tests/api/__init__.py | 0 tests/api/test_jobs.py | 47 +++++++++++++++++++ tests/api/test_utils.py | 24 ++++++++++ 7 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 src/diracx/api/jobs.py create mode 100644 src/diracx/api/utils.py create mode 100644 tests/api/__init__.py create mode 100644 tests/api/test_jobs.py create mode 100644 tests/api/test_utils.py diff --git a/src/diracx/api/__init__.py b/src/diracx/api/__init__.py index e69de29b..329dd573 100644 --- a/src/diracx/api/__init__.py +++ b/src/diracx/api/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +__all__ = ("jobs",) + +from . import jobs diff --git a/src/diracx/api/jobs.py b/src/diracx/api/jobs.py new file mode 100644 index 00000000..a5bd5e45 --- /dev/null +++ b/src/diracx/api/jobs.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +__all__ = ("create_sandbox", "download_sandbox") + +import hashlib +import logging +import os +import tarfile +import tempfile +from pathlib import Path + +import httpx + +from diracx.client.aio import DiracClient +from diracx.client.models import SandboxInfo + +from .utils import with_client + +logger = logging.getLogger(__name__) + +SANDBOX_CHECKSUM_ALGORITHM = "sha256" +SANDBOX_COMPRESSION = "bz2" + + +@with_client +async def create_sandbox(paths: list[Path], *, client: DiracClient) -> str: + """Create a sandbox from the given paths and upload it to the storage backend. + + Any paths that are directories will be added recursively. + The returned value is the PFN of the sandbox in the storage backend and can + be used to submit jobs. + """ + with tempfile.TemporaryFile(mode="w+b") as tar_fh: + with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf: + for path in paths: + logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name) + tf.add(path.resolve(), path.name, recursive=True) + tar_fh.seek(0) + + hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)() + while data := tar_fh.read(512 * 1024): + hasher.update(data) + checksum = hasher.hexdigest() + tar_fh.seek(0) + logger.debug("Sandbox checksum is %s", checksum) + + sandbox_info = SandboxInfo( + checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM, + checksum=checksum, + size=os.stat(tar_fh.fileno()).st_size, + format=f"tar.{SANDBOX_COMPRESSION}", + ) + + res = await client.jobs.initiate_sandbox_upload(sandbox_info) + if res.url: + logger.debug("Uploading sandbox for %s", res.pfn) + files = {"file": ("file", tar_fh)} + async with httpx.AsyncClient() as httpx_client: + response = await httpx_client.post( + res.url, data=res.fields, files=files + ) + # TODO: Handle this error better + response.raise_for_status() + + logger.debug( + "Sandbox uploaded for %s with status code %s", + res.pfn, + response.status_code, + ) + else: + logger.debug("%s already exists in storage backend", res.pfn) + return res.pfn + + +@with_client +async def download_sandbox(pfn: str, destination: Path, *, client: DiracClient): + """Download a sandbox from the storage backend to the given destination.""" + res = await client.jobs.get_sandbox_file(pfn=pfn) + logger.debug("Downloading sandbox for %s", pfn) + with tempfile.TemporaryFile(mode="w+b") as fh: + async with httpx.AsyncClient() as http_client: + response = await http_client.get(res.url) + # TODO: Handle this error better + response.raise_for_status() + async for chunk in response.aiter_bytes(): + fh.write(chunk) + fh.seek(0) + logger.debug("Sandbox downloaded for %s", pfn) + + with tarfile.open(fileobj=fh) as tf: + tf.extractall(path=destination, filter="data") + logger.debug("Extracted %s to %s", pfn, destination) diff --git a/src/diracx/api/utils.py b/src/diracx/api/utils.py new file mode 100644 index 00000000..b53338da --- /dev/null +++ b/src/diracx/api/utils.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +__all__ = ("with_client",) + +from functools import wraps + +from diracx.client.aio import DiracClient + + +def with_client(func): + """Decorator to provide a DiracClient to a function. + + If the function already has a `client` keyword argument, it will be used. + Otherwise, a new DiracClient will be created and passed as a keyword argument. + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + if "client" in kwargs: + return await func(*args, **kwargs) + + async with DiracClient() as client: + return await func(*args, **kwargs, client=client) + + return wrapper diff --git a/src/diracx/core/utils.py b/src/diracx/core/utils.py index 2e81ea81..39fb23de 100644 --- a/src/diracx/core/utils.py +++ b/src/diracx/core/utils.py @@ -20,7 +20,7 @@ def dotenv_files_from_environment(prefix: str) -> list[str]: return [v for _, v in sorted(env_files.items())] -def write_credentials(token_response: TokenResponse, location: Path | None = None): +def write_credentials(token_response: TokenResponse, *, location: Path | None = None): """Write credentials received in dirax_preferences.credentials_path""" from diracx.core.preferences import get_diracx_preferences diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/api/test_jobs.py b/tests/api/test_jobs.py new file mode 100644 index 00000000..cf47595c --- /dev/null +++ b/tests/api/test_jobs.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import logging +import secrets + +from diracx.api.jobs import create_sandbox, download_sandbox + + +async def test_upload_download_sandbox(tmp_path, with_cli_login, caplog): + caplog.set_level(logging.DEBUG) + + input_directory = tmp_path / "input" + input_directory.mkdir() + input_files = [] + + input_file = input_directory / "input.dat" + input_file.write_bytes(secrets.token_bytes(512)) + input_files.append(input_file) + + input_file = input_directory / "a" / "b" / "c" / "nested.dat" + input_file.parent.mkdir(parents=True) + input_file.write_bytes(secrets.token_bytes(512)) + input_files.append(input_file) + + # Upload the sandbox + caplog.clear() + pfn = await create_sandbox(input_files) + assert has_record(caplog.records, "diracx.api.jobs", "Uploading sandbox for") + + # Uploading the same sandbox again should return the same PFN + caplog.clear() + pfn2 = await create_sandbox(input_files) + assert pfn == pfn2 + assert has_record(caplog.records, "diracx.api.jobs", "already exists in storage") + + # Download the sandbox + destination = tmp_path / "output" + await download_sandbox(pfn, destination) + assert (destination / "input.dat").is_file() + assert (destination / "nested.dat").is_file() + + +def has_record(records: list[logging.LogRecord], logger_name: str, message: str): + for record in records: + if record.name == logger_name and message in record.message: + return True + return False diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py new file mode 100644 index 00000000..a9778b45 --- /dev/null +++ b/tests/api/test_utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from diracx.api.utils import with_client +from diracx.client.aio import DiracClient + + +async def test_with_client_default(with_cli_login): + """Ensure that the with_client decorator provides a DiracClient.""" + + @with_client + async def test_func(*, client): + assert isinstance(client, DiracClient) + + await test_func() + + +async def test_with_client_override(): + """Ensure that the with_client can be overridden by providing a client kwarg.""" + + @with_client + async def test_func(*, client): + assert client == "foobar" + + await test_func(client="foobar") From bc6e5e8c86a36db4fb3d4b7d28ed075585a05088 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Mon, 2 Oct 2023 17:41:06 +0200 Subject: [PATCH 5/8] Add lifetime_function method to ServiceSettingsBase --- src/diracx/core/settings.py | 8 +++++++- src/diracx/routers/__init__.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diracx/core/settings.py b/src/diracx/core/settings.py index d98d0d9c..dcc36e61 100644 --- a/src/diracx/core/settings.py +++ b/src/diracx/core/settings.py @@ -6,8 +6,9 @@ "ServiceSettingsBase", ) +import contextlib from pathlib import Path -from typing import TYPE_CHECKING, Any, Self, TypeVar +from typing import TYPE_CHECKING, Any, AsyncIterator, Self, TypeVar from authlib.jose import JsonWebKey from pydantic import AnyUrl, BaseSettings, SecretStr, parse_obj_as @@ -58,3 +59,8 @@ class ServiceSettingsBase(BaseSettings, allow_mutation=False): @classmethod def create(cls) -> Self: raise NotImplementedError("This should never be called") + + @contextlib.asynccontextmanager + async def lifetime_function(self) -> AsyncIterator[None]: + """A context manager that can be used to run code at startup and shutdown.""" + yield diff --git a/src/diracx/routers/__init__.py b/src/diracx/routers/__init__.py index c3b49239..437833d5 100644 --- a/src/diracx/routers/__init__.py +++ b/src/diracx/routers/__init__.py @@ -52,6 +52,7 @@ def create_app_inner( cls = type(service_settings) assert cls not in available_settings_classes available_settings_classes.add(cls) + app.lifetime_functions.append(service_settings.lifetime_function) app.dependency_overrides[cls.create] = partial(lambda x: x, service_settings) # Override the configuration source From 2246bbadc78a67ab6a620c6d754afe7b7dcc58d6 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 3 Oct 2023 09:01:11 +0200 Subject: [PATCH 6/8] Switch to aiobotocore --- .pre-commit-config.yaml | 2 + environment.yml | 8 +-- setup.cfg | 1 + src/diracx/core/s3.py | 24 +++++--- src/diracx/routers/job_manager/sandboxes.py | 47 +++++++------- tests/conftest.py | 33 +++++++--- tests/core/test_s3.py | 68 ++++++++++----------- 7 files changed, 102 insertions(+), 81 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 211d7547..7ca15338 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,4 +33,6 @@ repos: - types-PyYAML - types-cachetools - types-requests + - types-aiobotocore[essential] + - boto3-stubs[essential] exclude: ^(src/diracx/client/|tests/|build) diff --git a/environment.yml b/environment.yml index a10cfad2..a8167b3c 100644 --- a/environment.yml +++ b/environment.yml @@ -51,8 +51,8 @@ dependencies: - types-requests - uvicorn - moto - - mypy-boto3-s3 + - aiobotocore - botocore - - boto3-stubs - # - pip: - # - git+https://github.com/DIRACGrid/DIRAC.git@integration + - pip: + - types-aiobotocore[essential] + - boto3-stubs[essential] diff --git a/setup.cfg b/setup.cfg index abd38583..01c95b0f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ package_dir = = src python_requires = >=3.10 install_requires = + aiobotocore authlib aiohttp aiomysql diff --git a/src/diracx/core/s3.py b/src/diracx/core/s3.py index 34cf7990..6fc56657 100644 --- a/src/diracx/core/s3.py +++ b/src/diracx/core/s3.py @@ -1,7 +1,11 @@ """Utilities for interacting with S3-compatible storage.""" from __future__ import annotations -__all__ = ("s3_bucket_exists", "s3_object_exists", "generate_presigned_upload") +__all__ = ( + "s3_bucket_exists", + "s3_object_exists", + "generate_presigned_upload", +) import base64 from typing import TYPE_CHECKING, TypedDict, cast @@ -11,7 +15,7 @@ from .models import ChecksumAlgorithm if TYPE_CHECKING: - from mypy_boto3_s3.client import S3Client + from types_aiobotocore_s3.client import S3Client class S3PresignedPostInfo(TypedDict): @@ -19,19 +23,19 @@ class S3PresignedPostInfo(TypedDict): fields: dict[str, str] -def s3_bucket_exists(s3_client: S3Client, bucket_name: str) -> bool: +async def s3_bucket_exists(s3_client: S3Client, bucket_name: str) -> bool: """Check if a bucket exists in S3.""" - return _s3_exists(s3_client.head_bucket, Bucket=bucket_name) + return await _s3_exists(s3_client.head_bucket, Bucket=bucket_name) -def s3_object_exists(s3_client: S3Client, bucket_name: str, key: str) -> bool: +async def s3_object_exists(s3_client: S3Client, bucket_name: str, key: str) -> bool: """Check if an object exists in an S3 bucket.""" - return _s3_exists(s3_client.head_object, Bucket=bucket_name, Key=key) + return await _s3_exists(s3_client.head_object, Bucket=bucket_name, Key=key) -def _s3_exists(method, **kwargs: str) -> bool: +async def _s3_exists(method, **kwargs: str) -> bool: try: - method(**kwargs) + await method(**kwargs) except ClientError as e: if e.response["Error"]["Code"] != "404": raise @@ -40,7 +44,7 @@ def _s3_exists(method, **kwargs: str) -> bool: return True -def generate_presigned_upload( +async def generate_presigned_upload( s3_client: S3Client, bucket_name: str, key: str, @@ -60,7 +64,7 @@ def generate_presigned_upload( conditions = [["content-length-range", size, size]] + [ {k: v} for k, v in fields.items() ] - result = s3_client.generate_presigned_post( + result = await s3_client.generate_presigned_post( Bucket=bucket_name, Key=key, Fields=fields, diff --git a/src/diracx/routers/job_manager/sandboxes.py b/src/diracx/routers/job_manager/sandboxes.py index b47af93f..0c1c0ee8 100644 --- a/src/diracx/routers/job_manager/sandboxes.py +++ b/src/diracx/routers/job_manager/sandboxes.py @@ -1,9 +1,10 @@ from __future__ import annotations +import contextlib from http import HTTPStatus -from typing import TYPE_CHECKING, Annotated +from typing import TYPE_CHECKING, Annotated, AsyncIterator -import botocore.session +from aiobotocore.session import get_session from botocore.config import Config from botocore.errorfactory import ClientError from fastapi import Depends, HTTPException, Query @@ -22,7 +23,7 @@ from diracx.core.settings import ServiceSettingsBase if TYPE_CHECKING: - from mypy_boto3_s3.client import S3Client + from types_aiobotocore_s3.client import S3Client from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token from ..dependencies import SandboxMetadataDB, add_settings_annotation @@ -42,28 +43,26 @@ class SandboxStoreSettings(ServiceSettingsBase, env_prefix="DIRACX_SANDBOX_STORE url_validity_seconds: int = 5 * 60 _client: S3Client = PrivateAttr(None) - def __init__(self, **kwargs): - super().__init__(**kwargs) - - # TODO: Use async - session = botocore.session.get_session() - self._client = session.create_client( + @contextlib.asynccontextmanager + async def lifetime_function(self) -> AsyncIterator[None]: + async with get_session().create_client( "s3", - # endpoint_url=s3_cred["endpoint"], - # aws_access_key_id=s3_cred["access_key_id"], - # aws_secret_access_key=s3_cred["secret_access_key"], **self.s3_client_kwargs, config=Config(signature_version="v4"), - ) - if not s3_bucket_exists(self._client, self.bucket_name): - if not self.auto_create_bucket: - raise ValueError( - f"Bucket {self.bucket_name} does not exist and auto_create_bucket is disabled" - ) - try: - self._client.create_bucket(Bucket=self.bucket_name) - except ClientError as e: - raise ValueError(f"Failed to create bucket {self.bucket_name}") from e + ) as self._client: # type: ignore + if not await s3_bucket_exists(self._client, self.bucket_name): + if not self.auto_create_bucket: + raise ValueError( + f"Bucket {self.bucket_name} does not exist and auto_create_bucket is disabled" + ) + try: + await self._client.create_bucket(Bucket=self.bucket_name) + except ClientError as e: + raise ValueError( + f"Failed to create bucket {self.bucket_name}" + ) from e + + yield @property def s3_client(self) -> S3Client: @@ -116,7 +115,7 @@ async def initiate_sandbox_upload( await sandbox_metadata_db.update_sandbox_last_access_time(pfn) return SandboxUploadResponse(pfn=pfn) - upload_info = generate_presigned_upload( + upload_info = await generate_presigned_upload( settings.s3_client, settings.bucket_name, pfn_to_key(pfn), @@ -166,7 +165,7 @@ async def get_sandbox_file( """ # TODO: Prevent people from downloading other people's sandboxes? # TODO: Support by name and by job id? - presigned_url = settings.s3_client.generate_presigned_url( + presigned_url = await settings.s3_client.generate_presigned_url( ClientMethod="get_object", Params={"Bucket": settings.bucket_name, "Key": pfn_to_key(pfn)}, ExpiresIn=settings.url_validity_seconds, diff --git a/tests/conftest.py b/tests/conftest.py index 60ec907f..87f33410 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa from fastapi.testclient import TestClient from git import Repo -from moto import mock_s3 +from moto.server import ThreadedMotoServer from diracx.core.config import Config, ConfigSource from diracx.core.preferences import get_diracx_preferences @@ -80,14 +80,31 @@ def test_auth_settings() -> AuthSettings: ) +@pytest.fixture(scope="session") +def aio_moto(): + """Start the moto server in a separate thread and return the base URL + + The mocking provided by moto doesn't play nicely with aiobotocore so we use + the server directly. See https://github.com/aio-libs/aiobotocore/issues/755 + """ + port = 27132 + server = ThreadedMotoServer(port=port) + server.start() + yield { + "endpoint_url": f"http://localhost:{port}", + "aws_access_key_id": "testing", + "aws_secret_access_key": "testing", + } + server.stop() + + @pytest.fixture(scope="function") -def test_sandbox_settings() -> SandboxStoreSettings: - with mock_s3(): - yield SandboxStoreSettings( - bucket_name="sandboxes", - s3_client_kwargs={}, - auto_create_bucket=True, - ) +def test_sandbox_settings(aio_moto) -> SandboxStoreSettings: + yield SandboxStoreSettings( + bucket_name="sandboxes", + s3_client_kwargs=aio_moto, + auto_create_bucket=True, + ) @pytest.fixture diff --git a/tests/core/test_s3.py b/tests/core/test_s3.py index c6acb2a3..4c6fad5d 100644 --- a/tests/core/test_s3.py +++ b/tests/core/test_s3.py @@ -4,10 +4,9 @@ import hashlib import secrets -import botocore.exceptions import pytest import requests -from moto import mock_s3 +from aiobotocore.session import get_session from diracx.core.s3 import ( b16_to_b64, @@ -43,42 +42,39 @@ def test_b16_to_b64_random(): @pytest.fixture(scope="function") -def moto_s3(): +async def moto_s3(aio_moto): """Very basic moto-based S3 backend. This is a fixture that can be used to test S3 interactions using moto. Note that this is not a complete S3 backend, in particular authentication and validation of requests is not implemented. """ - with mock_s3(): - client = botocore.session.get_session().create_client("s3") - client.create_bucket(Bucket=BUCKET_NAME) - client.create_bucket(Bucket=OTHER_BUCKET_NAME) + async with get_session().create_client("s3", **aio_moto) as client: + await client.create_bucket(Bucket=BUCKET_NAME) + await client.create_bucket(Bucket=OTHER_BUCKET_NAME) yield client -def test_s3_bucket_exists(moto_s3): - assert s3_bucket_exists(moto_s3, BUCKET_NAME) - assert not s3_bucket_exists(moto_s3, MISSING_BUCKET_NAME) +async def test_s3_bucket_exists(moto_s3): + assert await s3_bucket_exists(moto_s3, BUCKET_NAME) + assert not await s3_bucket_exists(moto_s3, MISSING_BUCKET_NAME) -def test_s3_object_exists(moto_s3): - with pytest.raises(botocore.exceptions.ClientError): - s3_object_exists(moto_s3, MISSING_BUCKET_NAME, "key") +async def test_s3_object_exists(moto_s3): + assert not await s3_object_exists(moto_s3, MISSING_BUCKET_NAME, "key") + assert not await s3_object_exists(moto_s3, BUCKET_NAME, "key") + await moto_s3.put_object(Bucket=BUCKET_NAME, Key="key", Body=b"hello") + assert await s3_object_exists(moto_s3, BUCKET_NAME, "key") - assert not s3_object_exists(moto_s3, BUCKET_NAME, "key") - moto_s3.put_object(Bucket=BUCKET_NAME, Key="key", Body=b"hello") - assert s3_object_exists(moto_s3, BUCKET_NAME, "key") - -def test_presigned_upload_moto(moto_s3): +async def test_presigned_upload_moto(moto_s3): """Test the presigned upload with moto This doesn't actually test the signature, see test_presigned_upload_minio """ file_content, checksum = _random_file(128) key = f"{checksum}.dat" - upload_info = generate_presigned_upload( + upload_info = await generate_presigned_upload( moto_s3, BUCKET_NAME, key, "sha256", checksum, len(file_content), 60 ) @@ -89,30 +85,32 @@ def test_presigned_upload_moto(moto_s3): assert r.status_code == 204, r.text # Make sure the object is actually there - obj = moto_s3.get_object(Bucket=BUCKET_NAME, Key=key) - assert obj["Body"].read() == file_content + obj = await moto_s3.get_object(Bucket=BUCKET_NAME, Key=key) + assert (await obj["Body"].read()) == file_content -@pytest.fixture(scope="session") -def minio_client(demo_urls): +@pytest.fixture(scope="function") +async def minio_client(demo_urls): """Create a S3 client that uses minio from the demo as backend""" - yield botocore.session.get_session().create_client( + async with get_session().create_client( "s3", endpoint_url=demo_urls["minio"], aws_access_key_id="console", aws_secret_access_key="console123", - ) + ) as client: + yield client -@pytest.fixture(scope="session") -def test_bucket(minio_client): +@pytest.fixture(scope="function") +async def test_bucket(minio_client): """Create a test bucket that is cleaned up after the test session""" bucket_name = f"dirac-test-{secrets.token_hex(8)}" - minio_client.create_bucket(Bucket=bucket_name) + await minio_client.create_bucket(Bucket=bucket_name) yield bucket_name - for obj in minio_client.list_objects(Bucket=bucket_name)["Contents"]: - minio_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) - minio_client.delete_bucket(Bucket=bucket_name) + objects = await minio_client.list_objects(Bucket=bucket_name) + for obj in objects.get("Contents", []): + await minio_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) + await minio_client.delete_bucket(Bucket=bucket_name) @pytest.mark.parametrize( @@ -127,7 +125,7 @@ def test_bucket(minio_client): [_random_file(128)[0], _random_file(128)[1], 128, "ContentChecksumMismatch"], ], ) -def test_presigned_upload_minio( +async def test_presigned_upload_minio( minio_client, test_bucket, content, checksum, size, expected_error ): """Test the presigned upload with Minio @@ -138,7 +136,7 @@ def test_presigned_upload_minio( """ key = f"{checksum}.dat" # Prepare the signed URL - upload_info = generate_presigned_upload( + upload_info = await generate_presigned_upload( minio_client, test_bucket, key, "sha256", checksum, size, 60 ) # Ensure the URL doesn't work @@ -147,8 +145,8 @@ def test_presigned_upload_minio( ) if expected_error is None: assert r.status_code == 204, r.text - assert s3_object_exists(minio_client, test_bucket, key) + assert await s3_object_exists(minio_client, test_bucket, key) else: assert r.status_code == 400, r.text assert expected_error in r.text - assert not s3_object_exists(minio_client, test_bucket, key) + assert not (await s3_object_exists(minio_client, test_bucket, key)) From def84cd46e05c44285cfbd6fd1251457ea1311d1 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 3 Oct 2023 09:02:10 +0200 Subject: [PATCH 7/8] Add S3ClientKwargs annotation --- src/diracx/core/s3.py | 12 +++++++++++- src/diracx/routers/job_manager/sandboxes.py | 5 +++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/diracx/core/s3.py b/src/diracx/core/s3.py index 6fc56657..26e2609a 100644 --- a/src/diracx/core/s3.py +++ b/src/diracx/core/s3.py @@ -2,13 +2,14 @@ from __future__ import annotations __all__ = ( + "S3ClientKwargs", "s3_bucket_exists", "s3_object_exists", "generate_presigned_upload", ) import base64 -from typing import TYPE_CHECKING, TypedDict, cast +from typing import TYPE_CHECKING, NotRequired, TypedDict, cast from botocore.errorfactory import ClientError @@ -18,6 +19,15 @@ from types_aiobotocore_s3.client import S3Client +class S3ClientKwargs(TypedDict): + region_name: NotRequired[str] + use_ssl: NotRequired[bool] + verify: NotRequired[bool | str] + endpoint_url: NotRequired[str] + aws_access_key_id: NotRequired[str] + aws_secret_access_key: NotRequired[str] + + class S3PresignedPostInfo(TypedDict): url: str fields: dict[str, str] diff --git a/src/diracx/routers/job_manager/sandboxes.py b/src/diracx/routers/job_manager/sandboxes.py index 0c1c0ee8..1ddf9ee3 100644 --- a/src/diracx/routers/job_manager/sandboxes.py +++ b/src/diracx/routers/job_manager/sandboxes.py @@ -16,6 +16,7 @@ ) from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.core.s3 import ( + S3ClientKwargs, generate_presigned_upload, s3_bucket_exists, s3_object_exists, @@ -38,7 +39,7 @@ class SandboxStoreSettings(ServiceSettingsBase, env_prefix="DIRACX_SANDBOX_STORE """Settings for the sandbox store.""" bucket_name: str - s3_client_kwargs: dict[str, str] + s3_client_kwargs: S3ClientKwargs auto_create_bucket: bool = False url_validity_seconds: int = 5 * 60 _client: S3Client = PrivateAttr(None) @@ -49,7 +50,7 @@ async def lifetime_function(self) -> AsyncIterator[None]: "s3", **self.s3_client_kwargs, config=Config(signature_version="v4"), - ) as self._client: # type: ignore + ) as self._client: if not await s3_bucket_exists(self._client, self.bucket_name): if not self.auto_create_bucket: raise ValueError( From 4d7f8fea82c2945ad93f521b7c3fbef8d0e1eb77 Mon Sep 17 00:00:00 2001 From: Chris Burr Date: Tue, 3 Oct 2023 08:57:56 +0200 Subject: [PATCH 8/8] WIP: Support pydantic 2 --- .pre-commit-config.yaml | 2 +- environment.yml | 2 +- setup.cfg | 3 ++- src/diracx/core/config/schema.py | 27 ++++++++++++++------------- src/diracx/core/models.py | 16 ++++++++-------- src/diracx/core/preferences.py | 6 ++++-- src/diracx/core/settings.py | 18 +++++++----------- 7 files changed, 37 insertions(+), 37 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ca15338..39ecb475 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,7 +28,7 @@ repos: hooks: - id: mypy additional_dependencies: - - pydantic==1.10.10 + - pydantic - sqlalchemy - types-PyYAML - types-cachetools diff --git a/environment.yml b/environment.yml index a8167b3c..f2f29961 100644 --- a/environment.yml +++ b/environment.yml @@ -31,7 +31,7 @@ dependencies: - isodate - mypy - opensearch-py - - pydantic =1.10.10 + - pydantic >=2 - pyjwt - pytest - pytest-asyncio diff --git a/setup.cfg b/setup.cfg index 01c95b0f..983aeacc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,8 @@ install_requires = isodate mypy opensearch-py - pydantic ==1.10.10 + pydantic >=2 + pydantic-settings >=2 python-dotenv python-jose python-multipart diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py index b06dc296..6945ed32 100644 --- a/src/diracx/core/config/schema.py +++ b/src/diracx/core/config/schema.py @@ -5,13 +5,14 @@ from typing import Any, Optional from pydantic import BaseModel as _BaseModel -from pydantic import EmailStr, PrivateAttr, root_validator +from pydantic import EmailStr, PrivateAttr, model_validator from ..properties import SecurityProperty class BaseModel(_BaseModel, extra="forbid", allow_mutation=False): - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def legacy_adaptor(cls, v): """Applies transformations to interpret the legacy DIRAC CFG format""" if not os.environ.get("DIRAC_COMPAT_ENABLE_CS_CONVERSION"): @@ -37,7 +38,7 @@ class UserConfig(BaseModel): CA: str DN: str PreferedUsername: str - Email: EmailStr | None + Email: EmailStr | None = None Suspended: list[str] = [] Quota: int | None = None # TODO: These should be LHCbDIRAC specific @@ -49,12 +50,12 @@ class GroupConfig(BaseModel): AutoAddVOMS: bool = False AutoUploadPilotProxy: bool = False AutoUploadProxy: bool = False - JobShare: Optional[int] + JobShare: Optional[int] = None Properties: list[SecurityProperty] - Quota: Optional[int] + Quota: Optional[int] = None Users: list[str] AllowBackgroundTQs: bool = False - VOMSRole: Optional[str] + VOMSRole: Optional[str] = None AutoSyncVOMS: bool = False @@ -97,7 +98,7 @@ class JobMonitoringConfig(BaseModel): class ServicesConfig(BaseModel): - Catalogs: dict[str, Any] | None + Catalogs: dict[str, Any] | None = None JobMonitoring: JobMonitoringConfig = JobMonitoringConfig() @@ -138,12 +139,12 @@ class Config(BaseModel): # TODO: Should this be split by vo rather than setup? Operations: dict[str, OperationsConfig] - LocalSite: Any - LogLevel: Any - MCTestingDestination: Any - Resources: Any - Systems: Any - WebApp: Any + LocalSite: Any = None + LogLevel: Any = None + MCTestingDestination: Any = None + Resources: Any = None + Systems: Any = None + WebApp: Any = None _hexsha: str = PrivateAttr() _modified: datetime = PrivateAttr() diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index 21fc53a9..8fd07285 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -46,7 +46,7 @@ class TokenResponse(BaseModel): access_token: str expires_in: int token_type: str = "Bearer" - refresh_token: str | None + refresh_token: str | None = None class JobStatus(StrEnum): @@ -98,13 +98,13 @@ class JobStatusReturn(LimitedJobStatusReturn): class SetJobStatusReturn(BaseModel): - status: JobStatus | None = Field(alias="Status") - minor_status: str | None = Field(alias="MinorStatus") - application_status: str | None = Field(alias="ApplicationStatus") - heartbeat_time: datetime | None = Field(alias="HeartBeatTime") - start_exec_time: datetime | None = Field(alias="StartExecTime") - end_exec_time: datetime | None = Field(alias="EndExecTime") - last_update_time: datetime | None = Field(alias="LastUpdateTime") + status: JobStatus | None = Field(None, alias="Status") + minor_status: str | None = Field(None, alias="MinorStatus") + application_status: str | None = Field(None, alias="ApplicationStatus") + heartbeat_time: datetime | None = Field(None, alias="HeartBeatTime") + start_exec_time: datetime | None = Field(None, alias="StartExecTime") + end_exec_time: datetime | None = Field(None, alias="EndExecTime") + last_update_time: datetime | None = Field(None, alias="LastUpdateTime") class UserInfo(BaseModel): diff --git a/src/diracx/core/preferences.py b/src/diracx/core/preferences.py index 92bcdb82..144d3fa4 100644 --- a/src/diracx/core/preferences.py +++ b/src/diracx/core/preferences.py @@ -8,7 +8,8 @@ from functools import lru_cache from pathlib import Path -from pydantic import AnyHttpUrl, BaseSettings, Field, validator +from pydantic import AnyHttpUrl, Field, field_validator +from pydantic_settings import BaseSettings from .utils import dotenv_files_from_environment @@ -41,7 +42,8 @@ class DiracxPreferences(BaseSettings, env_prefix="DIRACX_"): def from_env(cls): return cls(_env_file=dotenv_files_from_environment("DIRACX_DOTENV")) - @validator("log_level", pre=True) + @field_validator("log_level", mode="before") + @classmethod def validate_log_level(cls, v: str): if isinstance(v, str): return getattr(LogLevels, v.upper()) diff --git a/src/diracx/core/settings.py b/src/diracx/core/settings.py index dcc36e61..1b8daee2 100644 --- a/src/diracx/core/settings.py +++ b/src/diracx/core/settings.py @@ -8,15 +8,11 @@ import contextlib from pathlib import Path -from typing import TYPE_CHECKING, Any, AsyncIterator, Self, TypeVar +from typing import Any, AsyncIterator, Self, TypeVar from authlib.jose import JsonWebKey -from pydantic import AnyUrl, BaseSettings, SecretStr, parse_obj_as - -if TYPE_CHECKING: - from pydantic.config import BaseConfig - from pydantic.fields import ModelField - +from pydantic import AnyUrl, SecretStr, parse_obj_as +from pydantic_settings import BaseSettings T = TypeVar("T") @@ -33,11 +29,11 @@ def __init__(self, data: str): self.jwk = JsonWebKey.import_key(self.get_secret_value()) @classmethod - # TODO: This should return TokenSigningKey but pydantic's type hints are wrong - def validate(cls, value: Any) -> SecretStr: + def validate(cls, value: Any) -> TokenSigningKey: """Load private keys from files if needed""" if isinstance(value, str) and not value.strip().startswith("-----BEGIN"): url = parse_obj_as(LocalFileUrl, value) + assert url.path, url.path value = Path(url.path).read_text() return super().validate(value) @@ -48,11 +44,11 @@ class LocalFileUrl(AnyUrl): @classmethod # TODO: This should return LocalFileUrl but pydantic's type hints are wrong - def validate(cls, value: Any, field: ModelField, config: BaseConfig) -> AnyUrl: + def validate(cls, value: Any) -> AnyUrl: """Overrides AnyUrl.validate to add file:// scheme if not present.""" if isinstance(value, str) and "://" not in value: value = f"file://{value}" - return super().validate(value, field, config) + return super().validate(value) class ServiceSettingsBase(BaseSettings, allow_mutation=False):