diff --git a/gpjax/base/module.py b/gpjax/base/module.py index bcc7b66af..3c4419007 100644 --- a/gpjax/base/module.py +++ b/gpjax/base/module.py @@ -87,12 +87,21 @@ def static_field( # noqa: PLR0913 ) +def _inherited_metadata(cls: type) -> Dict[str]: + meta_data = dict() + for parent_class in cls.mro(): + if parent_class is not cls and parent_class is not Module: + if issubclass(parent_class, Module): + meta_data.update(parent_class._pytree__meta) + return meta_data + + class Module(Pytree): _pytree__meta: Dict[str, Any] = static_field() def __init_subclass__(cls, mutable: bool = False): - cls._pytree__meta = {} super().__init_subclass__(mutable=mutable) + cls._pytree__meta = _inherited_metadata(cls) class_vars = vars(cls) for field, value in class_vars.items(): if ( diff --git a/tests/test_base/test_module.py b/tests/test_base/test_module.py index e9d33e686..8878660ea 100644 --- a/tests/test_base/test_module.py +++ b/tests/test_base/test_module.py @@ -876,3 +876,66 @@ class Foo(Module, mutable=True): # test mutation pytree.x = 4 assert pytree.x == 4 + + +@pytest.mark.parametrize("is_dataclass", [True, False]) +@pytest.mark.parametrize("iterable", [list, tuple]) +def test_inheritance_different_meta(is_dataclass, iterable): + class Tree(Module): + a: int = param_field(bijector=tfb.Identity(), default=1) + b: int = param_field(bijector=tfb.Softplus(), default=2) + c: int = param_field(bijector=tfb.Tanh(), default=0, trainable=False) + + def __init__(self, a=1.0, b=2.0, c=0.0): + self.a = a + self.b = b + self.c = c + + if is_dataclass: + Tree = dataclass(Tree) + + class SubTree(Tree): + pass + + tree = SubTree() + + assert isinstance(tree, Module) + assert isinstance(tree, Pytree) + + assert tree.a == 1.0 + assert tree.b == 2.0 + assert tree.c == 0.0 + + meta_tree = meta(tree) + + assert isinstance(meta_tree, Module) + assert isinstance(meta_tree, Pytree) + + assert isinstance(meta_tree.a["bijector"], tfb.Identity) + assert meta_tree.a["trainable"] is True + assert isinstance(meta_tree.b["bijector"], tfb.Softplus) + assert meta_tree.b["trainable"] is True + assert isinstance(meta_tree.c["bijector"], tfb.Tanh) + assert meta_tree.c["trainable"] is False + + # Test constrain and unconstrain + + constrained_tree = tree.constrain() + unconstrained_tree = tree.unconstrain() + + assert jtu.tree_structure(unconstrained_tree) == jtu.tree_structure(tree) + assert jtu.tree_structure(constrained_tree) == jtu.tree_structure(tree) + + assert isinstance(constrained_tree, Module) + assert isinstance(constrained_tree, Pytree) + + assert isinstance(unconstrained_tree, Module) + assert isinstance(unconstrained_tree, Pytree) + + assert constrained_tree.a == tfb.Identity().forward(1.0) + assert constrained_tree.b == tfb.Softplus().forward(2.0) + assert constrained_tree.c == tfb.Tanh().forward(0.0) + + assert unconstrained_tree.a == tfb.Identity().inverse(1.0) + assert unconstrained_tree.b == tfb.Softplus().inverse(2.0) + assert unconstrained_tree.c == tfb.Tanh().inverse(0.0)