Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN Module #2604

Merged
merged 1 commit into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api_reference/flax.linen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,5 @@ RNN primitives
LSTMCell
OptimizedLSTMCell
GRUCell
RNNCellBase
RNN
126 changes: 23 additions & 103 deletions examples/seq2seq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,78 +25,19 @@
import jax.numpy as jnp
import numpy as np

Array = Any
PRNGKey = Any
Array = jax.Array
PRNGKey = jax.random.KeyArray


class EncoderLSTM(nn.Module):
"""EncoderLSTM Module wrapped in a lifted scan transform."""
eos_id: int

@functools.partial(
nn.scan,
variable_broadcast='params',
in_axes=1,
out_axes=1,
split_rngs={'params': False})
@nn.compact
def __call__(self, carry: Tuple[Array, Array],
x: Array) -> Tuple[Tuple[Array, Array], Array]:
"""Applies the module."""
lstm_state, is_eos = carry
new_lstm_state, y = nn.LSTMCell()(lstm_state, x)
# Pass forward the previous state if EOS has already been reached.
def select_carried_state(new_state, old_state):
return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
# LSTM state is a tuple (c, h).
carried_lstm_state = tuple(
select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
# Update `is_eos`.
is_eos = jnp.logical_or(is_eos, x[:, self.eos_id])
return (carried_lstm_state, is_eos), y

@staticmethod
def initialize_carry(batch_size: int, hidden_size: int):
# Use a dummy key since the default state init fn is just zeros.
return nn.LSTMCell.initialize_carry(
jax.random.PRNGKey(0), (batch_size,), hidden_size)


class Encoder(nn.Module):
"""LSTM encoder, returning state after finding the EOS token in the input."""
hidden_size: int
eos_id: int

@nn.compact
def __call__(self, inputs: Array):
# inputs.shape = (batch_size, seq_length, vocab_size).
batch_size = inputs.shape[0]
lstm = EncoderLSTM(name='encoder_lstm', eos_id=self.eos_id)
init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size)
# We use the `is_eos` array to determine whether the encoder should carry
# over the last lstm state, or apply the LSTM cell on the previous state.
init_is_eos = jnp.zeros(batch_size, dtype=bool)
init_carry = (init_lstm_state, init_is_eos)
(final_state, _), _ = lstm(init_carry, inputs)
return final_state


class DecoderLSTM(nn.Module):
class DecoderLSTMCell(nn.RNNCellBase):
"""DecoderLSTM Module wrapped in a lifted scan transform.

Attributes:
teacher_force: See docstring on Seq2seq module.
vocab_size: Size of the vocabulary.
"""
teacher_force: bool
vocab_size: int

@functools.partial(
nn.scan,
variable_broadcast='params',
in_axes=1,
out_axes=1,
split_rngs={'params': False, 'lstm': True})
@nn.compact
def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
"""Applies the DecoderLSTM model."""
Expand All @@ -116,40 +57,6 @@ def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array:
return (lstm_state, prediction), (logits, prediction)


class Decoder(nn.Module):
"""LSTM decoder.

Attributes:
init_state: [batch_size, hidden_size]
Initial state of the decoder (i.e., the final state of the encoder).
teacher_force: See docstring on Seq2seq module.
vocab_size: Size of the vocabulary.
"""
init_state: Tuple[Any]
teacher_force: bool
vocab_size: int

@nn.compact
def __call__(self, inputs: Array) -> Tuple[Array, Array]:
"""Applies the decoder model.

Args:
inputs: [batch_size, max_output_len-1, vocab_size]
Contains the inputs to the decoder at each time step (only used when not
using teacher forcing). Since each token at position i is fed as input
to the decoder at position i+1, the last token is not provided.

Returns:
Pair (logits, predictions), which are two arrays of respectively decoded
logits and predictions (in one hot-encoding format).
"""
lstm = DecoderLSTM(teacher_force=self.teacher_force,
vocab_size=self.vocab_size)
init_carry = (self.init_state, inputs[:, 0])
_, (logits, predictions) = lstm(init_carry, inputs)
return logits, predictions


class Seq2seq(nn.Module):
"""Sequence-to-sequence class using encoder/decoder architecture.

Expand Down Expand Up @@ -189,12 +96,25 @@ def __call__(self, encoder_inputs: Array,
encoding format).
"""
# Encode inputs.
init_decoder_state = Encoder(
hidden_size=self.hidden_size, eos_id=self.eos_id)(encoder_inputs)
# Decode outputs.
logits, predictions = Decoder(
init_state=init_decoder_state,
teacher_force=self.teacher_force,
vocab_size=self.vocab_size)(decoder_inputs[:, :-1])
encoder = nn.RNN(nn.LSTMCell(), self.hidden_size, return_carry=True, name='encoder')
marcvanzee marked this conversation as resolved.
Show resolved Hide resolved
decoder = nn.RNN(DecoderLSTMCell(self.teacher_force, self.vocab_size), decoder_inputs.shape[-1],
split_rngs={'params': False, 'lstm': True}, name='decoder')

segmentation_mask = self.get_segmentation_mask(encoder_inputs)

encoder_state, _ = encoder(encoder_inputs, segmentation_mask=segmentation_mask)
logits, predictions = decoder(decoder_inputs[:, :-1], initial_carry=(encoder_state, decoder_inputs[:, 0]))

return logits, predictions

def get_segmentation_mask(self, inputs: Array) -> Array:
"""Get segmentation mask for inputs."""
# undo one-hot encoding
inputs = jnp.argmax(inputs, axis=-1)
# calculate eos index
eos_idx = jnp.argmax(inputs == self.eos_id, axis=-1, keepdims=True)
# create index array
indexes = jnp.arange(inputs.shape[1])
indexes = jnp.broadcast_to(indexes, inputs.shape[:2])
# return mask
return indexes < eos_idx
4 changes: 3 additions & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@
ConvLSTMCell as ConvLSTMCell,
GRUCell as GRUCell,
LSTMCell as LSTMCell,
OptimizedLSTMCell as OptimizedLSTMCell
OptimizedLSTMCell as OptimizedLSTMCell,
RNNCellBase as RNNCellBase,
RNN as RNN,
)
from .stochastic import Dropout as Dropout
from .transforms import (
Expand Down
Loading