diff --git a/keras/layers/rnn/lstm.py b/keras/layers/rnn/lstm.py index 73ed18aac8d..761d222d6dd 100644 --- a/keras/layers/rnn/lstm.py +++ b/keras/layers/rnn/lstm.py @@ -233,9 +233,9 @@ def call(self, inputs, states, training=False): rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1) if training and 0.0 < self.dropout < 1.0: - inputs *= dp_mask + inputs = inputs * dp_mask if training and 0.0 < self.recurrent_dropout < 1.0: - h_tm1 *= rec_dp_mask + h_tm1 = h_tm1 * rec_dp_mask if self.implementation == 1: inputs_i = inputs