diff --git a/tests/test_pytree.py b/tests/test_pytree.py index 46a2a06..77c9bb5 100644 --- a/tests/test_pytree.py +++ b/tests/test_pytree.py @@ -55,6 +55,20 @@ class Foo(Pytree): with pytest.raises(AttributeError, match="cannot assign to field"): pytree.x = 4 + def test_jit(self): + @dataclasses.dataclass + class Foo(Pytree): + a: int + b: int = static_field() + + module = Foo(a=1, b=2) + + @jax.jit + def f(m: Foo): + return m.a + m.b + + assert f(module) == 3 + class TestMutablePytree: def test_pytree(self):