diff --git a/data_iterator.py b/data_iterator.py index d72edc9..18bb691 100644 --- a/data_iterator.py +++ b/data_iterator.py @@ -53,7 +53,7 @@ def add_random_variables_to_batch(state, rng, batch, prev_batch, evaluate_mode): if not eos_indices[-1] == batch['x'].shape[0]: eos_indices = eos_indices + [batch['x'].shape[0]] else: - eos_indices = [0] + [batch['x'].shape[0]-1] + eos_indices = [0] + [batch['x'].shape[0]] # Sample random variables using NumPy ran_vectors = rng.normal(loc=0, scale=1, size=(len(eos_indices), state['latent_gaussian_per_utterance_dim'])) diff --git a/evaluate.py b/evaluate.py index 7d79101..f5ac2ba 100644 --- a/evaluate.py +++ b/evaluate.py @@ -140,7 +140,7 @@ def main(): batch['num_preds'] = numpy.sum(x_cost_mask) - c, c_list, _, _ = eval_batch(x_data, x_data_reversed, max_length, x_cost_mask, reset_mask, ran_cost_utterance, ran_decoder_drop_mask) + c, _, c_list, _, _ = eval_batch(x_data, x_data_reversed, max_length, x_cost_mask, reset_mask, ran_cost_utterance, ran_decoder_drop_mask) c_list = c_list.reshape((batch['x'].shape[1],max_length-1), order=(1,0)) c_list = numpy.sum(c_list, axis=1)