diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index 3ab40361..deb55167 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -19,6 +19,7 @@ from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB +from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB T = TypeVar("T") @@ -32,6 +33,9 @@ def add_settings_annotation(cls: T) -> T: AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] +SandboxMetadataDB = Annotated[ + _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) +] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index df36e15b..39592c52 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -29,12 +29,14 @@ from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token from ..dependencies import JobDB, JobLoggingDB from ..fastapi_classes import DiracxRouter +from .sandboxes import router as sandboxes_router MAX_PARAMETRIC_JOBS = 20 logger = logging.getLogger(__name__) router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) +router.include_router(sandboxes_router) class JobSummaryParams(BaseModel): diff --git a/src/diracx/routers/job_manager/sandboxes.py b/src/diracx/routers/job_manager/sandboxes.py new file mode 100644 index 00000000..b47af93f --- /dev/null +++ b/src/diracx/routers/job_manager/sandboxes.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import TYPE_CHECKING, Annotated + +import botocore.session +from botocore.config import Config +from botocore.errorfactory import ClientError +from fastapi import Depends, HTTPException, Query +from pydantic import BaseModel, PrivateAttr +from sqlalchemy.exc import NoResultFound + +from diracx.core.models import ( + SandboxInfo, +) +from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER +from diracx.core.s3 import ( + generate_presigned_upload, + s3_bucket_exists, + s3_object_exists, +) +from diracx.core.settings import ServiceSettingsBase + +if TYPE_CHECKING: + from mypy_boto3_s3.client import S3Client + +from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token +from ..dependencies import SandboxMetadataDB, add_settings_annotation +from ..fastapi_classes import DiracxRouter + +MAX_SANDBOX_SIZE_BYTES = 100 * 1024 * 1024 +router = DiracxRouter(dependencies=[has_properties(NORMAL_USER | JOB_ADMINISTRATOR)]) + + +@add_settings_annotation +class SandboxStoreSettings(ServiceSettingsBase, env_prefix="DIRACX_SANDBOX_STORE_"): + """Settings for the sandbox store.""" + + bucket_name: str + s3_client_kwargs: dict[str, str] + auto_create_bucket: bool = False + url_validity_seconds: int = 5 * 60 + _client: S3Client = PrivateAttr(None) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # TODO: Use async + session = botocore.session.get_session() + self._client = session.create_client( + "s3", + # endpoint_url=s3_cred["endpoint"], + # aws_access_key_id=s3_cred["access_key_id"], + # aws_secret_access_key=s3_cred["secret_access_key"], + **self.s3_client_kwargs, + config=Config(signature_version="v4"), + ) + if not s3_bucket_exists(self._client, self.bucket_name): + if not self.auto_create_bucket: + raise ValueError( + f"Bucket {self.bucket_name} does not exist and auto_create_bucket is disabled" + ) + try: + self._client.create_bucket(Bucket=self.bucket_name) + except ClientError as e: + raise ValueError(f"Failed to create bucket {self.bucket_name}") from e + + @property + def s3_client(self) -> S3Client: + return self._client + + +class SandboxUploadResponse(BaseModel): + pfn: str + url: str | None = None + fields: dict[str, str] = {} + + +@router.post("/sandbox") +async def initiate_sandbox_upload( + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + sandbox_info: SandboxInfo, + sandbox_metadata_db: SandboxMetadataDB, + settings: SandboxStoreSettings, +) -> SandboxUploadResponse: + """Get the PFN for the given sandbox, initiate an upload as required. + + If the sandbox already exists in the database then the PFN is returned + and there is no "url" field in the response. + + If the sandbox does not exist in the database then the "url" and "fields" + should be used to upload the sandbox to the storage backend. + """ + if sandbox_info.size > MAX_SANDBOX_SIZE_BYTES: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Sandbox too large. Max size is {MAX_SANDBOX_SIZE_BYTES} bytes", + ) + + pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info) + + try: + exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned(pfn) + except NoResultFound: + # The sandbox doesn't exist in the database + pass + else: + # As sandboxes are registered in the DB before uploading to the storage + # backend we can't on their existence in the database to determine if + # they have been uploaded. Instead we check if the sandbox has been + # assigned to a job. If it has then we know it has been uploaded and we + # can avoid communicating with the storage backend. + if exists_and_assigned or s3_object_exists( + settings.s3_client, settings.bucket_name, pfn_to_key(pfn) + ): + await sandbox_metadata_db.update_sandbox_last_access_time(pfn) + return SandboxUploadResponse(pfn=pfn) + + upload_info = generate_presigned_upload( + settings.s3_client, + settings.bucket_name, + pfn_to_key(pfn), + sandbox_info.checksum_algorithm, + sandbox_info.checksum, + sandbox_info.size, + settings.url_validity_seconds, + ) + await sandbox_metadata_db.insert_sandbox(user_info, pfn, sandbox_info.size) + + return SandboxUploadResponse(**upload_info, pfn=pfn) + + +class SandboxDownloadResponse(BaseModel): + url: str + expires_in: int + + +def pfn_to_key(pfn: str) -> str: + """Convert a PFN to a key for S3 + + This removes the leading "/S3/" from the PFN. + """ + return "/".join(pfn.split("/")[3:]) + + +SANDBOX_PFN_REGEX = ( + # Starts with /S3/ + r"^/S3/[a-z0-9\.\-]{3,63}" + # Followed ////:. + r"(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\.[a-z0-9\.]+$" +) + + +@router.get("/sandbox") +async def get_sandbox_file( + pfn: Annotated[str, Query(max_length=256, pattern=SANDBOX_PFN_REGEX)], + settings: SandboxStoreSettings, +) -> SandboxDownloadResponse: + """Get a presigned URL to download a sandbox file + + This route cannot use a redirect response most clients will also send the + authorization header when following a redirect. This is not desirable as + it would leak the authorization token to the storage backend. Additionally, + most storage backends return an error when they receive an authorization + header for a presigned URL. + """ + # TODO: Prevent people from downloading other people's sandboxes? + # TODO: Support by name and by job id? + presigned_url = settings.s3_client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": settings.bucket_name, "Key": pfn_to_key(pfn)}, + ExpiresIn=settings.url_validity_seconds, + ) + return SandboxDownloadResponse( + url=presigned_url, expires_in=settings.url_validity_seconds + ) diff --git a/tests/conftest.py b/tests/conftest.py index 9b4d224c..60ec907f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,12 +14,14 @@ from cryptography.hazmat.primitives.asymmetric import rsa from fastapi.testclient import TestClient from git import Repo +from moto import mock_s3 from diracx.core.config import Config, ConfigSource from diracx.core.preferences import get_diracx_preferences from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.routers import create_app_inner from diracx.routers.auth import AuthSettings, create_token +from diracx.routers.job_manager.sandboxes import SandboxStoreSettings # to get a string like this run: # openssl rand -hex 32 @@ -78,8 +80,18 @@ def test_auth_settings() -> AuthSettings: ) +@pytest.fixture(scope="function") +def test_sandbox_settings() -> SandboxStoreSettings: + with mock_s3(): + yield SandboxStoreSettings( + bucket_name="sandboxes", + s3_client_kwargs={}, + auto_create_bucket=True, + ) + + @pytest.fixture -def with_app(test_auth_settings, with_config_repo): +def with_app(test_auth_settings, test_sandbox_settings, with_config_repo): """ Create a DiracxApp with hard coded configuration for test """ @@ -87,6 +99,7 @@ def with_app(test_auth_settings, with_config_repo): enabled_systems={".well-known", "auth", "config", "jobs"}, all_service_settings=[ test_auth_settings, + test_sandbox_settings, ], database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", diff --git a/tests/routers/jobs/__init__.py b/tests/routers/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/routers/jobs/test_sandboxes.py b/tests/routers/jobs/test_sandboxes.py new file mode 100644 index 00000000..a683c2b7 --- /dev/null +++ b/tests/routers/jobs/test_sandboxes.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import hashlib +import secrets +from io import BytesIO + +import requests +from fastapi.testclient import TestClient + + +def test_upload_then_download(normal_user_client: TestClient): + data = secrets.token_bytes(512) + checksum = hashlib.sha256(data).hexdigest() + + # Initiate the upload + r = normal_user_client.post( + "/jobs/sandbox", + json={ + "checksum_algorithm": "sha256", + "checksum": checksum, + "size": len(data), + "format": "tar.bz2", + }, + ) + assert r.status_code == 200, r.text + upload_info = r.json() + assert upload_info["url"] + sandbox_pfn = upload_info["pfn"] + assert sandbox_pfn.startswith("/S3/") + + # Actually upload the file + files = {"file": ("file", BytesIO(data))} + r = requests.post(upload_info["url"], data=upload_info["fields"], files=files) + assert r.status_code == 204, r.text + + # Make sure we can download it and get the same data back + r = normal_user_client.get("/jobs/sandbox", params={"pfn": sandbox_pfn}) + assert r.status_code == 200, r.text + download_info = r.json() + assert download_info["expires_in"] > 5 + r = requests.get(download_info["url"]) + assert r.status_code == 200, r.text + assert r.content == data + + +def test_upload_oversized(normal_user_client: TestClient): + data = secrets.token_bytes(512) + checksum = hashlib.sha256(data).hexdigest() + + # Initiate the upload + r = normal_user_client.post( + "/jobs/sandbox", + json={ + "checksum_algorithm": "sha256", + "checksum": checksum, + # We can forge the size here to be larger than the actual data as + # we should get an error and never actually upload the data + "size": 1024 * 1024 * 1024, + "format": "tar.bz2", + }, + ) + assert r.status_code == 400, r.text + assert "Sandbox too large" in r.json()["detail"], r.text