Skip to content

Commit

Permalink
Add a new hydra util
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Sep 26, 2023
1 parent 16bb2ab commit 6c7229f
Show file tree
Hide file tree
Showing 2 changed files with 311 additions and 2 deletions.
106 changes: 104 additions & 2 deletions ranzen/hydra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
from __future__ import annotations
from collections.abc import MutableMapping
from contextlib import contextmanager
from dataclasses import asdict
import dataclasses
from dataclasses import MISSING, Field, asdict, is_dataclass
from enum import Enum
import shlex
from typing import Any, Iterator, Sequence
from typing import Any, Dict, Final, Iterator, Sequence, Union, cast
from typing_extensions import deprecated

import attrs
from attrs import NOTHING, Attribute
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig
from hydra.utils import instantiate
Expand All @@ -20,8 +24,13 @@
"as_pretty_dict",
"reconstruct_cmd",
"recursively_instantiate",
"prepare_for_logging",
"register_hydra_config",
]

NEED: Final = "there should be"
IF: Final = "if an entry has"


def _clean_up_dict(obj: Any) -> Any:
"""Convert enums to strings and filter out _target_."""
Expand All @@ -47,6 +56,7 @@ def reconstruct_cmd() -> str:
return shlex.join([program] + OmegaConf.to_container(args)) # type: ignore[operator]


@deprecated("Use _recursive_=True instead.")
def recursively_instantiate(
hydra_config: DictConfig, *, keys_to_exclude: Sequence[str] = ()
) -> dict[str, Any]:
Expand Down Expand Up @@ -101,3 +111,95 @@ def __init__(self, cs: ConfigStore, *, group_name: str, package: str):
def add_option(self, config_class: type, *, name: str) -> None:
"""Register a schema as an option for this group."""
self._cs.store(group=self._group_name, name=name, node=config_class, package=self._package)


def prepare_for_logging(hydra_config: DictConfig, *, enum_to_str: bool = True) -> dict[str, Any]:
"""Takes a hydra config dict and makes it prettier for logging.
Things this function does: turn enums to strings, resolve any references, mark entries with
their type.
"""
raw_config = OmegaConf.to_container(
hydra_config, throw_on_missing=True, enum_to_str=enum_to_str, resolve=True
)
assert isinstance(raw_config, dict)
raw_config = cast(Dict[str, Any], raw_config)
return {
f"{key}/{OmegaConf.get_type(dict_).__name__}" # type: ignore
if isinstance(dict_ := hydra_config[key], DictConfig)
else key: value
for key, value in raw_config.items()
}


def register_hydra_config(main_cls: type, groups: dict[str, dict[str, type]]) -> None:
"""Check the given config and store everything in the ConfigStore.
This function performs two tasks: 1) make the necessary calls to `ConfigStore`
and 2) run some checks over the given config and if there are problems, try to give a nice
error message.
:param main_cls: The main config class; can be dataclass or attrs.
:param groups: A dictionary that defines all the variants. The keys of top level of the
dictionary should corresponds to the group names, and the keys in the nested dictionaries
should correspond to the names of the options.
:raises ValueError: If the config is malformed in some way.
:raises RuntimeError: If hydra itself is throwing an error.
:example:
.. code-block:: python
@dataclass
class DataModule:
root: Path = Path()
@dataclass
class LinearModel:
dim: int = 256
@dataclass
class CNNModel:
kernel: int = 3
@dataclass
class Config:
dm: DataModule = dataclasses.field(default_factory=DataModule)
model: Any
groups = {"model": {"linear": LinearModel, "cnn": CNNModel}}
register_hydra_config(Config, groups)
"""
configs: Union[tuple[Attribute, ...], tuple[Field, ...]]
is_dc = is_dataclass(main_cls)
if is_dc:
configs = dataclasses.fields(main_cls)
elif attrs.has(main_cls):
configs = attrs.fields(main_cls)
else:
raise ValueError("The given class is neither a dataclass nor an attrs class.")
ABSENT = MISSING if is_dc else NOTHING

for config in configs:
if config.type == Any or (isinstance(typ := config.type, str) and typ == "Any"):
if config.name not in groups:
raise ValueError(f"{IF} type Any, {NEED} variants: `{config.name}`")
if config.default is not ABSENT or (
isinstance(config, Field) and config.default_factory is not ABSENT
):
raise ValueError(f"{IF} type Any, {NEED} no default value: `{config.name}`")
else:
if config.name in groups:
raise ValueError(f"{IF} a real type, {NEED} no variants: `{config.name}`")
if config.default is ABSENT and not (
isinstance(config, Field) and config.default_factory is not ABSENT
):
raise ValueError(f"{IF} a real type, {NEED} a default value: `{config.name}`")

cs = ConfigStore.instance()
cs.store(node=main_cls, name="config_schema")
for group, entries in groups.items():
for name, node in entries.items():
try:
cs.store(node=node, name=name, group=group)
except Exception as exc:
raise RuntimeError(f"{main_cls=}, {node=}, {name=}, {group=}") from exc
207 changes: 207 additions & 0 deletions tests/hydra_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import dataclasses
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any

import attrs
from attrs import define
from omegaconf import MISSING, DictConfig, MissingMandatoryValue, OmegaConf
import pytest

from ranzen.hydra import register_hydra_config, prepare_for_logging


def test_dataclass_no_default() -> None:

Check failure on line 15 in tests/hydra_test.py

View workflow job for this annotation

GitHub Actions / lint_with_ruff

Ruff (I001)

tests/hydra_test.py:1:1: I001 Import block is un-sorted or un-formatted
"""This isn't so much wrong as just clumsy."""

@dataclass
class DataModule:
root: Path

@dataclass
class Config:
dm: DataModule

options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"dm": {"base": DataModule}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_dataclass_any() -> None:
@dataclass
class DataModule:
root: Path

@dataclass
class Config:
dm: Any

# we're assuming that the only reason you want to use Any is that
# you want to use variants
options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"dm": {"base": DataModule}}
register_hydra_config(Config, options)


def test_dataclass_any_with_default() -> None:
"""An Any field with default is completely out."""

@dataclass
class Model:
layers: int = 1

@dataclass
class Config:
model: Any = dataclasses.field(default_factory=Model)

options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_dataclass_with_default() -> None:
"""A normal field with a default should not have variants."""

@dataclass
class Model:
layers: int = 1

@dataclass
class Config:
model: Model = dataclasses.field(default_factory=Model)

options = {}
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_attrs_no_default() -> None:
"""This isn't so much wrong as just clumsy."""

@define
class DataModule:
root: Path

@define
class Config:
dm: DataModule

options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"dm": {"base": DataModule}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_attrs_any() -> None:
@define
class DataModule:
root: Path

@define
class Config:
dm: Any

# we're assuming that the only reason you want to use Any is that
# you want to use variants
options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"dm": {"base": DataModule}}
register_hydra_config(Config, options)


def test_attrs_any_with_default() -> None:
"""An Any field with default is completely out."""

@define
class Model:
layers: int = 1

@define
class Config:
# it should of course be `factory` and not `default` here,
# but OmegaConf is stupid as always
model: Any = attrs.field(default=Model)

options = {}
with pytest.raises(ValueError):
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_attrs_with_default() -> None:
"""A normal field with a default should not have variants."""

@define
class Model:
layers: int = 1

@define
class Config:
# it should of course be `factory` and not `default` here,
# but OmegaConf is stupid as always
model: Model = attrs.field(default=Model)

options = {}
register_hydra_config(Config, options)

options = {"model": {"base": Model}}
with pytest.raises(ValueError):
register_hydra_config(Config, options)


def test_logging_dict() -> None:
class TrainingType(Enum):
iter = auto()
epoch = auto()

@dataclass
class DataModule:
root: Path = MISSING

@dataclass
class Model:
layers: int = 1

@dataclass
class Config:
dm: DataModule = dataclasses.field(default_factory=DataModule)
model: Model = dataclasses.field(default_factory=Model)
train: TrainingType = TrainingType.iter

hydra_config: DictConfig = OmegaConf.structured(Config)
hydra_config.model.layers = 3

with pytest.raises(MissingMandatoryValue): # `dm.root` is missing
logging_dict = prepare_for_logging(hydra_config)

hydra_config.dm.root = "."
logging_dict = prepare_for_logging(hydra_config)

assert logging_dict == {
"dm/DataModule": {"root": Path()},
"model/Model": {"layers": 3},
"train": "iter",
}

0 comments on commit 6c7229f

Please sign in to comment.