Skip to content

Commit

Permalink
Merge pull request #416 from JaxGaussianProcesses/fix_meta_inheritanc…
Browse files Browse the repository at this point in the history
…e_bug

Fix bug and add test.
  • Loading branch information
daniel-dodd authored Nov 27, 2023
2 parents 615051d + 9ffac5d commit 2d2f451
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
11 changes: 10 additions & 1 deletion gpjax/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,21 @@ def static_field( # noqa: PLR0913
)


def _inherited_metadata(cls: type) -> Dict[str]:
meta_data = dict()
for parent_class in cls.mro():
if parent_class is not cls and parent_class is not Module:
if issubclass(parent_class, Module):
meta_data.update(parent_class._pytree__meta)
return meta_data


class Module(Pytree):
_pytree__meta: Dict[str, Any] = static_field()

def __init_subclass__(cls, mutable: bool = False):
cls._pytree__meta = {}
super().__init_subclass__(mutable=mutable)
cls._pytree__meta = _inherited_metadata(cls)
class_vars = vars(cls)
for field, value in class_vars.items():
if (
Expand Down
63 changes: 63 additions & 0 deletions tests/test_base/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,3 +876,66 @@ class Foo(Module, mutable=True):
# test mutation
pytree.x = 4
assert pytree.x == 4


@pytest.mark.parametrize("is_dataclass", [True, False])
@pytest.mark.parametrize("iterable", [list, tuple])
def test_inheritance_different_meta(is_dataclass, iterable):
class Tree(Module):
a: int = param_field(bijector=tfb.Identity(), default=1)
b: int = param_field(bijector=tfb.Softplus(), default=2)
c: int = param_field(bijector=tfb.Tanh(), default=0, trainable=False)

def __init__(self, a=1.0, b=2.0, c=0.0):
self.a = a
self.b = b
self.c = c

if is_dataclass:
Tree = dataclass(Tree)

class SubTree(Tree):
pass

tree = SubTree()

assert isinstance(tree, Module)
assert isinstance(tree, Pytree)

assert tree.a == 1.0
assert tree.b == 2.0
assert tree.c == 0.0

meta_tree = meta(tree)

assert isinstance(meta_tree, Module)
assert isinstance(meta_tree, Pytree)

assert isinstance(meta_tree.a["bijector"], tfb.Identity)
assert meta_tree.a["trainable"] is True
assert isinstance(meta_tree.b["bijector"], tfb.Softplus)
assert meta_tree.b["trainable"] is True
assert isinstance(meta_tree.c["bijector"], tfb.Tanh)
assert meta_tree.c["trainable"] is False

# Test constrain and unconstrain

constrained_tree = tree.constrain()
unconstrained_tree = tree.unconstrain()

assert jtu.tree_structure(unconstrained_tree) == jtu.tree_structure(tree)
assert jtu.tree_structure(constrained_tree) == jtu.tree_structure(tree)

assert isinstance(constrained_tree, Module)
assert isinstance(constrained_tree, Pytree)

assert isinstance(unconstrained_tree, Module)
assert isinstance(unconstrained_tree, Pytree)

assert constrained_tree.a == tfb.Identity().forward(1.0)
assert constrained_tree.b == tfb.Softplus().forward(2.0)
assert constrained_tree.c == tfb.Tanh().forward(0.0)

assert unconstrained_tree.a == tfb.Identity().inverse(1.0)
assert unconstrained_tree.b == tfb.Softplus().inverse(2.0)
assert unconstrained_tree.c == tfb.Tanh().inverse(0.0)

0 comments on commit 2d2f451

Please sign in to comment.