Skip to content

Commit

Permalink
Initial version of filter_catalog feature. (#113)
Browse files Browse the repository at this point in the history
- We can take a fits file as a config
- We filter objects_ids out of a big dataset based on it
- We also skip filesystem checks if there is enough info in the filter catalog.
- Lacks any unit testing
- Added the prepare verb, but right now it just gives you the dataset object
  when run from a notebook.
  • Loading branch information
mtauraso authored Nov 8, 2024
1 parent c2e32c7 commit e666dbe
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 11 deletions.
99 changes: 91 additions & 8 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import torch
from astropy.io import fits
from astropy.table import Table
from torch.utils.data import Dataset
from torchvision.transforms.v2 import CenterCrop, Compose, Lambda

Expand Down Expand Up @@ -240,6 +241,10 @@ def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[self.indexes[idx]]


dim_dict = dict[str, tuple[int, int]]
files_dict = dict[str, dict[str, str]]


class HSCDataSetContainer(Dataset):
def __init__(self, config):
# TODO: What will be a reasonable set of tranformations?
Expand All @@ -250,12 +255,14 @@ def __init__(self, config):

crop_to = config["data_set"]["crop_to"]
filters = config["data_set"]["filters"]
filter_catalog = config["data_set"]["filter_catalog"]

self._init_from_path(
config["general"]["data_dir"],
transform=transform,
cutout_shape=crop_to if crop_to else None,
filters=filters if filters else None,
filter_catalog=Path(filter_catalog) if filter_catalog else None,
)

def _init_from_path(
Expand All @@ -265,6 +272,7 @@ def _init_from_path(
transform=None,
cutout_shape: Optional[tuple[int, int]] = None,
filters: Optional[list[str]] = None,
filter_catalog: Optional[Path] = None,
):
"""__init__ helper. Initialize an HSC data set from a path. This involves several filesystem scan
operations and will ultimately open and read the header info of every fits file in the given directory
Expand All @@ -284,12 +292,31 @@ def _init_from_path(
cutouts which do not have fits files corresponding to every filter in the list will be dropped
from the data set. Defaults to None. If not provided, the filters available on the filesystem for
the first object in the directory will be used.
filter_catalog: Path, optional
Path to a .fits file which specifies objects and or files to use directly, bypassing the default
of attempting to use every file in the path.
Columns for this fits file are object_id (required), filter (optional), filename (optional), and
dims (optional tuple of x/y pixel size of images).
- Filenames must be relative to the path provided to this function.
- When filters and filenames are both provided, initialization skips a directory listing, which
can provide better performance on large datasets.
- When filters, filenames, and dims are specified we also skip opening the files to get
the dimensions. This can also provide better performance on large datasets.
"""
self.path = path
self.transform = transform

self.files = self._scan_file_names(filters)
self.dims = self._scan_file_dimensions()
self.filter_catalog = self._read_filter_catalog(filter_catalog)
if isinstance(self.filter_catalog, tuple):
self.files = self.filter_catalog[0]
self.dims = self.filter_catalog[1]
print(self.dims)
elif isinstance(self.filter_catalog, dict):
self.files = self.filter_catalog
self.dims = self._scan_file_dimensions()
else:
self.files = self._scan_file_names(filters)
self.dims = self._scan_file_dimensions()

# If no filters provided, we choose the first file in the dict as the prototypical set of filters
# Any objects lacking this full set of filters will be pruned by _prune_objects
Expand All @@ -313,7 +340,7 @@ def _init_from_path(

logger.info(f"HSC Data set loader has {len(self)} objects")

def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dict[str, str]]:
def _scan_file_names(self, filters: Optional[list[str]] = None) -> files_dict:
"""Class initialization helper
Parameters
Expand All @@ -335,11 +362,17 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dic

files = {}
# Go scan the path for object ID's so we have a list.
for filepath in Path(self.path).glob("[0-9]*.fits"):
for filepath in Path(self.path).iterdir():
filename = filepath.name
m = re.match(full_regex, filename)

# Skip files that don't match the pattern.
# If we are filtering based off a user-provided catalog of object ids, Filter out any
# objects_ids not in the catalog. Do this before regex match for speed of discarding
# irrelevant files.
if isinstance(self.filter_catalog, list) and filename[:17] not in self.filter_catalog:
continue

m = re.match(full_regex, filename)
# Skip files that don't allow us to extract both object_id and filter
if m is None:
continue

Expand All @@ -359,7 +392,57 @@ def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dic

return files

def _scan_file_dimensions(self) -> dict[str, tuple[int, int]]:
def _read_filter_catalog(
self, filter_catalog_path: Optional[Path]
) -> Optional[Union[list[str], files_dict, tuple[files_dict, dim_dict]]]:
if filter_catalog_path is None:
return None

if not filter_catalog_path.exists():
logger.error(f"Filter catalog file {filter_catalog_path} given in config does not exist.")
return None

table = Table.read(filter_catalog_path, format="fits")
colnames = table.colnames
if "object_id" not in colnames:
logger.error(f"Filter catalog file {filter_catalog_path} has no column object_id")
return None

# We are dealing with just a list of object_ids
if "filter" not in colnames and "filename" not in colnames:
return list(table["object_id"])

# Or a table that lacks both filter and filename
elif "filter" not in colnames or "filename" not in colnames:
msg = f"Filter catalog file {filter_catalog_path} provides one of filters or filenames "
msg += "without the other. Filesystem scan will still occur without both defined."
logger.warning(msg)
return list(set(table["object_id"]))

# We have filter and filename defined so we can assemble the catalog at file level.
filter_catalog = {}
if "dim" in colnames:
dim_catalog = {}

for row in table:
object_id = row["object_id"]
filter = row["filter"]
filename = row["filename"]

if object_id not in filter_catalog:
filter_catalog[object_id] = {}

filter_catalog[object_id][filter] = filename

# Dimension is optional
if "dim" in colnames:
if object_id not in dim_catalog:
dim_catalog[object_id] = []
dim_catalog[object_id].append(tuple(row["dim"]))

return (filter_catalog, dim_catalog) if "dim" in colnames else filter_catalog

def _scan_file_dimensions(self) -> dim_dict:
# Scan the filesystem to get the widths and heights of all images into a dict
return {
object_id: [self._fits_file_dims(filepath) for filepath in self._object_files(object_id)]
Expand Down Expand Up @@ -445,7 +528,7 @@ def _check_file_dimensions(self) -> tuple[int, int]:
The minimum width and height in pixels of the entire dataset. In other words: the maximal image
size in pixels that can be generated from ALL cutout images via cropping.
"""
# Find the makximal cutout size that all images can support
# Find the maximal cutout size that all images can support
all_widths = [shape[0] for shape_list in self.dims.values() for shape in shape_list]
cutout_width = np.min(all_widths)

Expand Down
5 changes: 5 additions & 0 deletions src/fibad/downloadCutout/downloadCutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from collections.abc import Generator
from typing import IO, Any, Callable, Optional, Union, cast

import numpy as np

__all__ = []


Expand Down Expand Up @@ -762,6 +764,9 @@ def parse_bool(s: Union[str, bool]) -> bool:
if isinstance(s, bool):
return s

if isinstance(s, np.bool):
return s

return {
"false": False,
"f": False,
Expand Down
10 changes: 9 additions & 1 deletion src/fibad/fibad.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Fibad:
CLI functions in fibad_cli are implemented by calling this class
"""

verbs = ["train", "predict", "download"]
verbs = ["train", "predict", "download", "prepare"]

def __init__(self, *, config_file: Union[Path, str] = None, setup_logging: bool = True):
"""Initialize fibad. Always applies the default config, and merges it with any provided config file.
Expand Down Expand Up @@ -177,3 +177,11 @@ def predict(self, **kwargs):
from .predict import run

return run(config=self.config, **kwargs)

def prepare(self, **kwargs):
"""
See Fibad.predict.run()
"""
from .prepare import run

return run(config=self.config, **kwargs)
4 changes: 4 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ crop_to = false
#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]
filters = false

# A fits file which specifies object IDs to filter a large dataset in [general].data_dir down
# Implementation is dataset class dependent. Default is false meaning now filtering.
filter_catalog = false

[data_loader]
# Default PyTorch DataLoader parameters
batch_size = 32
Expand Down
21 changes: 21 additions & 0 deletions src/fibad/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import logging

from fibad.pytorch_ignite import setup_model_and_dataset

logger = logging.getLogger(__name__)


def run(config):
"""Prepare the dataset for a given model and data loader.
Parameters
----------
config : dict
The parsed config file as a nested
dict
"""

_, data_set = setup_model_and_dataset(config, split=config["train"]["split"])

logger.info("Finished Prepare")
return data_set
13 changes: 11 additions & 2 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, test_files: dict):
self.test_files = test_files

mock_paths = [Path(x) for x in list(test_files.keys())]
target = "fibad.data_sets.hsc_data_set.Path.glob"
target = "fibad.data_sets.hsc_data_set.Path.iterdir"
self.patchers.append(mock.patch(target, return_value=mock_paths))

mock_fits_open = mock.Mock(side_effect=self._open_file)
Expand All @@ -53,7 +53,15 @@ def __exit__(self, *exc):
patcher.stop()


def mkconfig(crop_to=False, filters=False, train_size=0.2, test_size=0.6, validate_size=0, seed=False):
def mkconfig(
crop_to=False,
filters=False,
train_size=0.2,
test_size=0.6,
validate_size=0,
seed=False,
filter_catalog=False,
):
"""Makes a configuration that points at nonexistent path so HSCDataSet.__init__ will create an object,
and our FakeFitsFS shim can be called.
"""
Expand All @@ -62,6 +70,7 @@ def mkconfig(crop_to=False, filters=False, train_size=0.2, test_size=0.6, valida
"data_set": {
"crop_to": crop_to,
"filters": filters,
"filter_catalog": filter_catalog,
},
"prepare": {
"seed": seed,
Expand Down

0 comments on commit e666dbe

Please sign in to comment.