-
Hi, I want to create multiple instances of a nnx.module (each initialized with a different key). def make(rng):
m = my_module.init(rng, dummy_input)
return ...
rngs = jax.random.split(jax.random.PRNGKey(0), num=5)
models = jax.vmap(make)(rngs) How can I achieve the same with nnx? def make_model(rngs):
return nnx.Sequential(
nnx.Linear(..., rngs=rngs),
...
)
init_keys = jax.random.split(jax.random.PRNGKey(0), num=5)
rngs = nnx.Rngs(init_keys)
model = jax.vmap(task.make_model)(rngs) But I get the error |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Jul 3, 2024
Replies: 1 comment 3 replies
-
Hey @JeyRunner! def make_model(rngs):
return nnx.Linear(2, 3, rngs=rngs)
rngs = nnx.Rngs(0)
model = nnx.vmap(make_model, split_rngs=True, axis_size=5)(rngs)
print(model) Output: Linear(
bias=Param(
value=Array(shape=(5, 3), dtype=float32)
),
bias_init=<function zeros at 0x11ee95f30>,
dot_general=<function dot_general at 0x11e933910>,
dtype=None,
in_features=2,
kernel=Param(
value=Array(shape=(5, 2, 3), dtype=float32)
),
kernel_init=<function variance_scaling.<locals>.init at 0x11fa8fe20>,
out_features=3,
param_dtype=<class 'jax.numpy.float32'>,
precision=None,
use_bias=True
) |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
JeyRunner
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey @JeyRunner!
nnx.vmap
has some support splitting RNG state automatically, try:Output: