Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention visualization #121

Merged
merged 10 commits into from
Nov 8, 2016
9 changes: 8 additions & 1 deletion neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def __init__(self, encoders, vocabulary, data_id, **kwargs):

self.runtime_rnn_outputs, _ = attention_decoder(
runtime_inputs, state, attention_objects, cell,
attention_maxout_size, loop_function=loop_function)
attention_maxout_size, loop_function=loop_function,
summary_collections=["summary_val_plots"])

val_plots_collection = tf.get_collection("summary_val_plots")
self.summary_val_plots = (
tf.merge_summary(val_plots_collection)
if val_plots_collection else None
)

_, train_logits = self._decode(self.train_rnn_outputs)
self.decoded, runtime_logits = self._decode(self.runtime_rnn_outputs)
Expand Down
15 changes: 13 additions & 2 deletions neuralmonkey/decoding_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from neuralmonkey.logging import debug
from neuralmonkey.nn.projection import maxout, linear

# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-locals
# Great functions require great number of parameters
def attention_decoder(decoder_inputs, initial_state, attention_objects,
cell, maxout_size, loop_function=None, scope=None):
cell, maxout_size, loop_function=None, scope=None,
summary_collections=None):
outputs = []
states = []

Expand Down Expand Up @@ -47,6 +48,16 @@ def attention_decoder(decoder_inputs, initial_state, attention_objects,
outputs.append(output)
states.append(state)

if summary_collections:
for i, a in enumerate(attention_objects):
attentions = a.attentions_in_time[-len(decoder_inputs):]
alignments = tf.expand_dims(tf.transpose(
tf.pack(attentions), perm=[1, 2, 0]), -1)

tf.image_summary("attention_{}".format(i), alignments,
collections=summary_collections,
max_images=256)

return outputs, states


Expand Down
21 changes: 16 additions & 5 deletions neuralmonkey/learning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ def training_loop(sess, saver,

if step % validation_period == validation_period - 1:
decoded_val_sentences, decoded_raw_val_sentences, \
val_evaluation = run_on_dataset(
val_evaluation, val_plots = run_on_dataset(
sess, runner, all_coders, decoder, val_dataset,
evaluators, postprocess, write_out=False)
evaluators, postprocess, write_out=False,
extra_fetches=decoder.summary_val_plots)

this_score = val_evaluation[evaluators[-1].name]

Expand Down Expand Up @@ -309,6 +310,8 @@ def argworst(scores, minimize):

log_print("")

tb_writer.add_summary(val_plots[0], step)

except KeyboardInterrupt:
log("Training interrupted by user.")

Expand All @@ -329,7 +332,8 @@ def argworst(scores, minimize):


def run_on_dataset(sess, runner, all_coders, decoder, dataset,
evaluators, postprocess, write_out=False):
evaluators, postprocess, write_out=False,
extra_fetches=None):
"""
Applies the model on a dataset and optionally writes outpus into a file.

Expand All @@ -353,13 +357,17 @@ def run_on_dataset(sess, runner, all_coders, decoder, dataset,
write_out: Flag whether the outputs should be printed to a file defined
in the dataset object.

extra_fetches: Extra tensors to evaluate for each batch.

Returns:

Tuple of resulting sentences/numpy arrays, and evaluation results if
they are available which are dictionary function -> value.

"""
result_raw, opt_loss, dec_loss = runner(sess, dataset, all_coders)
result_raw, opt_loss, dec_loss, evaluated_fetches = \
runner(sess, dataset, all_coders, extra_fetches)

if postprocess is not None:
result = postprocess(result_raw)
else:
Expand Down Expand Up @@ -387,7 +395,10 @@ def run_on_dataset(sess, runner, all_coders, decoder, dataset,
for func in evaluators:
evaluation[func.name] = func(result, test_targets)

return result, result_raw, evaluation
if extra_fetches is not None:
return result, result_raw, evaluation, evaluated_fetches
else:
return result, result_raw, evaluation


def process_evaluation(evaluators, tb_writer, eval_result,
Expand Down
20 changes: 13 additions & 7 deletions neuralmonkey/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ def __init__(self, decoder, batch_size):
self.batch_size = batch_size
self.vocabulary = decoder.vocabulary

def __call__(self, sess, dataset, coders):
def __call__(self, sess, dataset, coders, extra_fetches=None):
batched_dataset = dataset.batch_dataset(self.batch_size)
decoded_sentences = []
evaluated_fetches = []

if extra_fetches is None:
extra_fetches = []

loss_with_gt_ins = 0.0
loss_with_decoded_ins = 0.0
Expand All @@ -29,14 +33,16 @@ def __call__(self, sess, dataset, coders):
else:
losses = [tf.zeros([]), tf.zeros([])]

computation = sess.run(losses + self.decoder.decoded,
feed_dict=batch_feed_dict)
loss_with_gt_ins += computation[0]
loss_with_decoded_ins += computation[1]
(loss_with_gt_ins, loss_with_decoded_ins), \
decoded, fetches_batch = \
sess.run((losses, self.decoder.decoded, extra_fetches),
feed_dict=batch_feed_dict)
decoded_sentences_batch = \
self.vocabulary.vectors_to_sentences(computation[len(losses):])
self.vocabulary.vectors_to_sentences(decoded)
decoded_sentences += decoded_sentences_batch
evaluated_fetches += [fetches_batch]

return decoded_sentences, \
loss_with_gt_ins / batch_count, \
loss_with_decoded_ins / batch_count
loss_with_decoded_ins / batch_count, \
evaluated_fetches
Loading