Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test jax.extend custom PRNG construction round trip #17988

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from absl.testing import absltest

import jax
import jax.extend as jex
import jax.numpy as jnp

from jax._src import abstract_arrays
from jax._src import linear_util
Expand All @@ -26,6 +28,7 @@


class ExtendTest(jtu.JaxTestCase):

def test_symbols(self):
# Assume these are tested in random_test.py, only check equivalence
self.assertIs(jex.random.PRNGImpl, prng.PRNGImpl)
Expand All @@ -47,5 +50,33 @@ def test_symbols(self):
self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init)


class RandomTest(jtu.JaxTestCase):

def test_key_make_with_custom_impl(self):
shape = (4, 2, 7)

def seed_rule(_):
return jnp.ones(shape, dtype=jnp.dtype('uint32'))

def no_rule(*args, **kwargs):
assert False, 'unreachable'

impl = jex.random.PRNGImpl(shape, seed_rule, no_rule, no_rule, no_rule)
k = jax.random.key(42, impl=impl)
self.assertEqual(k.shape, ())
self.assertEqual(impl, jax.random.key_impl(k)._impl)

def test_key_wrap_with_custom_impl(self):
def no_rule(*args, **kwargs):
assert False, 'unreachable'

shape = (4, 2, 7)
impl = jex.random.PRNGImpl(shape, no_rule, no_rule, no_rule, no_rule)
data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32'))
k = jax.random.wrap_key_data(data, impl=impl)
self.assertEqual(k.shape, (3,))
self.assertEqual(impl, jax.random.key_impl(k)._impl)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())