Skip to content

Is nnx.vmap a replacement for jax.vmap? Getting different behavior #4143

Answered by cgarciae
jlperla asked this question in General
Discussion options

You must be logged in to vote

Am I misunderstanding this function? I am not entirely sure I need to use it in the circumstances, but I am having trouble forming a mental model of why these are different calls?

I'm very happy we made the change to JAX-style transforms, starting from 0.9.0 the mental model for Flax transforms is the same as for JAX transforms which should help users a lot.

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@cgarciae
Comment options

Answer selected by jlperla
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants