Skip to content

Commit

Permalink
Switching to Dataset based modularity (#76)
Browse files Browse the repository at this point in the history
- data_loader_registry is now data_set_registry
- Everything custom data loader is now data set
- [data_loader] config now read in only one place (when we make the pytorch ignite data loader)
- [data_set] config used to set which data set you are using.
- train and predict verbs work.
  • Loading branch information
mtauraso authored Sep 27, 2024
1 parent 734a4e2 commit 5b601db
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 120 deletions.
5 changes: 0 additions & 5 deletions src/fibad/data_loaders/__init__.py

This file was deleted.

43 changes: 0 additions & 43 deletions src/fibad/data_loaders/example_cifar_data_loader.py

This file was deleted.

5 changes: 5 additions & 0 deletions src/fibad/data_sets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .data_set_registry import DATA_SET_REGISTRY, fibad_data_set
from .example_cifar_data_set import CifarDataSet
from .hsc_data_set import HSCDataSet

__all__ = ["fibad_data_set", "DATA_SET_REGISTRY", "CifarDataSet", "HSCDataSet"]
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
import logging

from fibad.plugin_utils import get_or_load_class, update_registry

DATA_LOADER_REGISTRY = {}
logger = logging.getLogger(__name__)

DATA_SET_REGISTRY = {}


def fibad_data_loader(cls):
def fibad_data_set(cls):
"""Decorator to register a data loader with the registry.
Returns
-------
type
The original, unmodified class.
"""
update_registry(DATA_LOADER_REGISTRY, cls.__name__, cls)
required_methods = ["shape", "__getitem__", "__len__"]
for name in required_methods:
if not hasattr(cls, name):
logger.error(f"Fibad data set {cls.__name__} missing required method {name}.")

update_registry(DATA_SET_REGISTRY, cls.__name__, cls)
return cls


def fetch_data_loader_class(runtime_config: dict) -> type:
def fetch_data_set_class(runtime_config: dict) -> type:
"""Fetch the data loader class from the registry.
Parameters
Expand All @@ -36,12 +45,12 @@ def fetch_data_loader_class(runtime_config: dict) -> type:
If no data loader was specified in the runtime configuration.
"""

data_loader_config = runtime_config["data_loader"]
data_loader_cls = None
data_set_config = runtime_config["data_set"]
data_set_cls = None

try:
data_loader_cls = get_or_load_class(data_loader_config, DATA_LOADER_REGISTRY)
data_set_cls = get_or_load_class(data_set_config, DATA_SET_REGISTRY)
except ValueError as exc:
raise ValueError("Error fetching data loader class") from exc
raise ValueError("Error fetching data set class") from exc

return data_loader_cls
return data_set_cls
21 changes: 21 additions & 0 deletions src/fibad/data_sets/example_cifar_data_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ruff: noqa: D101, D102
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

from .data_set_registry import fibad_data_set


@fibad_data_set
class CifarDataSet(CIFAR10):
"""This is simply a version of CIFAR10 that has our needed shape method, and is initialized using
FIBAD config with a transformation that works well for example code.
"""

def __init__(self, config):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
super().__init__(root=config["general"]["data_dir"], train=True, download=True, transform=transform)

def shape(self):
return (3, 32, 32)
Original file line number Diff line number Diff line change
Expand Up @@ -11,71 +11,40 @@
from torch.utils.data import Dataset
from torchvision.transforms.v2 import CenterCrop, Compose, Lambda

from .data_loader_registry import fibad_data_loader
from .data_set_registry import fibad_data_set

logger = logging.getLogger(__name__)


@fibad_data_loader
class HSCDataLoader:
@fibad_data_set
class HSCDataSet(Dataset):
def __init__(self, config):
self.config = config
self._data_set = self.data_set()

def get_data_loader(self):
"""This is the primary method for this class.
Returns
-------
torch.utils.data.DataLoader
The dataloader to use for training.
"""
return self.data_loader(self.data_set())

def data_set(self):
# Only construct a data set once per loader object, since it involves a filesystem scan.
if self.__dict__.get("_data_set", None) is not None:
return self._data_set

# TODO: What will be a reasonable set of tranformations?
# For now tanh all the values so they end up in [-1,1]
# Another option might be sinh, but we'd need to mess with the example autoencoder module
# Because it goes from unbounded NN output space -> [-1,1] with tanh in its decode step.
transform = Lambda(lambd=np.tanh)

crop_to = self.config["data_loader"]["crop_to"]
filters = self.config["data_loader"]["filters"]
crop_to = config["data_loader"]["crop_to"]
filters = config["data_loader"]["filters"]

return HSCDataSet(
self.config["general"]["data_dir"],
self._init_from_path(
config["general"]["data_dir"],
transform=transform,
cutout_shape=crop_to if crop_to else None,
filters=filters if filters else None,
)

def data_loader(self, data_set):
return torch.utils.data.DataLoader(
data_set,
batch_size=self.config["data_loader"]["batch_size"],
shuffle=self.config["data_loader"]["shuffle"],
num_workers=self.config["data_loader"]["num_workers"],
)

def shape(self):
return self.data_set().shape()


class HSCDataSet(Dataset):
def __init__(
def _init_from_path(
self,
path: Union[Path, str],
*,
transform=None,
cutout_shape: Optional[tuple[int, int]] = None,
filters: Optional[list[str]] = None,
):
"""Initialize an HSC data set from a path. This involves several filesystem scan operations and will
ultimately open and read the header info of every fits file in the given directory
"""__init__ helper. Initialize an HSC data set from a path. This involves several filesystem scan
operations and will ultimately open and read the header info of every fits file in the given directory
Parameters
----------
Expand Down
3 changes: 2 additions & 1 deletion src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ epochs = 10
base_channel_size = 32
latent_dim =64

[data_loader]
[data_set]
# Name of the built-in data loader to use or the libpath to an external data loader
# e.g. "user_package.submodule.ExternalDataLoader" or "HSCDataLoader"
name = "HSCDataLoader"

[data_loader]
# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small.
#
# If not provided by user, the default of 'false' scans the directory for the smallest dimensioned files, and
Expand Down
2 changes: 1 addition & 1 deletion src/fibad/models/example_cnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@fibad_model
class ExampleCNN(nn.Module):
def __init__(self, config, shape):
def __init__(self, config, _):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
Expand Down
13 changes: 5 additions & 8 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from torch.utils.data import Dataset

from fibad.config_utils import ConfigDict
from fibad.data_loaders.data_loader_registry import fetch_data_loader_class
from fibad.data_sets.data_set_registry import fetch_data_set_class
from fibad.models.model_registry import fetch_model_class

logger = logging.getLogger(__name__)


def setup_model_and_dataset(config: ConfigDict) -> tuple:
"""
Construct the data loader and the model according to configuration.
Construct the dataset and the model according to configuration.
Primarily exists so the train and predict actions do this the same way.
Expand All @@ -32,15 +32,12 @@ def setup_model_and_dataset(config: ConfigDict) -> tuple:
(model object, data loader object)
"""
# Fetch data loader class specified in config and create an instance of it
data_loader_cls = fetch_data_loader_class(config)
data_loader = data_loader_cls(config)
data_set_cls = fetch_data_set_class(config)
data_set = data_set_cls(config)

# Fetch model class specified in config and create an instance of it
model_cls = fetch_model_class(config)
model = model_cls(config=config, shape=data_loader.shape())

# Get the pytorch.dataset from dataloader, and use it to create a distributed dataloader
data_set = data_loader.data_set()
model = model_cls(config=config, shape=data_set.shape())

return model, data_set

Expand Down
Loading

0 comments on commit 5b601db

Please sign in to comment.