diff --git a/docs/api_reference/flax.linen.rst b/docs/api_reference/flax.linen.rst index 0fe78e2a6d..6fa2731b84 100644 --- a/docs/api_reference/flax.linen.rst +++ b/docs/api_reference/flax.linen.rst @@ -268,3 +268,5 @@ RNN primitives LSTMCell OptimizedLSTMCell GRUCell + RNNCellBase + RNN diff --git a/examples/seq2seq/models.py b/examples/seq2seq/models.py index 28a67000cd..67c9c4f67d 100644 --- a/examples/seq2seq/models.py +++ b/examples/seq2seq/models.py @@ -25,65 +25,12 @@ import jax.numpy as jnp import numpy as np -Array = Any -PRNGKey = Any +Array = jax.Array +PRNGKey = jax.random.KeyArray -class EncoderLSTM(nn.Module): - """EncoderLSTM Module wrapped in a lifted scan transform.""" - eos_id: int - - @functools.partial( - nn.scan, - variable_broadcast='params', - in_axes=1, - out_axes=1, - split_rngs={'params': False}) - @nn.compact - def __call__(self, carry: Tuple[Array, Array], - x: Array) -> Tuple[Tuple[Array, Array], Array]: - """Applies the module.""" - lstm_state, is_eos = carry - new_lstm_state, y = nn.LSTMCell()(lstm_state, x) - # Pass forward the previous state if EOS has already been reached. - def select_carried_state(new_state, old_state): - return jnp.where(is_eos[:, np.newaxis], old_state, new_state) - # LSTM state is a tuple (c, h). - carried_lstm_state = tuple( - select_carried_state(*s) for s in zip(new_lstm_state, lstm_state)) - # Update `is_eos`. - is_eos = jnp.logical_or(is_eos, x[:, self.eos_id]) - return (carried_lstm_state, is_eos), y - - @staticmethod - def initialize_carry(batch_size: int, hidden_size: int): - # Use a dummy key since the default state init fn is just zeros. - return nn.LSTMCell.initialize_carry( - jax.random.PRNGKey(0), (batch_size,), hidden_size) - - -class Encoder(nn.Module): - """LSTM encoder, returning state after finding the EOS token in the input.""" - hidden_size: int - eos_id: int - - @nn.compact - def __call__(self, inputs: Array): - # inputs.shape = (batch_size, seq_length, vocab_size). - batch_size = inputs.shape[0] - lstm = EncoderLSTM(name='encoder_lstm', eos_id=self.eos_id) - init_lstm_state = lstm.initialize_carry(batch_size, self.hidden_size) - # We use the `is_eos` array to determine whether the encoder should carry - # over the last lstm state, or apply the LSTM cell on the previous state. - init_is_eos = jnp.zeros(batch_size, dtype=bool) - init_carry = (init_lstm_state, init_is_eos) - (final_state, _), _ = lstm(init_carry, inputs) - return final_state - - -class DecoderLSTM(nn.Module): +class DecoderLSTMCell(nn.RNNCellBase): """DecoderLSTM Module wrapped in a lifted scan transform. - Attributes: teacher_force: See docstring on Seq2seq module. vocab_size: Size of the vocabulary. @@ -91,12 +38,6 @@ class DecoderLSTM(nn.Module): teacher_force: bool vocab_size: int - @functools.partial( - nn.scan, - variable_broadcast='params', - in_axes=1, - out_axes=1, - split_rngs={'params': False, 'lstm': True}) @nn.compact def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: """Applies the DecoderLSTM model.""" @@ -116,40 +57,6 @@ def __call__(self, carry: Tuple[Array, Array], x: Array) -> Array: return (lstm_state, prediction), (logits, prediction) -class Decoder(nn.Module): - """LSTM decoder. - - Attributes: - init_state: [batch_size, hidden_size] - Initial state of the decoder (i.e., the final state of the encoder). - teacher_force: See docstring on Seq2seq module. - vocab_size: Size of the vocabulary. - """ - init_state: Tuple[Any] - teacher_force: bool - vocab_size: int - - @nn.compact - def __call__(self, inputs: Array) -> Tuple[Array, Array]: - """Applies the decoder model. - - Args: - inputs: [batch_size, max_output_len-1, vocab_size] - Contains the inputs to the decoder at each time step (only used when not - using teacher forcing). Since each token at position i is fed as input - to the decoder at position i+1, the last token is not provided. - - Returns: - Pair (logits, predictions), which are two arrays of respectively decoded - logits and predictions (in one hot-encoding format). - """ - lstm = DecoderLSTM(teacher_force=self.teacher_force, - vocab_size=self.vocab_size) - init_carry = (self.init_state, inputs[:, 0]) - _, (logits, predictions) = lstm(init_carry, inputs) - return logits, predictions - - class Seq2seq(nn.Module): """Sequence-to-sequence class using encoder/decoder architecture. @@ -189,12 +96,25 @@ def __call__(self, encoder_inputs: Array, encoding format). """ # Encode inputs. - init_decoder_state = Encoder( - hidden_size=self.hidden_size, eos_id=self.eos_id)(encoder_inputs) - # Decode outputs. - logits, predictions = Decoder( - init_state=init_decoder_state, - teacher_force=self.teacher_force, - vocab_size=self.vocab_size)(decoder_inputs[:, :-1]) + encoder = nn.RNN(nn.LSTMCell(), self.hidden_size, return_carry=True, name='encoder') + decoder = nn.RNN(DecoderLSTMCell(self.teacher_force, self.vocab_size), decoder_inputs.shape[-1], + split_rngs={'params': False, 'lstm': True}, name='decoder') + + segmentation_mask = self.get_segmentation_mask(encoder_inputs) + + encoder_state, _ = encoder(encoder_inputs, segmentation_mask=segmentation_mask) + logits, predictions = decoder(decoder_inputs[:, :-1], initial_carry=(encoder_state, decoder_inputs[:, 0])) return logits, predictions + + def get_segmentation_mask(self, inputs: Array) -> Array: + """Get segmentation mask for inputs.""" + # undo one-hot encoding + inputs = jnp.argmax(inputs, axis=-1) + # calculate eos index + eos_idx = jnp.argmax(inputs == self.eos_id, axis=-1, keepdims=True) + # create index array + indexes = jnp.arange(inputs.shape[1]) + indexes = jnp.broadcast_to(indexes, inputs.shape[:2]) + # return mask + return indexes < eos_idx diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index dd03d25604..44a04cc993 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -120,7 +120,9 @@ ConvLSTMCell as ConvLSTMCell, GRUCell as GRUCell, LSTMCell as LSTMCell, - OptimizedLSTMCell as OptimizedLSTMCell + OptimizedLSTMCell as OptimizedLSTMCell, + RNNCellBase as RNNCellBase, + RNN as RNN, ) from .stochastic import Dropout as Dropout from .transforms import ( diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 277c833fde..810722b767 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -18,9 +18,12 @@ see: https://flax.readthedocs.io/en/latest/advanced_topics/lift.html. """ +from abc import abstractmethod import abc from functools import partial # pylint: disable=g-importing-member -from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union, TypeVar +from typing_extensions import Protocol +from absl import logging from flax.linen.activation import sigmoid from flax.linen.activation import tanh @@ -35,11 +38,19 @@ from jax import numpy as jnp from jax import random import numpy as np +from flax.core import lift +from flax.core.frozen_dict import FrozenDict +from flax.linen import transforms +import jax +A = TypeVar('A') PRNGKey = Any Shape = Tuple[int, ...] Dtype = Any # this could be a real type? Array = Any +Carry = Any +CarryHistory = Any +Output = Any class RNNCellBase(Module): @@ -498,3 +509,229 @@ def initialize_carry(rng, batch_dims, size, init_fn=initializers.zeros_init()): key1, key2 = random.split(rng) mem_shape = batch_dims + size return init_fn(key1, mem_shape), init_fn(key2, mem_shape) + +class RNN(Module): + """The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence + using :func:`flax.linen.scan`. + + Example:: + + >>> import jax.numpy as jnp + >>> import jax + >>> import flax.linen as nn + ... + >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) + >>> lstm = nn.RNN(nn.LSTMCell(), cell_size=64) + >>> variables = lstm.init(jax.random.PRNGKey(0), x) + >>> y = lstm.apply(variables, x) + >>> y.shape # (batch, time, cell_size) + (10, 50, 64) + + As shown above, RNN uses the ``cell_size`` argument to set the ``size`` argument for the cell's + ``initialize_carry`` method, in practice this is typically the number of hidden units you want + for the cell. However, this may vary depending on the cell you are using, for example the + :class:`ConvLSTMCell` requires a ``size`` argument of the form + ``(kernel_height, kernel_width, features)``:: + + >>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features) + >>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3)), cell_size=(32, 32, 64)) + >>> y, variables = conv_lstm.init_with_output(jax.random.PRNGKey(0), x) + >>> y.shape # (batch, time, height, width, features) + (10, 50, 32, 32, 64) + + By default RNN expect the time dimension after the batch dimension (``(*batch, time, *features)``), + if you set ``time_major=True`` RNN will instead expect the time dimesion to be at the beginning + (``(time, *batch, *features)``):: + + >>> x = jnp.ones((50, 10, 32)) # (time, batch, features) + >>> lstm = nn.RNN(nn.LSTMCell(), cell_size=64, time_major=True) + >>> variables = lstm.init(jax.random.PRNGKey(0), x) + >>> y = lstm.apply(variables, x) + >>> y.shape # (time, batch, cell_size) + (50, 10, 64) + + The output is an array of shape ``(*batch, time, *cell_size)`` by default (typically), however + if you set ``return_carry=True`` it will instead return a tuple of the final carry and the output:: + + >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) + >>> lstm = nn.RNN(nn.LSTMCell(), cell_size=64, return_carry=True) + >>> variables = lstm.init(jax.random.PRNGKey(0), x) + >>> carry, y = lstm.apply(variables, x) + >>> jax.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size)) + ((10, 64), (10, 64)) + >>> y.shape # (batch, time, cell_size) + (10, 50, 64) + + To support variable length sequences, you can pass a ``segmentation_mask`` which is an integer + array of shape ``(*batch, time)``, where a 1 indicates the element is part of the sequence and a 0 indicates + a padding element. Sequences must be padded to the right, i.e. all elements of a sequence must be + contiguous and padded elements must be to the right of the sequence. For example:: + + >>> # 3 sequences with max length 5 + >>> segmentation_mask = jnp.array([ + ... [1, 1, 1, 0, 0], # length 3 + ... [1, 1, 0, 0, 0], # length 2 + ... [1, 1, 1, 1, 1], # length 5 + ... ]) + + We use this integer mask format because its compatible with sequence packing which might get + implemented in the future. The output elements corresponding to padding elements are NOT + zeroed out. If ``return_carry`` is set to ``True`` the carry will be the state of the last + valid element of each sequence. + + RNN also accepts some of the arguments of :func:`flax.linen.scan`, by default they are set to + work with cells like :class:`LSTMCell` and :class:`GRUCell` but they can be overriden as needed. + Overriding default values to scan looks like this:: + + >>> lstm = nn.RNN( + ... nn.LSTMCell(), cell_size=64, + ... unroll=1, variable_axes={}, variable_broadcast='params', + ... variable_carry=False, split_rngs={'params': False}) + + Attributes: + cell: an instance of :class:`RNNCellBase`. + cell_size: the size of the cell as requested by :meth:`RNNCellBase.initialize_carry`, + it can be an integer or a tuple of integers. + time_major: if ``time_major=False`` (default) it will expect inputs with shape + ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. + return_carry: if ``return_carry=False`` (default) only the output sequence is returned, + else it will return a tuple of the final carry and the output sequence. + unroll: how many scan iterations to unroll within a single iteration of a loop, + defaults to 1. This argument will be passed to `nn.scan`. + variable_axes: a dictionary mapping each collection to either an integer `i` (meaning we scan over + dimension `i`) or `None` (replicate rather than scan). This argument is forwarded to `nn.scan`. + variable_broadcast: Specifies the broadcasted variable collections. A + broadcasted variable should not depend on any computation that cannot be + lifted out of the loop. This is typically used to define shared parameters + inside the fn. This argument is forwarded to `nn.scan`. + variable_carry: Specifies the variable collections that are carried through + the loop. Mutations to these variables are carried to the next iteration + and will be preserved when the scan finishes. This argument is forwarded to + `nn.scan`. + split_rngs: a mapping from PRNGSequenceFilter to bool specifying whether a collection's + PRNG key should be split such that its values are different at each step, or replicated + such that its values remain the same at each step. This argument is forwarded to `nn.scan`. + """ + cell: RNNCellBase + cell_size: Union[int, Tuple[int, ...]] + time_major: bool = False + return_carry: bool = False + unroll: int = 1 + variable_axes: Mapping[lift.CollectionFilter,lift.InOutScanAxis] = FrozenDict() + variable_broadcast: lift.CollectionFilter = 'params' + variable_carry: lift.CollectionFilter = False + split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict({'params': False}) + + def __call__( + self, + inputs: jax.Array, + *, + initial_carry: Optional[Carry] = None, + init_key: Optional[random.KeyArray] = None, + segmentation_mask: Optional[Array] = None, + return_carry: Optional[bool] = None, + time_major: Optional[bool] = None + ) -> Union[Output, Tuple[Carry, Output]]: + """ + Applies the RNN to the inputs. + + ``__call__`` allows you to optionally override some attributes like ``return_carry`` + and ``time_major`` defined in the constructor. + + Arguments: + inputs: the input sequence. + initial_carry: the initial carry, if not provided it will be initialized + using the cell's :meth:`RNNCellBase.initialize_carry` method. + init_key: a PRNG key used to initialize the carry, if not provided + ``jax.random.PRNGKey(0)`` will be used. Most cells will ignore this + argument. + segmentation_mask: an integer array of shape ``(*batch, time)`` indicating + which elements are part of the sequence and which are padding elements. + return_carry: if ``return_carry=False`` (default) only the output sequence is returned, + else it will return a tuple of the final carry and the output sequence. + time_major: if ``time_major=False`` (default) it will expect inputs with shape + ``(*batch, time, *features)``, else it will expect inputs with shape ``(time, *batch, *features)``. + Returns: + if ``return_carry=False`` (default) only the output sequence is returned, + else it will return a tuple of the final carry and the output sequence. + """ + + if return_carry is None: + return_carry = self.return_carry + + if time_major is None: + time_major = self.time_major + + # Infer the number of batch dimensions from the input shape. + # Cells like ConvLSTM have additional spatial dimensions. + num_features_dims = 1 if isinstance(self.cell_size, int) else len(self.cell_size) + time_axis = 0 if time_major else inputs.ndim - num_features_dims - 1 + if time_major: + batch_dims = inputs.shape[1:-num_features_dims] + else: + batch_dims = inputs.shape[:time_axis] + + carry: Carry + if initial_carry is None: + if init_key is None: + init_key = random.PRNGKey(0) + carry = self.cell.initialize_carry( + init_key, batch_dims=batch_dims, size=self.cell_size) + else: + carry = initial_carry + + def scan_fn( + cell: RNNCellBase, carry: Carry, x: Array + ) -> Union[Tuple[Carry, Array], Tuple[Carry, Tuple[Carry, Array]]]: + carry, y = cell(carry, x) + # When we have a segmentation mask we return the carry as an output + # so that we can select the last carry for each sequence later. + # This uses more memory but is faster than using jnp.where at each + # iteration. As a small optimization do this when we really need it. + if segmentation_mask is not None and return_carry: + return carry, (carry, y) + else: + return carry, y + + scan = transforms.scan( + scan_fn, + in_axes=time_axis, + out_axes=time_axis if segmentation_mask is None else (0, time_axis), + unroll=self.unroll, + variable_axes=self.variable_axes, + variable_broadcast=self.variable_broadcast, + variable_carry=self.variable_carry, + split_rngs=self.split_rngs, + ) + + scan_output = scan(self.cell, carry, inputs) + + # Next we select the final carry. If a segmentation mask was provided and + # return_carry is True we slice the carry history and select the last valid + # carry for each sequence. Otherwise we just use the last carry. + if segmentation_mask is not None and return_carry: + _, (carries, outputs) = scan_output + # segmentation_mask[None] expands the shape of the mask to match the + # number of dimensions of the carry. + carry = _select_last(carries, segmentation_mask[None], axis=0) + else: + carry, outputs = scan_output + + if return_carry: + return carry, outputs + else: + return outputs + +def _select_last(sequence: A, segmentation_mask: jnp.ndarray, axis: int) -> A: + last_idx = segmentation_mask.sum(axis=-1) - 1 + + def _slice_array(x: jnp.ndarray): + _last_idx = _expand_dims_for(last_idx, to_target=x) + x = jnp.take_along_axis(x, _last_idx, axis=axis) + return x.squeeze(axis=axis) + + return jax.tree_map(_slice_array, sequence) + +def _expand_dims_for(x, *, to_target): + """Expands the shape of 'x' to match those of 'to_target' by adding singleton dimensions.""" + return x.reshape(x.shape + (1,) * (to_target.ndim - x.ndim)) \ No newline at end of file diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py new file mode 100644 index 0000000000..f4cd397b6d --- /dev/null +++ b/tests/linen/linen_recurrent_test.py @@ -0,0 +1,321 @@ + +# Copyright 2022 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Recurrent tests.""" + + +from absl.testing import absltest +import jax +import jax.numpy as jnp +import numpy as np +from flax import errors +from flax import linen as nn +import pytest +import einops +from flax.linen.recurrent import _select_last + +# Parse absl flags test_srcdir and test_tmpdir. +jax.config.parse_flags_with_absl() + + +class RNNTest(absltest.TestCase): + def test_rnn_basic_forward(self): + batch_size = 10 + seq_len = 40 + channels_in = 5 + channels_out = 15 + + rnn = nn.RNN(nn.LSTMCell(), channels_out, return_carry=True) + + xs = jnp.ones((batch_size, seq_len, channels_in)) + variables = rnn.init(jax.random.PRNGKey(0), xs) + ys: jnp.ndarray + carry, ys = rnn.apply(variables, xs) + + self.assertEqual(ys.shape, (batch_size, seq_len, channels_out)) + + for layer_params in variables['params']['cell'].values(): + if 'bias' in layer_params: + self.assertEqual(layer_params['bias'].shape, (channels_out,)) + self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertEqual(layer_params['kernel'].shape[1], channels_out) + + def test_rnn_multiple_batch_dims(self): + batch_dims = (10, 11) + seq_len = 40 + channels_in = 5 + channels_out = 15 + + rnn = nn.RNN(nn.LSTMCell(), channels_out, return_carry=True) + + xs = jnp.ones((*batch_dims, seq_len, channels_in)) + variables = rnn.init(jax.random.PRNGKey(0), xs) + ys: jnp.ndarray + carry, ys = rnn.apply(variables, xs) + + self.assertEqual(ys.shape, (*batch_dims, seq_len, channels_out)) + + for layer_params in variables['params']['cell'].values(): + if 'bias' in layer_params: + self.assertEqual(layer_params['bias'].shape, (channels_out,)) + self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertEqual(layer_params['kernel'].shape[1], channels_out) + + def test_rnn_unroll(self): + batch_size = 10 + seq_len = 40 + channels_in = 5 + channels_out = 15 + + rnn = nn.RNN(nn.LSTMCell(), channels_out, unroll=10, return_carry=True) + + xs = jnp.ones((batch_size, seq_len, channels_in)) + variables = rnn.init(jax.random.PRNGKey(0), xs) + ys: jnp.ndarray + carry, ys = rnn.apply(variables, xs) + + self.assertEqual(ys.shape, (batch_size, seq_len, channels_out)) + + for layer_params in variables['params']['cell'].values(): + if 'bias' in layer_params: + self.assertEqual(layer_params['bias'].shape, (channels_out,)) + self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertEqual(layer_params['kernel'].shape[1], channels_out) + + def test_rnn_time_major(self): + seq_len = 40 + batch_size = 10 + channels_in = 5 + channels_out = 15 + + rnn = nn.RNN(nn.LSTMCell(), channels_out, time_major=True, return_carry=True) + + xs = jnp.ones((seq_len, batch_size, channels_in)) + variables = rnn.init(jax.random.PRNGKey(0), xs) + + ys: jnp.ndarray + carry, ys = rnn.apply(variables, xs) + + # carry state should not be zeros after apply + for leaf in jax.tree_util.tree_leaves(carry): + assert not np.allclose(leaf, jnp.zeros_like(leaf)) + self.assertEqual(leaf.shape, (batch_size, channels_out)) + + self.assertEqual(ys.shape, (seq_len, batch_size, channels_out)) + + for layer_params in variables['params']['cell'].values(): + if 'bias' in layer_params: + self.assertEqual(layer_params['bias'].shape, (channels_out,)) + self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertEqual(layer_params['kernel'].shape[1], channels_out) + + def test_rnn_with_spatial_dimensions(self): + batch_size = 10 + seq_len = 40 + kernel_size = (3, 3) + image_size = (32, 32) + channels_in = 5 + channels_out = 15 + + rnn = nn.RNN( + nn.ConvLSTMCell(channels_out, kernel_size), + cell_size=(*image_size, channels_out), + ) + + xs = jnp.ones((batch_size, seq_len, *image_size, channels_in)) + variables = rnn.init(jax.random.PRNGKey(0), xs) + + ys: jnp.ndarray + carry, ys = rnn.apply(variables, xs, return_carry=True) + + # carry state should not be zeros after apply + for leaf in jax.tree_util.tree_leaves(carry): + assert not np.allclose(leaf, jnp.zeros_like(leaf)) + self.assertEqual(leaf.shape[:-1], (batch_size, *image_size)) + self.assertIn(leaf.shape[-1], [channels_in, channels_out]) + + self.assertEqual(ys.shape, (batch_size, seq_len, *image_size, channels_out)) + + for layer_params in variables['params']['cell'].values(): + if 'bias' in layer_params: + self.assertEqual(layer_params['bias'].shape, (channels_out * 4,)) + self.assertIn(layer_params['kernel'].shape[2], [channels_in, channels_out, channels_out * 4]) + self.assertEqual(layer_params['kernel'].shape[3], channels_out * 4) + + @pytest.mark.skip(reason='TODO: discuss supporting reverse instead of flip_sequences') + def test_go_backwards(self): + batch_size = 10 + seq_len = 40 + channels_in = 5 + channels_out = 15 + + rnn = nn.RNN(nn.LSTMCell(), channels_out, reverse=True) + + xs = jnp.ones((batch_size, seq_len, channels_in)) + variables = rnn.init(jax.random.PRNGKey(0), xs) + + # carry state should be zeros on init + for carry in jax.tree_leaves(variables['memory']['carry']): + assert np.allclose(carry, jnp.zeros_like(carry)) + self.assertEqual(carry.shape, (batch_size, channels_out)) + + ys: jnp.ndarray + ys, updates = rnn.apply(variables, xs, mutable=['memory']) + variables = variables.copy(updates) + + # carry state should not be zeros after apply + for carry in jax.tree_leaves(variables['memory']['carry']): + assert not np.allclose(carry, jnp.zeros_like(carry)) + self.assertEqual(carry.shape, (batch_size, channels_out)) + + self.assertEqual(ys.shape, (batch_size, seq_len, channels_out)) + + for layer_params in variables['params']['cell'].values(): + if 'bias' in layer_params: + self.assertEqual(layer_params['bias'].shape, (channels_out,)) + self.assertIn(layer_params['kernel'].shape[0], [channels_in, channels_out]) + self.assertEqual(layer_params['kernel'].shape[1], channels_out) + + def test_numerical_equivalence(self): + batch_size = 3 + seq_len = 4 + channels_in = 5 + channels_out = 6 + + rnn = nn.RNN(nn.LSTMCell(), channels_out, return_carry=True) + + xs = jnp.ones((batch_size, seq_len, channels_in)) + ys: jnp.ndarray + (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs) + + cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), channels_out) + cell_params = variables['params']['cell'] + + for i in range(seq_len): + cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[:, i, :]) + np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-5) + + np.testing.assert_allclose(cell_carry, carry, rtol=1e-5) + + def test_numerical_equivalence_with_mask(self): + batch_size = 3 + seq_len = 4 + channels_in = 5 + channels_out = 6 + + key = jax.random.PRNGKey(0) + seq_lengths = jax.random.randint(key, (batch_size,), minval=1, maxval=seq_len + 1) + segmentation_mask = einops.repeat( + jnp.arange(seq_len), 'time -> batch time', batch=batch_size) + segmentation_mask = (segmentation_mask < seq_lengths[:, None]).astype(jnp.int32) + + rnn = nn.RNN(nn.LSTMCell(), channels_out, return_carry=True) + + xs = jnp.ones((batch_size, seq_len, channels_in)) + ys: jnp.ndarray + (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs, segmentation_mask=segmentation_mask) + + cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), channels_out) + cell_params = variables['params']['cell'] + carries = [] + + for i in range(seq_len): + cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[:, i, :]) + np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-5) + carries.append(cell_carry) + + for batch_idx, length in enumerate(seq_lengths): + t = int(length) - 1 + for carries_t_, carry_ in zip(carries[t], carry): + np.testing.assert_allclose(carries_t_[batch_idx], carry_[batch_idx], rtol=1e-5) + + @pytest.mark.skip(reason='TODO: possible bug with scan') + def test_numerical_equivalence_single_batch(self): + batch_size = 3 + seq_len = 4 + channels_in = 5 + channels_out = 6 + + rnn = nn.RNN(nn.LSTMCell(), channels_out, return_carry=True) + + xs = jnp.ones((batch_size, seq_len, channels_in)) + ys: jnp.ndarray + (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs) + + cell_params = variables['params']['cell'] + + for batch_idx in range(batch_size): + cell_carry = rnn.cell.initialize_carry(jax.random.PRNGKey(0), (1,), channels_out) + + for i in range(seq_len): + cell_carry, y = rnn.cell.apply({'params': cell_params}, cell_carry, xs[batch_idx, i, :][None]) + np.testing.assert_allclose(y[0], ys[batch_idx, i, :]) + + np.testing.assert_allclose(cell_carry, carry) + + def test_numerical_equivalence_single_batch_nn_scan(self): + batch_size = 3 + seq_len = 4 + channels_in = 5 + channels_out = 6 + + cell = nn.LSTMCell() + rnn = nn.scan(nn.LSTMCell, in_axes=1, out_axes=1, + variable_broadcast='params', + split_rngs={'params': False})() + + xs = jnp.ones((batch_size, seq_len, channels_in)) + carry = rnn.initialize_carry(jax.random.PRNGKey(0), (batch_size,), channels_out) + ys: jnp.ndarray + (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), carry, xs) + + cell_params = variables['params'] + + for batch_idx in range(batch_size): + cell_carry = cell.initialize_carry(jax.random.PRNGKey(0), (1,), channels_out) + + for i in range(seq_len): + cell_carry, y = cell.apply({'params': cell_params}, cell_carry, xs[batch_idx:batch_idx+1, i, :]) + np.testing.assert_allclose(y[0], ys[batch_idx, i, :], rtol=1e-5) + + carry_i = jax.tree_map(lambda x: x[batch_idx:batch_idx+1], carry) + np.testing.assert_allclose(cell_carry, carry_i, rtol=1e-5) + + def test_numerical_equivalence_single_batch_jax_scan(self): + batch_size = 3 + seq_len = 4 + channels_in = 5 + channels_out = 6 + + xs = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, seq_len, channels_in)) + cell = nn.LSTMCell() + carry = cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), channels_out) + variables = cell.init(jax.random.PRNGKey(0), carry, xs[:, 0]) + cell_params = variables['params'] + + def scan_fn(carry, x): + return cell.apply({'params': cell_params}, carry, x) + + ys: jnp.ndarray + carry, ys = jax.lax.scan(scan_fn, carry, xs.swapaxes(0, 1)) + ys = ys.swapaxes(0, 1) + + cell_carry = cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), channels_out) + + for i in range(seq_len): + cell_carry, y = cell.apply({'params': cell_params}, cell_carry, xs[:, i, :]) + np.testing.assert_allclose(y, ys[:, i, :], rtol=1e-4) + + np.testing.assert_allclose(cell_carry, carry, rtol=1e-4) \ No newline at end of file