Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for uploading/downloading sandboxes #110

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Prepare environment
run: |
pip install typer pyyaml gitpython packaging
git clone https://github.com/DIRACGrid/DIRAC.git -b "${{ matrix.dirac-branch }}" /tmp/DIRACRepo
git clone https://github.com/chaen/DIRAC.git -b "diracx_sandbox" /tmp/DIRACRepo
# We need to cd in the directory for the integration_tests.py to work
- name: Prepare environment
run: cd /tmp/DIRACRepo && ./integration_tests.py prepare-environment "TEST_DIRACX=Yes" --extra-module "diracx=${GITHUB_WORKSPACE}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
pip install .
- name: Start demo
run: |
git clone https://github.com/DIRACGrid/diracx-charts.git ../diracx-charts
git clone https://github.com/chrisburr/diracx-charts.git ../diracx-charts -b update-notes
../diracx-charts/run_demo.sh --enable-coverage --exit-when-done $PWD
- name: Debugging information
run: |
Expand Down
4 changes: 4 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,9 @@ dependencies:
- types-PyYAML
- types-requests
- uvicorn
- moto
- mypy-boto3-s3
- botocore
- boto3-stubs
# - pip:
# - git+https://github.com/DIRACGrid/DIRAC.git@integration
5 changes: 5 additions & 0 deletions src/diracx/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

__all__ = ("jobs",)

from . import jobs
92 changes: 92 additions & 0 deletions src/diracx/api/jobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

__all__ = ("create_sandbox", "download_sandbox")

import hashlib
import logging
import os
import tarfile
import tempfile
from pathlib import Path

import httpx

from diracx.client.aio import DiracClient
from diracx.client.models import SandboxInfo

from .utils import with_client

logger = logging.getLogger(__name__)

SANDBOX_CHECKSUM_ALGORITHM = "sha256"
SANDBOX_COMPRESSION = "bz2"


@with_client
async def create_sandbox(paths: list[Path], *, client: DiracClient) -> str:
"""Create a sandbox from the given paths and upload it to the storage backend.

Any paths that are directories will be added recursively.
The returned value is the PFN of the sandbox in the storage backend and can
be used to submit jobs.
"""
with tempfile.TemporaryFile(mode="w+b") as tar_fh:
with tarfile.open(fileobj=tar_fh, mode=f"w|{SANDBOX_COMPRESSION}") as tf:
for path in paths:
logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name)
tf.add(path.resolve(), path.name, recursive=True)
tar_fh.seek(0)

hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)()
while data := tar_fh.read(512 * 1024):
hasher.update(data)
checksum = hasher.hexdigest()
tar_fh.seek(0)
logger.debug("Sandbox checksum is %s", checksum)

sandbox_info = SandboxInfo(
checksum_algorithm=SANDBOX_CHECKSUM_ALGORITHM,
checksum=checksum,
size=os.stat(tar_fh.fileno()).st_size,
format=f"tar.{SANDBOX_COMPRESSION}",
)

res = await client.jobs.initiate_sandbox_upload(sandbox_info)
if res.url:
logger.debug("Uploading sandbox for %s", res.pfn)
files = {"file": ("file", tar_fh)}
async with httpx.AsyncClient() as httpx_client:
response = await httpx_client.post(
res.url, data=res.fields, files=files
)
# TODO: Handle this error better
response.raise_for_status()

logger.debug(
"Sandbox uploaded for %s with status code %s",
res.pfn,
response.status_code,
)
else:
logger.debug("%s already exists in storage backend", res.pfn)
return res.pfn


@with_client
async def download_sandbox(pfn: str, destination: Path, *, client: DiracClient):
"""Download a sandbox from the storage backend to the given destination."""
res = await client.jobs.get_sandbox_file(pfn=pfn)
logger.debug("Downloading sandbox for %s", pfn)
with tempfile.TemporaryFile(mode="w+b") as fh:
async with httpx.AsyncClient() as http_client:
response = await http_client.get(res.url)
# TODO: Handle this error better
response.raise_for_status()
async for chunk in response.aiter_bytes():
fh.write(chunk)
fh.seek(0)
logger.debug("Sandbox downloaded for %s", pfn)

with tarfile.open(fileobj=fh) as tf:
tf.extractall(path=destination, filter="data")
logger.debug("Extracted %s to %s", pfn, destination)
25 changes: 25 additions & 0 deletions src/diracx/api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

__all__ = ("with_client",)

from functools import wraps

from diracx.client.aio import DiracClient


def with_client(func):
"""Decorator to provide a DiracClient to a function.

If the function already has a `client` keyword argument, it will be used.
Otherwise, a new DiracClient will be created and passed as a keyword argument.
"""

@wraps(func)
async def wrapper(*args, **kwargs):
if "client" in kwargs:
return await func(*args, **kwargs)

async with DiracClient() as client:
return await func(*args, **kwargs, client=client)

return wrapper
194 changes: 194 additions & 0 deletions src/diracx/client/aio/operations/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
build_jobs_delete_bulk_jobs_request,
build_jobs_get_job_status_bulk_request,
build_jobs_get_job_status_history_bulk_request,
build_jobs_get_sandbox_file_request,
build_jobs_get_single_job_request,
build_jobs_get_single_job_status_history_request,
build_jobs_get_single_job_status_request,
build_jobs_initiate_sandbox_upload_request,
build_jobs_kill_bulk_jobs_request,
build_jobs_search_request,
build_jobs_set_job_status_bulk_request,
Expand Down Expand Up @@ -819,6 +821,198 @@ def __init__(self, *args, **kwargs) -> None:
input_args.pop(0) if input_args else kwargs.pop("deserializer")
)

@distributed_trace_async
async def get_sandbox_file(
self, *, pfn: str, **kwargs: Any
) -> _models.SandboxDownloadResponse:
"""Get Sandbox File.

Get a presigned URL to download a sandbox file

This route cannot use a redirect response most clients will also send the
authorization header when following a redirect. This is not desirable as
it would leak the authorization token to the storage backend. Additionally,
most storage backends return an error when they receive an authorization
header for a presigned URL.

:keyword pfn: Required.
:paramtype pfn: str
:return: SandboxDownloadResponse
:rtype: ~client.models.SandboxDownloadResponse
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[_models.SandboxDownloadResponse] = kwargs.pop("cls", None)

request = build_jobs_get_sandbox_file_request(
pfn=pfn,
headers=_headers,
params=_params,
)
request.url = self._client.format_url(request.url)

_stream = False
pipeline_response: PipelineResponse = (
await self._client._pipeline.run( # pylint: disable=protected-access
request, stream=_stream, **kwargs
)
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(
status_code=response.status_code, response=response, error_map=error_map
)
raise HttpResponseError(response=response)

deserialized = self._deserialize("SandboxDownloadResponse", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})

return deserialized

@overload
async def initiate_sandbox_upload(
self,
body: _models.SandboxInfo,
*,
content_type: str = "application/json",
**kwargs: Any
) -> _models.SandboxUploadResponse:
"""Initiate Sandbox Upload.

Get the PFN for the given sandbox, initiate an upload as required.

If the sandbox already exists in the database then the PFN is returned
and there is no "url" field in the response.

If the sandbox does not exist in the database then the "url" and "fields"
should be used to upload the sandbox to the storage backend.

:param body: Required.
:type body: ~client.models.SandboxInfo
:keyword content_type: Body Parameter content-type. Content type parameter for JSON body.
Default value is "application/json".
:paramtype content_type: str
:return: SandboxUploadResponse
:rtype: ~client.models.SandboxUploadResponse
:raises ~azure.core.exceptions.HttpResponseError:
"""

@overload
async def initiate_sandbox_upload(
self, body: IO, *, content_type: str = "application/json", **kwargs: Any
) -> _models.SandboxUploadResponse:
"""Initiate Sandbox Upload.

Get the PFN for the given sandbox, initiate an upload as required.

If the sandbox already exists in the database then the PFN is returned
and there is no "url" field in the response.

If the sandbox does not exist in the database then the "url" and "fields"
should be used to upload the sandbox to the storage backend.

:param body: Required.
:type body: IO
:keyword content_type: Body Parameter content-type. Content type parameter for binary body.
Default value is "application/json".
:paramtype content_type: str
:return: SandboxUploadResponse
:rtype: ~client.models.SandboxUploadResponse
:raises ~azure.core.exceptions.HttpResponseError:
"""

@distributed_trace_async
async def initiate_sandbox_upload(
self, body: Union[_models.SandboxInfo, IO], **kwargs: Any
) -> _models.SandboxUploadResponse:
"""Initiate Sandbox Upload.

Get the PFN for the given sandbox, initiate an upload as required.

If the sandbox already exists in the database then the PFN is returned
and there is no "url" field in the response.

If the sandbox does not exist in the database then the "url" and "fields"
should be used to upload the sandbox to the storage backend.

:param body: Is either a SandboxInfo type or a IO type. Required.
:type body: ~client.models.SandboxInfo or IO
:keyword content_type: Body Parameter content-type. Known values are: 'application/json'.
Default value is None.
:paramtype content_type: str
:return: SandboxUploadResponse
:rtype: ~client.models.SandboxUploadResponse
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = kwargs.pop("params", {}) or {}

content_type: Optional[str] = kwargs.pop(
"content_type", _headers.pop("Content-Type", None)
)
cls: ClsType[_models.SandboxUploadResponse] = kwargs.pop("cls", None)

content_type = content_type or "application/json"
_json = None
_content = None
if isinstance(body, (IOBase, bytes)):
_content = body
else:
_json = self._serialize.body(body, "SandboxInfo")

request = build_jobs_initiate_sandbox_upload_request(
content_type=content_type,
json=_json,
content=_content,
headers=_headers,
params=_params,
)
request.url = self._client.format_url(request.url)

_stream = False
pipeline_response: PipelineResponse = (
await self._client._pipeline.run( # pylint: disable=protected-access
request, stream=_stream, **kwargs
)
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(
status_code=response.status_code, response=response, error_map=error_map
)
raise HttpResponseError(response=response)

deserialized = self._deserialize("SandboxUploadResponse", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})

return deserialized

@overload
async def submit_bulk_jobs(
self, body: List[str], *, content_type: str = "application/json", **kwargs: Any
Expand Down
Loading
Loading