Skip to content

Commit

Permalink
test_MaskedComputationLayer_in_loop_auto_unmask
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 8, 2022
1 parent f3a77e2 commit 0874a8b
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6211,6 +6211,67 @@ def test_MaskedComputationLayer_UnmaskLayer_in_loop_opt():
x = y


def test_MaskedComputationLayer_in_loop_auto_unmask():
# https://github.com/rwth-i6/returnn/issues/769
from test_TFNetworkLayer import make_feed_dict
from returnn.tf.layers.rec import _SubnetworkRecCell
for opt in [False, True]:
print("*** using rec optimization:", opt)
with make_scope() as session:
config = Config({"debug_print_layer_output_template": True})
net = TFNetwork(
extern_data=ExternData({"data": {"dim": 20, "sparse": True}}),
config=config)
net_dict = {
"output": {
"class": "rec",
"from": "data",
"optimize_move_layers_out": opt, # test both variants
"unit": {
"const1": {"class": "constant", "value": 1, "with_batch_dim": True}, # just to broadcast mask
"mask": {
"class": "eval", "from": [":i", "const1"], "out_type": {"dtype": "bool"},
"eval": "tf.equal(source(0) % 2, source(1))"},
"in": {"class": "reinterpret_data", "from": "data:source", "set_sparse": False},
"masked": {
"class": "masked_computation", "from": "in", "mask": "mask",
"unit": {"class": "cumsum", "from": "data", "initial_output": 1}},
"masked_out": {"class": "copy", "from": "masked"},
"output": {"class": "eval", "from": ["masked_out", "in"], "eval": "source(0) + source(1) ** 2"},
}
}
}
net.construct_from_dict(net_dict)
rec_layer = net.get_layer("output")
assert isinstance(rec_layer, RecLayer)
rec_cell = rec_layer.cell
assert isinstance(rec_cell, _SubnetworkRecCell)
if opt:
assert not rec_cell.layers_in_loop # all moved out
else:
assert not rec_cell.input_layers_moved_out and not rec_cell.output_layers_moved_out # none moved out
in_data = net.get_layer("data").output
out_data = net.get_layer("output").output.copy_as_batch_major()
print("out:", out_data)
assert in_data.get_time_dim_tag() == out_data.get_time_dim_tag()
in_v, out_v, out_seq_lens_v = session.run(
(in_data.placeholder, out_data.placeholder, out_data.get_sequence_lengths()),
feed_dict=make_feed_dict(net.extern_data))
print(in_v)
print(out_v)
print("seq lens:", out_seq_lens_v)
assert_equal(in_v.shape, out_v.shape)
for b in range(in_v.shape[0]):
x = 1
for t in range(in_v.shape[1]):
if t >= out_seq_lens_v[b]:
continue
if t % 2 == 1:
x = x + in_v[b, t]
y = x + in_v[b, t] ** 2
numpy.testing.assert_almost_equal(y, out_v[b, t])


def test_att_train_search_loss_prev_beam():
beam_size = 1
num_ner_labels = 13
Expand Down

0 comments on commit 0874a8b

Please sign in to comment.