-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add API for uploading/downloading sandboxes
- Loading branch information
Showing
7 changed files
with
194 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("jobs",) | ||
|
||
from . import jobs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("create_sandbox", "download_sandbox") | ||
|
||
import hashlib | ||
import logging | ||
import os | ||
import tarfile | ||
import tempfile | ||
from pathlib import Path | ||
|
||
import httpx | ||
|
||
from diracx.client.aio import DiracClient | ||
from diracx.client.models import SandboxInfo | ||
|
||
from .utils import with_client | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
SANDBOX_CHECKSUM_ALGORITHM = "sha256" | ||
SANDBOX_COMPRESSION = "bz2" | ||
|
||
|
||
@with_client | ||
async def create_sandbox(paths: list[Path], *, client: DiracClient) -> str: | ||
"""Create a sandbox from the given paths and upload it to the storage backend. | ||
Any paths that are directories will be added recursively. | ||
The returned value is the PFN of the sandbox in the storage backend and can | ||
be used to submit jobs. | ||
""" | ||
with tempfile.TemporaryFile(mode="w+b") as tar_fh: | ||
with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf: | ||
for path in paths: | ||
logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name) | ||
tf.add(path.resolve(), path.name, recursive=True) | ||
tar_fh.seek(0) | ||
|
||
hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)() | ||
while data := tar_fh.read(512 * 1024): | ||
hasher.update(data) | ||
checksum = hasher.hexdigest() | ||
tar_fh.seek(0) | ||
logger.debug("Sandbox checksum is %s", checksum) | ||
|
||
sandbox_info = SandboxInfo( | ||
checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM, | ||
checksum=checksum, | ||
size=os.stat(tar_fh.fileno()).st_size, | ||
format=f"tar.{SANDBOX_COMPRESSION}", | ||
) | ||
|
||
res = await client.jobs.initiate_sandbox_upload(sandbox_info) | ||
if res.url: | ||
logger.debug("Uploading sandbox for %s", res.pfn) | ||
files = {"file": ("file", tar_fh)} | ||
async with httpx.AsyncClient() as httpx_client: | ||
response = await httpx_client.post( | ||
res.url, data=res.fields, files=files | ||
) | ||
# TODO: Handle this error better | ||
response.raise_for_status() | ||
|
||
logger.debug( | ||
"Sandbox uploaded for %s with status code %s", | ||
res.pfn, | ||
response.status_code, | ||
) | ||
else: | ||
logger.debug("%s already exists in storage backend", res.pfn) | ||
return res.pfn | ||
|
||
|
||
@with_client | ||
async def download_sandbox(pfn: str, destination: Path, *, client: DiracClient): | ||
"""Download a sandbox from the storage backend to the given destination.""" | ||
res = await client.jobs.get_sandbox_file(pfn=pfn) | ||
logger.debug("Downloading sandbox for %s", pfn) | ||
with tempfile.TemporaryFile(mode="w+b") as fh: | ||
async with httpx.AsyncClient() as http_client: | ||
response = await http_client.get(res.url) | ||
# TODO: Handle this error better | ||
response.raise_for_status() | ||
async for chunk in response.aiter_bytes(): | ||
fh.write(chunk) | ||
fh.seek(0) | ||
logger.debug("Sandbox downloaded for %s", pfn) | ||
|
||
with tarfile.open(fileobj=fh) as tf: | ||
tf.extractall(path=destination, filter="data") | ||
logger.debug("Extracted %s to %s", pfn, destination) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = ("with_client",) | ||
|
||
from functools import wraps | ||
|
||
from diracx.client.aio import DiracClient | ||
|
||
|
||
def with_client(func): | ||
"""Decorator to provide a DiracClient to a function. | ||
If the function already has a `client` keyword argument, it will be used. | ||
Otherwise, a new DiracClient will be created and passed as a keyword argument. | ||
""" | ||
|
||
@wraps(func) | ||
async def wrapper(*args, **kwargs): | ||
if "client" in kwargs: | ||
return await func(*args, **kwargs) | ||
|
||
async with DiracClient() as client: | ||
return await func(*args, **kwargs, client=client) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import secrets | ||
|
||
from diracx.api.jobs import create_sandbox, download_sandbox | ||
|
||
|
||
async def test_upload_download_sandbox(tmp_path, with_cli_login, caplog): | ||
caplog.set_level(logging.DEBUG) | ||
|
||
input_directory = tmp_path / "input" | ||
input_directory.mkdir() | ||
input_files = [] | ||
|
||
input_file = input_directory / "input.dat" | ||
input_file.write_bytes(secrets.token_bytes(512)) | ||
input_files.append(input_file) | ||
|
||
input_file = input_directory / "a" / "b" / "c" / "nested.dat" | ||
input_file.parent.mkdir(parents=True) | ||
input_file.write_bytes(secrets.token_bytes(512)) | ||
input_files.append(input_file) | ||
|
||
# Upload the sandbox | ||
caplog.clear() | ||
pfn = await create_sandbox(input_files) | ||
assert has_record(caplog.records, "diracx.api.jobs", "Uploading sandbox for") | ||
|
||
# Uploading the same sandbox again should return the same PFN | ||
caplog.clear() | ||
pfn2 = await create_sandbox(input_files) | ||
assert pfn == pfn2 | ||
assert has_record(caplog.records, "diracx.api.jobs", "already exists in storage") | ||
|
||
# Download the sandbox | ||
destination = tmp_path / "output" | ||
await download_sandbox(pfn, destination) | ||
assert (destination / "input.dat").is_file() | ||
assert (destination / "nested.dat").is_file() | ||
|
||
|
||
def has_record(records: list[logging.LogRecord], logger_name: str, message: str): | ||
for record in records: | ||
if record.name == logger_name and message in record.message: | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from diracx.api.utils import with_client | ||
from diracx.client.aio import DiracClient | ||
|
||
|
||
async def test_with_client_default(with_cli_login): | ||
"""Ensure that the with_client decorator provides a DiracClient.""" | ||
|
||
@with_client | ||
async def test_func(*, client): | ||
assert isinstance(client, DiracClient) | ||
|
||
await test_func() | ||
|
||
|
||
async def test_with_client_override(): | ||
"""Ensure that the with_client can be overridden by providing a client kwarg.""" | ||
|
||
@with_client | ||
async def test_func(*, client): | ||
assert client == "foobar" | ||
|
||
await test_func(client="foobar") |