Skip to content

Commit

Permalink
Make random wrapping more robust to changes in jax.random (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann authored Oct 13, 2021
1 parent b7fe608 commit 6d9f135
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions scico/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,16 @@ def _wrap(fun):
return fun_wrapped


exceptions = [ # these do not take key and shape
"PRNGKey",
"double_sided_maxwell",
"fold_in",
"permutation",
"shuffle",
"split",
"weibull_min",
"threefry_2x32",
]
def _is_wrappable(fun):
params = inspect.signature(getattr(jax.random, fun)).parameters
return list(params.keys())[0] == "key" and "shape" in params.keys()


func_names = [
t[0] for t in inspect.getmembers(jax.random, inspect.isfunction) if t[0] not in exceptions
wrappable_func_names = [
t[0] for t in inspect.getmembers(jax.random, inspect.isfunction) if _is_wrappable(t[0])
]

for name in func_names:
for name in wrappable_func_names:
setattr(sys.modules[__name__], name, _wrap(getattr(jax.random, name)))


Expand Down

0 comments on commit 6d9f135

Please sign in to comment.