Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention visualization #121

Merged
merged 10 commits into from
Nov 8, 2016
5 changes: 4 additions & 1 deletion neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def __init__(self, encoders, vocabulary, data_id, **kwargs):

self.runtime_rnn_outputs, _ = attention_decoder(
runtime_inputs, state, attention_objects, cell,
attention_maxout_size, loop_function=loop_function)
attention_maxout_size, loop_function=loop_function,
summary_collections=["summary_val_img"])

self.summary_val_img = tf.merge_summary(tf.get_collection("summary_val_img"))

_, train_logits = self._decode(self.train_rnn_outputs)
self.decoded, runtime_logits = self._decode(self.runtime_rnn_outputs)
Expand Down
9 changes: 8 additions & 1 deletion neuralmonkey/decoding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# pylint: disable=too-many-arguments
# Great functions require great number of parameters
def attention_decoder(decoder_inputs, initial_state, attention_objects,
cell, maxout_size, loop_function=None, scope=None):
cell, maxout_size, loop_function=None, scope=None,
summary_collections=None):
outputs = []
states = []

Expand Down Expand Up @@ -47,6 +48,12 @@ def attention_decoder(decoder_inputs, initial_state, attention_objects,
outputs.append(output)
states.append(state)

if summary_collections:
for i, a in enumerate(attention_objects):
alignments = tf.expand_dims(tf.transpose(tf.pack(a.attentions_in_time), perm=[1, 2, 0]), -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is too long. We check the style of the files using pylint. If you run tests/lint_run.sh, it will check all python files get you a detailed report.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this is general enough. If you look e.g, at the Recurrent Neural Machine Translation paper, they use GRU net do the attention instead of the softmax weighting and come with a clever way visualizing this attention. i think that reserving the attention_in_time field for the visualization purpose should work even for the recurrent attention, but please check if it is so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just FYI, the line length used by me is 80. I will add it to the documentation for developers

tf.image_summary("attention_{}".format(i), alignments,
collections=summary_collections, max_images=128)

return outputs, states


Expand Down
21 changes: 16 additions & 5 deletions neuralmonkey/learning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ def training_loop(sess, saver,

if step % validation_period == validation_period - 1:
decoded_val_sentences, decoded_raw_val_sentences, \
val_evaluation = run_on_dataset(
val_evaluation, val_images = run_on_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that the inputs of NM can be also images (we did image captioning with it previously), name images may be confusing. What about plots, visualizations, ... I don't know, maybe even images are OK.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about out_images?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd stick with val_plots.

sess, runner, all_coders, decoder, val_dataset,
evaluators, postprocess, write_out=False)
evaluators, postprocess, write_out=False,
extra_fetches=decoder.summary_val_img)

this_score = val_evaluation[evaluators[-1].name]

Expand Down Expand Up @@ -309,6 +310,8 @@ def argworst(scores, minimize):

log_print("")

tb_writer.add_summary(val_images[0], step)

except KeyboardInterrupt:
log("Training interrupted by user.")

Expand All @@ -329,7 +332,8 @@ def argworst(scores, minimize):


def run_on_dataset(sess, runner, all_coders, decoder, dataset,
evaluators, postprocess, write_out=False):
evaluators, postprocess, write_out=False,
extra_fetches=None):
"""
Applies the model on a dataset and optionally writes outpus into a file.

Expand All @@ -353,13 +357,17 @@ def run_on_dataset(sess, runner, all_coders, decoder, dataset,
write_out: Flag whether the outputs should be printed to a file defined
in the dataset object.

extra_fetches: Extra tensors to evaluate for each batch.

Returns:

Tuple of resulting sentences/numpy arrays, and evaluation results if
they are available which are dictionary function -> value.

"""
result_raw, opt_loss, dec_loss = runner(sess, dataset, all_coders)
result_raw, opt_loss, dec_loss, evaluated_fetches = \
runner(sess, dataset, all_coders, extra_fetches)

if postprocess is not None:
result = postprocess(result_raw)
else:
Expand Down Expand Up @@ -387,7 +395,10 @@ def run_on_dataset(sess, runner, all_coders, decoder, dataset,
for func in evaluators:
evaluation[func.name] = func(result, test_targets)

return result, result_raw, evaluation
if extra_fetches is not None:
return result, result_raw, evaluation, evaluated_fetches
else:
return result, result_raw, evaluation


def process_evaluation(evaluators, tb_writer, eval_result,
Expand Down
20 changes: 13 additions & 7 deletions neuralmonkey/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ def __init__(self, decoder, batch_size):
self.batch_size = batch_size
self.vocabulary = decoder.vocabulary

def __call__(self, sess, dataset, coders):
def __call__(self, sess, dataset, coders, extra_fetches=None):
batched_dataset = dataset.batch_dataset(self.batch_size)
decoded_sentences = []
evaluated_fetches = []

if extra_fetches is None:
extra_fetches = []

loss_with_gt_ins = 0.0
loss_with_decoded_ins = 0.0
Expand All @@ -29,14 +33,16 @@ def __call__(self, sess, dataset, coders):
else:
losses = [tf.zeros([]), tf.zeros([])]

computation = sess.run(losses + self.decoder.decoded,
feed_dict=batch_feed_dict)
loss_with_gt_ins += computation[0]
loss_with_decoded_ins += computation[1]
(loss_with_gt_ins, loss_with_decoded_ins), \
decoded, fetches_batch = \
sess.run((losses, self.decoder.decoded, extra_fetches),
feed_dict=batch_feed_dict)
decoded_sentences_batch = \
self.vocabulary.vectors_to_sentences(computation[len(losses):])
self.vocabulary.vectors_to_sentences(decoded)
decoded_sentences += decoded_sentences_batch
evaluated_fetches += [fetches_batch]

return decoded_sentences, \
loss_with_gt_ins / batch_count, \
loss_with_decoded_ins / batch_count
loss_with_decoded_ins / batch_count, \
evaluated_fetches
2 changes: 1 addition & 1 deletion tests/small.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ name=translation
output=tests/tmp-test-output
overwrite_output_dir=True
batch_size=16
epochs=2
epochs=5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 5?

encoders=[<encoder>]
decoder=<decoder>
train_dataset=<train_data>
Expand Down