Skip to content

Commit

Permalink
Merge pull request #1 from cgarciae/add-jit-test
Browse files Browse the repository at this point in the history
Add jit test
  • Loading branch information
cgarciae authored Feb 28, 2023
2 parents 20d65ba + 29b7f43 commit c9d327f
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ class Foo(Pytree):
with pytest.raises(AttributeError, match="cannot assign to field"):
pytree.x = 4

def test_jit(self):
@dataclasses.dataclass
class Foo(Pytree):
a: int
b: int = static_field()

module = Foo(a=1, b=2)

@jax.jit
def f(m: Foo):
return m.a + m.b

assert f(module) == 3


class TestMutablePytree:
def test_pytree(self):
Expand Down

0 comments on commit c9d327f

Please sign in to comment.