diff --git a/src/geant4_python_application/application.py b/src/geant4_python_application/application.py index 81374ea..00c3017 100644 --- a/src/geant4_python_application/application.py +++ b/src/geant4_python_application/application.py @@ -43,7 +43,7 @@ def _start_application(pipe: multiprocessing.Pipe): class Application: def __init__(self): - geant4_python_application.datasets.install_datasets(show_progress=False) + geant4_python_application.datasets.install_datasets(show_progress=True) self._detector = geant4_python_application.Detector(self) self._pipe, child_pipe = multiprocessing.Pipe() diff --git a/src/geant4_python_application/datasets.py b/src/geant4_python_application/datasets.py index 30cf7a2..45aa837 100644 --- a/src/geant4_python_application/datasets.py +++ b/src/geant4_python_application/datasets.py @@ -9,7 +9,7 @@ from collections import namedtuple import requests -import tqdm +from tqdm import tqdm url = "https://cern.ch/geant4-data/datasets" data_dir = os.path.join(os.path.dirname(__file__), "geant4/data") @@ -98,25 +98,35 @@ ) -def _download_extract_dataset( - dataset: Dataset, progress_bar_position: int = 0, show_progress: bool = True -): +def _dataset_url(dataset: Dataset) -> str: + return f"{url}/{dataset.filename}.{dataset.version}.tar.gz" + + +def _get_dataset_download_size(dataset: Dataset) -> int: + r = requests.head(_dataset_url(dataset)) + r.raise_for_status() + return int(r.headers.get("content-length", 0)) + + +def _get_total_download_size(datasets_to_download: list[Dataset] = datasets) -> int: + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(datasets_to_download) + ) as executor: + futures = [ + executor.submit(_get_dataset_download_size, dataset) + for dataset in datasets_to_download + ] + return sum(f.result() for f in concurrent.futures.as_completed(futures)) + + +def _download_extract_dataset(dataset: Dataset, pbar: tqdm): filename = dataset.filename urlpath = f"{url}/{filename}.{dataset.version}.tar.gz" r = requests.get(urlpath, stream=True) r.raise_for_status() - # Get the total file size for tqdm - total_size = int(r.headers.get("content-length", 0)) - chunk_size = 4096 - with tempfile.TemporaryFile() as f, tqdm.tqdm( - total=total_size, - unit="B", - unit_scale=True, - desc=f"Downloading and Extracting {filename}", - position=progress_bar_position, - disable=not show_progress, - ) as pbar: + chunk_size = 1024 + with tempfile.TemporaryFile() as f: for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) pbar.update(chunk_size) @@ -131,7 +141,6 @@ def _download_extract_dataset( f.seek(0) with tarfile.open(fileobj=f, mode="r:gz") as tar: tar.extractall(data_dir) - pbar.update(total_size) def install_datasets(force: bool = False, show_progress: bool = True): @@ -148,18 +157,24 @@ def install_datasets(force: bool = False, show_progress: bool = True): if len(datasets_to_download) == 0: return - with concurrent.futures.ThreadPoolExecutor( - max_workers=len(datasets_to_download) - ) as executor: - futures = [ - executor.submit(_download_extract_dataset, dataset, i, show_progress) - for i, dataset in enumerate(datasets_to_download) - ] - concurrent.futures.wait(futures) - + with tqdm( + total=_get_total_download_size(datasets_to_download), + desc="Downloading Geant4 datasets", + disable=not show_progress, + unit="B", + unit_scale=True, + ) as pbar: + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(datasets_to_download) + ) as executor: + futures = [ + executor.submit(_download_extract_dataset, dataset, pbar) + for i, dataset in enumerate(datasets_to_download) + ] + concurrent.futures.wait(futures) -def reinstall_datasets(): - install_datasets(force=True) + if show_progress: + print(f"Geant4 datasets installed to {data_dir}") def uninstall_datasets(): @@ -167,7 +182,13 @@ def uninstall_datasets(): package_dir = os.path.dirname(__file__) if not os.path.relpath(package_dir, dir_to_remove).startswith(".."): + # make sure we don't accidentally delete something important raise RuntimeError( f"Refusing to remove {dir_to_remove} because it is not a subdirectory of {package_dir}" ) shutil.rmtree(dir_to_remove, ignore_errors=True) + + +def reinstall_datasets(): + uninstall_datasets() + install_datasets(force=True)