From 5277c7f0c187d07a4ed6f54ff67741a56bbaa4f8 Mon Sep 17 00:00:00 2001 From: Michael Tauraso Date: Tue, 27 Aug 2024 16:16:01 -0700 Subject: [PATCH] Adding configurable list of filters to dataloader Given dataloader:filters as config, the dataloader will: - Only scan files which are part of its filter set - Prune objects where the full list of filters provided are not present on the filesystem. --- src/fibad/data_loaders/hsc_data_loader.py | 47 ++++++++++++++++++----- src/fibad/fibad_default_config.toml | 7 ++++ tests/fibad/test_hsc_dataset.py | 45 ++++++++++++++++++++++ 3 files changed, 89 insertions(+), 10 deletions(-) diff --git a/src/fibad/data_loaders/hsc_data_loader.py b/src/fibad/data_loaders/hsc_data_loader.py index 1546403..6c61eb1 100644 --- a/src/fibad/data_loaders/hsc_data_loader.py +++ b/src/fibad/data_loaders/hsc_data_loader.py @@ -49,6 +49,7 @@ def data_set(self): self.config.get("path", "./data"), transform=transform, cutout_shape=self.config.get("crop_to", None), + filters=self.config.get("filters", None), ) def data_loader(self, data_set): @@ -65,7 +66,12 @@ def shape(self): class HSCDataSet(Dataset): def __init__( - self, path: Union[Path, str], *, transform=None, cutout_shape: Optional[tuple[int, int]] = None + self, + path: Union[Path, str], + *, + transform=None, + cutout_shape: Optional[tuple[int, int]] = None, + filters: Optional[list[str]] = None, ): """Initialize an HSC data set from a path. This involves several filesystem scan operations and will ultimately open and read the header info of every fits file in the given directory @@ -78,19 +84,24 @@ def __init__( transform : torchvision.transforms.v2.Transform, optional Transformation to apply to every image in the dataset, by default None cutout_shape: tuple[int,int], optional - Forces all cutouts to be a particular pixel size. RuntimeError is raised if this size is larger - than the pixel dimension of any cutout in the dataset. + Forces all cutouts to be a particular pixel size. If this size is larger than the pixel dimension + of particular cutouts on the filesystem, those objects are dropped from the data set. + filters: list[str], optional + Forces all cutout tensors provided to be from the list of HSC filters provided. If provided, any + cutouts which do not have fits files corresponding to every filter in the list will be dropped + from the data set. Defaults to None. If not provided, the filters available on the filesystem for + the first object in the directory will be used. """ self.path = path self.transform = transform - self.files = self._scan_file_names() + self.files = self._scan_file_names(filters) self.dims = self._scan_file_dimensions() - # We choose the first file in the dict as the prototypical set of filters - # Any objects lacking this full set of filters will be pruned by - # _prune_objects - filters_ref = list(list(self.files.values())[0]) + # If no filters provided, we choose the first file in the dict as the prototypical set of filters + # Any objects lacking this full set of filters will be pruned by _prune_objects + filters_ref = list(list(self.files.values())[0]) if filters is None else filters + self.num_filters = len(filters_ref) self.cutout_shape = cutout_shape @@ -109,20 +120,36 @@ def __init__( logger.info(f"HSC Data set loader has {len(self)} objects") - def _scan_file_names(self) -> dict[str, dict[str, str]]: + def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dict[str, str]]: """Class initialization helper + Parameters + ---------- + filters : list[str], optional + If passed, only these filters will be scanned for from the data files. Defaults to None, which + corresponds to the standard set of filters ["HSC-G","HSC-R","HSC-I","HSC-Z","HSC-Y"]. + Returns ------- dict[str,dict[str,str]] Nested dictionary where the first level maps object_id -> dict, and the second level maps filter_name -> file name. Corresponds to self.files """ + + object_id_regex = r"[0-9]{17}" + filter_regex = r"HSC-[GRIZY]" if filters is None else "|".join(filters) + full_regex = f"({object_id_regex})_.*_({filter_regex}).fits" + files = {} # Go scan the path for object ID's so we have a list. for filepath in Path(self.path).glob("[0-9]*.fits"): filename = filepath.name - m = re.match(r"([0-9]{17})_.*\_(HSC-[GRIZY]).fits", filename) + m = re.match(full_regex, filename) + + # Skip files that don't match the pattern. + if m is None: + continue + object_id = m[1] filter = m[2] diff --git a/src/fibad/fibad_default_config.toml b/src/fibad/fibad_default_config.toml index a4b7195..04659ae 100644 --- a/src/fibad/fibad_default_config.toml +++ b/src/fibad/fibad_default_config.toml @@ -66,6 +66,13 @@ path = "./data" # #crop_to = [100,100] +# Limit data loader to only particular filters when there are more in the data set. +# +# When not provided, the number of filters will be automatically gleaned from the data set. +# Defaults to not provided. +# +#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"] + # Default PyTorch DataLoader parameters batch_size = 500 shuffle = true diff --git a/tests/fibad/test_hsc_dataset.py b/tests/fibad/test_hsc_dataset.py index 5ce74c9..8feda14 100644 --- a/tests/fibad/test_hsc_dataset.py +++ b/tests/fibad/test_hsc_dataset.py @@ -285,3 +285,48 @@ def test_prune_size(caplog): # We should warn that we are dropping objects and the reason assert "Dropping object" in caplog.text assert "too small" in caplog.text + + +def test_partial_filter(caplog): + """Test to ensure when we only load some of the filters, only those filters end up in the dataset""" + caplog.set_level(logging.WARNING) + test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263)) + with FakeFitsFS(test_files): + a = HSCDataSet("thispathdoesnotexist", filters=["HSC-G", "HSC-R"]) + + # 10 objects should load + assert len(a) == 10 + + # The number of filters, and image dimensions should be correct + assert a.shape() == (2, 262, 263) + + # No warnings should be printed + assert caplog.text == "" + + +def test_partial_filter_prune_warn_1_percent(caplog): + """Test to ensure when a the user supplies a filter list and >1% of loaded objects are + missing a filter, that is a warning and that the resulting dataset drops the objects that + are missing filters. + """ + caplog.set_level(logging.WARNING) + + # Generate two files which + test_files = generate_files(num_objects=98, num_filters=3, shape=(100, 100)) + # Object 101 is missing the HSC-G and HSC-I filters, we only provide the R filter + test_files["00000000000000101_missing_g_HSC-R.fits"] = (100, 100) + + with FakeFitsFS(test_files): + a = HSCDataSet("thispathdoesnotexist", filters=["HSC-R", "HSC-I"]) + + # We should have the correct number of objects + assert len(a) == 98 + + # Object 101 should not be loaded + assert "00000000000000101" not in a + + # We should Error log because greater than 5% of the objects were pruned + assert "Greater than 1% of objects in the data directory were pruned." in caplog.text + + # We should warn that we dropped an object explicitly + assert "Dropping object" in caplog.text