Skip to content

Commit

Permalink
Adding configurable list of filters to dataloader
Browse files Browse the repository at this point in the history
Given dataloader:filters as config, the dataloader will:
- Only scan files which are part of its filter set
- Prune objects where the full list of filters provided
  are not present on the filesystem.
  • Loading branch information
mtauraso committed Aug 27, 2024
1 parent 9cd3680 commit 5277c7f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 10 deletions.
47 changes: 37 additions & 10 deletions src/fibad/data_loaders/hsc_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def data_set(self):
self.config.get("path", "./data"),
transform=transform,
cutout_shape=self.config.get("crop_to", None),
filters=self.config.get("filters", None),
)

def data_loader(self, data_set):
Expand All @@ -65,7 +66,12 @@ def shape(self):

class HSCDataSet(Dataset):
def __init__(
self, path: Union[Path, str], *, transform=None, cutout_shape: Optional[tuple[int, int]] = None
self,
path: Union[Path, str],
*,
transform=None,
cutout_shape: Optional[tuple[int, int]] = None,
filters: Optional[list[str]] = None,
):
"""Initialize an HSC data set from a path. This involves several filesystem scan operations and will
ultimately open and read the header info of every fits file in the given directory
Expand All @@ -78,19 +84,24 @@ def __init__(
transform : torchvision.transforms.v2.Transform, optional
Transformation to apply to every image in the dataset, by default None
cutout_shape: tuple[int,int], optional
Forces all cutouts to be a particular pixel size. RuntimeError is raised if this size is larger
than the pixel dimension of any cutout in the dataset.
Forces all cutouts to be a particular pixel size. If this size is larger than the pixel dimension
of particular cutouts on the filesystem, those objects are dropped from the data set.
filters: list[str], optional
Forces all cutout tensors provided to be from the list of HSC filters provided. If provided, any
cutouts which do not have fits files corresponding to every filter in the list will be dropped
from the data set. Defaults to None. If not provided, the filters available on the filesystem for
the first object in the directory will be used.
"""
self.path = path
self.transform = transform

self.files = self._scan_file_names()
self.files = self._scan_file_names(filters)
self.dims = self._scan_file_dimensions()

# We choose the first file in the dict as the prototypical set of filters
# Any objects lacking this full set of filters will be pruned by
# _prune_objects
filters_ref = list(list(self.files.values())[0])
# If no filters provided, we choose the first file in the dict as the prototypical set of filters
# Any objects lacking this full set of filters will be pruned by _prune_objects
filters_ref = list(list(self.files.values())[0]) if filters is None else filters

self.num_filters = len(filters_ref)

self.cutout_shape = cutout_shape
Expand All @@ -109,20 +120,36 @@ def __init__(

logger.info(f"HSC Data set loader has {len(self)} objects")

def _scan_file_names(self) -> dict[str, dict[str, str]]:
def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dict[str, str]]:
"""Class initialization helper
Parameters
----------
filters : list[str], optional
If passed, only these filters will be scanned for from the data files. Defaults to None, which
corresponds to the standard set of filters ["HSC-G","HSC-R","HSC-I","HSC-Z","HSC-Y"].
Returns
-------
dict[str,dict[str,str]]
Nested dictionary where the first level maps object_id -> dict, and the second level maps
filter_name -> file name. Corresponds to self.files
"""

object_id_regex = r"[0-9]{17}"
filter_regex = r"HSC-[GRIZY]" if filters is None else "|".join(filters)
full_regex = f"({object_id_regex})_.*_({filter_regex}).fits"

files = {}
# Go scan the path for object ID's so we have a list.
for filepath in Path(self.path).glob("[0-9]*.fits"):
filename = filepath.name
m = re.match(r"([0-9]{17})_.*\_(HSC-[GRIZY]).fits", filename)
m = re.match(full_regex, filename)

# Skip files that don't match the pattern.
if m is None:
continue

object_id = m[1]
filter = m[2]

Expand Down
7 changes: 7 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ path = "./data"
#
#crop_to = [100,100]

# Limit data loader to only particular filters when there are more in the data set.
#
# When not provided, the number of filters will be automatically gleaned from the data set.
# Defaults to not provided.
#
#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]

# Default PyTorch DataLoader parameters
batch_size = 500
shuffle = true
Expand Down
45 changes: 45 additions & 0 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,48 @@ def test_prune_size(caplog):
# We should warn that we are dropping objects and the reason
assert "Dropping object" in caplog.text
assert "too small" in caplog.text


def test_partial_filter(caplog):
"""Test to ensure when we only load some of the filters, only those filters end up in the dataset"""
caplog.set_level(logging.WARNING)
test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263))
with FakeFitsFS(test_files):
a = HSCDataSet("thispathdoesnotexist", filters=["HSC-G", "HSC-R"])

# 10 objects should load
assert len(a) == 10

# The number of filters, and image dimensions should be correct
assert a.shape() == (2, 262, 263)

# No warnings should be printed
assert caplog.text == ""


def test_partial_filter_prune_warn_1_percent(caplog):
"""Test to ensure when a the user supplies a filter list and >1% of loaded objects are
missing a filter, that is a warning and that the resulting dataset drops the objects that
are missing filters.
"""
caplog.set_level(logging.WARNING)

# Generate two files which
test_files = generate_files(num_objects=98, num_filters=3, shape=(100, 100))
# Object 101 is missing the HSC-G and HSC-I filters, we only provide the R filter
test_files["00000000000000101_missing_g_HSC-R.fits"] = (100, 100)

with FakeFitsFS(test_files):
a = HSCDataSet("thispathdoesnotexist", filters=["HSC-R", "HSC-I"])

# We should have the correct number of objects
assert len(a) == 98

# Object 101 should not be loaded
assert "00000000000000101" not in a

# We should Error log because greater than 5% of the objects were pruned
assert "Greater than 1% of objects in the data directory were pruned." in caplog.text

# We should warn that we dropped an object explicitly
assert "Dropping object" in caplog.text

0 comments on commit 5277c7f

Please sign in to comment.