Skip to content

Commit

Permalink
update datasets installation
Browse files Browse the repository at this point in the history
  • Loading branch information
lobis committed Dec 7, 2023
1 parent 3c827d9 commit 5df4c7f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/geant4_python_application/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _start_application(pipe: multiprocessing.Pipe):

class Application:
def __init__(self):
geant4_python_application.datasets.install_datasets(show_progress=False)
geant4_python_application.datasets.install_datasets(show_progress=True)
self._detector = geant4_python_application.Detector(self)

self._pipe, child_pipe = multiprocessing.Pipe()
Expand Down
75 changes: 48 additions & 27 deletions src/geant4_python_application/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import namedtuple

import requests
import tqdm
from tqdm import tqdm

url = "https://cern.ch/geant4-data/datasets"
data_dir = os.path.join(os.path.dirname(__file__), "geant4/data")
Expand Down Expand Up @@ -98,25 +98,35 @@
)


def _download_extract_dataset(
dataset: Dataset, progress_bar_position: int = 0, show_progress: bool = True
):
def _dataset_url(dataset: Dataset) -> str:
return f"{url}/{dataset.filename}.{dataset.version}.tar.gz"


def _get_dataset_download_size(dataset: Dataset) -> int:
r = requests.head(_dataset_url(dataset))
r.raise_for_status()
return int(r.headers.get("content-length", 0))


def _get_total_download_size(datasets_to_download: list[Dataset] = datasets) -> int:
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(datasets_to_download)
) as executor:
futures = [
executor.submit(_get_dataset_download_size, dataset)
for dataset in datasets_to_download
]
return sum(f.result() for f in concurrent.futures.as_completed(futures))


def _download_extract_dataset(dataset: Dataset, pbar: tqdm):
filename = dataset.filename
urlpath = f"{url}/{filename}.{dataset.version}.tar.gz"
r = requests.get(urlpath, stream=True)
r.raise_for_status()

# Get the total file size for tqdm
total_size = int(r.headers.get("content-length", 0))
chunk_size = 4096
with tempfile.TemporaryFile() as f, tqdm.tqdm(
total=total_size,
unit="B",
unit_scale=True,
desc=f"Downloading and Extracting {filename}",
position=progress_bar_position,
disable=not show_progress,
) as pbar:
chunk_size = 1024
with tempfile.TemporaryFile() as f:
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)
Expand All @@ -131,7 +141,6 @@ def _download_extract_dataset(
f.seek(0)
with tarfile.open(fileobj=f, mode="r:gz") as tar:
tar.extractall(data_dir)
pbar.update(total_size)


def install_datasets(force: bool = False, show_progress: bool = True):
Expand All @@ -148,26 +157,38 @@ def install_datasets(force: bool = False, show_progress: bool = True):
if len(datasets_to_download) == 0:
return

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(datasets_to_download)
) as executor:
futures = [
executor.submit(_download_extract_dataset, dataset, i, show_progress)
for i, dataset in enumerate(datasets_to_download)
]
concurrent.futures.wait(futures)

with tqdm(
total=_get_total_download_size(datasets_to_download),
desc="Downloading Geant4 datasets",
disable=not show_progress,
unit="B",
unit_scale=True,
) as pbar:
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(datasets_to_download)
) as executor:
futures = [
executor.submit(_download_extract_dataset, dataset, pbar)
for i, dataset in enumerate(datasets_to_download)
]
concurrent.futures.wait(futures)

def reinstall_datasets():
install_datasets(force=True)
if show_progress:
print(f"Geant4 datasets installed to {data_dir}")


def uninstall_datasets():
dir_to_remove = os.path.dirname(data_dir)
package_dir = os.path.dirname(__file__)

if not os.path.relpath(package_dir, dir_to_remove).startswith(".."):
# make sure we don't accidentally delete something important
raise RuntimeError(
f"Refusing to remove {dir_to_remove} because it is not a subdirectory of {package_dir}"
)
shutil.rmtree(dir_to_remove, ignore_errors=True)


def reinstall_datasets():
uninstall_datasets()
install_datasets(force=True)

0 comments on commit 5df4c7f

Please sign in to comment.