From 7c7d73530c1ab1b47f9fb5f0612ec13fef1a26c6 Mon Sep 17 00:00:00 2001 From: Waleed Date: Thu, 3 Aug 2017 19:15:13 -0700 Subject: [PATCH] Fix TimeDistributed BatchNormalization (#7467) * Fix TimeDistributed BatchNormalization. * Fix PEP8 indentation warning. * Unit test for TimeDistributed(BatchNormalization) * Change Wrapper.input_map to _input_map. * Work around incorrect PEP8 W503 warning. The CI test fails with W503 error, but actually the testing tool is enforcing PEP8 wrong. https://github.com/PyCQA/pycodestyle/issues/498 --- keras/layers/wrappers.py | 25 ++++++++++++++++++++----- tests/keras/layers/wrappers_test.py | 24 +++++++++++++++++++++++- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/keras/layers/wrappers.py b/keras/layers/wrappers.py index 1f517c80db6..cf5b1876529 100644 --- a/keras/layers/wrappers.py +++ b/keras/layers/wrappers.py @@ -4,6 +4,7 @@ import copy from ..engine import Layer from ..engine import InputSpec +from ..engine.topology import _object_list_uid from ..utils.generic_utils import has_arg from .. import backend as K @@ -21,6 +22,10 @@ class Wrapper(Layer): def __init__(self, layer, **kwargs): self.layer = layer + # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when + # the inner layer has update ops that depend on it's inputs (as opposed + # to the inputs to the Wrapper layer). + self._input_map = {} super(Wrapper, self).__init__(**kwargs) def build(self, input_shape=None): @@ -48,10 +53,17 @@ def updates(self): return [] def get_updates_for(self, inputs=None): - if inputs is None: - updates = self.layer.get_updates_for(None) - return updates + super(Wrapper, self).get_updates_for(None) - return super(Wrapper, self).get_updates_for(inputs) + # If the wrapper modifies the inputs, use the modified inputs to + # get the updates from the inner layer. + inner_inputs = inputs + if inputs is not None: + uid = _object_list_uid(inputs) + if uid in self._input_map: + inner_inputs = self._input_map[uid] + + updates = self.layer.get_updates_for(inner_inputs) + updates += super(Wrapper, self).get_updates_for(inputs) + return updates @property def losses(self): @@ -182,8 +194,11 @@ def step(x, _): input_length = input_shape[1] if not input_length: input_length = K.shape(inputs)[1] - # Shape: (num_samples * timesteps, ...) + # Shape: (num_samples * timesteps, ...). And track the + # transformation in self._input_map. + input_uid = _object_list_uid(inputs) inputs = K.reshape(inputs, (-1,) + input_shape[2:]) + self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) y = self.layer.call(inputs, **kwargs) if hasattr(y, '_uses_learning_phase'): diff --git a/tests/keras/layers/wrappers_test.py b/tests/keras/layers/wrappers_test.py index 1d410964d93..2e0cdd52024 100644 --- a/tests/keras/layers/wrappers_test.py +++ b/tests/keras/layers/wrappers_test.py @@ -3,9 +3,10 @@ from numpy.testing import assert_allclose from keras.utils.test_utils import keras_test from keras.layers import wrappers, Input -from keras.layers import core, convolutional, recurrent, embeddings +from keras.layers import core, convolutional, recurrent, embeddings, normalization from keras.models import Sequential, Model, model_from_json from keras import backend as K +from keras.engine.topology import _object_list_uid @keras_test @@ -87,6 +88,27 @@ def test_TimeDistributed(): outer_model.compile(optimizer='rmsprop', loss='mse') outer_model.fit(np.random.random((10, 3, 2)), np.random.random((10, 3, 3)), epochs=1, batch_size=10) + # test with BatchNormalization + model = Sequential() + model.add(wrappers.TimeDistributed(normalization.BatchNormalization(center=True, scale=True), + name='bn', input_shape=(10, 2))) + model.compile(optimizer='rmsprop', loss='mse') + # Assert that mean and variance are 0 and 1. + td = model.layers[0] + assert np.array_equal(td.get_weights()[2], np.array([0, 0])) + assert np.array_equal(td.get_weights()[3], np.array([1, 1])) + # Train + model.train_on_batch(np.random.normal(loc=2, scale=2, size=(1, 10, 2)), + np.broadcast_to(np.array([0, 1]), (1, 10, 2))) + # Assert that mean and variance changed. + assert not np.array_equal(td.get_weights()[2], np.array([0, 0])) + assert not np.array_equal(td.get_weights()[3], np.array([1, 1])) + # Verify input_map has one mapping from inputs to reshaped inputs. + uid = _object_list_uid(model.inputs) + assert len(td._input_map.keys()) == 1 + assert uid in td._input_map + assert K.int_shape(td._input_map[uid]) == (None, 2) + @keras_test @pytest.mark.skipif((K.backend() == 'cntk'),