Skip to content

Commit

Permalink
feat(job_attachments)!: add mechanism to cancel file download (#3)
Browse files Browse the repository at this point in the history
BREAKING CHANGE:
- `on_downloading_files` must now return a bool indicating whether to cancel the download(s) for functions listed below:
  - AssetSync.sync_inputs(), and
  - download_files_in_directory(), download_files_from_manifest(), download_job_output() and mount_vfs_from_manifest() in download.py
  • Loading branch information
gahyusuh authored May 10, 2023
1 parent 413a207 commit ac44c2a
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 40 deletions.
168 changes: 168 additions & 0 deletions examples/download_cancel_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.

#! /usr/bin/env python3
import argparse
import pathlib
from tempfile import TemporaryDirectory
import time
from threading import Thread

from bealine_job_attachments.asset_sync import AssetSync
from bealine_job_attachments.aws.bealine import get_job, get_queue
from bealine_job_attachments.download import download_job_output
from bealine_job_attachments.errors import AssetSyncCancelledError

# A testing script to simulate cancellation of (1) syncing inputs, and (2) downloading outputs.
#
# How to test:
# 1. Run the script with the following command for each test:
# (1) To test canceling syncing inputs, run the following command:
# $ python3 download_cancel_test.py sync_inputs -f <farm_id> -q <queue_id> -j <job_id>
# (2) To test canceling downloading outputs, run the following command:
# $ python3 download_cancel_test.py download_outputs -f <farm_id> -q <queue_id> -j <job_id>
# 2. In the middle of downloading files, you can send a cencel signal by pressing 'k' key
# and then pressing 'Enter' key in succession. Confirm that cancelling is working as expected.

MESSAGE_HOW_TO_CANCEL = (
"To stop the download process, please hit 'k' key and then 'Enter' key in succession.\n"
)
continue_reporting = True
main_terminated = False


def run():
print(MESSAGE_HOW_TO_CANCEL)
parser = argparse.ArgumentParser(description=MESSAGE_HOW_TO_CANCEL)
parser.add_argument(
"test_to_run",
choices=["sync_inputs", "download_outputs"],
help="Test to run. ('sync_inputs' or 'download_outputs')",
)
parser.add_argument(
"-f", "--farm-id", type=str, help="Bealine Farm to download assets from.", required=True
)
parser.add_argument(
"-q", "--queue-id", type=str, help="Bealine Queue to download assets from.", required=True
)
parser.add_argument(
"-j", "--job-id", type=str, help="Bealine Job to download assets from.", required=True
)
args = parser.parse_args()

test_to_run = args.test_to_run
farm_id = args.farm_id
queue_id = args.queue_id
job_id = args.job_id

if test_to_run == "sync_inputs":
test_sync_inputs(farm_id=farm_id, queue_id=queue_id, job_id=job_id)
elif test_to_run == "download_outputs":
test_download_outputs(farm_id=farm_id, queue_id=queue_id, job_id=job_id)


def test_sync_inputs(
farm_id: str,
queue_id: str,
job_id: str,
):
"""
Tests cancellation during execution of the `sync_inputs` function.
"""
start_time = time.perf_counter()

with TemporaryDirectory() as temp_root_dir:
print(f"Created a temporary directory for the test: {temp_root_dir}")

queue = get_queue(farm_id=farm_id, queue_id=queue_id)
job = get_job(farm_id=farm_id, queue_id=queue_id, job_id=job_id)

print("Starting test to sync inputs...")
asset_sync = AssetSync()

try:
download_start = time.perf_counter()
(summary_statistics, local_roots) = asset_sync.sync_inputs(
s3_settings=queue.jobAttachmentSettings,
ja_settings=job.attachmentSettings,
queue_id=queue_id,
job_id=job_id,
session_dir=pathlib.Path(temp_root_dir),
on_downloading_files=mock_on_downloading_files,
)
print(f"Download Summary Statistics:\n{summary_statistics}")
print(
f"Finished downloading after {time.perf_counter() - download_start} seconds, returned:\n{local_roots}"
)

except AssetSyncCancelledError as asce:
print(f"AssetSyncCancelledError: {asce}")
print(f"payload: {asce.summary_statistics}")

print(f"\nTotal test runtime: {time.perf_counter() - start_time}")

print(f"Cleaned up the temporary directory: {temp_root_dir}")
global main_terminated
main_terminated = True


def test_download_outputs(
farm_id: str,
queue_id: str,
job_id: str,
):
"""
Tests cancellation during execution of the `download_job_output` function.
"""
start_time = time.perf_counter()

queue = get_queue(farm_id=farm_id, queue_id=queue_id)

print("Starting test to download outputs...")

try:
download_start = time.perf_counter()
summary_statistics = download_job_output(
s3_settings=queue.jobAttachmentSettings,
job_id=job_id,
on_downloading_files=mock_on_downloading_files,
)
print(f"Download Summary Statistics:\n{summary_statistics}")
print(f"Finished downloading after {time.perf_counter() - download_start} seconds")

except AssetSyncCancelledError as asce:
print(f"AssetSyncCancelledError: {asce}")
print(f"payload: {asce.summary_statistics}")

print(f"\nTotal test runtime: {time.perf_counter() - start_time}")

global main_terminated
main_terminated = True


def mock_on_downloading_files(metadata):
print(metadata)
return mock_on_cancellation_check()


def mock_on_cancellation_check():
return continue_reporting


def wait_for_cancellation_input():
while not main_terminated:
ch = input()
if ch == "k":
set_cancelled()
break


def set_cancelled():
global continue_reporting
continue_reporting = False
print("Canceled the process.")


if __name__ == "__main__":
t = Thread(target=wait_for_cancellation_input)
t.start()
run()
4 changes: 3 additions & 1 deletion hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ dependencies = [

[envs.default.scripts]
test = "pytest --cov-config pyproject.toml {args:test/bealine test/bealine_job_attachments/unit}"
integtest = "pytest {args:test/bealine_job_attachments/integ}"
typing = "mypy {args:src test}"
style = [
"ruff {args:.}",
Expand All @@ -41,6 +40,9 @@ lint = [
[[envs.all.matrix]]
python = ["3.7", "3.9", "3.10", "3.11"]

[envs.integ.scripts]
test = "pytest {args:test/bealine_job_attachments/integ} -vvv --numprocesses=1"

[envs.default.env-vars]
PIP_INDEX_URL="https://aws:{env:CODEARTIFACT_AUTH_TOKEN}@{env:CODEARTIFACT_DOMAIN}-{env:CODEARTIFACT_ACCOUNT_ID}.d.codeartifact.{env:CODEARTIFACT_REGION}.amazonaws.com/pypi/{env:CODEARTIFACT_REPOSITORY}/simple/"

Expand Down
18 changes: 15 additions & 3 deletions src/bealine_job_attachments/asset_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,24 @@ def sync_inputs(
queue_id: str,
job_id: str,
session_dir: Path,
on_downloading_files: Optional[Callable] = None,
on_downloading_files: Optional[Callable[[ProgressReportMetadata], bool]] = None,
) -> Tuple[DownloadSummaryStatistics, List[Dict[str, str]]]:
"""
Downloads a manifest file and corresponding input files, if found.
Returns a tuple of (1) final summary statistics for file downloads, and
(2) a list of local roots for each asset root, used for path mapping.
Args:
s3_settings: S3-specific Job Attachment settings.
ja_settings: Job Attachment settings.
queue_id: the ID of the queue.
job_id: the ID of the job.
session_dir: the directory that the session is going to use.
on_downloading_files: a function that will be called with a ProgressReportMetadata object
for each file being downloaded. If the function returns False, the download will be
cancelled. If it returns True, the download will continue.
Returns:
a tuple of (1) final summary statistics for file downloads, and
(2) a list of local roots for each asset root, used for path mapping.
"""
if not s3_settings:
logger.info(
Expand Down
Loading

0 comments on commit ac44c2a

Please sign in to comment.