From d5506564cad4985ef5ce7b9138a3c7ec2404395e Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 28 May 2024 15:22:20 -0500 Subject: [PATCH] User/tom/fix/s2 pctasks perf (#291) - Share ContainerClient across BlobStorage methods - parallelize walk - Share BlobStorage objects in execute_workflow - Adjust expiry for SAS tokens --- datasets/sentinel-2/Dockerfile | 2 +- datasets/sentinel-2/README.md | 8 + datasets/sentinel-2/dataset.yaml | 8 +- pctasks/core/pctasks/core/storage/blob.py | 66 ++- pctasks/core/tests/storage/test_blob.py | 28 +- pctasks/run/pctasks/run/argo/client.py | 6 + pctasks/run/pctasks/run/settings.py | 1 + .../pctasks/run/workflow/executor/models.py | 4 + .../pctasks/run/workflow/executor/remote.py | 514 +++++++++--------- pctasks/run/tests/workflow/test_remote.py | 6 +- 10 files changed, 353 insertions(+), 290 deletions(-) diff --git a/datasets/sentinel-2/Dockerfile b/datasets/sentinel-2/Dockerfile index faeb20324..05eb2c193 100644 --- a/datasets/sentinel-2/Dockerfile +++ b/datasets/sentinel-2/Dockerfile @@ -29,7 +29,7 @@ RUN curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/downloa ENV PATH /opt/conda/bin:$PATH ENV LD_LIBRARY_PATH /opt/conda/lib/:$LD_LIBRARY_PATH -RUN mamba install -y -c conda-forge python=3.8 gdal=3.3.3 pip setuptools cython numpy==1.21.5 +RUN mamba install -y -c conda-forge python=3.11 gdal pip setuptools cython numpy RUN python -m pip install --upgrade pip diff --git a/datasets/sentinel-2/README.md b/datasets/sentinel-2/README.md index 5e09734f8..d0aa96ef0 100644 --- a/datasets/sentinel-2/README.md +++ b/datasets/sentinel-2/README.md @@ -13,3 +13,11 @@ ```shell az acr build -r {the registry} --subscription {the subscription} -t pctasks-sentinel-2:latest -t pctasks-sentinel-2:{date}.{count} -f datasets/sentinel-2/Dockerfile . ``` + +## Update Workflow + +Created with + +``` +pctasks dataset process-items --is-update-workflow sentinel-2-l2a-update -d datasets/sentinel-2/dataset.yaml +``` \ No newline at end of file diff --git a/datasets/sentinel-2/dataset.yaml b/datasets/sentinel-2/dataset.yaml index 51d45cfb5..dc20ddc47 100644 --- a/datasets/sentinel-2/dataset.yaml +++ b/datasets/sentinel-2/dataset.yaml @@ -1,5 +1,5 @@ id: sentinel-2 -image: ${{ args.registry }}/pctasks-sentinel-2:2023.8.15.0 +image: ${{ args.registry }}/pctasks-sentinel-2:2024.5.28.0 args: - registry @@ -32,8 +32,10 @@ collections: options: # extensions: [.safe] ends_with: manifest.safe - min_depth: 7 - max_depth: 8 + # From the root, we want a depth of 7 + # But we start at depth=2 thanks to the split, so we use a depth of 5 here. + min_depth: 5 + max_depth: 5 chunk_length: 5000 chunk_storage: uri: blob://sentinel2l2a01/sentinel2-l2-info/pctasks-chunks/ diff --git a/pctasks/core/pctasks/core/storage/blob.py b/pctasks/core/pctasks/core/storage/blob.py index 0a6490908..4a2ae17b6 100644 --- a/pctasks/core/pctasks/core/storage/blob.py +++ b/pctasks/core/pctasks/core/storage/blob.py @@ -1,3 +1,5 @@ +import concurrent.futures +import contextlib import logging import multiprocessing import os @@ -252,6 +254,7 @@ def __init__( self.storage_account_name = storage_account_name self.container_name = container_name self.prefix = prefix.strip("/") if prefix is not None else prefix + self._container_client_wrapper: Optional[ContainerClientWrapper] = None def __repr__(self) -> str: prefix_part = "" if self.prefix is None else f"/{self.prefix}" @@ -261,14 +264,17 @@ def __repr__(self) -> str: ) def _get_client(self) -> ContainerClientWrapper: - account_client = BlobServiceClient( - account_url=self.account_url, - credential=self._blob_creds, - ) - - container_client = account_client.get_container_client(self.container_name) + if self._container_client_wrapper is None: + account_client = BlobServiceClient( + account_url=self.account_url, + credential=self._blob_creds, + ) - return ContainerClientWrapper(account_client, container_client) + container_client = account_client.get_container_client(self.container_name) + self._container_client_wrapper = ContainerClientWrapper( + account_client, container_client + ) + return self._container_client_wrapper def _get_name_starts_with( self, additional_prefix: Optional[str] = None @@ -337,7 +343,10 @@ def _generate_container_sas( attached credentials) to generate a container-level SAS token. """ start = Datetime.utcnow() - timedelta(hours=10) - expiry = Datetime.utcnow() + timedelta(hours=24 * 7) + # Chop off a couple hours at the end to avoid any issues with the + # SAS token having too long of a duration. + # https://github.com/microsoft/planetary-computer-tasks/pull/291#issuecomment-2135599782 + expiry = start + timedelta(hours=(24 * 7) - 2) permission = ContainerSasPermissions( read=read, write=write, @@ -388,7 +397,8 @@ def get_path(self, uri: str) -> str: return blob_uri.blob_name or "" def get_file_info(self, file_path: str) -> StorageFileInfo: - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client(self._add_prefix(file_path)) as blob: try: props = with_backoff(lambda: blob.get_blob_properties()) @@ -397,7 +407,8 @@ def get_file_info(self, file_path: str) -> StorageFileInfo: return StorageFileInfo(size=cast(int, props.size)) def file_exists(self, file_path: str) -> bool: - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client(self._add_prefix(file_path)) as blob: return with_backoff(lambda: blob.exists()) @@ -450,7 +461,8 @@ def fetch_blobs() -> Iterable[str]: for blob_name in page: yield blob_name - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): return with_backoff(fetch_blobs) def walk( @@ -464,6 +476,7 @@ def walk( matches: Optional[str] = None, walk_limit: Optional[int] = None, file_limit: Optional[int] = None, + max_concurrency: int = 32, ) -> Generator[Tuple[str, List[str], List[str]], None, None]: # Ensure UTC set since_date = map_opt(lambda d: d.replace(tzinfo=timezone.utc), since_date) @@ -488,6 +501,7 @@ def _get_depth(path: Optional[str]) -> int: def _get_prefix_content( full_prefix: Optional[str], ) -> Tuple[List[str], List[str]]: + logger.info("Listing prefix=%s", full_prefix) folders = [] files = [] for item in client.container.walk_blobs(name_starts_with=full_prefix): @@ -514,7 +528,9 @@ def _get_prefix_content( limit_break = False full_prefixes: List[str] = [self._get_name_starts_with(name_starts_with) or ""] - with self._get_client() as client: + client = self._get_client() + pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) + with contextlib.nullcontext(): while full_prefixes: if walk_limit and walk_count >= walk_limit: break @@ -523,6 +539,7 @@ def _get_prefix_content( break next_level_prefixes: List[str] = [] + futures = {} for full_prefix in full_prefixes: if walk_limit and walk_count >= walk_limit: limit_break = True @@ -534,7 +551,12 @@ def _get_prefix_content( if max_depth and prefix_depth > max_depth: break - folders, files = _get_prefix_content(full_prefix) + future = pool.submit(_get_prefix_content, full_prefix) + futures[future] = full_prefix + + for future in concurrent.futures.as_completed(futures): + full_prefix = futures[future] + folders, files = future.result() files = [file for file in files if path_filter(file)] @@ -570,7 +592,8 @@ def download_file( if timeout_seconds is not None: kwargs["timeout"] = timeout_seconds - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client(self._add_prefix(file_path)) as blob: with open(output_path, "wb" if is_binary else "w") as f: try: @@ -585,7 +608,8 @@ def upload_bytes( target_path: str, overwrite: bool = True, ) -> None: - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client( self._add_prefix(target_path) ) as blob: @@ -615,7 +639,8 @@ def upload_file( kwargs = {} if content_type: kwargs["content_settings"] = ContentSettings(content_type=content_type) - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client( self._add_prefix(target_path) ) as blob: @@ -629,7 +654,8 @@ def _upload() -> None: def read_bytes(self, file_path: str) -> bytes: try: blob_path = self._add_prefix(file_path) - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client(blob_path) as blob: blob_data = with_backoff( lambda: blob.download_blob( @@ -649,7 +675,8 @@ def read_bytes(self, file_path: str) -> bytes: def write_bytes(self, file_path: str, data: bytes, overwrite: bool = True) -> None: full_path = self._add_prefix(file_path) - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client(full_path) as blob: with_backoff( lambda: blob.upload_blob(data, overwrite=overwrite) # type: ignore @@ -660,7 +687,8 @@ def delete_folder(self, folder_path: Optional[str] = None) -> None: self.delete_file(file_path) def delete_file(self, file_path: str) -> None: - with self._get_client() as client: + client = self._get_client() + with contextlib.nullcontext(): with client.container.get_blob_client(self._add_prefix(file_path)) as blob: try: with_backoff(lambda: blob.delete_blob()) diff --git a/pctasks/core/tests/storage/test_blob.py b/pctasks/core/tests/storage/test_blob.py index 5923ef65a..81ecc7e15 100644 --- a/pctasks/core/tests/storage/test_blob.py +++ b/pctasks/core/tests/storage/test_blob.py @@ -88,21 +88,19 @@ def test_blob_download_timeout(): with temp_azurite_blob_storage( HERE / ".." / "data-files" / "simple-assets" ) as storage: - with storage._get_client() as client: - with client.container.get_blob_client( - storage._add_prefix("a/asset-a-1.json") - ) as blob: - storage_stream_downloader = blob.download_blob(timeout=TIMEOUT_SECONDS) - assert ( - storage_stream_downloader._request_options["timeout"] - == TIMEOUT_SECONDS - ) - - storage_stream_downloader = blob.download_blob() - assert ( - storage_stream_downloader._request_options.pop("timeout", None) - is None - ) + client = storage._get_client() + with client.container.get_blob_client( + storage._add_prefix("a/asset-a-1.json") + ) as blob: + storage_stream_downloader = blob.download_blob(timeout=TIMEOUT_SECONDS) + assert ( + storage_stream_downloader._request_options["timeout"] == TIMEOUT_SECONDS + ) + + storage_stream_downloader = blob.download_blob() + assert ( + storage_stream_downloader._request_options.pop("timeout", None) is None + ) @pytest.mark.parametrize( diff --git a/pctasks/run/pctasks/run/argo/client.py b/pctasks/run/pctasks/run/argo/client.py index d0e4e7769..afce13927 100644 --- a/pctasks/run/pctasks/run/argo/client.py +++ b/pctasks/run/pctasks/run/argo/client.py @@ -8,6 +8,7 @@ import argo_workflows from argo_workflows.api import workflow_service_api from argo_workflows.exceptions import NotFoundException +from argo_workflows.model.capabilities import Capabilities from argo_workflows.model.container import Container from argo_workflows.model.env_var import EnvVar from argo_workflows.model.io_argoproj_workflow_v1alpha1_template import ( @@ -26,6 +27,7 @@ IoArgoprojWorkflowV1alpha1WorkflowTerminateRequest, ) from argo_workflows.model.object_meta import ObjectMeta +from argo_workflows.model.security_context import SecurityContext from argo_workflows.models import ( Affinity, IoArgoprojWorkflowV1alpha1Metadata, @@ -219,6 +221,10 @@ def submit_workflow( image_pull_policy=get_pull_policy(runner_image), command=["pctasks"], env=env, + security_context=SecurityContext( + # Enables tools like py-spy for debugging + capabilities=Capabilities(add=["SYS_PTRACE"]) + ), args=[ "-v", "run", diff --git a/pctasks/run/pctasks/run/settings.py b/pctasks/run/pctasks/run/settings.py index 311a2d924..caecd4a59 100644 --- a/pctasks/run/pctasks/run/settings.py +++ b/pctasks/run/pctasks/run/settings.py @@ -52,6 +52,7 @@ class RunSettings(PCTasksSettings): def section_name(cls) -> str: return "run" + max_concurrent_workflow_tasks: int = 120 remote_runner_threads: int = 50 default_task_wait_seconds: int = 60 max_wait_retries: int = 10 diff --git a/pctasks/run/pctasks/run/workflow/executor/models.py b/pctasks/run/pctasks/run/workflow/executor/models.py index 10baa4aa0..f9df958e9 100644 --- a/pctasks/run/pctasks/run/workflow/executor/models.py +++ b/pctasks/run/pctasks/run/workflow/executor/models.py @@ -372,6 +372,10 @@ def partition_id(self) -> str: def run_id(self) -> str: return self.job_part_submit_msg.run_id + @property + def has_next_task(self) -> bool: + return bool(self.task_queue) + def prepare_next_task(self, settings: RunSettings) -> None: next_task_config = next(iter(self.task_queue), None) if next_task_config: diff --git a/pctasks/run/pctasks/run/workflow/executor/remote.py b/pctasks/run/pctasks/run/workflow/executor/remote.py index 8f2d61c4e..089de1063 100644 --- a/pctasks/run/pctasks/run/workflow/executor/remote.py +++ b/pctasks/run/pctasks/run/workflow/executor/remote.py @@ -3,7 +3,7 @@ import random import time from concurrent import futures -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Union from azure.storage.queue import BinaryBase64DecodePolicy, BinaryBase64EncodePolicy @@ -20,8 +20,9 @@ WorkflowRunStatus, ) from pctasks.core.models.task import CompletedTaskResult, FailedTaskResult -from pctasks.core.models.workflow import WorkflowSubmitMessage +from pctasks.core.models.workflow import JobDefinition, WorkflowSubmitMessage from pctasks.core.queues import QueueService +from pctasks.core.storage.blob import BlobStorage from pctasks.core.utils import grouped, map_opt from pctasks.run.constants import TASKS_TEMPLATE_PATH from pctasks.run.dag import sort_jobs @@ -31,7 +32,6 @@ JobPartition, JobPartitionSubmitMessage, PreparedTaskData, - PreparedTaskSubmitMessage, SuccessfulTaskSubmitResult, ) from pctasks.run.settings import WorkflowExecutorConfig @@ -317,30 +317,35 @@ def update_submit_result( ) return TaskRunStatus.SUBMITTED - def complete_job_partitions( + def complete_job_partition_group( self, run_id: str, job_id: str, group_id: str, job_part_states: List[JobPartitionState], container: CosmosDBContainer[JobPartitionRunRecord], + max_concurrent_partition_tasks: int, + is_last_job: bool, + task_io_storage: BlobStorage, + task_log_storage: BlobStorage, ) -> List[Dict[str, Any]]: """Complete job partitions and return the results. This is a blocking loop that is meant to be called on it's own thread. """ - task_io_storage = self.config.run_settings.get_task_io_storage() - task_log_storage = self.config.run_settings.get_log_storage() - completed_job_count = 0 + running_task_count = 0 failed_job_count = 0 total_job_count = len(job_part_states) - _jobs_left = lambda: total_job_count - completed_job_count - failed_job_count - _report_status = lambda: logger.info( + _jobs_left = ( + lambda: total_job_count - completed_job_count - failed_job_count + ) # noqa: E731 + _report_status = lambda: logger.info( # noqa: E731 f"{job_id} {group_id} status: " f"{completed_job_count} completed, " f"{failed_job_count} failed, " - f"{_jobs_left()} remaining" + f"{_jobs_left()} remaining, " + f"{running_task_count} tasks running" ) _last_runner_poll_time: float = time.monotonic() @@ -348,6 +353,7 @@ def complete_job_partitions( try: while _jobs_left() > 0: + # Check the task runner for any failed tasks. if ( time.monotonic() - _last_runner_poll_time @@ -411,6 +417,11 @@ def complete_job_partitions( if task_state.status == TaskStateStatus.NEW: # New task, submit it + # wait if max concurrent tasks are already running + if running_task_count >= max_concurrent_partition_tasks: + continue + else: + running_task_count += 1 update_task_run_status( container, @@ -441,6 +452,7 @@ def complete_job_partitions( elif task_state.status == TaskStateStatus.RUNNING: # Job is still running... + pass if not task_state.status_updated: update_task_run_status( container, @@ -454,6 +466,7 @@ def complete_job_partitions( elif task_state.status == TaskStateStatus.WAITING: # If we just moved the job state to waiting, # update the record. + running_task_count -= 1 if not task_state.status_updated: update_task_run_status( @@ -466,6 +479,8 @@ def complete_job_partitions( task_state.status_updated = True elif task_state.status == TaskStateStatus.FAILED: + running_task_count -= 1 + logger.warning( f"Task failed: {job_part_state.job_id} " f"- {task_state.task_id}" @@ -503,6 +518,8 @@ def complete_job_partitions( _report_status() elif task_state.status == TaskStateStatus.COMPLETED: + running_task_count -= 1 + logger.info( f"Task completed: {job_part_state.job_id}:{part_id}" f":{task_state.task_id}" @@ -519,9 +536,13 @@ def complete_job_partitions( ) continue - job_part_state.task_outputs[task_state.task_id] = { - "output": task_state.task_result.output - } + if (not is_last_job) or (job_part_state.has_next_task): + job_part_state.task_outputs[task_state.task_id] = { + "output": task_state.task_result.output + } + else: + # Clear task output to save memory + task_state.task_result.output = {} update_task_run_status( container, @@ -564,6 +585,12 @@ def complete_job_partitions( status=JobPartitionRunStatus.COMPLETED, ) self.handle_job_part_notifications(job_part_state) + + # If this is the last job, clear the task output + # to save memory + if is_last_job: + job_part_state.task_outputs = {} + except Exception: job_part_state.status = ( JobPartitionStateStatus.FAILED @@ -592,9 +619,190 @@ def complete_job_partitions( return [job_state.task_outputs for job_state in job_part_states] + def execute_job_partitions( + self, + run_id: str, + job_id: str, + workflow_run: WorkflowRunRecord, + job_def: JobDefinition, + total_job_part_count: int, + task_data: List[PreparedTaskData], + is_last_job: bool, + job_part_states: List[JobPartitionState], + wf_run_container: CosmosDBContainer[WorkflowRunRecord], + jp_container: CosmosDBContainer[JobPartitionRunRecord], + pool: futures.ThreadPoolExecutor, + task_io_storage: BlobStorage, + task_log_storage: BlobStorage, + ) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]: + """Execute job partitions and return the results. + + If return value is None, then the workflow is failed. + The output will be an empty dict if the job is the last job in the workflow. + Otherwise it will be the output of the job partitions. + """ + + try: + # Split the job partitions into groups + # based on the number of threads. + # Each thread will execute and monitor + # a group of job partitions. + grouped_job_partition_states = [ + (list(g), group_num) + for group_num, g in enumerate( + grouped( + job_part_states, + int( + math.ceil( + len(job_part_states) + / self.config.run_settings.remote_runner_threads + ) + ), + ) + ) + ] + + logger.info("Executing job partitions...") + + job_part_futures = { + pool.submit( + self.complete_job_partition_group, + run_id, + job_id, + f"job-part-group-{group_num}", + job_state_group, + jp_container, + int( + math.ceil( + self.config.run_settings.max_concurrent_workflow_tasks + / self.config.run_settings.remote_runner_threads + ) + ), + is_last_job, + task_io_storage=task_io_storage, + task_log_storage=task_log_storage, + ): job_state_group + for ( + job_state_group, + group_num, + ) in grouped_job_partition_states + } + + job_results: List[Dict[str, Any]] = [] + + job_done_count = 0 + failed_job_part_errors: List[str] = [] + job_failed = False + for job_future in futures.as_completed(job_part_futures.keys()): + job_part_states = job_part_futures[job_future] + + if job_future.cancelled(): + job_failed = True + failed_job_part_errors.append( + "Job partitions failed due to thread cancellation." + ) + + future_error = job_future.exception() + if future_error: + job_failed = True + failed_job_part_errors.append( + f"Job partitions thread failed with {future_error}" + ) + + job_done_count += len(job_part_states) + for job_part_state in job_part_states: + if job_part_state.status == JobPartitionStateStatus.FAILED: + job_failed = True + logger.warning( + f"JOB PART FAILED: {job_id} " + f"{job_part_state.partition_id}" + ) + else: + if not is_last_job: + # If this is not the last job in the + # workflow, record job outputs so they + # can be used to template downstream jobs. + # If this is the last job in the workflow, + # don't collect any job results. + job_results.append(job_part_state.task_outputs) + + logger.info( + f"Job {job_id} partition progress: " + f"({job_done_count}/{total_job_part_count})" + ) + + # ## PROCESS JOB PARTITION RESULTS + + if job_failed: + update_job_run_status( + wf_run_container, + workflow_run, + job_def.get_id(), + JobRunStatus.FAILED, + errors=failed_job_part_errors, + ) + return None + else: + result: Union[Dict[str, Any], List[Dict[str, Any]]] + + if is_last_job: + result = {} + elif len(job_results) == 1: + result = {TASKS_TEMPLATE_PATH: job_results[0]} + else: + job_output_entry: List[Dict[str, Any]] = [] + for job_result in job_results: + job_output_entry.append({TASKS_TEMPLATE_PATH: job_result}) + result = job_output_entry + + update_job_run_status( + wf_run_container, + workflow_run, + job_def.get_id(), + JobRunStatus.COMPLETED, + ) + + return result + finally: + logger.info(f"...cleaning up based on task data for job: {job_id}") + self.task_runner.cleanup([d.runner_info for d in task_data]) + def execute_workflow( self, submit_message: WorkflowSubmitMessage, + ) -> Dict[str, Any]: + workflow_id: str = submit_message.workflow.id + run_id = submit_message.run_id + + logger.info(f"*** Workflow started *** {workflow_id=}, {run_id=}") + + # trace_parent: Union[str, None] = None + # trace_state: Union[str, None] = None + # if submit_message.args is not None: + # trace_parent = submit_message.args.get("traceparent") + # trace_state = submit_message.args.get("tracestate") + + # otel_detach_token = otel.attach_to_parent( + # trace_parent, trace_state, operation_id=run_id + # ) + + # with tracer.start_as_current_span(__name__) as span: + # span.set_attribute("operation_id", workflow_id) + # span.set_attribute("run_id", run_id) + try: + result: Dict[str, Any] = self.execute_workflow_internal( + submit_message=submit_message + ) + finally: + pass + # if otel_detach_token: + # otel.detach(otel_detach_token) + + return result + + def execute_workflow_internal( + self, + submit_message: WorkflowSubmitMessage, ) -> Dict[str, Any]: workflow = submit_message.get_workflow_with_templated_args() trigger_event = map_opt(lambda e: e.dict(), submit_message.trigger_event) @@ -613,18 +821,14 @@ def execute_workflow( f"{run_settings.log_blob_container}/{log_path}" ) log_storage = run_settings.get_log_storage() + task_io_storage = run_settings.get_task_io_storage() + task_log_storage = run_settings.get_log_storage() with StorageLogger.from_uri(log_uri, log_storage=log_storage): - logger.info("***********************************") - logger.info(f"Workflow: {submit_message.workflow.id}") - logger.info(f"Run Id: {run_id}") - logger.info("***********************************") - logger.info(f"Logging to: {log_uri}") - logger.info("Creating CosmosDB connections...") - # Create containers + logger.info("Creating CosmosDB connections...") with WorkflowRunsContainer( WorkflowRunRecord, db=self.config.get_cosmosdb() ) as wf_run_container, WorkflowRunsContainer( @@ -649,8 +853,9 @@ def execute_workflow( try: workflow_jobs = list(workflow.definition.jobs.values()) sorted_jobs = sort_jobs(workflow_jobs) + len_jobs = len(sorted_jobs) logger.info(f"Running jobs: {[j.id for j in sorted_jobs]}") - for job_def in sorted_jobs: + for job_idx, job_def in enumerate(sorted_jobs): # For each job, create the job partitions # through the task pool, submit all initial # tasks, and then wait for all tasks to complete @@ -658,6 +863,12 @@ def execute_workflow( job_id = job_def.get_id() job_run = workflow_run.get_job_run(job_id) + # Track if this is the last job. + # If so, there's no reason to hold onto + # job outputs. Avoiding this will save + # memory. + is_last_job = job_idx == len_jobs - 1 + if not job_run: raise WorkflowRunRecordError( f"Job run {job_def.get_id()} not found." @@ -864,248 +1075,51 @@ def execute_workflow( workflow_failed = True continue - # ## SUBMIT INITIAL TASKS - - initial_tasks: List[PreparedTaskSubmitMessage] = [] - for job_part_state in job_part_states: - # Case with no tasks has already been handled. - assert job_part_state.current_task - initial_tasks.append( - job_part_state.current_task.prepared_task - ) - - # First tasks in a job are submitted in bulk to minimize - # the number of concurrent API calls. - logger.info(" - Submitting initial tasks...") - try: - submit_results: List[ - Union[ - SuccessfulTaskSubmitResult, FailedTaskSubmitResult - ] - ] = self.task_runner.submit_tasks(initial_tasks) - except Exception as e: - submit_results = [ - FailedTaskSubmitResult(errors=[str(e)]) - for _ in initial_tasks - ] - - logger.info(" - Initial tasks submitted, checking status...") - - try: - - def _update_submit_results( - tup: Tuple[ - JobPartitionState, - Union[ - SuccessfulTaskSubmitResult, - FailedTaskSubmitResult, - ], - ], - _jp_container: CosmosDBContainer[JobPartitionRunRecord], - ) -> TaskRunStatus: - jps, submit_result = tup - assert jps.current_task - - return self.update_submit_result( - jps.current_task, - submit_result, - container=_jp_container, - run_id=jps.run_id, - job_partition_run_id=jps.job_part_run_record_id, - ) - - any_submitted = False - all_submitted = True - for submitted in pool.map( - lambda x: _update_submit_results(x, jp_container), - zip(job_part_states, submit_results), - ): - _this_task_submitted = ( - submitted == TaskRunStatus.SUBMITTED - ) - any_submitted = any_submitted or _this_task_submitted - all_submitted = all_submitted and _this_task_submitted - - if all_submitted: - logger.info( - " - All initial tasks submitted successfully." - ) - elif any_submitted: - logger.info( - " - Some (not all) initial tasks " - "submitted successfully." - ) + current_job_outputs = self.execute_job_partitions( + run_id, + job_id, + workflow_run, + job_def, + total_job_part_count, + task_data, + is_last_job, + job_part_states, + wf_run_container, + jp_container, + pool, + task_io_storage, + task_log_storage, + ) - except Exception as e: - logger.error(f"Failed to update task status: {e}") - logger.exception(e) - logger.info( - f"...cleaning up based on task data for job: {job_id}" - ) - self.task_runner.cleanup([d.runner_info for d in task_data]) - update_job_run_status( - wf_run_container, - workflow_run, - job_def.get_id(), - JobRunStatus.FAILED, - ) + if current_job_outputs is None: workflow_failed = True - continue - - if not any_submitted: - logger.error(" - All initial tasks failed to submit.") - logger.info( - f"...cleaning up based on task data for job: {job_id}" - ) - self.task_runner.cleanup([d.runner_info for d in task_data]) - update_job_run_status( - wf_run_container, - workflow_run, - job_def.get_id(), - JobRunStatus.FAILED, - errors=[ - f"Failed to submit tasks for job {job_def.get_id()}" - ], - ) - workflow_failed = True - continue - - # ## MONITOR GROUPED JOB PARTITIONS - try: - # Split the job partitions into groups - # based on number of threads - grouped_job_partition_states = [ - (list(g), group_num) - for group_num, g in enumerate( - grouped( - job_part_states, - int( - math.ceil( - len(job_part_states) - / run_settings.remote_runner_threads - ) - ), - ) - ) - ] - - # Wait for the first task of the job partition to complete, - # and then complete all remaining tasks - - logger.info("Waiting for tasks to complete...") - - job_part_futures = { - pool.submit( - self.complete_job_partitions, - run_id, - job_id, - f"job-part-group-{group_num}", - job_state_group, - jp_container, - ): job_state_group - for ( - job_state_group, - group_num, - ) in grouped_job_partition_states - } - - job_results: List[Dict[str, Any]] = [] - - job_done_count = 0 - failed_job_part_errors: List[str] = [] - job_failed = False - for job_future in futures.as_completed( - job_part_futures.keys() - ): - job_part_states = job_part_futures[job_future] - - if job_future.cancelled(): - job_failed = True - failed_job_part_errors.append( - "Job partitions failed " - "due to thread cancellation." - ) - - future_error = job_future.exception() - if future_error: - job_failed = True - failed_job_part_errors.append( - str( - "Job partitions thread failed " - f"with {future_error}" - ) - ) - - job_done_count += len(job_part_states) - for job_part_state in job_part_states: - if ( - job_part_state.status - == JobPartitionStateStatus.FAILED - ): - job_failed = True - logger.warning( - f"JOB PART FAILED: {job_id} " - f"{job_part_state.partition_id}" - ) - else: - job_results.append(job_part_state.task_outputs) - - logger.info( - f"Job {job_id} partition progress: " - f"({job_done_count}/{total_job_part_count})" - ) - - # ## PROCESS JOB PARTITION RESULTS - - if job_failed: - update_job_run_status( - wf_run_container, - workflow_run, - job_def.get_id(), - JobRunStatus.FAILED, - errors=failed_job_part_errors, - ) - workflow_failed = True - else: - if len(job_results) == 1: - job_outputs[job_id] = { - TASKS_TEMPLATE_PATH: job_results[0] - } - else: - job_output_entry: List[Dict[str, Any]] = [] - for job_result in job_results: - job_output_entry.append( - {TASKS_TEMPLATE_PATH: job_result} - ) - job_outputs[job_id] = job_output_entry - - update_job_run_status( - wf_run_container, - workflow_run, - job_def.get_id(), - JobRunStatus.COMPLETED, - ) - finally: - logger.info( - f"...cleaning up based on task data for job: {job_id}" - ) - self.task_runner.cleanup([d.runner_info for d in task_data]) + else: + job_outputs[job_id] = current_job_outputs if workflow_failed: logger.error("Workflow failed!") - update_workflow_run_status( - wf_run_container, workflow_run, WorkflowRunStatus.FAILED - ) + # The workflow will be marked as failed in the except block + # update_workflow_run_status( + # wf_run_container, + # op_container, + # tracking_container, + # workflow_run, + # WorkflowRunStatus.FAILED, + # ) raise WorkflowFailedError(f"Workflow '{workflow.id}' failed.") else: logger.info("Workflow completed!") update_workflow_run_status( - wf_run_container, workflow_run, WorkflowRunStatus.COMPLETED + wf_run_container, + workflow_run, + WorkflowRunStatus.COMPLETED, ) except Exception as e: logger.exception(e) update_workflow_run_status( - wf_run_container, workflow_run, WorkflowRunStatus.FAILED + wf_run_container, + workflow_run, + WorkflowRunStatus.FAILED, ) return job_outputs diff --git a/pctasks/run/tests/workflow/test_remote.py b/pctasks/run/tests/workflow/test_remote.py index 63cf9676f..aa460a171 100644 --- a/pctasks/run/tests/workflow/test_remote.py +++ b/pctasks/run/tests/workflow/test_remote.py @@ -39,6 +39,8 @@ def run_workflow( run_settings = RunSettings.get() run_settings = run_settings.copy(deep=True) run_settings.task_poll_seconds = 5 + run_settings.max_concurrent_workflow_tasks = 5 + run_settings.remote_runner_threads = 2 cosmosdb_settings = CosmosDBSettings.get() with RemoteWorkflowExecutor( @@ -139,7 +141,7 @@ def test_remote_processes_dataset_like_workflow(): args: output_dir: "${{ args.base_output_dir }}/job-1-task-1" options: - num_outputs: 2 + num_outputs: 10 schema_version: 1.0.0 job-2: tasks: @@ -190,7 +192,7 @@ def test_remote_processes_dataset_like_workflow(): storage.list_files(name_starts_with="job-3-task-2/") ) - assert len(last_task_output_paths) == 4 + assert len(last_task_output_paths) == 20 @pytest.mark.usefixtures("cosmosdb_containers")