Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centralize config defaults, add safety for Config updates #68

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 161 additions & 16 deletions src/fibad/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Union

Expand All @@ -6,6 +7,160 @@
DEFAULT_CONFIG_FILEPATH = Path(__file__).parent.resolve() / "fibad_default_config.toml"
DEFAULT_USER_CONFIG_FILEPATH = Path.cwd() / "fibad_config.toml"

logger = logging.getLogger(__name__)


class ConfigDict(dict):
"""The purpose of this class is to ensure key errors on config dictionaries return something helpful.
and to discourage mutation actions on config dictionaries that should not happen at runtime.
"""

# TODO: Should there be some sort of "bake" method which occurs after config processing, and
# percolates down to nested ConfigDicts and prevents __setitem__ and other mutations of dictionary
# values? i.e. a method to make a config dictionary fully immutable (or very difficult/annoying to
# mutuate) before we pass control to possibly external module code that is relying on the dictionary
# to be static througout the run.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems like a really good idea. It will be interesting to see, when we have an actual user(s) interacting with Fibad if there is a desire to modify the merged config prior to running train or predict. But I agree, that during runtime, the config should be/becomes immutable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, I do feel like its complicated enough that we should hold off for now.


__slots__ = () # we don't need __dict__ on this object at all.

def __init__(self, map: dict, **kwargs):
super().__init__(map, **kwargs)

# Replace all dictionary keys with values recursively.
for key in self:
if isinstance(self[key], dict) and not isinstance(self[key], ConfigDict):
self[key] = ConfigDict(map=self[key])

def __missing__(self, key):
msg = f"Accessed configuration key/section {key} which has not been defined. "
msg += "All configuration keys and sections must be defined in {DEFAULT_CONFIG_FILEPATH}"
logger.fatal(msg)
raise RuntimeError(msg)

Check warning on line 38 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L35-L38

Added lines #L35 - L38 were not covered by tests

def get(self, key, default=None):
"""Nonfunctional stub of dict.get() which errors always"""
msg = f"ConfigDict.get({key},{default}) called. "
msg += "Please index config dictionaries with [] or __getitem__() only. "
msg += "Configuration keys and sections must be defined in {DEFAULT_CONFIG_FILEPATH}"
logger.fatal(msg)
raise RuntimeError(msg)

Check warning on line 46 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L42-L46

Added lines #L42 - L46 were not covered by tests

def __delitem__(self, key):
raise RuntimeError("Removing keys or sections from a ConfigDict using del is not supported")

Check warning on line 49 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L49

Added line #L49 was not covered by tests

def pop(self, key, default):
"""Nonfunctional stub of dict.pop() which errors always"""
raise RuntimeError("Removing keys or sections from a ConfigDict using pop() is not supported")

Check warning on line 53 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L53

Added line #L53 was not covered by tests

def popitem(self):
"""Nonfunctional stub of dict.popitem() which errors always"""
raise RuntimeError("Removing keys or sections from a ConfigDict using popitem() is not supported")

Check warning on line 57 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L57

Added line #L57 was not covered by tests

def clear(self):
"""Nonfunctional stub of dict.clear() which errors always"""
raise RuntimeError("Removing keys or sections from a ConfigDict using clear() is not supported")

Check warning on line 61 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L61

Added line #L61 was not covered by tests


def validate_runtime_config(runtime_config: ConfigDict):
"""Validates that defaults exist for every config value before we begin to use a config.

This should be called at the moment the runtime config is fully baked for science calculations. Meaning
that all sources of config info have been combined in `runtime_config` and there are no further
config altering operations that will be performed.

Parameters
----------
runtime_config : ConfigDict
The current runtime config dictionary.

Raises
------
RuntimeError
Raised if any config that exists in the runtime config does not have a default defined
"""
default_config = _read_runtime_config(DEFAULT_CONFIG_FILEPATH)
_validate_runtime_config(runtime_config, default_config)

Check warning on line 82 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L81-L82

Added lines #L81 - L82 were not covered by tests


def _validate_runtime_config(runtime_config: ConfigDict, default_config: ConfigDict):
"""Recursive helper for validate_runtime_config.

The two arguments passed in must represent the same nesting level of the runtime config and all
default config parameters respectively.

Parameters
----------
runtime_config : ConfigDict
Nested config dictionary representing the runtime config.
default_config : ConfigDict
Nested config dictionary representing the defaults

Raises
------
RuntimeError
Raised if any config that exists in the runtime config does not have a default defined in
default_config
"""
for key in runtime_config:
if key not in default_config:
msg = f"Runtime config contains key or section {key} which has no default defined."
msg += f"All configuration keys and sections must be defined in {DEFAULT_CONFIG_FILEPATH}"
raise RuntimeError(msg)

Check warning on line 108 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L104-L108

Added lines #L104 - L108 were not covered by tests

if isinstance(runtime_config[key], dict):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be worth validating that if runtime_config[key] is a dict, that default_config[key] is as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is reasonable paranoia.

_validate_runtime_config(runtime_config[key], default_config[key])

Check warning on line 111 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L110-L111

Added lines #L110 - L111 were not covered by tests


def _read_runtime_config(config_filepath: Union[Path, str] = DEFAULT_CONFIG_FILEPATH) -> ConfigDict:
"""Read a single toml file and return a config dictionary

Parameters
----------
config_filepath : Union[Path, str], optional
What file is to be read, by default DEFAULT_CONFIG_FILEPATH

Returns
-------
ConfigDict
The contents of that toml file as nested ConfigDicts
"""
with open(config_filepath, "r") as f:
parsed_dict = toml.load(f)
return ConfigDict(parsed_dict)


def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = None) -> Path:
"""Resolve a user-supplied runtime config to where we will actually pull config from.

1) If a runtime config file is specified, we will use that file
2) If not file is specified and there is a file named "fibad_config.toml" in the cwd we will use that file
3) If no file is specified and there is no file named "fibad_config.toml" in the current working directory
we will exclusively work off the configuration defaults in the packaged "fibad_default_config.toml"
file.

Parameters
----------
runtime_config_filepath : Union[Path, str, None], optional
Location of the supplied config file, by default None

Returns
-------
Path
Path to the configuration file ultimately used for config resolution. When we fall back to the
package supplied default config file, the Path to that file is returned.
"""
if isinstance(runtime_config_filepath, str):
runtime_config_filepath = Path(runtime_config_filepath)

# If a named config exists in cwd, and no config specified on cmdline, use cwd.
if runtime_config_filepath is None and DEFAULT_USER_CONFIG_FILEPATH.exists():
runtime_config_filepath = DEFAULT_USER_CONFIG_FILEPATH

Check warning on line 157 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L157

Added line #L157 was not covered by tests

if runtime_config_filepath is None:
runtime_config_filepath = DEFAULT_CONFIG_FILEPATH

Check warning on line 160 in src/fibad/config_utils.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/config_utils.py#L160

Added line #L160 was not covered by tests

return runtime_config_filepath


def get_runtime_config(
runtime_config_filepath: Union[Path, str, None] = None,
Expand Down Expand Up @@ -33,24 +188,14 @@
The parsed runtime configuration.
"""

if isinstance(runtime_config_filepath, str):
runtime_config_filepath = Path(runtime_config_filepath)

with open(default_config_filepath, "r") as f:
default_runtime_config = toml.load(f)

# If a named config exists in cwd, and no config specified on cmdline, use cwd.
if runtime_config_filepath is None and DEFAULT_USER_CONFIG_FILEPATH.exists():
runtime_config_filepath = DEFAULT_USER_CONFIG_FILEPATH
runtime_config_filepath = resolve_runtime_config(runtime_config_filepath)
default_runtime_config = _read_runtime_config(default_config_filepath)

if runtime_config_filepath is not None:
if runtime_config_filepath is not DEFAULT_CONFIG_FILEPATH:
if not runtime_config_filepath.exists():
raise FileNotFoundError(f"Runtime configuration file not found: {runtime_config_filepath}")

with open(runtime_config_filepath, "r") as f:
users_runtime_config = toml.load(f)

final_runtime_config = merge_configs(default_runtime_config, users_runtime_config)
users_runtime_config = _read_runtime_config(runtime_config_filepath)
final_runtime_config = merge_configs(default_runtime_config, users_runtime_config)
else:
final_runtime_config = default_runtime_config

Expand Down Expand Up @@ -80,7 +225,7 @@
final_config = default_config.copy()
for k, v in user_config.items():
if k in final_config and isinstance(final_config[k], dict) and isinstance(v, dict):
final_config[k] = merge_configs(default_config.get(k, {}), v)
final_config[k] = merge_configs(default_config[k], v)
else:
final_config[k] = v

Expand Down
2 changes: 1 addition & 1 deletion src/fibad/data_loaders/data_loader_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
If no data loader was specified in the runtime configuration.
"""

data_loader_config = runtime_config.get("data_loader", {})
data_loader_config = runtime_config["data_loader"]

Check warning on line 39 in src/fibad/data_loaders/data_loader_registry.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_loaders/data_loader_registry.py#L39

Added line #L39 was not covered by tests
data_loader_cls = None

try:
Expand Down
12 changes: 6 additions & 6 deletions src/fibad/data_loaders/example_cifar_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

@fibad_data_loader
class CifarDataLoader:
def __init__(self, data_loader_config):
self.config = data_loader_config
def __init__(self, config):
self.config = config

Check warning on line 13 in src/fibad/data_loaders/example_cifar_data_loader.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_loaders/example_cifar_data_loader.py#L13

Added line #L13 was not covered by tests

def shape(self):
return (3, 32, 32)
Expand All @@ -31,13 +31,13 @@
)

return torchvision.datasets.CIFAR10(
root=self.config.get("path", "./data"), train=True, download=True, transform=transform
root=self.config["general"]["data_dir"], train=True, download=True, transform=transform
)

def data_loader(self, data_set):
return torch.utils.data.DataLoader(
data_set,
batch_size=self.config.get("batch_size", 4),
shuffle=self.config.get("shuffle", True),
num_workers=self.config.get("num_workers", 2),
batch_size=self.config["data_loader"]["batch_size"],
shuffle=self.config["data_loader"]["shuffle"],
num_workers=self.config["data_loader"]["num_workers"],
)
21 changes: 11 additions & 10 deletions src/fibad/data_loaders/hsc_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

@fibad_data_loader
class HSCDataLoader:
def __init__(self, data_loader_config):
self.config = data_loader_config
def __init__(self, config):
self.config = config

Check warning on line 22 in src/fibad/data_loaders/hsc_data_loader.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_loaders/hsc_data_loader.py#L22

Added line #L22 was not covered by tests
self._data_set = self.data_set()

def get_data_loader(self):
Expand All @@ -37,27 +37,28 @@
if self.__dict__.get("_data_set", None) is not None:
return self._data_set

self.config.get("path", "./data")

# TODO: What will be a reasonable set of tranformations?
# For now tanh all the values so they end up in [-1,1]
# Another option might be sinh, but we'd need to mess with the example autoencoder module
# Because it goes from unbounded NN output space -> [-1,1] with tanh in its decode step.
transform = Lambda(lambd=np.tanh)

crop_to = self.config["data_loader"]["crop_to"]
filters = self.config["data_loader"]["filters"]

Check warning on line 47 in src/fibad/data_loaders/hsc_data_loader.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/data_loaders/hsc_data_loader.py#L46-L47

Added lines #L46 - L47 were not covered by tests

return HSCDataSet(
self.config.get("path", "./data"),
self.config["general"]["data_dir"],
transform=transform,
cutout_shape=self.config.get("crop_to", None),
filters=self.config.get("filters", None),
cutout_shape=crop_to if crop_to else None,
filters=filters if filters else None,
)

def data_loader(self, data_set):
return torch.utils.data.DataLoader(
data_set,
batch_size=self.config.get("batch_size", 4),
shuffle=self.config.get("shuffle", True),
num_workers=self.config.get("num_workers", 2),
batch_size=self.config["data_loader"]["batch_size"],
shuffle=self.config["data_loader"]["shuffle"],
num_workers=self.config["data_loader"]["num_workers"],
)

def shape(self):
Expand Down
Loading