Skip to content

Commit

Permalink
test custom PRNG impl construction round trip
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Oct 6, 2023
1 parent 2052673 commit 4e03faa
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from absl.testing import absltest

import jax
import jax.extend as jex
import jax.numpy as jnp

from jax._src import abstract_arrays
from jax._src import linear_util
Expand All @@ -26,6 +28,7 @@


class ExtendTest(jtu.JaxTestCase):

def test_symbols(self):
# Assume these are tested in random_test.py, only check equivalence
self.assertIs(jex.random.PRNGImpl, prng.PRNGImpl)
Expand All @@ -47,5 +50,33 @@ def test_symbols(self):
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)


class RandomTest(jtu.JaxTestCase):

def test_key_make_with_custom_impl(self):
shape = (4, 2, 7)

def seed_rule(_):
return jnp.ones(shape, dtype=jnp.dtype('uint32'))

def no_rule(*args, **kwargs):
assert False, 'unreachable'

impl = jex.random.PRNGImpl(shape, seed_rule, no_rule, no_rule, no_rule)
k = jax.random.key(42, impl=impl)
self.assertEquals(k.shape, ())
self.assertEquals(impl, jax.random.key_impl(k)._impl)

def test_key_wrap_with_custom_impl(self):
def no_rule(*args, **kwargs):
assert False, 'unreachable'

shape = (4, 2, 7)
impl = jex.random.PRNGImpl(shape, no_rule, no_rule, no_rule, no_rule)
data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32'))
k = jax.random.wrap_key_data(data, impl=impl)
self.assertEquals(k.shape, (3,))
self.assertEquals(impl, jax.random.key_impl(k)._impl)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 4e03faa

Please sign in to comment.