Skip to content

Commit

Permalink
Add test to verify that we get the same distribution under flatten/un…
Browse files Browse the repository at this point in the history
…flatten logic (#1510)
  • Loading branch information
fehiepsi authored Dec 17, 2022
1 parent ac0e073 commit fd3e3c2
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,15 @@ def f(x):
pytest.skip("EulerMaruyama doesn't define flatten/unflatten")
jax.jit(f)(0) # this test for flatten/unflatten
lax.map(f, np.ones(3)) # this test for compatibility w.r.t. scan
# Test that parameters do not change after flattening.
expected_dist = f(0)
actual_dist = jax.jit(f)(0)
expected_sample = expected_dist.sample(random.PRNGKey(0))
actual_sample = actual_dist.sample(random.PRNGKey(0))
expected_log_prob = expected_dist.log_prob(expected_sample)
actual_log_prob = actual_dist.log_prob(actual_sample)
assert_allclose(actual_sample, expected_sample, rtol=1e-6)
assert_allclose(actual_log_prob, expected_log_prob, rtol=2e-6)


@pytest.mark.parametrize(
Expand Down

0 comments on commit fd3e3c2

Please sign in to comment.