Skip to content

Commit

Permalink
Add merge tests
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee committed Aug 5, 2017
1 parent 281bc58 commit ef28bbd
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/keras/layers/merge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ def test_merge_add():
assert out.shape == (2, 4, 5)
assert_allclose(out, x1 + x2 + x3, atol=1e-4)

assert add_layer.compute_mask([i1, i2, i3], [None, None, None]) is None
assert np.all(K.eval(add_layer.compute_mask(
[i1, i2, i3], [K.variable(x1), K.variable(x2), K.variable(x3)])))

# Test invalid use case
with pytest.raises(ValueError):
add_layer.compute_mask([i1, i2, i3], x1)
with pytest.raises(ValueError):
add_layer.compute_mask(i1, [None, None, None])
with pytest.raises(ValueError):
add_layer.compute_mask([i1, i2, i3], [None, None])


@keras_test
def test_merge_multiply():
Expand Down Expand Up @@ -91,6 +103,12 @@ def test_merge_maximum():

@keras_test
def test_merge_concatenate():
i1 = layers.Input(shape=(None, 5))
i2 = layers.Input(shape=(None, 5))
o = layers.concatenate([i1, i2], axis=1)
assert o._keras_shape == (None, None, 5)
model = models.Model([i1, i2], o)

i1 = layers.Input(shape=(4, 5))
i2 = layers.Input(shape=(4, 5))
o = layers.concatenate([i1, i2], axis=1)
Expand Down Expand Up @@ -121,6 +139,18 @@ def test_merge_concatenate():
assert concat_out.shape == (1, 16, 1)
assert_allclose(concat_out, x3)

assert concat_layer.compute_mask([i1, i2], [None, None]) is None
assert np.all(K.eval(concat_layer.compute_mask(
[i1, i2], [K.variable(x1), K.variable(x2)])).reshape(-1))

# Test invalid use case
with pytest.raises(ValueError):
concat_layer.compute_mask([i1, i2], x1)
with pytest.raises(ValueError):
concat_layer.compute_mask(i1, [None, None])
with pytest.raises(ValueError):
concat_layer.compute_mask([i1, i2], [None])


@keras_test
def test_merge_dot():
Expand Down

0 comments on commit ef28bbd

Please sign in to comment.