Skip to content

Commit

Permalink
MaskedComputationLayer, fix get_rec_initial_extra_outputs call
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 8, 2022
1 parent e90762a commit 79846ef
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7981,9 +7981,11 @@ def get_rec_initial_extra_outputs(cls, batch_dim, rec_layer, **kwargs):
assert isinstance(output, Data)
assert issubclass(layer_class, LayerBase)
with layer_class.cls_setup_scope(**kwargs):
d = layer_class.get_rec_initial_extra_outputs(batch_dim=batch_dim, rec_layer=rec_layer, **layer_desc)
output = output.copy_as_batch_major()
d = layer_class.get_rec_initial_extra_outputs(
batch_dim=batch_dim, rec_layer=rec_layer, output=output, **layer_desc)
initial_out = layer_class.get_rec_initial_output(
batch_dim=batch_dim, rec_layer=rec_layer, output=output.copy_as_batch_major(), **layer_desc)
batch_dim=batch_dim, rec_layer=rec_layer, output=output, **layer_desc)
assert "_output" not in d
d["_output"] = initial_out
return d
Expand Down

0 comments on commit 79846ef

Please sign in to comment.