From 063e787a6695e07179099da2fdf6e34b17989393 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Thu, 10 Oct 2024 15:15:36 -0700 Subject: [PATCH] updates Signed-off-by: Ayush Kamat --- latch_cli/services/cp/download/main.py | 51 +++++++---------- latch_cli/services/cp/download/worker.py | 73 +++++++++++++----------- latch_cli/services/cp/http_utils.py | 28 ++++++--- latch_cli/services/cp/upload/main.py | 31 +++++----- latch_cli/services/cp/upload/worker.py | 52 ++++++++++------- latch_cli/services/cp/utils.py | 24 +++++++- 6 files changed, 145 insertions(+), 114 deletions(-) diff --git a/latch_cli/services/cp/download/main.py b/latch_cli/services/cp/download/main.py index a17daff1..c40f452e 100644 --- a/latch_cli/services/cp/download/main.py +++ b/latch_cli/services/cp/download/main.py @@ -2,7 +2,6 @@ import queue import shutil import time -from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from textwrap import dedent from typing import Dict, List, Literal, Optional, TypedDict @@ -11,11 +10,12 @@ import requests import requests.adapters import tqdm +import uvloop -from ....utils import get_auth_header, human_readable_time, urljoins, with_si_suffix +from ....utils import get_auth_header, human_readable_time, with_si_suffix from ....utils.path import normalize_path from ..glob import expand_pattern -from .worker import Work, worker +from .worker import Work, run_workers http_session = requests.Session() @@ -75,7 +75,7 @@ def download( from latch.ldata.type import LDataNodeType all_node_data = _get_node_data(*srcs) - work_queue = queue.Queue() + work_queue = asyncio.Queue[Work]() total = 0 if expand_globs: @@ -131,10 +131,10 @@ def download( try: work_dest.unlink(missing_ok=True) + work_queue.put_nowait(Work(gsud["url"], work_dest, chunk_size_mib)) except OSError: - shutil.rmtree(work_dest) + click.echo(f"Cannot write file to {work_dest} - directory exists.") - work_queue.put(Work(gsud["url"], work_dest, chunk_size_mib)) else: gsurd: GetSignedUrlsRecursiveData = json["data"] total += len(gsurd["urls"]) @@ -157,20 +157,14 @@ def download( for rel, url in gsurd["urls"].items(): res = work_dest / rel - for parent in res.parents: - try: - parent.mkdir(exist_ok=True, parents=True) - break - except NotADirectoryError: # somewhere up the tree is a file - continue - except FileExistsError: - parent.unlink() - break - - # todo(ayush): use only one mkdir call - res.parent.mkdir(exist_ok=True, parents=True) - - work_queue.put(Work(url, work_dest / rel, chunk_size_mib)) + try: + res.parent.mkdir(exist_ok=True, parents=True) + work_queue.put_nowait(Work(url, work_dest / rel, chunk_size_mib)) + except (NotADirectoryError, FileExistsError): + click.echo( + f"Cannot write file to {work_dest / rel} - upstream file" + " exists." + ) tbar = tqdm.tqdm( total=total, @@ -182,16 +176,15 @@ def download( disable=progress == "none", ) - workers = min(total, cores) - with ThreadPoolExecutor(workers) as exec: - futs = [ - exec.submit(worker, work_queue, tbar, progress == "tasks", verbose) - for _ in range(workers) - ] + num_workers = min(total, cores) + uvloop.install() + + loop = uvloop.new_event_loop() + res = loop.run_until_complete( + run_workers(work_queue, num_workers, tbar, progress != "none", verbose) + ) - total_bytes = 0 - for fut in as_completed(futs): - total_bytes += fut.result() + total_bytes = sum(res) tbar.clear() total_time = time.monotonic() - start diff --git a/latch_cli/services/cp/download/worker.py b/latch_cli/services/cp/download/worker.py index ac94497d..ea232163 100644 --- a/latch_cli/services/cp/download/worker.py +++ b/latch_cli/services/cp/download/worker.py @@ -2,7 +2,9 @@ import os import queue import shutil +import time from dataclasses import dataclass +from http import HTTPStatus from pathlib import Path from typing import Awaitable, List @@ -10,6 +12,8 @@ import tqdm import uvloop +from latch_cli.services.cp.utils import chunked + from ....constants import Units from ..http_utils import RetryClientSession @@ -37,8 +41,8 @@ async def download_chunk( pbar.update(os.pwrite(fd, content, start)) -async def work_loop( - work_queue: queue.Queue, +async def worker( + work_queue: asyncio.Queue[Work], tbar: tqdm.tqdm, show_task_progress: bool, print_file_on_completion: bool, @@ -46,32 +50,37 @@ async def work_loop( pbar = tqdm.tqdm( total=0, leave=False, + smoothing=0, unit="B", unit_scale=True, disable=not show_task_progress, ) - total_bytes = 0 - async with RetryClientSession(read_timeout=90, conn_timeout=10) as sess: - while True: - try: - work: Work = work_queue.get_nowait() - except queue.Empty: - break + try: + async with RetryClientSession(read_timeout=90, conn_timeout=10) as sess: + while True: + try: + work: Work = work_queue.get_nowait() + except asyncio.QueueEmpty: + break + + pbar.reset() + pbar.desc = work.dest.name - try: - if work.dest.exists() and work.dest.is_dir(): - shutil.rmtree(work.dest) + res = await sess.get(work.url, headers={"Range": "bytes=0-0"}) - async with sess.get(work.url) as res: - total_size = res.content_length - assert total_size is not None + # s3 throws a REQUESTED_RANGE_NOT_SATISFIABLE if the file is empty + if res.status == 416: + total_size = 0 + else: + content_range = res.headers["Content-Range"] + total_size = int(content_range.replace("bytes 0-0/", "")) - total_bytes += total_size + assert total_size is not None + total_bytes += total_size pbar.total = total_size - pbar.desc = work.dest.name chunk_size = work.chunk_size_mib * Units.MiB @@ -90,27 +99,23 @@ async def work_loop( await asyncio.gather(*coros) if print_file_on_completion: - pbar.write(work.dest.name) - - except Exception as e: - raise Exception(f"{work}: {e}") + pbar.write(str(work.dest)) - pbar.reset() - tbar.update(1) + tbar.update(1) - pbar.clear() - return total_bytes + return total_bytes + finally: + pbar.clear() -def worker( - work_queue: queue.Queue, +async def run_workers( + work_queue: asyncio.Queue[Work], + num_workers: int, tbar: tqdm.tqdm, show_task_progress: bool, print_file_on_completion: bool, -) -> int: - uvloop.install() - - loop = uvloop.new_event_loop() - return loop.run_until_complete( - work_loop(work_queue, tbar, show_task_progress, print_file_on_completion) - ) +) -> List[int]: + return await asyncio.gather(*[ + worker(work_queue, tbar, show_task_progress, print_file_on_completion) + for _ in range(num_workers) + ]) diff --git a/latch_cli/services/cp/http_utils.py b/latch_cli/services/cp/http_utils.py index c77b337e..56570857 100644 --- a/latch_cli/services/cp/http_utils.py +++ b/latch_cli/services/cp/http_utils.py @@ -1,8 +1,9 @@ import asyncio from http import HTTPStatus -from typing import Awaitable, Callable, List, Optional +from typing import Awaitable, Callable, Dict, List, Optional import aiohttp +import aiohttp.typedefs from typing_extensions import ParamSpec P = ParamSpec("P") @@ -39,14 +40,21 @@ def __init__( self.retries = retries self.backoff = backoff + self.semas: Dict[aiohttp.typedefs.StrOrURL, asyncio.BoundedSemaphore] = { + "https://nucleus.latch.bio/ldata/start-upload": asyncio.BoundedSemaphore(2), + "https://nucleus.latch.bio/ldata/end-upload": asyncio.BoundedSemaphore(2), + } + super().__init__(*args, **kwargs) - async def _with_retry( + async def _request( self, - f: Callable[P, Awaitable[aiohttp.ClientResponse]], - *args: P.args, - **kwargs: P.kwargs, + method: str, + str_or_url: aiohttp.typedefs.StrOrURL, + **kwargs, ) -> aiohttp.ClientResponse: + sema = self.semas.get(str_or_url) + error: Optional[Exception] = None last_res: Optional[aiohttp.ClientResponse] = None @@ -58,7 +66,12 @@ async def _with_retry( cur += 1 try: - res = await f(*args, **kwargs) + if sema is None: + res = await super()._request(method, str_or_url, **kwargs) + else: + async with sema: + res = await super()._request(method, str_or_url, **kwargs) + if res.status in self.status_list: last_res = res continue @@ -76,6 +89,3 @@ async def _with_retry( # we'll never get here but putting here anyway so the type checker is happy raise RetriesExhaustedException - - async def _request(self, *args, **kwargs) -> aiohttp.ClientResponse: - return await self._with_retry(super()._request, *args, **kwargs) diff --git a/latch_cli/services/cp/upload/main.py b/latch_cli/services/cp/upload/main.py index 4dc4e02a..29f2326f 100644 --- a/latch_cli/services/cp/upload/main.py +++ b/latch_cli/services/cp/upload/main.py @@ -1,19 +1,17 @@ import asyncio import os -import queue import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass from pathlib import Path from textwrap import dedent from typing import List, Literal, Optional import click import tqdm +import uvloop from ....utils import human_readable_time, urljoins, with_si_suffix from ....utils.path import normalize_path -from .worker import Work, worker +from .worker import Work, run_workers def upload( @@ -36,7 +34,7 @@ def upload( from latch.ldata.path import _get_node_data from latch.ldata.type import LDataNodeType - dest_data = _get_node_data(dest).data[dest] + dest_data = _get_node_data(dest, allow_resolve_to_parent=True).data[dest] dest_is_dir = dest_data.type in { LDataNodeType.account_root, LDataNodeType.mount, @@ -45,7 +43,7 @@ def upload( LDataNodeType.dir, } - work_queue = queue.Queue() + work_queue = asyncio.Queue[Work]() total_bytes = 0 num_files = 0 @@ -56,7 +54,7 @@ def upload( normalized = normalize_path(dest) - if not dest_data.exists: + if not dest_data.exists(): root = normalized elif src_path.is_dir(): if not dest_is_dir: @@ -75,7 +73,7 @@ def upload( num_files += 1 total_bytes += src_path.resolve().stat().st_size - work_queue.put(Work(src_path, root, chunk_size_mib)) + work_queue.put_nowait(Work(src_path, root, chunk_size_mib)) else: for dir, _, file_names in os.walk(src_path, followlinks=True): @@ -91,7 +89,7 @@ def upload( num_files += 1 remote = urljoins(root, str(rel.relative_to(src_path))) - work_queue.put(Work(rel, remote, chunk_size_mib)) + work_queue.put_nowait(Work(rel, remote, chunk_size_mib)) total = tqdm.tqdm( total=num_files, @@ -103,15 +101,14 @@ def upload( disable=progress == "none", ) - workers = min(cores, num_files) - with ThreadPoolExecutor(workers) as exec: - futs = [ - exec.submit(worker, work_queue, total, progress == "tasks", verbose) - for _ in range(workers) - ] + num_workers = min(cores, num_files) - for f in as_completed(futs): - f.result() + uvloop.install() + + loop = uvloop.new_event_loop() + loop.run_until_complete( + run_workers(work_queue, num_workers, total, progress == "tasks", verbose) + ) total.clear() total_time = time.monotonic() - start diff --git a/latch_cli/services/cp/upload/worker.py b/latch_cli/services/cp/upload/worker.py index a452e5fb..19cf58e8 100644 --- a/latch_cli/services/cp/upload/worker.py +++ b/latch_cli/services/cp/upload/worker.py @@ -6,7 +6,7 @@ import random from dataclasses import dataclass from pathlib import Path -from typing import List, TypedDict +from typing import Iterable, List, TypedDict, TypeVar import aiohttp import click @@ -17,6 +17,7 @@ from latch_cli.utils import get_auth_header, with_si_suffix from ..http_utils import RateLimitExceeded, RetryClientSession +from ..utils import chunked @dataclass @@ -67,9 +68,12 @@ async def upload_chunk( min_part_size = 5 * Units.MiB +start_upload_sema = asyncio.BoundedSemaphore(2) +end_upload_sema = asyncio.BoundedSemaphore(2) -async def work_loop( - work_queue: queue.Queue, + +async def worker( + work_queue: asyncio.Queue[Work], total_pbar: tqdm.tqdm, show_task_progress: bool, print_file_on_completion: bool, @@ -77,6 +81,7 @@ async def work_loop( pbar = tqdm.tqdm( total=0, leave=False, + smoothing=0, unit="B", unit_scale=True, disable=not show_task_progress, @@ -85,8 +90,8 @@ async def work_loop( async with RetryClientSession(read_timeout=90, conn_timeout=10) as sess: while True: try: - work: Work = work_queue.get_nowait() - except queue.Empty: + work = work_queue.get_nowait() + except asyncio.QueueEmpty: break resolved = work.src @@ -131,9 +136,6 @@ async def work_loop( pbar.desc = resolved.name pbar.total = file_size - # jitter to not dos nuc-data - await asyncio.sleep(0.1 * random.random()) - resp = await sess.post( "https://nucleus.latch.bio/ldata/start-upload", headers={"Authorization": get_auth_header()}, @@ -143,6 +145,7 @@ async def work_loop( "part_count": part_count, }, ) + if resp.status == 429: raise RateLimitExceeded( "The service is currently under load and could not complete your" @@ -159,16 +162,21 @@ async def work_loop( # file is empty - nothing to do continue + parts: List[CompletedPart] = [] try: - parts = await asyncio.gather(*[ - upload_chunk(sess, resolved, url, index, part_size, pbar) - for index, url in enumerate(data["urls"]) - ]) + for pairs in chunked(enumerate(data["urls"])): + parts.extend( + await asyncio.gather(*[ + upload_chunk(sess, resolved, url, index, part_size, pbar) + for index, url in pairs + ]) + ) except TimeoutError: - work_queue.put(Work(work.src, work.dest, work.chunk_size_mib // 2)) + await work_queue.put( + Work(work.src, work.dest, work.chunk_size_mib // 2) + ) continue - # exception handling resp = await sess.post( "https://nucleus.latch.bio/ldata/end-upload", headers={"Authorization": get_auth_header()}, @@ -184,6 +192,7 @@ async def work_loop( ], }, ) + if resp.status == 429: raise RateLimitExceeded( "The service is currently under load and could not complete your" @@ -201,15 +210,14 @@ async def work_loop( pbar.clear() -def worker( - work_queue: queue.Queue, +async def run_workers( + work_queue: asyncio.Queue[Work], + num_workers: int, total: tqdm.tqdm, show_task_progress: bool, print_file_on_completion: bool, ): - uvloop.install() - - loop = uvloop.new_event_loop() - loop.run_until_complete( - work_loop(work_queue, total, show_task_progress, print_file_on_completion) - ) + await asyncio.gather(*[ + worker(work_queue, total, show_task_progress, print_file_on_completion) + for _ in range(num_workers) + ]) diff --git a/latch_cli/services/cp/utils.py b/latch_cli/services/cp/utils.py index 3d6bc858..6d16b64d 100644 --- a/latch_cli/services/cp/utils.py +++ b/latch_cli/services/cp/utils.py @@ -1,8 +1,9 @@ -from typing import List, TypedDict +import sys +from typing import Iterable, List, TypedDict, TypeVar -try: +if sys.version_info >= (3, 9): from functools import cache -except ImportError: +else: from functools import lru_cache as cache import gql @@ -162,3 +163,20 @@ def _get_known_domains_for_account() -> List[str]: res.extend(f"{x}.mount" for x in buckets) return res + + +chunk_batch_size = 3 + +T = TypeVar("T") + + +def chunked(iter: Iterable[T]) -> Iterable[List[T]]: + chunk = [] + for x in iter: + if len(chunk) == chunk_batch_size: + yield chunk + chunk = [] + + chunk.append(x) + + yield chunk