Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR address several issues:
The existing RNN layer is not training properly due the usage of a fresh StatelessScope in the jax.lax.scan loop. This is causing all the trainable variables to miss the mapping to the actual value in the training loop. Update them to use the parent Stateless scope if it is there. This will address the training issue Example
nlp/lstm_seq2seq.py
doesn't train with JAX backend #322The RNN layers with dropout will have a RNG seed update in the step function, which is not allowed by the jax.lax.scan. We noticed this issue since the updated seed is traced for non-trainable variable, and raise error when we try to put sharding constraint for distribution. Added a new method to pre-populate the dropout mask on the layer and make the inner_loop to be stateless.
During the unit test, I noticed the stackRNNCell doesn't work with existing RNNCell, since it unwrap the list for the state, make the call function to keep the list if the input state is a list.
Expose the SimpleRNN|GRU|LSTM cells in the init.py since they are public API.