diff --git a/bg_atlasapi/bg_atlas.py b/bg_atlasapi/bg_atlas.py index 33eca11e..13200e63 100644 --- a/bg_atlasapi/bg_atlas.py +++ b/bg_atlasapi/bg_atlas.py @@ -38,6 +38,9 @@ class BrainGlobeAtlas(core.Atlas): instantiation and to suppress warnings. print_authors : bool (optional) If true, disable default listing of the atlas reference. + fn_update : Callable + Handler function to update during download. Takes completed and total + bytes. """ @@ -51,8 +54,10 @@ def __init__( interm_download_dir=None, check_latest=True, config_dir=None, + fn_update=None, ): self.atlas_name = atlas_name + self.fn_update = fn_update # Read BrainGlobe configuration file: conf = config.read_config(config_dir) @@ -156,7 +161,9 @@ def download_extract_file(self): destination_path = self.interm_download_dir / COMPRESSED_FILENAME # Try to download atlas data - utils.retrieve_over_http(self.remote_url, destination_path) + utils.retrieve_over_http( + self.remote_url, destination_path, self.fn_update + ) # Uncompress in brainglobe path: tar = tarfile.open(destination_path) diff --git a/bg_atlasapi/utils.py b/bg_atlasapi/utils.py index bd76fc17..6b10b62d 100644 --- a/bg_atlasapi/utils.py +++ b/bg_atlasapi/utils.py @@ -1,6 +1,7 @@ import configparser import json import logging +from typing import Callable import requests import tifffile @@ -128,7 +129,9 @@ def check_internet_connection( return False -def retrieve_over_http(url, output_file_path): +def retrieve_over_http( + url, output_file_path, fn_update: Callable[[int, int], None] = None +): """Download file from remote location, with progress bar. Parameters @@ -137,6 +140,9 @@ def retrieve_over_http(url, output_file_path): Remote URL. output_file_path : str or Path Full file destination for download. + fn_update : Callable + Handler function to update during download. Takes completed and total + bytes. """ # Make Rich progress bar @@ -157,17 +163,25 @@ def retrieve_over_http(url, output_file_path): try: with progress: + tot = int(response.headers.get("content-length", 0)) task_id = progress.add_task( "download", filename=output_file_path.name, start=True, - total=int(response.headers.get("content-length", 0)), + total=tot, ) with open(output_file_path, "wb") as fout: + advanced = 0 for chunk in response.iter_content(chunk_size=CHUNK_SIZE): fout.write(chunk) - progress.update(task_id, advance=len(chunk), refresh=True) + adv = len(chunk) + progress.update(task_id, advance=adv, refresh=True) + + if fn_update: + # update handler with completed and total bytes + advanced += adv + fn_update(advanced, tot) except requests.exceptions.ConnectionError: output_file_path.unlink()