Skip to content

Commit

Permalink
User/tom/fix/s2 pctasks perf (#291)
Browse files Browse the repository at this point in the history
- Share ContainerClient across BlobStorage methods
- parallelize walk
- Share BlobStorage objects in execute_workflow
- Adjust expiry for SAS tokens
  • Loading branch information
Tom Augspurger authored May 28, 2024
1 parent 71bf445 commit d550656
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 290 deletions.
2 changes: 1 addition & 1 deletion datasets/sentinel-2/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ RUN curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/downloa
ENV PATH /opt/conda/bin:$PATH
ENV LD_LIBRARY_PATH /opt/conda/lib/:$LD_LIBRARY_PATH

RUN mamba install -y -c conda-forge python=3.8 gdal=3.3.3 pip setuptools cython numpy==1.21.5
RUN mamba install -y -c conda-forge python=3.11 gdal pip setuptools cython numpy

RUN python -m pip install --upgrade pip

Expand Down
8 changes: 8 additions & 0 deletions datasets/sentinel-2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,11 @@
```shell
az acr build -r {the registry} --subscription {the subscription} -t pctasks-sentinel-2:latest -t pctasks-sentinel-2:{date}.{count} -f datasets/sentinel-2/Dockerfile .
```

## Update Workflow

Created with

```
pctasks dataset process-items --is-update-workflow sentinel-2-l2a-update -d datasets/sentinel-2/dataset.yaml
```
8 changes: 5 additions & 3 deletions datasets/sentinel-2/dataset.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
id: sentinel-2
image: ${{ args.registry }}/pctasks-sentinel-2:2023.8.15.0
image: ${{ args.registry }}/pctasks-sentinel-2:2024.5.28.0

args:
- registry
Expand Down Expand Up @@ -32,8 +32,10 @@ collections:
options:
# extensions: [.safe]
ends_with: manifest.safe
min_depth: 7
max_depth: 8
# From the root, we want a depth of 7
# But we start at depth=2 thanks to the split, so we use a depth of 5 here.
min_depth: 5
max_depth: 5
chunk_length: 5000
chunk_storage:
uri: blob://sentinel2l2a01/sentinel2-l2-info/pctasks-chunks/
66 changes: 47 additions & 19 deletions pctasks/core/pctasks/core/storage/blob.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import concurrent.futures
import contextlib
import logging
import multiprocessing
import os
Expand Down Expand Up @@ -252,6 +254,7 @@ def __init__(
self.storage_account_name = storage_account_name
self.container_name = container_name
self.prefix = prefix.strip("/") if prefix is not None else prefix
self._container_client_wrapper: Optional[ContainerClientWrapper] = None

def __repr__(self) -> str:
prefix_part = "" if self.prefix is None else f"/{self.prefix}"
Expand All @@ -261,14 +264,17 @@ def __repr__(self) -> str:
)

def _get_client(self) -> ContainerClientWrapper:
account_client = BlobServiceClient(
account_url=self.account_url,
credential=self._blob_creds,
)

container_client = account_client.get_container_client(self.container_name)
if self._container_client_wrapper is None:
account_client = BlobServiceClient(
account_url=self.account_url,
credential=self._blob_creds,
)

return ContainerClientWrapper(account_client, container_client)
container_client = account_client.get_container_client(self.container_name)
self._container_client_wrapper = ContainerClientWrapper(
account_client, container_client
)
return self._container_client_wrapper

def _get_name_starts_with(
self, additional_prefix: Optional[str] = None
Expand Down Expand Up @@ -337,7 +343,10 @@ def _generate_container_sas(
attached credentials) to generate a container-level SAS token.
"""
start = Datetime.utcnow() - timedelta(hours=10)
expiry = Datetime.utcnow() + timedelta(hours=24 * 7)
# Chop off a couple hours at the end to avoid any issues with the
# SAS token having too long of a duration.
# https://github.com/microsoft/planetary-computer-tasks/pull/291#issuecomment-2135599782
expiry = start + timedelta(hours=(24 * 7) - 2)
permission = ContainerSasPermissions(
read=read,
write=write,
Expand Down Expand Up @@ -388,7 +397,8 @@ def get_path(self, uri: str) -> str:
return blob_uri.blob_name or ""

def get_file_info(self, file_path: str) -> StorageFileInfo:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
try:
props = with_backoff(lambda: blob.get_blob_properties())
Expand All @@ -397,7 +407,8 @@ def get_file_info(self, file_path: str) -> StorageFileInfo:
return StorageFileInfo(size=cast(int, props.size))

def file_exists(self, file_path: str) -> bool:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
return with_backoff(lambda: blob.exists())

Expand Down Expand Up @@ -450,7 +461,8 @@ def fetch_blobs() -> Iterable[str]:
for blob_name in page:
yield blob_name

with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
return with_backoff(fetch_blobs)

def walk(
Expand All @@ -464,6 +476,7 @@ def walk(
matches: Optional[str] = None,
walk_limit: Optional[int] = None,
file_limit: Optional[int] = None,
max_concurrency: int = 32,
) -> Generator[Tuple[str, List[str], List[str]], None, None]:
# Ensure UTC set
since_date = map_opt(lambda d: d.replace(tzinfo=timezone.utc), since_date)
Expand All @@ -488,6 +501,7 @@ def _get_depth(path: Optional[str]) -> int:
def _get_prefix_content(
full_prefix: Optional[str],
) -> Tuple[List[str], List[str]]:
logger.info("Listing prefix=%s", full_prefix)
folders = []
files = []
for item in client.container.walk_blobs(name_starts_with=full_prefix):
Expand All @@ -514,7 +528,9 @@ def _get_prefix_content(
limit_break = False

full_prefixes: List[str] = [self._get_name_starts_with(name_starts_with) or ""]
with self._get_client() as client:
client = self._get_client()
pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency)
with contextlib.nullcontext():
while full_prefixes:
if walk_limit and walk_count >= walk_limit:
break
Expand All @@ -523,6 +539,7 @@ def _get_prefix_content(
break

next_level_prefixes: List[str] = []
futures = {}
for full_prefix in full_prefixes:
if walk_limit and walk_count >= walk_limit:
limit_break = True
Expand All @@ -534,7 +551,12 @@ def _get_prefix_content(
if max_depth and prefix_depth > max_depth:
break

folders, files = _get_prefix_content(full_prefix)
future = pool.submit(_get_prefix_content, full_prefix)
futures[future] = full_prefix

for future in concurrent.futures.as_completed(futures):
full_prefix = futures[future]
folders, files = future.result()

files = [file for file in files if path_filter(file)]

Expand Down Expand Up @@ -570,7 +592,8 @@ def download_file(
if timeout_seconds is not None:
kwargs["timeout"] = timeout_seconds

with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
with open(output_path, "wb" if is_binary else "w") as f:
try:
Expand All @@ -585,7 +608,8 @@ def upload_bytes(
target_path: str,
overwrite: bool = True,
) -> None:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(
self._add_prefix(target_path)
) as blob:
Expand Down Expand Up @@ -615,7 +639,8 @@ def upload_file(
kwargs = {}
if content_type:
kwargs["content_settings"] = ContentSettings(content_type=content_type)
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(
self._add_prefix(target_path)
) as blob:
Expand All @@ -629,7 +654,8 @@ def _upload() -> None:
def read_bytes(self, file_path: str) -> bytes:
try:
blob_path = self._add_prefix(file_path)
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(blob_path) as blob:
blob_data = with_backoff(
lambda: blob.download_blob(
Expand All @@ -649,7 +675,8 @@ def read_bytes(self, file_path: str) -> bytes:

def write_bytes(self, file_path: str, data: bytes, overwrite: bool = True) -> None:
full_path = self._add_prefix(file_path)
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(full_path) as blob:
with_backoff(
lambda: blob.upload_blob(data, overwrite=overwrite) # type: ignore
Expand All @@ -660,7 +687,8 @@ def delete_folder(self, folder_path: Optional[str] = None) -> None:
self.delete_file(file_path)

def delete_file(self, file_path: str) -> None:
with self._get_client() as client:
client = self._get_client()
with contextlib.nullcontext():
with client.container.get_blob_client(self._add_prefix(file_path)) as blob:
try:
with_backoff(lambda: blob.delete_blob())
Expand Down
28 changes: 13 additions & 15 deletions pctasks/core/tests/storage/test_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,19 @@ def test_blob_download_timeout():
with temp_azurite_blob_storage(
HERE / ".." / "data-files" / "simple-assets"
) as storage:
with storage._get_client() as client:
with client.container.get_blob_client(
storage._add_prefix("a/asset-a-1.json")
) as blob:
storage_stream_downloader = blob.download_blob(timeout=TIMEOUT_SECONDS)
assert (
storage_stream_downloader._request_options["timeout"]
== TIMEOUT_SECONDS
)

storage_stream_downloader = blob.download_blob()
assert (
storage_stream_downloader._request_options.pop("timeout", None)
is None
)
client = storage._get_client()
with client.container.get_blob_client(
storage._add_prefix("a/asset-a-1.json")
) as blob:
storage_stream_downloader = blob.download_blob(timeout=TIMEOUT_SECONDS)
assert (
storage_stream_downloader._request_options["timeout"] == TIMEOUT_SECONDS
)

storage_stream_downloader = blob.download_blob()
assert (
storage_stream_downloader._request_options.pop("timeout", None) is None
)


@pytest.mark.parametrize(
Expand Down
6 changes: 6 additions & 0 deletions pctasks/run/pctasks/run/argo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import argo_workflows
from argo_workflows.api import workflow_service_api
from argo_workflows.exceptions import NotFoundException
from argo_workflows.model.capabilities import Capabilities
from argo_workflows.model.container import Container
from argo_workflows.model.env_var import EnvVar
from argo_workflows.model.io_argoproj_workflow_v1alpha1_template import (
Expand All @@ -26,6 +27,7 @@
IoArgoprojWorkflowV1alpha1WorkflowTerminateRequest,
)
from argo_workflows.model.object_meta import ObjectMeta
from argo_workflows.model.security_context import SecurityContext
from argo_workflows.models import (
Affinity,
IoArgoprojWorkflowV1alpha1Metadata,
Expand Down Expand Up @@ -219,6 +221,10 @@ def submit_workflow(
image_pull_policy=get_pull_policy(runner_image),
command=["pctasks"],
env=env,
security_context=SecurityContext(
# Enables tools like py-spy for debugging
capabilities=Capabilities(add=["SYS_PTRACE"])
),
args=[
"-v",
"run",
Expand Down
1 change: 1 addition & 0 deletions pctasks/run/pctasks/run/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class RunSettings(PCTasksSettings):
def section_name(cls) -> str:
return "run"

max_concurrent_workflow_tasks: int = 120
remote_runner_threads: int = 50
default_task_wait_seconds: int = 60
max_wait_retries: int = 10
Expand Down
4 changes: 4 additions & 0 deletions pctasks/run/pctasks/run/workflow/executor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def partition_id(self) -> str:
def run_id(self) -> str:
return self.job_part_submit_msg.run_id

@property
def has_next_task(self) -> bool:
return bool(self.task_queue)

def prepare_next_task(self, settings: RunSettings) -> None:
next_task_config = next(iter(self.task_queue), None)
if next_task_config:
Expand Down
Loading

0 comments on commit d550656

Please sign in to comment.