Skip to content

Commit

Permalink
Fix TimeDistributed BatchNormalization (keras-team#7467)
Browse files Browse the repository at this point in the history
* Fix TimeDistributed BatchNormalization.

* Fix PEP8 indentation warning.

* Unit test for TimeDistributed(BatchNormalization)

* Change Wrapper.input_map to _input_map.

* Work around incorrect PEP8 W503 warning.

The CI test fails with W503 error, but actually
the testing tool is enforcing PEP8 wrong.
PyCQA/pycodestyle#498
  • Loading branch information
waleedka authored and fchollet committed Aug 4, 2017
1 parent 1cb81f2 commit 7c7d735
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
25 changes: 20 additions & 5 deletions keras/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
from ..engine import Layer
from ..engine import InputSpec
from ..engine.topology import _object_list_uid
from ..utils.generic_utils import has_arg
from .. import backend as K

Expand All @@ -21,6 +22,10 @@ class Wrapper(Layer):

def __init__(self, layer, **kwargs):
self.layer = layer
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on it's inputs (as opposed
# to the inputs to the Wrapper layer).
self._input_map = {}
super(Wrapper, self).__init__(**kwargs)

def build(self, input_shape=None):
Expand Down Expand Up @@ -48,10 +53,17 @@ def updates(self):
return []

def get_updates_for(self, inputs=None):
if inputs is None:
updates = self.layer.get_updates_for(None)
return updates + super(Wrapper, self).get_updates_for(None)
return super(Wrapper, self).get_updates_for(inputs)
# If the wrapper modifies the inputs, use the modified inputs to
# get the updates from the inner layer.
inner_inputs = inputs
if inputs is not None:
uid = _object_list_uid(inputs)
if uid in self._input_map:
inner_inputs = self._input_map[uid]

updates = self.layer.get_updates_for(inner_inputs)
updates += super(Wrapper, self).get_updates_for(inputs)
return updates

@property
def losses(self):
Expand Down Expand Up @@ -182,8 +194,11 @@ def step(x, _):
input_length = input_shape[1]
if not input_length:
input_length = K.shape(inputs)[1]
# Shape: (num_samples * timesteps, ...)
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = _object_list_uid(inputs)
inputs = K.reshape(inputs, (-1,) + input_shape[2:])
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
y = self.layer.call(inputs, **kwargs)
if hasattr(y, '_uses_learning_phase'):
Expand Down
24 changes: 23 additions & 1 deletion tests/keras/layers/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from numpy.testing import assert_allclose
from keras.utils.test_utils import keras_test
from keras.layers import wrappers, Input
from keras.layers import core, convolutional, recurrent, embeddings
from keras.layers import core, convolutional, recurrent, embeddings, normalization
from keras.models import Sequential, Model, model_from_json
from keras import backend as K
from keras.engine.topology import _object_list_uid


@keras_test
Expand Down Expand Up @@ -87,6 +88,27 @@ def test_TimeDistributed():
outer_model.compile(optimizer='rmsprop', loss='mse')
outer_model.fit(np.random.random((10, 3, 2)), np.random.random((10, 3, 3)), epochs=1, batch_size=10)

# test with BatchNormalization
model = Sequential()
model.add(wrappers.TimeDistributed(normalization.BatchNormalization(center=True, scale=True),
name='bn', input_shape=(10, 2)))
model.compile(optimizer='rmsprop', loss='mse')
# Assert that mean and variance are 0 and 1.
td = model.layers[0]
assert np.array_equal(td.get_weights()[2], np.array([0, 0]))
assert np.array_equal(td.get_weights()[3], np.array([1, 1]))
# Train
model.train_on_batch(np.random.normal(loc=2, scale=2, size=(1, 10, 2)),
np.broadcast_to(np.array([0, 1]), (1, 10, 2)))
# Assert that mean and variance changed.
assert not np.array_equal(td.get_weights()[2], np.array([0, 0]))
assert not np.array_equal(td.get_weights()[3], np.array([1, 1]))
# Verify input_map has one mapping from inputs to reshaped inputs.
uid = _object_list_uid(model.inputs)
assert len(td._input_map.keys()) == 1
assert uid in td._input_map
assert K.int_shape(td._input_map[uid]) == (None, 2)


@keras_test
@pytest.mark.skipif((K.backend() == 'cntk'),
Expand Down

0 comments on commit 7c7d735

Please sign in to comment.