Skip to content

Commit

Permalink
RF PT lstm, use PackedSequence
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Apr 18, 2023
1 parent 82745f4 commit 259e223
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
13 changes: 13 additions & 0 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,19 @@ def get_uniq_collection(cls, tags, is_equal_opts=None):
res.append(tag)
return res

def get_size_tensor(self) -> _t.Tensor:
"""
:return: size tensor, or dyn_size_ext if defined
:rtype: _t.Tensor
"""
if self.dyn_size_ext:
return self.dyn_size_ext

import returnn.frontend as rf

assert self.size is not None
return rf.convert_to_tensor(self.size, name="%s:size" % self.description)

def get_dim_value(self) -> Union[int, _t.RawTensorType]:
"""
Infers the dim this axis should have if unbroadcasted.
Expand Down
29 changes: 27 additions & 2 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,18 +1191,43 @@ def lstm(
state_h_raw = torch.reshape(state_h_raw, [1, batch_dim, out_dim.get_dim_value()])
state_c_raw = torch.reshape(state_c_raw, [1, batch_dim, out_dim.get_dim_value()])

sizes = spatial_dim.get_size_tensor()
sizes = sizes.copy_compatible_to(
Tensor("batch_dims", batch_dims, dtype=sizes.dtype), unbroadcast=True, check_sparse=False
)
sizes_raw = torch.reshape(sizes.raw_tensor, [batch_dim])

# See the code of torch.nn.LSTM for sorting the batch dims.
# We need pack_padded_sequence because otherwise the LSTM would ignore the padding,
# and we would get incorrect final states.
source_packed = torch.nn.utils.rnn.pack_padded_sequence(source_raw, sizes_raw, enforce_sorted=False)
state_h_raw = state_h_raw.index_select(dim=1, index=source_packed.sorted_indices)
state_c_raw = state_c_raw.index_select(dim=1, index=source_packed.sorted_indices)

out_raw, new_state_h_raw, new_state_c_raw = torch.lstm(
source_raw,
source_packed.data,
source_packed.batch_sizes,
(state_h_raw, state_c_raw),
lstm_params,
has_biases=has_biases,
num_layers=1,
dropout=0.0,
train=rf.get_run_ctx().train_flag,
bidirectional=False,
batch_first=False,
)

# Unsort the batch dims.
new_state_h_raw = new_state_h_raw.index_select(dim=1, index=source_packed.unsorted_indices)
new_state_c_raw = new_state_c_raw.index_select(dim=1, index=source_packed.unsorted_indices)
# Unpack the sequence.
output_packed = torch.nn.utils.rnn.PackedSequence(
out_raw,
batch_sizes=source_packed.batch_sizes,
sorted_indices=source_packed.sorted_indices,
unsorted_indices=source_packed.unsorted_indices,
)
out_raw = torch.nn.utils.rnn.pad_packed_sequence(output_packed)[0]

if len(batch_dims) != 1:
out_raw = torch.reshape(
out_raw,
Expand Down

0 comments on commit 259e223

Please sign in to comment.