Skip to content

Commit

Permalink
Parallelize remote file downloads within each iree_tests directory. (#…
Browse files Browse the repository at this point in the history
…297)

Progress on #285

This is a simple improvement over serial processing, but it could still
be improved further.

Looks like this shaves ~10 seconds off runs in this repo:
* Before 45s:
https://github.com/nod-ai/SHARK-TestSuite/actions/runs/9996869262/job/27632145027#step:6:15
* After 35s:
https://github.com/nod-ai/SHARK-TestSuite/actions/runs/9997455742/job/27634060037?pr=297#step:6:15

I saw 2m+ runs in IREE, hopefully this helps there too. Should be able
to get the total time down to 10-20 seconds.
  • Loading branch information
ScottTodd authored Jul 22, 2024
1 parent da6b621 commit 6c1ff8f
Showing 1 changed file with 36 additions and 11 deletions.
47 changes: 36 additions & 11 deletions iree_tests/download_remote_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from azure.storage.blob import BlobClient, BlobProperties
from functools import partial
from huggingface_hub import hf_hub_download
from multiprocessing import Pool
from pathlib import Path
from typing import Optional
import argparse
Expand Down Expand Up @@ -211,29 +213,44 @@ def download_generic_remote_file(
raise NotImplementedError("generic remote file downloads not implemented yet")


def download_file(remote_file: str, test_dir: Path, cache_dir: Optional[Path]):
"""
Downloads a file from URL into test_dir, if the URL schema is supported.
If cache_dir is set, downloads there instead, creating a symlink from
test_dir/file_name to cache_dir/file_name.
"""
if "blob.core.windows.net" in remote_file:
download_azure_remote_file(remote_file, test_dir, cache_dir)
elif "huggingface" in remote_file:
download_huggingface_remote_file(remote_file, test_dir, cache_dir)
else:
download_generic_remote_file(remote_file, test_dir, cache_dir)


def download_files_for_test_case(
test_case_json: dict, test_dir: Path, cache_dir: Optional[Path]
test_case_json: dict, test_dir: Path, jobs: int, cache_dir: Optional[Path]
):
if "remote_files" not in test_case_json:
return

# This is naive (greedy, serial) for now. We could batch downloads that
# share a source:
# This is naive for now. We could further optimize with batching:
# * Iterate over all files (across all included paths), building a list
# of files to download (checking hashes / local references before
# adding to the list)
# * (Optionally) Determine disk space needed/available and ask before
# continuing
# * Group files based on source (e.g. Azure container)
# * Start batched/parallel downloads

for remote_file in test_case_json["remote_files"]:
if "blob.core.windows.net" in remote_file:
download_azure_remote_file(remote_file, test_dir, cache_dir)
elif "huggingface" in remote_file:
download_huggingface_remote_file(remote_file, test_dir, cache_dir)
else:
download_generic_remote_file(remote_file, test_dir, cache_dir)
with Pool(jobs) as pool:
pool.map(
partial(
download_file,
test_dir=test_dir,
cache_dir=cache_dir,
),
test_case_json["remote_files"],
)


if __name__ == "__main__":
Expand All @@ -249,6 +266,13 @@ def download_files_for_test_case(
help="Local cache directory to download into. If set, symlinks will be created pointing to "
"this location",
)
parser.add_argument(
"-j",
"--jobs",
type=int,
default=8,
help="Number of parallel processes to use when downloading files",
)
args = parser.parse_args()

# Adjust logging levels.
Expand Down Expand Up @@ -287,5 +311,6 @@ def download_files_for_test_case(
download_files_for_test_case(
test_case_json=test_case_json,
test_dir=test_dir,
jobs=args.jobs,
cache_dir=cache_dir_for_test,
)

0 comments on commit 6c1ff8f

Please sign in to comment.