-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attention visualization #121
Changes from 6 commits
cf343f2
39c822a
90e77ce
37468af
0ab1be7
90c2d9a
d0fcbd2
9e1f9b2
06e6f7b
1c2ecbe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -223,9 +223,10 @@ def training_loop(sess, saver, | |
|
||
if step % validation_period == validation_period - 1: | ||
decoded_val_sentences, decoded_raw_val_sentences, \ | ||
val_evaluation = run_on_dataset( | ||
val_evaluation, val_images = run_on_dataset( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Considering that the inputs of NM can be also images (we did image captioning with it previously), name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd stick with |
||
sess, runner, all_coders, decoder, val_dataset, | ||
evaluators, postprocess, write_out=False) | ||
evaluators, postprocess, write_out=False, | ||
extra_fetches=decoder.summary_val_img) | ||
|
||
this_score = val_evaluation[evaluators[-1].name] | ||
|
||
|
@@ -309,6 +310,8 @@ def argworst(scores, minimize): | |
|
||
log_print("") | ||
|
||
tb_writer.add_summary(val_images[0], step) | ||
|
||
except KeyboardInterrupt: | ||
log("Training interrupted by user.") | ||
|
||
|
@@ -329,7 +332,8 @@ def argworst(scores, minimize): | |
|
||
|
||
def run_on_dataset(sess, runner, all_coders, decoder, dataset, | ||
evaluators, postprocess, write_out=False): | ||
evaluators, postprocess, write_out=False, | ||
extra_fetches=None): | ||
""" | ||
Applies the model on a dataset and optionally writes outpus into a file. | ||
|
||
|
@@ -353,13 +357,17 @@ def run_on_dataset(sess, runner, all_coders, decoder, dataset, | |
write_out: Flag whether the outputs should be printed to a file defined | ||
in the dataset object. | ||
|
||
extra_fetches: Extra tensors to evaluate for each batch. | ||
|
||
Returns: | ||
|
||
Tuple of resulting sentences/numpy arrays, and evaluation results if | ||
they are available which are dictionary function -> value. | ||
|
||
""" | ||
result_raw, opt_loss, dec_loss = runner(sess, dataset, all_coders) | ||
result_raw, opt_loss, dec_loss, evaluated_fetches = \ | ||
runner(sess, dataset, all_coders, extra_fetches) | ||
|
||
if postprocess is not None: | ||
result = postprocess(result_raw) | ||
else: | ||
|
@@ -387,7 +395,10 @@ def run_on_dataset(sess, runner, all_coders, decoder, dataset, | |
for func in evaluators: | ||
evaluation[func.name] = func(result, test_targets) | ||
|
||
return result, result_raw, evaluation | ||
if extra_fetches is not None: | ||
return result, result_raw, evaluation, evaluated_fetches | ||
else: | ||
return result, result_raw, evaluation | ||
|
||
|
||
def process_evaluation(evaluators, tb_writer, eval_result, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ name=translation | |
output=tests/tmp-test-output | ||
overwrite_output_dir=True | ||
batch_size=16 | ||
epochs=2 | ||
epochs=5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why 5? |
||
encoders=[<encoder>] | ||
decoder=<decoder> | ||
train_dataset=<train_data> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is too long. We check the style of the files using
pylint
. If you runtests/lint_run.sh
, it will check all python files get you a detailed report.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this is general enough. If you look e.g, at the Recurrent Neural Machine Translation paper, they use GRU net do the attention instead of the softmax weighting and come with a clever way visualizing this attention. i think that reserving the
attention_in_time
field for the visualization purpose should work even for the recurrent attention, but please check if it is so.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just FYI, the line length used by me is 80. I will add it to the documentation for developers