Skip to content

How to implement recurrent dropout with NNX ? #4397

Answered by cgarciae
sgkouzias asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @sgkouzias, you can implement recurrent dropout by selecting a rng_collection different to 'dropout' in Dropout constructor e.g. 'recurrent_dropout', and then broadcasting state for that RNG stream during scan (this is done internally by RNN). However, current RNN API is missing some options so I've created #4407 to address this. Here's a demo of the working solution:

class LSTMWithRecurrentDropout(nnx.OptimizedLSTMCell):
  def __init__(
    self,
    *,
    rngs: nnx.Rngs,
    in_features: int,
    hidden_features: int,
    dropout_rate: float,
    **kwargs,
  ):
    super().__init__(
      in_features=in_features,
      hidden_features=hidden_features,
      rngs=rngs,
      **kwargs,…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@sgkouzias
Comment options

@cgarciae
Comment options

@sgkouzias
Comment options

Answer selected by sgkouzias
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants