Skip to content

Commit

Permalink
Breaking setup_model_and_dataset into two functions. (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag authored Dec 13, 2024
1 parent 4d60235 commit 3ac3130
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 21 deletions.
10 changes: 8 additions & 2 deletions src/fibad/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from torch import Tensor

from fibad.config_utils import ConfigDict, create_results_dir, log_runtime_config
from fibad.pytorch_ignite import create_evaluator, dist_data_loader, setup_model_and_dataset
from fibad.pytorch_ignite import (
create_evaluator,
dist_data_loader,
setup_dataset,
setup_model,
)

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +24,8 @@ def run(config: ConfigDict):
The parsed config file as a nested dict
"""

model, data_set = setup_model_and_dataset(config, split=config["predict"]["split"])
data_set = setup_dataset(config, split=config["predict"]["split"])
model = setup_model(config, data_set)
logger.info(f"data set has length {len(data_set)}")
data_loader = dist_data_loader(data_set, config)

Expand Down
4 changes: 2 additions & 2 deletions src/fibad/prepare.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from fibad.pytorch_ignite import setup_model_and_dataset
from fibad.pytorch_ignite import setup_dataset

logger = logging.getLogger(__name__)

Expand All @@ -15,7 +15,7 @@ def run(config):
dict
"""

_, data_set = setup_model_and_dataset(config, split=config["train"]["split"])
data_set = setup_dataset(config, split=config["train"]["split"])

logger.info("Finished Prepare")
return data_set
44 changes: 31 additions & 13 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
from pathlib import Path
from typing import Any, Callable
from typing import Any, Callable, Union

import ignite.distributed as idist
import torch
Expand All @@ -18,33 +18,51 @@
logger = logging.getLogger(__name__)


def setup_model_and_dataset(config: ConfigDict, split: str) -> tuple:
"""
Construct the dataset and the model according to configuration.
Primarily exists so the train and predict actions do this the same way.
def setup_dataset(config: ConfigDict, split: Union[str, bool] = False) -> Dataset:
"""Create a dataset object based on the configuration.
Parameters
----------
config : ConfigDict
The entire runtime config
split : str
The name of the split we want to use from the data set.
The entire runtime configuration
split : Union[str,bool], optional
The name of the split that we want to use. If False, use the entire
dataset, by default False
Returns
-------
tuple
(model object, data loader object)
Dataset
An instance of the dataset class specified in the configuration
"""

# Fetch data loader class specified in config and create an instance of it
data_set_cls = fetch_data_set_class(config)
data_set = data_set_cls(config, split)

return data_set


def setup_model(config: ConfigDict, dataset: Dataset) -> torch.nn.Module:
"""Create a model object based on the configuration.
Parameters
----------
config : ConfigDict
The entire runtime configuration
dataset : Dataset
Only used to determine the input shape of the data
Returns
-------
torch.nn.Module
An instance of the model class specified in the configuration
"""

# Fetch model class specified in config and create an instance of it
model_cls = fetch_model_class(config)
model = model_cls(config=config, shape=data_set.shape())
model = model_cls(config=config, shape=dataset.shape())

return model, data_set
return model


def dist_data_loader(data_set: Dataset, config: ConfigDict, split: str):
Expand Down
4 changes: 2 additions & 2 deletions src/fibad/rebuild_manifest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from fibad.pytorch_ignite import setup_model_and_dataset
from fibad.pytorch_ignite import setup_dataset

logger = logging.getLogger(__name__)

Expand All @@ -17,7 +17,7 @@ def run(config):

config["rebuild_manifest"] = True

_, data_set = setup_model_and_dataset(config, split=config["train"]["split"])
data_set = setup_dataset(config, split=config["train"]["split"])

logger.info("Starting rebuild of manifest")

Expand Down
11 changes: 9 additions & 2 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

from fibad.config_utils import create_results_dir, log_runtime_config
from fibad.gpu_monitor import GpuMonitor
from fibad.pytorch_ignite import create_trainer, create_validator, dist_data_loader, setup_model_and_dataset
from fibad.pytorch_ignite import (
create_trainer,
create_validator,
dist_data_loader,
setup_dataset,
setup_model,
)

logger = logging.getLogger(__name__)

Expand All @@ -26,7 +32,8 @@ def run(config):
tensorboardx_logger = SummaryWriter(log_dir=results_dir)

# Instantiate the model and dataset
model, data_set = setup_model_and_dataset(config, split=config["train"]["split"])
data_set = setup_dataset(config, split=config["train"]["split"])
model = setup_model(config, data_set)

# Create a data loader for the training set
train_data_loader = dist_data_loader(data_set, config, "train")
Expand Down

0 comments on commit 3ac3130

Please sign in to comment.