Skip to content

Commit

Permalink
Fix mask for multi output --> multi inputs (#7591)
Browse files Browse the repository at this point in the history
* Fix 7589

* Style
  • Loading branch information
Frédéric Branchaud-Charron authored and fchollet committed Aug 10, 2017
1 parent 552727e commit c2b844b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
4 changes: 4 additions & 0 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ def __call__(self, inputs, **kwargs):
else:
output_shape = None

if not isinstance(output_mask, (list, tuple)) and len(output_ls) > 1:
# Augment the mask to match the length of the output.
output_mask = [output_mask] * len(output_ls)

# Add an inbound node to the layer, so that it keeps track
# of the call and of all new variables created during the call.
# This also updates the layer history of the output tensor(s).
Expand Down
31 changes: 30 additions & 1 deletion tests/keras/engine/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from keras.layers import Dense, Dropout, InputLayer
from keras import layers
from keras.engine import Input, get_source_inputs
from keras.engine import Input, Layer, get_source_inputs
from keras.models import Model, Sequential
from keras import backend as K
from keras.models import model_from_json, model_from_yaml
Expand Down Expand Up @@ -626,5 +626,34 @@ def test_layer_sharing_at_heterogenous_depth():
np.testing.assert_allclose(output_val, output_val_2, atol=1e-6)


@keras_test
def test_multi_output_mask():
"""Fixes #7589"""
class ArbitraryMultiOutputLayer(Layer):
def __init__(self, **kwargs):
super(ArbitraryMultiOutputLayer, self).__init__(**kwargs)

def call(self, inputs, **kwargs):
return [K.abs(inputs), K.abs(inputs)]

def compute_output_shape(self, input_shape):
out_shape = super(ArbitraryMultiOutputLayer, self).compute_output_shape(input_shape)
return [out_shape, out_shape]

class ArbitraryMultiInputLayer(Layer):
def __init__(self, **kwargs):
super(ArbitraryMultiInputLayer, self).__init__(**kwargs)

def call(self, inputs, **kwargs):
negative, positive = inputs
return negative + positive

input_layer = Input(shape=(16, 16, 3))
x, y = ArbitraryMultiOutputLayer()(input_layer)
z = ArbitraryMultiInputLayer()([x, y])
_ = Model(inputs=input_layer, outputs=z)
assert K.int_shape(z)[1:] == (16, 16, 3)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit c2b844b

Please sign in to comment.