Skip to content

Commit

Permalink
test_MaskedComputationLayer_in_loop_auto_unmask
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Mar 4, 2022
1 parent 02e3119 commit 585604f
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6200,6 +6200,72 @@ def test_MaskedComputationLayer_UnmaskLayer_in_loop_opt():
x = y


def test_MaskedComputationLayer_in_loop_auto_unmask():
# https://github.com/rwth-i6/returnn/issues/769
from test_TFNetworkLayer import make_feed_dict
from returnn.tf.layers.rec import _SubnetworkRecCell
for opt in [False, True]:
print("*** using rec optimization:", opt)
with make_scope() as session:
config = Config({"debug_print_layer_output_template": True})
net = TFNetwork(
extern_data=ExternData({"data": {"dim": 20, "sparse": True}}),
config=config)
net_dict = {
"output": {
"class": "rec",
"from": "data",
"optimize_move_layers_out": opt, # test both variants
"unit": {
"const1": {"class": "constant", "value": 1, "with_batch_dim": True}, # just to broadcast mask
"mask": {
"class": "eval", "from": [":i", "const1"], "out_type": {"dtype": "bool"},
"eval": "tf.equal(source(0) % 2, source(1))"},
"in": {"class": "cast", "from": "data:source", "dtype": "float32"},
"masked": {
"class": "masked_computation", "from": "data:source", "mask": "mask",
"unit": {
"class": "subnetwork", "from": "in",
"subnetwork": {
"input1": {"class": "expand_dims", "axis": "f", "from": "data"},
"output": {"class": "rec", "unit": "cumsum", "n_out": 1, "from": "input1"},
},
}},
"masked_out": {"class": "squeeze", "from": "masked", "axis": "f"},
"output": {"class": "eval", "from": ["masked_out", "in"], "eval": "source(0) + source(1) ** 2"},
}
}
}
net.construct_from_dict(net_dict)
rec_layer = net.get_layer("output")
assert isinstance(rec_layer, RecLayer)
rec_cell = rec_layer.cell
assert isinstance(rec_cell, _SubnetworkRecCell)
if opt:
assert not rec_cell.layers_in_loop # all moved out
else:
assert not rec_cell.input_layers_moved_out and not rec_cell.output_layers_moved_out # none moved out
in_data = net.get_layer("data").output
out_data = net.get_layer("output").output.copy_as_batch_major()
assert in_data.get_time_dim_tag() == out_data.get_time_dim_tag()
in_v, out_v, out_seq_lens_v = session.run(
(in_data.placeholder, out_data.placeholder, out_data.get_sequence_lengths()),
feed_dict=make_feed_dict(net.extern_data))
print(in_v)
print(out_v)
print("seq lens:", out_seq_lens_v)
assert_equal(in_v.shape, out_v.shape)
for b in range(in_v.shape[0]):
x = 0.0
for t in range(in_v.shape[1]):
if t >= out_seq_lens_v[b]:
continue
if t % 2 == 1:
x = x + in_v[b, t]
y = x + in_v[b, t] ** 2
numpy.testing.assert_almost_equal(y, out_v[b, t])


def test_att_train_search_loss_prev_beam():
beam_size = 1
num_ner_labels = 13
Expand Down

0 comments on commit 585604f

Please sign in to comment.