Skip to content

Commit

Permalink
Fixed minor bugs.
Browse files Browse the repository at this point in the history
  • Loading branch information
julianser committed Feb 28, 2016
1 parent 30639b8 commit e66f9e8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def add_random_variables_to_batch(state, rng, batch, prev_batch, evaluate_mode):
if not eos_indices[-1] == batch['x'].shape[0]:
eos_indices = eos_indices + [batch['x'].shape[0]]
else:
eos_indices = [0] + [batch['x'].shape[0]-1]
eos_indices = [0] + [batch['x'].shape[0]]

# Sample random variables using NumPy
ran_vectors = rng.normal(loc=0, scale=1, size=(len(eos_indices), state['latent_gaussian_per_utterance_dim']))
Expand Down
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def main():

batch['num_preds'] = numpy.sum(x_cost_mask)

c, c_list, _, _ = eval_batch(x_data, x_data_reversed, max_length, x_cost_mask, reset_mask, ran_cost_utterance, ran_decoder_drop_mask)
c, _, c_list, _, _ = eval_batch(x_data, x_data_reversed, max_length, x_cost_mask, reset_mask, ran_cost_utterance, ran_decoder_drop_mask)

c_list = c_list.reshape((batch['x'].shape[1],max_length-1), order=(1,0))
c_list = numpy.sum(c_list, axis=1)
Expand Down

0 comments on commit e66f9e8

Please sign in to comment.