Skip to content

nnx using vmap to create multiple models #4048

Answered by cgarciae
JeyRunner asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @JeyRunner! nnx.vmap has some support splitting RNG state automatically, try:

def make_model(rngs):
  return nnx.Linear(2, 3, rngs=rngs)

rngs = nnx.Rngs(0)
model = nnx.vmap(make_model, split_rngs=True, axis_size=5)(rngs)

print(model)

Output:

Linear(
  bias=Param(
    value=Array(shape=(5, 3), dtype=float32)
  ),
  bias_init=<function zeros at 0x11ee95f30>,
  dot_general=<function dot_general at 0x11e933910>,
  dtype=None,
  in_features=2,
  kernel=Param(
    value=Array(shape=(5, 2, 3), dtype=float32)
  ),
  kernel_init=<function variance_scaling.<locals>.init at 0x11fa8fe20>,
  out_features=3,
  param_dtype=<class 'jax.numpy.float32'>,
  precision=None,
  use_bias=True
)

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@JeyRunner
Comment options

@JeyRunner
Comment options

@cgarciae
Comment options

Answer selected by JeyRunner
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants