Skip to content

davisyoshida/jax_single_use_rng

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

jax_single_use_rng

A simple wrapper I made after the 73rd time I re-used an RNG key on accident.

Installation

pip install -e .

Usage

from jax_safe_prng import SafePRNGKey

rng = SafePRNGKey(12345)
jax.random.Uniform(rng.key) # Works
jax.random.Uniform(rng.key) # Error

rng1, rng2 = rng.split() # New SafePRNGKey instances

About

Simple wrapper to make RNG re-use bugs less likely

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages