Skip to content

Commit

Permalink
RF lstm cleanup weight, bias
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 18, 2023
1 parent 6d88d6c commit 8462ffe
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
14 changes: 6 additions & 8 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,10 +912,9 @@ def lstm(
*,
state_c: Tensor,
state_h: Tensor,
ff_weights: Tensor,
ff_biases: Tensor,
rec_weights: Tensor,
rec_biases: Tensor,
ff_weight: Tensor,
rec_weight: Tensor,
bias: Tensor,
spatial_dim: Dim,
in_dim: Dim,
out_dim: Dim,
Expand All @@ -926,10 +925,9 @@ def lstm(
:param source: Tensor of shape [*, in_dim].
:param state_c:
:param state_h:
:param ff_weights: Parameters for the weights of the feed-forward part.
:param ff_biases: Parameters for the biases of the feed-forward part.
:param rec_weights: Parameters for the weights of the recurrent part.
:param rec_biases: Parameters for the biases of the recurrent part.
:param ff_weight: Parameters for the weights of the feed-forward part.
:param rec_weight: Parameters for the weights of the recurrent part.
:param bias: Parameters for the bias.
:param spatial_dim: Dimension in which the LSTM operates.
:param in_dim:
:param out_dim:
Expand Down
24 changes: 10 additions & 14 deletions returnn/frontend/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,15 @@ def __init__(
self.in_dim = in_dim
self.out_dim = out_dim

self.ff_weights = rf.Parameter((4 * self.out_dim, self.in_dim)) # type: Tensor[T]
self.ff_weights.initial = rf.init.Glorot()
self.recurrent_weights = rf.Parameter((4 * self.out_dim, self.out_dim)) # type: Tensor[T]
self.recurrent_weights.initial = rf.init.Glorot()
self.ff_weight = rf.Parameter((4 * self.out_dim, self.in_dim)) # type: Tensor[T]
self.ff_weight.initial = rf.init.Glorot()
self.rec_weight = rf.Parameter((4 * self.out_dim, self.out_dim)) # type: Tensor[T]
self.rec_weight.initial = rf.init.Glorot()

self.ff_biases = None
self.recurrent_biases = None
self.bias = None
if with_bias:
self.ff_biases = rf.Parameter((4 * self.out_dim,)) # type: Tensor[T]
self.ff_biases.initial = 0.0
self.recurrent_biases = rf.Parameter((4 * self.out_dim,)) # type: Tensor[T]
self.recurrent_biases.initial = 0.0
self.bias = rf.Parameter((4 * self.out_dim,)) # type: Tensor[T]
self.bias.initial = 0.0

def __call__(self, source: Tensor[T], *, state: LstmState, spatial_dim: Dim) -> Tuple[Tensor, LstmState]:
"""
Expand All @@ -67,10 +64,9 @@ def __call__(self, source: Tensor[T], *, state: LstmState, spatial_dim: Dim) ->
source=source,
state_c=state.c,
state_h=state.h,
ff_weights=self.ff_weights,
ff_biases=self.ff_biases,
rec_weights=self.recurrent_weights,
rec_biases=self.recurrent_biases,
ff_weight=self.ff_weight,
rec_weight=self.rec_weight,
bias=self.bias,
spatial_dim=spatial_dim,
in_dim=self.in_dim,
out_dim=self.out_dim,
Expand Down
21 changes: 11 additions & 10 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,10 +1145,9 @@ def lstm(
*,
state_h: _TT,
state_c: _TT,
ff_weights: _TT,
ff_biases: Optional[_TT],
rec_weights: _TT,
rec_biases: Optional[_TT],
ff_weight: _TT,
rec_weight: _TT,
bias: Optional[_TT],
spatial_dim: Dim,
in_dim: Dim,
out_dim: Dim,
Expand All @@ -1159,16 +1158,18 @@ def lstm(
:return: Tuple consisting of two elements: the result as a :class:`Tensor`
and the new state as a :class:`State` (different from the previous one).
"""
if ff_biases is None and rec_biases is None:
lstm_params = (ff_weights.raw_tensor, rec_weights.raw_tensor)
if bias is None:
lstm_params = (ff_weight.raw_tensor, rec_weight.raw_tensor)
has_biases = False
else:
assert (
ff_biases is not None and rec_biases is not None
), "A bias in the LSTM (feed-forward or recurrent) is set while the other is unset."
# Feed-forward has priority over recurrent, and weights have priority over biases. See the torch docstring
# or torch LSTMCell: https://github.com/pytorch/pytorch/blob/4bead64/aten/src/ATen/native/RNN.cpp#L1458
lstm_params = (ff_weights.raw_tensor, rec_weights.raw_tensor, ff_biases.raw_tensor, rec_biases.raw_tensor)
lstm_params = (
ff_weight.raw_tensor,
rec_weight.raw_tensor,
bias.raw_tensor * 0.5,
bias.raw_tensor * 0.5,
)
has_biases = True

batch_dims = [d for d in source.dims if d != spatial_dim and d != in_dim]
Expand Down

0 comments on commit 8462ffe

Please sign in to comment.