Skip to content

Commit

Permalink
Enable and refactor image summaries (resolves #108)
Browse files Browse the repository at this point in the history
  • Loading branch information
cifkao authored and jlibovicky committed Dec 1, 2016
1 parent 1ea4b9c commit 2e99b30
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
17 changes: 4 additions & 13 deletions neuralmonkey/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,7 @@ def __init__(self, encoders, vocabulary, data_id, name, **kwargs):

self.runtime_rnn_outputs, self.runtime_rnn_states, runtime_logits = \
self._attention_decoder(
runtime_inputs, state, runtime_mode=True,
summary_collections=["summary_val_plots"])

val_plots_collection = tf.get_collection("summary_val_plots")
self.summary_val_plots = (
tf.merge_summary(val_plots_collection)
if val_plots_collection else None
)
runtime_inputs, state, runtime_mode=True)

self.train_logprobs = [tf.nn.log_softmax(l) for l in train_logits]
self.decoded = [tf.argmax(l[:, 1:], 1) + 1 for l in runtime_logits]
Expand Down Expand Up @@ -318,7 +311,7 @@ def _logit_function(self, rnn_output):
#pylint: disable=too-many-arguments
# TODO reduce the number of arguments
def _attention_decoder(self, inputs, initial_state, runtime_mode=False,
summary_collections=None, scope="attention_decoder"):
scope="attention_decoder"):
"""Run the decoder RNN.
Arguments:
Expand All @@ -327,8 +320,6 @@ def _attention_decoder(self, inputs, initial_state, runtime_mode=False,
initial_state: The initial state of the decoder.
runtime_mode: Boolean flag whether the decoder is running in
runtime mode (with loop function).
summary_collections: The list of summary collections to which
the alignments are logged.
scope: The variable scope to use with this function.
"""
cell = self._get_rnn_cell()
Expand Down Expand Up @@ -371,14 +362,14 @@ def _attention_decoder(self, inputs, initial_state, runtime_mode=False,
rnn_outputs.append(output)
rnn_states.append(state)

if summary_collections:
if runtime_mode:
for i, a in enumerate(att_objects):
attentions = a.attentions_in_time[-len(inputs):]
alignments = tf.expand_dims(tf.transpose(
tf.pack(attentions), perm=[1, 2, 0]), -1)

tf.image_summary("attention_{}".format(i), alignments,
collections=summary_collections,
collections=["summary_val_plots"],
max_images=256)

return rnn_outputs, rnn_states, output_logits
Expand Down
12 changes: 11 additions & 1 deletion neuralmonkey/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(self,
super(GreedyRunner, self).__init__(output_series, decoder)
self._postprocess = postprocess

self.image_summaries = tf.merge_summary(
tf.get_collection("summary_val_plots"))

def get_executable(self, train=False, summaries=True):
if train:
fecthes = {"train_xent": self._decoder.train_loss,
Expand All @@ -26,6 +29,10 @@ def get_executable(self, train=False, summaries=True):
fecthes = {"train_xent": tf.zeros([]),
"runtime_xent": tf.zeros([])}
fecthes["decoded_logprobs"] = self._decoder.runtime_logprobs

if summaries:
fecthes['image_summaries'] = self.image_summaries

return GreedyRunExecutable(self.all_coders, fecthes,
self._decoder.vocabulary,
self._postprocess)
Expand Down Expand Up @@ -69,10 +76,13 @@ def collect_results(self, results: List[Dict]) -> None:
if self._postprocess is not None:
decoded_tokens = [self._postprocess(seq) for seq in decoded_tokens]


image_summaries = results[0].get('image_summaries')

self.result = ExecutionResult(
outputs=decoded_tokens,
losses=[train_loss, runtime_loss],
scalar_summaries=None,
histogram_summaries=None,
image_summaries=None
image_summaries=image_summaries
)

0 comments on commit 2e99b30

Please sign in to comment.