Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Kamat <[email protected]>
  • Loading branch information
ayushkamat committed Oct 10, 2024
1 parent f10e534 commit 063e787
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 114 deletions.
51 changes: 22 additions & 29 deletions latch_cli/services/cp/download/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import queue
import shutil
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from textwrap import dedent
from typing import Dict, List, Literal, Optional, TypedDict
Expand All @@ -11,11 +10,12 @@
import requests
import requests.adapters
import tqdm
import uvloop

from ....utils import get_auth_header, human_readable_time, urljoins, with_si_suffix
from ....utils import get_auth_header, human_readable_time, with_si_suffix
from ....utils.path import normalize_path
from ..glob import expand_pattern
from .worker import Work, worker
from .worker import Work, run_workers

http_session = requests.Session()

Expand Down Expand Up @@ -75,7 +75,7 @@ def download(
from latch.ldata.type import LDataNodeType

all_node_data = _get_node_data(*srcs)
work_queue = queue.Queue()
work_queue = asyncio.Queue[Work]()
total = 0

if expand_globs:
Expand Down Expand Up @@ -131,10 +131,10 @@ def download(

try:
work_dest.unlink(missing_ok=True)
work_queue.put_nowait(Work(gsud["url"], work_dest, chunk_size_mib))
except OSError:
shutil.rmtree(work_dest)
click.echo(f"Cannot write file to {work_dest} - directory exists.")

work_queue.put(Work(gsud["url"], work_dest, chunk_size_mib))
else:
gsurd: GetSignedUrlsRecursiveData = json["data"]
total += len(gsurd["urls"])
Expand All @@ -157,20 +157,14 @@ def download(
for rel, url in gsurd["urls"].items():
res = work_dest / rel

for parent in res.parents:
try:
parent.mkdir(exist_ok=True, parents=True)
break
except NotADirectoryError: # somewhere up the tree is a file
continue
except FileExistsError:
parent.unlink()
break

# todo(ayush): use only one mkdir call
res.parent.mkdir(exist_ok=True, parents=True)

work_queue.put(Work(url, work_dest / rel, chunk_size_mib))
try:
res.parent.mkdir(exist_ok=True, parents=True)
work_queue.put_nowait(Work(url, work_dest / rel, chunk_size_mib))
except (NotADirectoryError, FileExistsError):
click.echo(
f"Cannot write file to {work_dest / rel} - upstream file"
" exists."
)

tbar = tqdm.tqdm(
total=total,
Expand All @@ -182,16 +176,15 @@ def download(
disable=progress == "none",
)

workers = min(total, cores)
with ThreadPoolExecutor(workers) as exec:
futs = [
exec.submit(worker, work_queue, tbar, progress == "tasks", verbose)
for _ in range(workers)
]
num_workers = min(total, cores)
uvloop.install()

loop = uvloop.new_event_loop()
res = loop.run_until_complete(
run_workers(work_queue, num_workers, tbar, progress != "none", verbose)
)

total_bytes = 0
for fut in as_completed(futs):
total_bytes += fut.result()
total_bytes = sum(res)

tbar.clear()
total_time = time.monotonic() - start
Expand Down
73 changes: 39 additions & 34 deletions latch_cli/services/cp/download/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
import os
import queue
import shutil
import time
from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path
from typing import Awaitable, List

import aiohttp
import tqdm
import uvloop

from latch_cli.services.cp.utils import chunked

from ....constants import Units
from ..http_utils import RetryClientSession

Expand Down Expand Up @@ -37,41 +41,46 @@ async def download_chunk(
pbar.update(os.pwrite(fd, content, start))


async def work_loop(
work_queue: queue.Queue,
async def worker(
work_queue: asyncio.Queue[Work],
tbar: tqdm.tqdm,
show_task_progress: bool,
print_file_on_completion: bool,
) -> int:
pbar = tqdm.tqdm(
total=0,
leave=False,
smoothing=0,
unit="B",
unit_scale=True,
disable=not show_task_progress,
)

total_bytes = 0

async with RetryClientSession(read_timeout=90, conn_timeout=10) as sess:
while True:
try:
work: Work = work_queue.get_nowait()
except queue.Empty:
break
try:
async with RetryClientSession(read_timeout=90, conn_timeout=10) as sess:
while True:
try:
work: Work = work_queue.get_nowait()
except asyncio.QueueEmpty:
break

pbar.reset()
pbar.desc = work.dest.name

try:
if work.dest.exists() and work.dest.is_dir():
shutil.rmtree(work.dest)
res = await sess.get(work.url, headers={"Range": "bytes=0-0"})

async with sess.get(work.url) as res:
total_size = res.content_length
assert total_size is not None
# s3 throws a REQUESTED_RANGE_NOT_SATISFIABLE if the file is empty
if res.status == 416:
total_size = 0
else:
content_range = res.headers["Content-Range"]
total_size = int(content_range.replace("bytes 0-0/", ""))

total_bytes += total_size
assert total_size is not None

total_bytes += total_size
pbar.total = total_size
pbar.desc = work.dest.name

chunk_size = work.chunk_size_mib * Units.MiB

Expand All @@ -90,27 +99,23 @@ async def work_loop(
await asyncio.gather(*coros)

if print_file_on_completion:
pbar.write(work.dest.name)

except Exception as e:
raise Exception(f"{work}: {e}")
pbar.write(str(work.dest))

pbar.reset()
tbar.update(1)
tbar.update(1)

pbar.clear()
return total_bytes
return total_bytes
finally:
pbar.clear()


def worker(
work_queue: queue.Queue,
async def run_workers(
work_queue: asyncio.Queue[Work],
num_workers: int,
tbar: tqdm.tqdm,
show_task_progress: bool,
print_file_on_completion: bool,
) -> int:
uvloop.install()

loop = uvloop.new_event_loop()
return loop.run_until_complete(
work_loop(work_queue, tbar, show_task_progress, print_file_on_completion)
)
) -> List[int]:
return await asyncio.gather(*[
worker(work_queue, tbar, show_task_progress, print_file_on_completion)
for _ in range(num_workers)
])
28 changes: 19 additions & 9 deletions latch_cli/services/cp/http_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from http import HTTPStatus
from typing import Awaitable, Callable, List, Optional
from typing import Awaitable, Callable, Dict, List, Optional

import aiohttp
import aiohttp.typedefs
from typing_extensions import ParamSpec

P = ParamSpec("P")
Expand Down Expand Up @@ -39,14 +40,21 @@ def __init__(
self.retries = retries
self.backoff = backoff

self.semas: Dict[aiohttp.typedefs.StrOrURL, asyncio.BoundedSemaphore] = {
"https://nucleus.latch.bio/ldata/start-upload": asyncio.BoundedSemaphore(2),
"https://nucleus.latch.bio/ldata/end-upload": asyncio.BoundedSemaphore(2),
}

super().__init__(*args, **kwargs)

async def _with_retry(
async def _request(
self,
f: Callable[P, Awaitable[aiohttp.ClientResponse]],
*args: P.args,
**kwargs: P.kwargs,
method: str,
str_or_url: aiohttp.typedefs.StrOrURL,
**kwargs,
) -> aiohttp.ClientResponse:
sema = self.semas.get(str_or_url)

error: Optional[Exception] = None
last_res: Optional[aiohttp.ClientResponse] = None

Expand All @@ -58,7 +66,12 @@ async def _with_retry(
cur += 1

try:
res = await f(*args, **kwargs)
if sema is None:
res = await super()._request(method, str_or_url, **kwargs)
else:
async with sema:
res = await super()._request(method, str_or_url, **kwargs)

if res.status in self.status_list:
last_res = res
continue
Expand All @@ -76,6 +89,3 @@ async def _with_retry(

# we'll never get here but putting here anyway so the type checker is happy
raise RetriesExhaustedException

async def _request(self, *args, **kwargs) -> aiohttp.ClientResponse:
return await self._with_retry(super()._request, *args, **kwargs)
31 changes: 14 additions & 17 deletions latch_cli/services/cp/upload/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import asyncio
import os
import queue
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent
from typing import List, Literal, Optional

import click
import tqdm
import uvloop

from ....utils import human_readable_time, urljoins, with_si_suffix
from ....utils.path import normalize_path
from .worker import Work, worker
from .worker import Work, run_workers


def upload(
Expand All @@ -36,7 +34,7 @@ def upload(
from latch.ldata.path import _get_node_data
from latch.ldata.type import LDataNodeType

dest_data = _get_node_data(dest).data[dest]
dest_data = _get_node_data(dest, allow_resolve_to_parent=True).data[dest]
dest_is_dir = dest_data.type in {
LDataNodeType.account_root,
LDataNodeType.mount,
Expand All @@ -45,7 +43,7 @@ def upload(
LDataNodeType.dir,
}

work_queue = queue.Queue()
work_queue = asyncio.Queue[Work]()
total_bytes = 0
num_files = 0

Expand All @@ -56,7 +54,7 @@ def upload(

normalized = normalize_path(dest)

if not dest_data.exists:
if not dest_data.exists():
root = normalized
elif src_path.is_dir():
if not dest_is_dir:
Expand All @@ -75,7 +73,7 @@ def upload(
num_files += 1
total_bytes += src_path.resolve().stat().st_size

work_queue.put(Work(src_path, root, chunk_size_mib))
work_queue.put_nowait(Work(src_path, root, chunk_size_mib))
else:

for dir, _, file_names in os.walk(src_path, followlinks=True):
Expand All @@ -91,7 +89,7 @@ def upload(
num_files += 1

remote = urljoins(root, str(rel.relative_to(src_path)))
work_queue.put(Work(rel, remote, chunk_size_mib))
work_queue.put_nowait(Work(rel, remote, chunk_size_mib))

total = tqdm.tqdm(
total=num_files,
Expand All @@ -103,15 +101,14 @@ def upload(
disable=progress == "none",
)

workers = min(cores, num_files)
with ThreadPoolExecutor(workers) as exec:
futs = [
exec.submit(worker, work_queue, total, progress == "tasks", verbose)
for _ in range(workers)
]
num_workers = min(cores, num_files)

for f in as_completed(futs):
f.result()
uvloop.install()

loop = uvloop.new_event_loop()
loop.run_until_complete(
run_workers(work_queue, num_workers, total, progress == "tasks", verbose)
)

total.clear()
total_time = time.monotonic() - start
Expand Down
Loading

0 comments on commit 063e787

Please sign in to comment.