Skip to content

Commit

Permalink
Merge pull request #118 from chrisburr/sandbox-api
Browse files Browse the repository at this point in the history
Add API methods for uploading/downloading sandboxes
  • Loading branch information
chaen authored Oct 3, 2023
2 parents 09130c5 + 7877a5c commit 248ae87
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/diracx/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

__all__ = ("jobs",)

from . import jobs
92 changes: 92 additions & 0 deletions src/diracx/api/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

__all__ = ("create_sandbox", "download_sandbox")

import hashlib
import logging
import os
import tarfile
import tempfile
from pathlib import Path

import httpx

from diracx.client.aio import DiracClient
from diracx.client.models import SandboxInfo

from .utils import with_client

logger = logging.getLogger(__name__)

SANDBOX_CHECKSUM_ALGORITHM = "sha256"
SANDBOX_COMPRESSION = "bz2"


@with_client
async def create_sandbox(paths: list[Path], *, client: DiracClient) -> str:
"""Create a sandbox from the given paths and upload it to the storage backend.
Any paths that are directories will be added recursively.
The returned value is the PFN of the sandbox in the storage backend and can
be used to submit jobs.
"""
with tempfile.TemporaryFile(mode="w+b") as tar_fh:
with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf:
for path in paths:
logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name)
tf.add(path.resolve(), path.name, recursive=True)
tar_fh.seek(0)

hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)()
while data := tar_fh.read(512 * 1024):
hasher.update(data)
checksum = hasher.hexdigest()
tar_fh.seek(0)
logger.debug("Sandbox checksum is %s", checksum)

sandbox_info = SandboxInfo(
checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM,
checksum=checksum,
size=os.stat(tar_fh.fileno()).st_size,
format=f"tar.{SANDBOX_COMPRESSION}",
)

res = await client.jobs.initiate_sandbox_upload(sandbox_info)
if res.url:
logger.debug("Uploading sandbox for %s", res.pfn)
files = {"file": ("file", tar_fh)}
async with httpx.AsyncClient() as httpx_client:
response = await httpx_client.post(
res.url, data=res.fields, files=files
)
# TODO: Handle this error better
response.raise_for_status()

logger.debug(
"Sandbox uploaded for %s with status code %s",
res.pfn,
response.status_code,
)
else:
logger.debug("%s already exists in storage backend", res.pfn)
return res.pfn


@with_client
async def download_sandbox(pfn: str, destination: Path, *, client: DiracClient):
"""Download a sandbox from the storage backend to the given destination."""
res = await client.jobs.get_sandbox_file(pfn=pfn)
logger.debug("Downloading sandbox for %s", pfn)
with tempfile.TemporaryFile(mode="w+b") as fh:
async with httpx.AsyncClient() as http_client:
response = await http_client.get(res.url)
# TODO: Handle this error better
response.raise_for_status()
async for chunk in response.aiter_bytes():
fh.write(chunk)
fh.seek(0)
logger.debug("Sandbox downloaded for %s", pfn)

with tarfile.open(fileobj=fh) as tf:
tf.extractall(path=destination, filter="data")
logger.debug("Extracted %s to %s", pfn, destination)
25 changes: 25 additions & 0 deletions src/diracx/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

__all__ = ("with_client",)

from functools import wraps

from diracx.client.aio import DiracClient


def with_client(func):
"""Decorator to provide a DiracClient to a function.
If the function already has a `client` keyword argument, it will be used.
Otherwise, a new DiracClient will be created and passed as a keyword argument.
"""

@wraps(func)
async def wrapper(*args, **kwargs):
if "client" in kwargs:
return await func(*args, **kwargs)

async with DiracClient() as client:
return await func(*args, **kwargs, client=client)

return wrapper
2 changes: 1 addition & 1 deletion src/diracx/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def dotenv_files_from_environment(prefix: str) -> list[str]:
return [v for _, v in sorted(env_files.items())]


def write_credentials(token_response: TokenResponse, location: Path | None = None):
def write_credentials(token_response: TokenResponse, *, location: Path | None = None):
"""Write credentials received in dirax_preferences.credentials_path"""
from diracx.core.preferences import get_diracx_preferences

Expand Down
Empty file added tests/api/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tests/api/test_jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import logging
import secrets

from diracx.api.jobs import create_sandbox, download_sandbox


async def test_upload_download_sandbox(tmp_path, with_cli_login, caplog):
caplog.set_level(logging.DEBUG)

input_directory = tmp_path / "input"
input_directory.mkdir()
input_files = []

input_file = input_directory / "input.dat"
input_file.write_bytes(secrets.token_bytes(512))
input_files.append(input_file)

input_file = input_directory / "a" / "b" / "c" / "nested.dat"
input_file.parent.mkdir(parents=True)
input_file.write_bytes(secrets.token_bytes(512))
input_files.append(input_file)

# Upload the sandbox
caplog.clear()
pfn = await create_sandbox(input_files)
assert has_record(caplog.records, "diracx.api.jobs", "Uploading sandbox for")

# Uploading the same sandbox again should return the same PFN
caplog.clear()
pfn2 = await create_sandbox(input_files)
assert pfn == pfn2
assert has_record(caplog.records, "diracx.api.jobs", "already exists in storage")

# Download the sandbox
destination = tmp_path / "output"
await download_sandbox(pfn, destination)
assert (destination / "input.dat").is_file()
assert (destination / "nested.dat").is_file()


def has_record(records: list[logging.LogRecord], logger_name: str, message: str):
for record in records:
if record.name == logger_name and message in record.message:
return True
return False
24 changes: 24 additions & 0 deletions tests/api/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from diracx.api.utils import with_client
from diracx.client.aio import DiracClient


async def test_with_client_default(with_cli_login):
"""Ensure that the with_client decorator provides a DiracClient."""

@with_client
async def test_func(*, client):
assert isinstance(client, DiracClient)

await test_func()


async def test_with_client_override():
"""Ensure that the with_client can be overridden by providing a client kwarg."""

@with_client
async def test_func(*, client):
assert client == "foobar"

await test_func(client="foobar")

0 comments on commit 248ae87

Please sign in to comment.