Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNNs redesign #2500

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open

RNNs redesign #2500

wants to merge 11 commits into from

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Oct 14, 2024

A complete rework of our recurrent layers, making them more similar to their pytorch counterpart.
This is in line with the proposal in #1365 and should allow to hook into the cuDNN machinery (future PR).
Hopefully, this ends the infinite source of troubles that the recurrent layers have been.

  • Recur is no more. Mutating its internal state was a source of problems for AD (explicit differentiation for RNN gives wrong results #2185)
  • Now RNNCell is exported and takes care of the minimal recursion step, i.e. a single time:
    • has forward cell(x , h)
    • x can be of size in or in x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out or out x batch_size
  • RNN instead takes in a (batched) sequence and a (batched) hidden state and returns the hidden state for the whole sequence:
    • has forward rnn(x, h)
    • x can be of size in x len or in x len x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out x len or out x len x batch_size
  • LSTM and GRU are similarly changed.

Close #2185, close #2341, close #2258, close #1547, close #807, close #1329

Related to #1678

PR Checklist

  • cpu tests
  • gpu tests
  • if hidden state not given as input, assumed to be zero
  • port LSTM and GRU
  • Entry in NEWS.md
  • Remove reset!
  • Docstrings
  • Benchmarks
  • use cuDNN (future PR)
  • implement the num_layers argument for stacked RNNs (future PR)
  • revisit whole documentation (future PR)
  • add dropout (future PR)

@CarloLucibello CarloLucibello changed the title RNN redesign RNNs redesign Oct 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment