You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey, in NNX model random state is just a regular type of state, Rngs internally holds RngKey and RngCount which are subtypes of RngState(Variable). Starting from flax>=0.9.0 NNX doesn't treat random state in a special way (see JAX-style NNX Transforms), to implement RNG state handling you can either split the rng keys passed to Rngs or use the new nnx.split_rngs API (easier):
split_rngs will temporarily lift/split the RngState and lower it afterwards. Note that instead of using None to broadcast the model state, you have to use StateAxes to specify that you want to vectorize the RngState on axis 0 and broadcast all other state.
Hi,
If I run something like:
the random key that is used across the vectorization is unique. It means that stochastic functions will have the same behavior across the batch.
Thanks!
The text was updated successfully, but these errors were encountered: