diff --git a/tests/extend_test.py b/tests/extend_test.py index 141b53332fb9..a902b543543e 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -14,7 +14,9 @@ from absl.testing import absltest +import jax import jax.extend as jex +import jax.numpy as jnp from jax._src import abstract_arrays from jax._src import linear_util @@ -26,6 +28,7 @@ class ExtendTest(jtu.JaxTestCase): + def test_symbols(self): # Assume these are tested in random_test.py, only check equivalence self.assertIs(jex.random.PRNGImpl, prng.PRNGImpl) @@ -47,5 +50,33 @@ def test_symbols(self): self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init) +class RandomTest(jtu.JaxTestCase): + + def test_key_make_with_custom_impl(self): + shape = (4, 2, 7) + + def seed_rule(_): + return jnp.ones(shape, dtype=jnp.dtype('uint32')) + + def no_rule(*args, **kwargs): + assert False, 'unreachable' + + impl = jex.random.PRNGImpl(shape, seed_rule, no_rule, no_rule, no_rule) + k = jax.random.key(42, impl=impl) + self.assertEquals(k.shape, ()) + self.assertEquals(impl, jax.random.key_impl(k)._impl) + + def test_key_wrap_with_custom_impl(self): + def no_rule(*args, **kwargs): + assert False, 'unreachable' + + shape = (4, 2, 7) + impl = jex.random.PRNGImpl(shape, no_rule, no_rule, no_rule, no_rule) + data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32')) + k = jax.random.wrap_key_data(data, impl=impl) + self.assertEquals(k.shape, (3,)) + self.assertEquals(impl, jax.random.key_impl(k)._impl) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())