diff --git a/src/fibad/download.py b/src/fibad/download.py index faeaf7c..3f636f0 100644 --- a/src/fibad/download.py +++ b/src/fibad/download.py @@ -1,24 +1,21 @@ import contextlib import datetime +import itertools import logging import os import urllib.request from pathlib import Path -from typing import Union +from threading import Thread +from typing import Optional, Union -import numpy as np from astropy.table import Table, hstack import fibad.downloadCutout.downloadCutout as dC -# These are the fields that are allowed to vary across the locations -# input from the catalog fits file. Other values for HSC cutout server -# must be provided by config. -variable_fields = ["tract", "ra", "dec"] - logger = logging.getLogger(__name__) +# TODO: Alter downloadCutout.py such that this is unnecessary @contextlib.contextmanager def working_directory(path: Path): """ @@ -38,143 +35,430 @@ def working_directory(path: Path): os.chdir(old_cwd) -def run(config): +# TODO: Remove this in favor of itertools.batched() when we no longer support python < 3.12. +def _batched(iterable, n): + """Brazenly copied and pasted from the python 3.12 documentation. + This is a dodgy version of a new itertools function in Python 3.12 called itertools.batched() """ - Main entrypoint for downloading cutouts from HSC for use with fibad + if n < 1: + raise ValueError("n must be at least one") + iterator = iter(iterable) + while batch := tuple(itertools.islice(iterator, n)): + yield batch - Parameters - ---------- - config : dict - Runtime configuration as a nested dictionary - """ - config = config.get("download", {}) +class Downloader: + """Class with primarily static methods to namespace downloader related constants and functions.""" - logger.info("Download command Start") + # These are the fields that are allowed to vary across the locations + # input from the catalog fits file. Other values for HSC cutout server + # must be provided by config. + VARIABLE_FIELDS = ["tract", "ra", "dec"] - fits_file = config.get("fits_file", "") - logger.info(f"Reading in fits catalog: {fits_file}") - # Filter the fits file for the fields we want - column_names = ["object_id"] + variable_fields - locations = filterfits(fits_file, column_names) + # These are the column names we retain when writing a rect out to the manifest.fits file + RECT_COLUMN_NAMES = VARIABLE_FIELDS + ["filter", "sw", "sh", "rerun", "type"] - # TODO slice up the locations to multiplex across connections if necessary, but for now - # we simply mask off a few - offset = config.get("offset", 0) - end = offset + config.get("num_sources", 10) - locations = locations[offset:end] + MANIFEST_FILE_NAME = "manifest.fits" - # Make a list of rects to pass to downloadCutout - rects = create_rects(locations, offset=0, default=rect_from_config(config)) + @staticmethod + def run(config): + """ + Main entrypoint for downloading cutouts from HSC for use with fibad - # Configure global parameters for the downloader - dC.set_max_connections(num=config.get("max_connections", 2)) + Parameters + ---------- + config : dict + Runtime configuration as a nested dictionary + """ - logger.info("Requesting cutouts") - # pass the rects to the cutout downloader - download_cutout_group( - rects=rects, - cutout_dir=config.get("cutout_dir"), - user=config["username"], - password=config["password"], - retrywait=config.get("retry_wait", 30), - retries=config.get("retries", 3), - timeout=config.get("timeout", 3600), - chunksize=config.get("chunk_size", 990), - ) + config = config.get("download", {}) - logger.info("Done") + logger.info("Download command Start") + fits_file = config.get("fits_file", "") + logger.info(f"Reading in fits catalog: {fits_file}") + # Filter the fits file for the fields we want + column_names = ["object_id"] + Downloader.VARIABLE_FIELDS + locations = Downloader.filterfits(fits_file, column_names) -# TODO add error checking -def filterfits(filename: str, column_names: list[str]) -> Table: - """Read a fits file with the required column names for making cutouts + # If offet/length specified, filter to that length + offset = config.get("offset", 0) + end = offset + config.get("num_sources", None) + if end is not None: + locations = locations[offset:end] - The easiest way to make such a fits file is to select from the main HSC catalog + # Make a list of rects to pass to downloadCutout + rects = Downloader.create_rects(locations, offset=0, default=Downloader.rect_from_config(config)) - Parameters - ---------- - filename : str - The fits file to read in - column_names : list[str] - The columns that are filtered out - - Returns - ------- - Table - Returns an astropy table containing only the fields specified in column_names - """ - t = Table.read(filename) - columns = [t[column] for column in column_names] - return hstack(columns, uniq_col_name="{table_name}", table_names=column_names) + # Prune any previously downloaded rects from our list using the manifest from the previous download + cutout_path = Path(config.get("cutout_dir")) + rects = Downloader._prune_downloaded_rects(cutout_path, rects) + # Early return if there is nothing to download. + if len(rects) == 0: + logger.info("Download already completed according to manifest.") + return -def rect_from_config(config: dict) -> dC.Rect: - """Takes our runtime config and loads cutout config - common to all cutouts into a prototypical Rect for downloading + # Create thread objects for each of our worker threads + num_threads = config.get("concurrent_connections", 2) + if num_threads > 5: + raise RuntimeError("This client only opens 5 connections or fewer.") - Parameters - ---------- - config : dict - Runtime config, only the download section + # If we are using more than one connection, cut the list of rectangles into + # batches, one batch for each thread. + thread_rects = list(_batched(rects, int(len(rects) / num_threads))) if num_threads != 1 else [rects] - Returns - ------- - dC.Rect - A single rectangle with fields `sw`, `sh`, `filter`, `rerun`, and `type` populated from the config - """ - return dC.Rect.create( - sw=config["sw"], - sh=config["sh"], - filter=config["filter"], - rerun=config["rerun"], - type=config["type"], - ) + # Empty dictionaries for the threads to create download manifests in + thread_manifests = [dict() for _ in range(num_threads)] + + shared_thread_args = ( + config["username"], + config["password"], + ) + + shared_thread_kwargs = { + "retrywait": config.get("retry_wait", 30), + "retries": config.get("retries", 3), + "timeout": config.get("timeout", 3600), + "chunksize": config.get("chunk_size", 990), + } + download_threads = [] + + try: + # downloadCutouts.py works with relative paths so we set cwd here. + # Since cwd is process-level, and our threads all share this state, we need to + # keep cwd the same until all threads finish. + with working_directory(cutout_path): + download_threads = [ + Thread( + target=Downloader.download_thread, + name=f"thread_{i}", + args=(thread_rects[i],) # rects + + shared_thread_args # cutout_dir,username, password + + (i, thread_manifests[i]), # thread_num, manifest + kwargs=shared_thread_kwargs, + ) + for i in range(num_threads) + ] + + logger.info(f"Started {len(download_threads)} request threads") + [thread.start() for thread in download_threads] + [thread.join() for thread in download_threads] + + finally: # Ensure this occurs even when we get a KeyboardInterrupt during download + Downloader.write_manifest(thread_manifests, cutout_path) + + logger.info("Done") + + @staticmethod + def _prune_downloaded_rects(cutout_path: Path, rects: list[dC.Rect]) -> list[dC.Rect]: + """Prunes already downloaded rects using the manifest in `cutout_path`. `rects` passed in is + mutated by this operation -def create_rects(locations: Table, offset: int = 0, default: dC.Rect = None) -> list[dC.Rect]: - """Create the rects we will need to pass to the downloader. - One Rect per location in our list of sky locations. + Parameters + ---------- + cutout_path : Path + Where on the filesystem to find the manifest + rects : list[dC.Rect] + List of rects from which we want to prune previously downloaded rects + + Returns + ------- + list[dC.Rect] + Returns `rects` that was passed in. This is only to enable explicit style at the call site. + ` rects` is mutated by this function. + + Raises + ------ + RuntimeError + When there is an issue reading the manifest file, or the manifest file corresponds to a different + set of cutouts than the current download being attempted + """ + # print(rects) + # Read in any prior manifest. + prior_manifest = Downloader.read_manifest(cutout_path) + + # If we found a manifest, we are resuming a download + if len(prior_manifest) != 0: + # Filter rects to figure out which ones are completely downloaded. + # This operation consumes prior_manifest in the process + rects[:] = [rect for rect in rects if Downloader._keep_rect(rect, prior_manifest)] + + # if prior_manifest was not completely consumed, than the earlier download attempted + # some sky locations which would not be included in the current download, and we have + # a problem. + if len(prior_manifest) != 0: + # print(len(prior_manifest)) + # print (prior_manifest) + raise RuntimeError( + f"""{cutout_path/Downloader.MANIFEST_FILE_NAME} describes a download with +sky locations that would not be downloaded in the download currently being attempted. Are you sure you are +resuming the correct download? Deleting the manifest and cutout files will start the download from scratch""" + ) + + return rects + + @staticmethod + def _keep_rect(location_rect: dC.Rect, prior_manifest: dict[dC.Rect, str]) -> bool: + """Private helper function to prune_downloaded_rects which operates the inner loop + of the prune function, and allows it to be written as a list comprehension. + + This function decides element-by-element for our rects that we want to download whether + or not these rects have already been downloaded in a prior download, given the manifest + from that prior download. - Rects are created with all fields in the default rect pre-filled + Parameters + ---------- + location_rect : dC.Rect + A rectangle on the sky that we are considering downloading. + + prior_manifest : dict[dC.Rect,str] + The manifest of the prior download. This object is slowly consumed by repeated calls + to this function. When the return value is False, all manifest entries corresponding to the + passed in location_rect have been removed. + + Returns + ------- + bool + Whether this sky location `location_rect` should be included in the download + """ + # Keep any location rect if the manifest passed has nothing in it. + if len(prior_manifest) == 0: + return True + + keep_rect = False + for filter_rect in location_rect.explode(): + # Consume any matching manifest entry, keep the rect if + # 1) The manifest entry doesn't exist -> pop returns None + # 2) The manifest entry contains "Attempted" for the filename -> The corresponding file wasn't + # successfully downloaded + matching_manifest_entry = prior_manifest.pop(filter_rect, None) + if matching_manifest_entry is None or matching_manifest_entry == "Attempted": + keep_rect = True + + return keep_rect + + @staticmethod + def write_manifest(thread_manifests: list[dict[dC.Rect, str]], file_path: Path): + """Write out manifest fits file that is an inventory of the download. + The manifest fits file should have columns object_id, ra, dec, tract, filter, filename + + If filename is empty string ("") that means a download attempt was made, but did not succeed + If the object is not present in the manifest, no download was attempted. + If the object is present in the manifest and the filename is not empty string that file exists + and downloaded successfully. + + This file respects the existence of other manifest files in the directory and operates additively. + If a manifest file is present from an earlier download, this function will read that manifest in, + and include the entire content of that manifest in addition to the manifests passed in. - Offset here is to allow multiple downloads on different sections of the source list - without file name clobbering during the download phase. The offset is intended to be - the index of the start of the locations table within some larger fits file. + Parameters + ---------- + thread_manifests : list[dict[dC.Rect,str]] + Manifests mapping rects -> Filename or status message. Each manifest came from a separate thread. - Parameters - ---------- - locations : Table - Table containing ra, dec locations in the sky - offset : int, optional - Index to start the `lineno` field in the rects at, by default 0. The purpose of this is to allow - multiple downloads on different sections of a larger source list without file name clobbering during - the download phase. This is important because `lineno` in a rect can becomes a file name parameter - The offset is intended to be the index of the start of the locations table within some larger fits - file. - default : dC.Rect, optional - The default Rect that contains properties common to all sky locations, by default None - - Returns - ------- - list[dC.Rect] - Rects populated with sky locations from the table - """ - rects = [] - for index, location in enumerate(locations): - args = {field: location[field] for field in variable_fields} - args["lineno"] = index + offset - args["tract"] = str(args["tract"]) - # Sets the file name on the rect to be the object_id, also includes other rect fields - # which are interpolated at save time, and are native fields of dc.Rect. - # - # This name is also parsed by FailedChunkCollector.hook to identify the object_id, so don't - # change it without updating code there too. - args["name"] = f"{location['object_id']}_{{type}}_{{ra:.5f}}_{{dec:+.5f}}_{{tract}}_{{filter}}" - rect = dC.Rect.create(default=default, **args) - rects.append(rect) - - return rects + file_path : Path + Full path to the location where the manifest file ought be written. The manifest file will be + named manifest.fits + """ + logger.info("Assembling download manifest") + # Start building a combined manifest from all threads from the ground truth of the prior manifest + # in this directory, which we will be overwriting. + combined_manifest = Downloader.read_manifest(file_path) + + # Combine all thread manifests with the prior manifest, so that the current status of a downloaded + # rect overwrites any status from the prior run (which is no longer relevant.) + for manifest in thread_manifests: + combined_manifest.update(manifest) + + logger.info(f"Writing out download manifest with {len(combined_manifest)} entries.") + + # Convert the combined manifest into an astropy table by building a dict of {column_name: column_data} + # for all the fields in a rect, plus our object_id and filename. + column_names = Downloader.RECT_COLUMN_NAMES + ["filename", "object_id"] + columns = {column_name: [] for column_name in column_names} + + for rect, msg in combined_manifest.items(): + # This parsing relies on the name format set up in create_rects to work properly + object_id = int(rect.name.split("_")[0]) + columns["object_id"].append(object_id) + columns["filename"].append(msg) + + for key in Downloader.RECT_COLUMN_NAMES: + columns[key].append(rect.__dict__[key]) + + # print(columns) + # for key, val in columns.items(): + # print (key, len(val), val) + + manifest_table = Table(columns) + manifest_table.write(file_path / Downloader.MANIFEST_FILE_NAME, overwrite=True, format="fits") + + @staticmethod + def read_manifest(file_path: Path) -> dict[dC.Rect, str]: + """Read the manifest.fits file from the given directory and return its contents as a dictionary with + downloadCutout.Rectangles as keys and filenames as values. + + If now manifest file is found, an empty dict is returned. + + Parameters + ---------- + file_path : Path + Where to find the manifest file + + Returns + ------- + dict[dC.Rect, str] + A dictionary containing all the rects in the manifest and all the filenames, or empty dict if no + manifest is found. + """ + filename = file_path / Downloader.MANIFEST_FILE_NAME + if filename.exists(): + manifest_table = Table.read(filename, format="fits") + rects = Downloader.create_rects(locations=manifest_table, fields=Downloader.RECT_COLUMN_NAMES) + return {rect: filename for rect, filename in zip(rects, manifest_table["filename"])} + else: + return {} + + @staticmethod + def download_thread( + rects: list[dC.Rect], + user: str, + password: str, + thread_num: int, + manifest: dict[dC.Rect, str], + **kwargs, + ): + """Download cutouts to the given directory. Called in its own thread with an id number. + + Calls downloadCutout.download, so supports long lists of rects beyond the limits of the HSC web API + + Parameters + ---------- + rects : list[dC.Rect] + The rects we would like to download + user : string + Username for HSC's download service to use + password : string + Password for HSC's download service to use + thread_num : int, + The ID number of thread we are, sequential from zero to num_threads-1 + manifest: + A dictionary from dC.Rect to filename which we will fill in in as we download rects. This is the + chief returned piece of data from each thread. + **kwargs: dict + Additonal arguments for downloadCutout.download. See downloadCutout.download for details + """ + logger.info(f"Thread {thread_num} got {len(rects)} rects") + with DownloadStats(thread_num) as stats_hook: + dC.download( + rects, + user=user, + password=password, + onmemory=False, + request_hook=stats_hook, + manifest=manifest, + **kwargs, + ) + + # TODO add error checking + @staticmethod + def filterfits(filename: str, column_names: list[str]) -> Table: + """Read a fits file with the required column names for making cutouts + + The easiest way to make such a fits file is to select from the main HSC catalog + + Parameters + ---------- + filename : str + The fits file to read in + column_names : list[str] + The columns that are selected from the file and returned in the astropy Table. + + Returns + ------- + Table + Returns an astropy table containing only the fields specified in column_names + """ + t = Table.read(filename) + columns = [t[column] for column in column_names] + return hstack(columns, uniq_col_name="{table_name}", table_names=column_names) + + @staticmethod + def rect_from_config(config: dict) -> dC.Rect: + """Takes our runtime config and loads cutout config + common to all cutouts into a prototypical Rect for downloading + + Parameters + ---------- + config : dict + Runtime config, only the download section + + Returns + ------- + dC.Rect + A single rectangle with fields `sw`, `sh`, `filter`, `rerun`, and `type` populated from the config + """ + return dC.Rect.create( + sw=config["sw"], + sh=config["sh"], + filter=config["filter"], + rerun=config["rerun"], + type=config["type"], + ) + + @staticmethod + def create_rects( + locations: Table, offset: int = 0, default: dC.Rect = None, fields: Optional[list[str]] = None + ) -> list[dC.Rect]: + """Create the rects we will need to pass to the downloader. + One Rect per location in our list of sky locations. + + Rects are created with all fields in the default rect pre-filled + + Offset here is to allow multiple downloads on different sections of the source list + without file name clobbering during the download phase. The offset is intended to be + the index of the start of the locations table within some larger fits file. + + Parameters + ---------- + locations : Table + Table containing ra, dec locations in the sky + offset : int, optional + Index to start the `lineno` field in the rects at, by default 0. The purpose of this is to allow + multiple downloads on different sections of a larger source list without file name clobbering + during the download phase. This is important because `lineno` in a rect can becomes a file name + parameter The offset is intended to be the index of the start of the locations table within some + larger fits file. + default : dC.Rect, optional + The default Rect that contains properties common to all sky locations, by default None + + fields : list[str], optional + Default fields to pull from the locations table. If not provided, defaults to + ["tract", "ra", "dec"] + + Returns + ------- + list[dC.Rect] + Rects populated with sky locations from the table + """ + rects = [] + fields = fields if fields else Downloader.VARIABLE_FIELDS + for index, location in enumerate(locations): + args = {field: location[field] for field in fields} + args["lineno"] = index + offset + args["tract"] = str(args["tract"]) + # Sets the file name on the rect to be the object_id, also includes other rect fields + # which are interpolated at save time, and are native fields of dc.Rect. + # + # This name is also parsed by FailedChunkCollector.hook to identify the object_id, so don't + # change it without updating code there too. + args["name"] = f"{location['object_id']}_{{type}}_{{ra:.5f}}_{{dec:+.5f}}_{{tract}}_{{filter}}" + rect = dC.Rect.create(default=default, **args) + rects.append(rect) + + return rects class DownloadStats: @@ -185,7 +469,7 @@ class DownloadStats: Can be used as a context manager for pretty printing. """ - def __init__(self): + def __init__(self, tid): self.stats = { "request_duration": datetime.timedelta(), # Time from request sent to first byte from the server "response_duration": datetime.timedelta(), # Total time spent recieving and processing a response @@ -193,6 +477,7 @@ def __init__(self): "response_size_bytes": 0, # Total size of all responses "snapshots": 0, # Number of fits snapshots downloaded } + self.tid = tid def __enter__(self): return self.hook @@ -231,7 +516,8 @@ def _print_stats(self, log_level): snapshot_rate = self.stats["snapshots"] / total_dur_s if total_dur_s != 0 else 0 - stats_message = f"Stats: Duration: {total_dur_s:.2f} s, " + stats_message = f"Thread {self.tid} " + stats_message += f"Stats: Duration: {total_dur_s:.2f} s, " stats_message += f"Files: {self.stats['snapshots']}, " stats_message += f"Upload: {up_rate_mb_s:.2f} MB/s, " stats_message += f"Download: {down_rate_mb_s:.2f} MB/s, " @@ -274,138 +560,3 @@ def hook( self._stat_accumulate("snapshots", chunk_size) self._print_stats(logging.INFO) - - -class FailedChunkCollector: - """Collection system for chunks of sky locations where the request for a chunk of cutouts failed. - - Keeps track of all variable_fields plus object_id for failed chunks - - save() dumps these chunks using astropy.table.Table.write() - - """ - - def __init__(self, filepath: Path, **kwargs): - """_summary_ - - Parameters - ---------- - filepath : Path - File to read in if we are resuming a download, and where to save the failed chunks after. - If the file does not exist yet an empty state is initialized. - - **kwargs : dict - Keyword args passed to astropy.table.Table.read() and write() in the case that a file is used. - Should only be used to control file format, not read/write semantics - """ - self.__dict__.update({key: [] for key in variable_fields + ["object_id"]}) - self.seen_object_ids = set() - self.filepath = filepath.resolve() - self.format_kwargs = kwargs - - # If there is a failed chunk file from a previous run, - # Read it in to initialize us - if filepath.exists(): - prev_failed_chunks = Table.read(filepath) - for key in variable_fields + ["object_id"]: - column_as_list = prev_failed_chunks[key].data.tolist() - self.__dict__[key] += column_as_list - logger.debug(f"Adding object ID :{self.object_id} to failed list") - - self.seen_object_ids = {id for id in self.object_id} - - self.count = len(self.seen_object_ids) - logger.debug(f"Failed chunk handler initialized with {self.count} objects") - - def __enter__(self): - return self.hook - - def __exit__(self, exc_type, exc_value, traceback): - self.save() - - def hook(self, rects: list[dC.Rect], exception: Exception, attempts: int): - """Called when dc.Download fails to download a chunk of rects - - Parameters - ---------- - rects : list[dC.Rect] - The list of rect objects that were requested from the server - exception : Exception - The exception that was thrown on the final attempt to request this chunk - attempts : int - The number of attempts that were made to request this chunk - - """ - - for rect in rects: - # Relies on the name format set up in create_rects to work properly - object_id = int(rect.name.split("_")[0]) - - if object_id not in self.seen_object_ids: - self.seen_object_ids.add(object_id) - - self.object_id.append(object_id) - - for key in variable_fields: - self.__dict__[key].append(rect.__dict__[key]) - - self.count += 1 - logger.debug(f"Failed chunk handler processed {len(rects)} rects and is now of size {self.count}") - - def save(self): - """ - Saves the current set of failed locations to the path specified. - If no failed locations were saved by the hook, this function does nothing. - """ - if self.count == 0: - return - else: - # convert our class-member-based representation to an astropy table. - for key in variable_fields + ["object_id"]: - self.__dict__[key] = np.array(self.__dict__[key]) - - missed = Table({key: self.__dict__[key] for key in variable_fields + ["object_id"]}) - - # note that the choice to do overwrite=True here and to read in the entire fits file in - # ___init__() is necessary because snapshots corresponding to the same object may cross - # chunk boundaries decided by dC.download. - # - # Since we are de-duplicating rects by object_id, we need to read in all rects from a prior - # run, and we therefore replace the file we were passed. - missed.write(self.filepath, overwrite=True, **self.format_kwargs) - - -def download_cutout_group(rects: list[dC.Rect], cutout_dir: Union[str, Path], user, password, **kwargs): - """Download cutouts to the given directory - - Calls downloadCutout.download, so supports long lists of rects beyond the limits of the HSC web API - - Parameters - ---------- - rects : list[dC.Rect] - The rects we would like to download - cutout_dir : Union[str, Path] - The directory to put the files - user : string - Username for HSC's download service to use - password : string - Password for HSC's download service to use - **kwargs: dict - Additonal arguments for downloadCutout.download. See downloadCutout.download for details - """ - - with working_directory(Path(cutout_dir)): - with ( - DownloadStats() as stats_hook, - FailedChunkCollector(Path("failed_locations.fits"), format="fits") as failed_chunk_hook, - ): - dC.download( - rects, - user=user, - password=password, - onmemory=False, - request_hook=stats_hook, - failed_chunk_hook=failed_chunk_hook, - resume=True, - **kwargs, - ) diff --git a/src/fibad/downloadCutout/downloadCutout.py b/src/fibad/downloadCutout/downloadCutout.py index d20a46d..067ce68 100644 --- a/src/fibad/downloadCutout/downloadCutout.py +++ b/src/fibad/downloadCutout/downloadCutout.py @@ -7,7 +7,6 @@ import datetime import errno import getpass -import hashlib import io import json import logging @@ -21,11 +20,8 @@ import urllib.request import urllib.response from collections.abc import Generator -from pathlib import Path from typing import IO, Any, Callable, Optional, Union, cast -import toml - __all__ = [] @@ -477,6 +473,41 @@ def explode(self) -> list["Rect"]: else: return [Rect.create(default=self)] + # Static field list used by __eq__ and __hash__ + immutable_fields = ["ra", "dec", "sw", "sh", "filter", "type", "rerun", "image", "variance", "mask"] + + def __eq__(self, obj) -> bool: + """Define equality on Rects by sky location, size, filter, type, rerun, and image/mask/variance state. + This allows rects can be used as keys in dictionaries while ignoring transient fields such as lineno, + or fields that may be incorrect/changed during download process like tract or name. + + This is a compromise between + 1) Dataclass's unsafe_hash=True which would hash all fields + and + 2) Making the dataclass frozen which would affect some of the mutability used to alter + lineno, tract, and name + + Note that this makes equality on Rects means "the cutout API should return the same data", + rather than "Literally all data members the same" + + Parameters + ---------- + obj : Rect + The rect to compare to self + + Returns + ------- + bool + True if the Rect's are equal + """ + return all([self.__dict__[field] == obj.__dict__[field] for field in Rect.immutable_fields]) + + def __hash__(self): + """Define a hash function on Rects. Outside of hash collisions, this function attempts to have the + same semantics as Rect.__eq__(). Look at Rect.__eq__() for further details. + """ + return hash(tuple([self.__dict__[field] for field in Rect.immutable_fields])) + class RectEncoder(json.JSONEncoder): # TODO this needs to be implemented on a subclass of JSONEncoder @@ -1008,9 +1039,6 @@ def download( Some important (but entirely optional!) keyword args processed later in the download callstack are listed below. Anything urllib.request.urlopen will accept is fair game too! - resume : bool - Whether to attempt to resume an ongoing download from filesystem data in onmemory=False mode. - Default: False. See _download() for greater detail. chunksize : int The number of rects to include in a single http request. Default 990 rects. See _download() for greater detail. @@ -1061,10 +1089,9 @@ def _download( *, onmemory: bool, chunksize: int = 990, - resume: bool = False, + # manifest: dict[Rect,str] = {}, retries: int = 3, retrywait: int = 30, - failed_chunk_hook: Optional[Callable[[list[Rect], Exception, int], Any]] = None, **kwargs__download_chunk, ) -> Optional[list[list]]: """ @@ -1089,26 +1116,12 @@ def _download( If `onmemory` is False, downloaded cut-outs are written to files in the current working directory. chunksize: int, optional Number of cutout lines to pack into a single request. Defaults to 990 if unspecified. - resume: bool, optional - When `onmemory == True`, uses resume data in the current working directory continue a failed download. - Noop when onmemory=False. Defaults to False if unspecified. - - Passing resume=True is safe when no resume data exists. - _download() will simply start downloading from the beginning of rects. retries: int, optional Number of attempts to make to fetch each chunk. Defaults to 3 if unspecified. retrywait: int, optional Base number of seconds to wait between retries. Retry waits are computed using an exponential backoff where the retry time for attempts is calculated as retrywait * (2 ** attempt) seconds , with attempt=0 for the first wait. - - failed_chunk_hook: Callable[[list[Rect], Exception, int], Any] - Hook which is called every time a chunk fails `retries` time. The arguments to the hook are - the rects in the failed chunk, the exception encountered while making the last request, and - the number of attempts. - - If this function raises, the entire download stops, but otherwise the download will ocntinue - kwargs__download_chunk: dict, optional Additional keyword args are passed through to _download_chunk() @@ -1130,6 +1143,7 @@ def _download( exploded_rects: list[tuple[Rect, int]] = [] for index, rect in enumerate(rects): exploded_rects.extend((r, index) for r in rect.explode()) + # manifest.update({r:None for r in rect.explode()}) # Sort the rects so that the server can use cache # as frequently as possible. @@ -1148,174 +1162,57 @@ def _download( datalist: list[tuple[int, dict, bytes]] = [] - failed_rect_index = 0 - - start_rect_index = 0 - if not onmemory and resume: - start_rect_index = _read_resume_data(exploded_rects) + # Chunk loop + for i in range(0, len(exploded_rects), chunksize): + # Retry loop + for attempt in range(0, retries): + try: + ret = _download_chunk( + exploded_rects[i : i + chunksize], + user, + password, + # manifest, + onmemory=onmemory, + **kwargs__download_chunk, + ) + break + except KeyboardInterrupt: + logger.critical("Keyboard Interrupt recieved.") + raise + except Exception as exception: + # Humans count attempts from 1, this loop counts from zero. + logger.warning( + f"Attempt {attempt + 1} of {retries} to request rects [{i}:{i+chunksize}] has error:" + ) + logger.warning(exception) - try: - # Chunk loop - for i in range(start_rect_index, len(exploded_rects), chunksize): - # Retry loop - for attempt in range(0, retries): - try: - ret = _download_chunk( - exploded_rects[i : i + chunksize], - user, - password, - onmemory=onmemory, - **kwargs__download_chunk, - ) + # If the final attempt on this chunk fails, we move on. + if attempt + 1 == retries: break - except KeyboardInterrupt: - logger.critical("Keyboard Interrupt recieved.") - failed_rect_index = i - raise - except Exception as exception: - # Humans count attempts from 1, this loop counts from zero. - logger.warning( - f"Attempt {attempt + 1} of {retries} to request rects [{i}:{i+chunksize}] has error:" - ) - logger.warning(exception) - - # If the final attempt on this chunk fails, we try to call the failed_chunk_hook - if attempt + 1 == retries: - if failed_chunk_hook is not None: - rect_chunk = [rect for rect, idx in exploded_rects[i : i + chunksize]] - failed_chunk_hook(rects=rect_chunk, exception=exception, attempts=retries) - # If no hook provided, or if the provided hook doesn't raise, we continue the download - break - # Otherwise do exponential backoff and try again - else: - backoff = retrywait * (2**attempt) - if backoff != 0: - logger.info(f"Retrying in {backoff} seconds... ") - time.sleep(backoff) - logger.info("Retrying now") - continue - if onmemory: - datalist += cast(list, ret) - - # Retries have failed or we are being killed - except (Exception, KeyboardInterrupt): - # Write out resume data if we're saving to filesystem and there's been any progress - if (not onmemory) and failed_rect_index != 0: - _write_resume_data(exploded_rects, failed_rect_index) - - # Reraise so exception can reach top level, very important for KeyboardInterrupt - raise + # Otherwise wait for exponential backoff and try again + else: + backoff = retrywait * (2**attempt) + if backoff != 0: + logger.info(f"Retrying in {backoff} seconds... ") + time.sleep(backoff) + logger.info("Retrying now") + continue + if onmemory: + datalist += cast(list, ret) if onmemory: returnedlist: list[list[tuple[dict, bytes]]] = [[] for i in range(len(rects))] for index, metadata, data in datalist: returnedlist[index].append((metadata, data)) - # On success we remove resume data - if not onmemory and resume and os.path.exists(resume_data_filename): - os.remove(resume_data_filename) - return returnedlist if onmemory else None -# TODO multiple connections resume data will need to be instanced by connection -# That will require some interface so the connection number can make it here -resume_data_filename = "resume_download.toml" - - -def _read_resume_data(rects: list[Rect]) -> int: - """Read the resume data from the current working directory - - Parameters - ---------- - rects : list[Rect] - List of rects we intend to process, needed for checksum to ensure the download we are resuming - is the same one that output resume data. - - Returns - ------- - Returns an integer specifying what index in the rect list the resumeing download should start. - If no resume data is found, 0 is returned. - - Raises - ------ - RuntimeError - "No resume data found in " when the resume file could not be found in cwd. - RuntimeError - "Resume data in corrupt" when the file is not a toml file containing keys - 'checksum' and 'start_rect_index' - RuntimeError - "Resume data failed checksum ..." when the rect list has changed from when the resume data file was - written - """ - # Load resume data so we start at the appropriate chunk. - if not os.path.exists(resume_data_filename): - return 0 - - logger.info(f"Resuming failed download from {Path.cwd() / resume_data_filename}") - with open(resume_data_filename, "r") as f: - resumedata = toml.load(f) - if "start_rect_index" not in resumedata or "checksum" not in resumedata: - raise RuntimeError(f"Resume data in {Path.cwd() / resume_data_filename} corrupt.") - - start_rect_index = resumedata["start_rect_index"] - - checksum = _calc_rect_list_checksum(rects[0:start_rect_index]) - if resumedata["checksum"] != checksum: - message = f"""Resume data failed checksum. - Has the list of sky locations changed? If so, remove {Path.cwd() / resume_data_filename}""" - raise RuntimeError(message) - - return start_rect_index - - -def _write_resume_data(rects: list[Rect], failed_rect_index: int) -> None: - """Write resume data - - Parameters - ---------- - rects : list[Rect] - List of Rects we were intending to download, needed to write the checksum into the resume data - failed_rect_index : int - The index of the beginning of the first chunk of rects to fail. - """ - logger.info("Writing resume data") - # Output enough information that we can retry/resume assuming same dir but, - # whatever was DL'ed in current chunk is corrupt - resumedata = { - "start_rect_index": failed_rect_index, - "checksum": _calc_rect_list_checksum(rects[0:failed_rect_index]), - } - with open(resume_data_filename, mode="w") as f: - toml.dump(resumedata, f) - logger.info("Done writing resume data") - - -def _calc_rect_list_checksum(rects: list[Rect]) -> str: - """ - Calculate a sha256 checksum of a list of Rects for the purpose of identifying tha list in the context of - a resumed download - - The method is to dump the list of Rects to JSON and sha256 the JSON. - - Parameters - ---------- - rects : list[Rect] - List of rects that we will checksum - - Returns - ------- - str - Sha256 hex digest of the list of rects. - """ - byte_string = json.dumps(rects, sort_keys=True, cls=RectEncoder).encode("utf-8") - return hashlib.sha256(byte_string).hexdigest() - - def _download_chunk( rects: list[tuple[Rect, Any]], user: str, password: str, + manifest: Optional[dict[Rect, str]], *, onmemory: bool, request_hook: Optional[ @@ -1338,6 +1235,9 @@ def _download_chunk( Username. password Password. + manifest + Dictionary from Rect to filename. If Provided, this function will fill in as it downloads. + If download of a file fails, the file's entry will read "Attempted" for the filename. onmemory Return `datalist` on memory. If `onmemory` is False, downloaded cut-outs are written to files. @@ -1388,6 +1288,11 @@ def _download_chunk( # Set timeout to 1 hour if no timout was set higher up kwargs_urlopen.setdefault("timeout", 3600) + # Set all manifest entries to indicate an attempt was made. + if manifest is not None: + for rect, _ in rects: + manifest[rect] = "Attempted" + with get_connection_semaphore(): request_started = datetime.datetime.now() with urllib.request.urlopen(req, **kwargs_urlopen) as fin: @@ -1420,6 +1325,8 @@ def _download_chunk( os.makedirs(dirname, exist_ok=True) with open(filename, "wb") as fout: _splice(fitem, fout) + if manifest is not None: + manifest[rect] = filename if request_hook: request_hook(req, request_started, response_started, response_size, len(rects)) diff --git a/src/fibad/fibad.py b/src/fibad/fibad.py index 4635634..07389fd 100644 --- a/src/fibad/fibad.py +++ b/src/fibad/fibad.py @@ -147,9 +147,9 @@ def download(self, **kwargs): """ See Fibad.download.run() """ - from .download import run + from .download import Downloader - return run(config=self.config, **kwargs) + return Downloader.run(config=self.config, **kwargs) def predict(self, **kwargs): """