diff --git a/src/flash/core/data/utils.py b/src/flash/core/data/utils.py index e142615d74..8a6827a145 100644 --- a/src/flash/core/data/utils.py +++ b/src/flash/core/data/utils.py @@ -59,7 +59,7 @@ } -def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: +def download_data(url: str, path: str = "data/", verbose: bool = False, chunk_size: int = 1024) -> None: """Download file with progressbar. # Code adapted from: https://gist.github.com/ruxi/5d6803c116ec1130d484a4ab8c00c603 @@ -78,39 +78,42 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: [...] """ + local_filename = os.path.join(path, url.split("/")[-1]) + if os.path.exists(local_filename): + if verbose: + print(f"local file already exists: '{local_filename}'") + return + + os.makedirs(path, exist_ok=True) # Disable warning about making an insecure request urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - if not os.path.exists(path): - os.makedirs(path) - local_filename = os.path.join(path, url.split("/")[-1]) r = requests.get(url, stream=True, verify=False) file_size = int(r.headers["Content-Length"]) if "Content-Length" in r.headers else 0 - chunk_size = 1024 num_bars = int(file_size / chunk_size) if verbose: - print({"file_size": file_size}) - print({"num_bars": num_bars}) - - if not os.path.exists(local_filename): - with open(local_filename, "wb") as fp: - for chunk in tq( - r.iter_content(chunk_size=chunk_size), - total=num_bars, - unit="KB", - desc=local_filename, - leave=True, # progressbar stays - ): - fp.write(chunk) # type: ignore - - def extract_tarfile(file_path: str, extract_path: str, mode: str): - if os.path.exists(file_path): - with tarfile.open(file_path, mode=mode) as tar_ref: - for member in tar_ref.getmembers(): - try: - tar_ref.extract(member, path=extract_path, set_attrs=False) - except PermissionError: - raise PermissionError(f"Could not extract tar file {file_path}") + print(f"file size: {file_size}") + print(f"num bars: {num_bars}") + + with open(local_filename, "wb") as fp: + for chunk in tq( + r.iter_content(chunk_size=chunk_size), + total=num_bars, + unit="KB", + desc=local_filename, + leave=True, # progressbar stays + ): + fp.write(chunk) # type: ignore + + def extract_tarfile(file_path: str, extract_path: str, mode: str) -> None: + if not os.path.exists(file_path): + return + with tarfile.open(file_path, mode=mode) as tar_ref: + for member in tar_ref.getmembers(): + try: + tar_ref.extract(member, path=extract_path, set_attrs=False) + except PermissionError: + raise PermissionError(f"Could not extract tar file {file_path}") if ".zip" in local_filename: if os.path.exists(local_filename):