From 3537381fa9487559b47f4cf3711265b2d5b049db Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Fri, 11 Aug 2017 10:16:15 -0700 Subject: [PATCH] cntk backend: fix the reversed rnn bug (#7593) * fix the reversed rnn bug * udpate error message. * Fix error msg --- keras/backend/cntk_backend.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 6116e373f23..02d927ce585 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1310,7 +1310,15 @@ def rnn(step_function, inputs, initial_states, initial.append(s) need_convert = not has_seq_axis(inputs) + if go_backwards and need_convert is False: + raise NotImplementedError('CNTK Backend: `go_backwards` is not supported with ' + 'variable-length sequences. Please specify a ' + 'static length for your sequences.') + if need_convert: + if go_backwards: + inputs = reverse(inputs, 1) + inputs = C.to_sequence(inputs) j = 0 @@ -1327,6 +1335,8 @@ def rnn(step_function, inputs, initial_states, j += 1 if mask is not None and not has_seq_axis(mask): + if go_backwards: + mask = reverse(mask, 1) if len(int_shape(mask)) == 2: mask = expand_dims(mask) mask = C.to_sequence_like(mask, inputs) @@ -1339,10 +1349,7 @@ def _recurrence(x, states, m): place_holders = [C.placeholder(dynamic_axes=x.dynamic_axes) for _ in states] past_values = [] for s, p in zip(states, place_holders): - past_values.append( - C.sequence.past_value( - p, s) if go_backwards is False else C.sequence.future_value( - p, s)) + past_values.append(C.sequence.past_value(p, s)) new_output, new_states = step_function( x, tuple(past_values) + tuple(constants)) if m is not None: