-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate training and runtime attention #174
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
35ddf28
Separate training and runtime attention
cifkao d93d9fe
Revert "Separate training and runtime attention"
cifkao 0ae8ce6
Separate attention objects for training and runtime
cifkao fe66c3d
fix after rebase
jlibovicky 1591c3a
attentive base class
jlibovicky 723e5cc
make encoders subclasses of attentive + a bit of style
jlibovicky 49f1631
fixes
jlibovicky 5893d08
don't call reuse for the attention, rely on higher places
jlibovicky e2223ae
fix after rebase
jlibovicky 5992dc5
attentive with abstract properties
jlibovicky 0aa02a6
nicer code
jlibovicky cb69c61
fix
jlibovicky File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from abc import ABCMeta, abstractproperty | ||
import tensorflow as tf | ||
|
||
# pylint: disable=too-few-public-methods | ||
class Attentive(metaclass=ABCMeta): | ||
"""A base class fro an attentive part of graph (typically encoder). | ||
|
||
Objects inheriting this class are able to generate an attention object that | ||
allows a decoder to perform attention over an attention_object provided by | ||
the encoder (e.g., input word representations in case of MT or | ||
convolutional maps in case of image captioning). | ||
""" | ||
def __init__(self, attention_type, **kwargs): | ||
self._attention_type = attention_type | ||
self._attention_kwargs = kwargs | ||
|
||
def get_attention_object(self, runtime: bool=False): | ||
"""Attention object that can be used in decoder.""" | ||
# pylint: disable=no-member | ||
if hasattr(self, "name") and self.name: | ||
name = self.name | ||
else: | ||
name = str(self) | ||
|
||
return self._attention_type( | ||
self._attention_tensor, | ||
scope="attention_{}".format(name), | ||
input_weights=self._attention_mask, | ||
runtime_mode=runtime, | ||
**self._attention_kwargs) if self._attention_type else None | ||
|
||
@abstractproperty | ||
def _attention_tensor(self): | ||
"""Tensor over which the attention is done.""" | ||
raise NotImplementedError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Proč nepoužít |
||
"Attentive object is missing attention_tensor.") | ||
|
||
@property | ||
def _attention_mask(self): | ||
"""Zero/one masking the attention logits.""" | ||
return tf.ones(tf.shape(self._attention_tensor)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chtělo by to ještě doplnit docstringy, ale už nechci zdržovat merge.